From 52ae96f5187c437a262e0497efff4b02e1ab0eab Mon Sep 17 00:00:00 2001 From: Alice Ryhl Date: Mon, 10 Feb 2025 09:53:36 +0000 Subject: [PATCH] rust: list: make the cursor point between elements I've been using the linked list cursor for a few different things, and I find it inconvenient to use because all of the functions have signatures along the lines of `Self -> Option`. The root cause of these signatures is that the cursor points *at* an element, rather than *between* two elements. Thus, change the cursor API to point between two elements. This is inspired by the stdlib linked list (well, really by this guy [1]), which also uses cursors that point between elements. The `peek_next` method returns a helper that lets you look at and optionally remove the element, as one common use-case of cursors is to iterate a list to look for an element, then remove that element. For many of the methods, this will reduce how many we need since they now just need a prev/next method, instead of the current state where you may end up needing all of curr/prev/next. Also, if we decide to add a function for splitting a list into two lists at the cursor, then a cursor that points between elements is exactly what makes the most sense. Another advantage is that this means you can now have a cursor into an empty list. Link: https://rust-unofficial.github.io/too-many-lists/sixth-cursors-intro.html [1] Reviewed-by: Andreas Hindborg Reviewed-by: Boqun Feng Signed-off-by: Alice Ryhl Link: https://lore.kernel.org/r/20250210-cursor-between-v7-2-36f0215181ed@google.com Signed-off-by: Miguel Ojeda --- rust/kernel/list.rs | 401 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 347 insertions(+), 54 deletions(-) diff --git a/rust/kernel/list.rs b/rust/kernel/list.rs index 97b3599b7207..c0ed227b8a4f 100644 --- a/rust/kernel/list.rs +++ b/rust/kernel/list.rs @@ -483,17 +483,21 @@ impl, const ID: u64> List { other.first = ptr::null_mut(); } - /// Returns a cursor to the first element of the list. - /// - /// If the list is empty, this returns `None`. - pub fn cursor_front(&mut self) -> Option> { - if self.first.is_null() { - None - } else { - Some(Cursor { - current: self.first, - list: self, - }) + /// Returns a cursor that points before the first element of the list. + pub fn cursor_front(&mut self) -> Cursor<'_, T, ID> { + // INVARIANT: `self.first` is in this list. + Cursor { + next: self.first, + list: self, + } + } + + /// Returns a cursor that points after the last element in the list. + pub fn cursor_back(&mut self) -> Cursor<'_, T, ID> { + // INVARIANT: `next` is allowed to be null. + Cursor { + next: core::ptr::null_mut(), + list: self, } } @@ -573,69 +577,358 @@ impl<'a, T: ?Sized + ListItem, const ID: u64> Iterator for Iter<'a, T, ID> { /// A cursor into a [`List`]. /// +/// A cursor always rests between two elements in the list. This means that a cursor has a previous +/// and next element, but no current element. It also means that it's possible to have a cursor +/// into an empty list. +/// +/// # Examples +/// +/// ``` +/// use kernel::prelude::*; +/// use kernel::list::{List, ListArc, ListLinks}; +/// +/// #[pin_data] +/// struct ListItem { +/// value: u32, +/// #[pin] +/// links: ListLinks, +/// } +/// +/// impl ListItem { +/// fn new(value: u32) -> Result> { +/// ListArc::pin_init(try_pin_init!(Self { +/// value, +/// links <- ListLinks::new(), +/// }), GFP_KERNEL) +/// } +/// } +/// +/// kernel::list::impl_has_list_links! { +/// impl HasListLinks<0> for ListItem { self.links } +/// } +/// kernel::list::impl_list_arc_safe! { +/// impl ListArcSafe<0> for ListItem { untracked; } +/// } +/// kernel::list::impl_list_item! { +/// impl ListItem<0> for ListItem { using ListLinks; } +/// } +/// +/// // Use a cursor to remove the first element with the given value. +/// fn remove_first(list: &mut List, value: u32) -> Option> { +/// let mut cursor = list.cursor_front(); +/// while let Some(next) = cursor.peek_next() { +/// if next.value == value { +/// return Some(next.remove()); +/// } +/// cursor.move_next(); +/// } +/// None +/// } +/// +/// // Use a cursor to remove the last element with the given value. +/// fn remove_last(list: &mut List, value: u32) -> Option> { +/// let mut cursor = list.cursor_back(); +/// while let Some(prev) = cursor.peek_prev() { +/// if prev.value == value { +/// return Some(prev.remove()); +/// } +/// cursor.move_prev(); +/// } +/// None +/// } +/// +/// // Use a cursor to remove all elements with the given value. The removed elements are moved to +/// // a new list. +/// fn remove_all(list: &mut List, value: u32) -> List { +/// let mut out = List::new(); +/// let mut cursor = list.cursor_front(); +/// while let Some(next) = cursor.peek_next() { +/// if next.value == value { +/// out.push_back(next.remove()); +/// } else { +/// cursor.move_next(); +/// } +/// } +/// out +/// } +/// +/// // Use a cursor to insert a value at a specific index. Returns an error if the index is out of +/// // bounds. +/// fn insert_at(list: &mut List, new: ListArc, idx: usize) -> Result { +/// let mut cursor = list.cursor_front(); +/// for _ in 0..idx { +/// if !cursor.move_next() { +/// return Err(EINVAL); +/// } +/// } +/// cursor.insert_next(new); +/// Ok(()) +/// } +/// +/// // Merge two sorted lists into a single sorted list. +/// fn merge_sorted(list: &mut List, merge: List) { +/// let mut cursor = list.cursor_front(); +/// for to_insert in merge { +/// while let Some(next) = cursor.peek_next() { +/// if to_insert.value < next.value { +/// break; +/// } +/// cursor.move_next(); +/// } +/// cursor.insert_prev(to_insert); +/// } +/// } +/// +/// let mut list = List::new(); +/// list.push_back(ListItem::new(14)?); +/// list.push_back(ListItem::new(12)?); +/// list.push_back(ListItem::new(10)?); +/// list.push_back(ListItem::new(12)?); +/// list.push_back(ListItem::new(15)?); +/// list.push_back(ListItem::new(14)?); +/// assert_eq!(remove_all(&mut list, 12).iter().count(), 2); +/// // [14, 10, 15, 14] +/// assert!(remove_first(&mut list, 14).is_some()); +/// // [10, 15, 14] +/// insert_at(&mut list, ListItem::new(12)?, 2)?; +/// // [10, 15, 12, 14] +/// assert!(remove_last(&mut list, 15).is_some()); +/// // [10, 12, 14] +/// +/// let mut list2 = List::new(); +/// list2.push_back(ListItem::new(11)?); +/// list2.push_back(ListItem::new(13)?); +/// merge_sorted(&mut list, list2); +/// +/// let mut items = list.into_iter(); +/// assert_eq!(items.next().unwrap().value, 10); +/// assert_eq!(items.next().unwrap().value, 11); +/// assert_eq!(items.next().unwrap().value, 12); +/// assert_eq!(items.next().unwrap().value, 13); +/// assert_eq!(items.next().unwrap().value, 14); +/// assert!(items.next().is_none()); +/// # Result::<(), Error>::Ok(()) +/// ``` +/// /// # Invariants /// -/// The `current` pointer points a value in `list`. +/// The `next` pointer is null or points a value in `list`. pub struct Cursor<'a, T: ?Sized + ListItem, const ID: u64 = 0> { - current: *mut ListLinksFields, list: &'a mut List, + /// Points at the element after this cursor, or null if the cursor is after the last element. + next: *mut ListLinksFields, } impl<'a, T: ?Sized + ListItem, const ID: u64> Cursor<'a, T, ID> { - /// Access the current element of this cursor. - pub fn current(&self) -> ArcBorrow<'_, T> { - // SAFETY: The `current` pointer points a value in the list. - let me = unsafe { T::view_value(ListLinks::from_fields(self.current)) }; - // SAFETY: - // * All values in a list are stored in an `Arc`. - // * The value cannot be removed from the list for the duration of the lifetime annotated - // on the returned `ArcBorrow`, because removing it from the list would require mutable - // access to the cursor or the list. However, the `ArcBorrow` holds an immutable borrow - // on the cursor, which in turn holds a mutable borrow on the list, so any such - // mutable access requires first releasing the immutable borrow on the cursor. - // * Values in a list never have a `UniqueArc` reference, because the list has a `ListArc` - // reference, and `UniqueArc` references must be unique. - unsafe { ArcBorrow::from_raw(me) } + /// Returns a pointer to the element before the cursor. + /// + /// Returns null if there is no element before the cursor. + fn prev_ptr(&self) -> *mut ListLinksFields { + let mut next = self.next; + let first = self.list.first; + if next == first { + // We are before the first element. + return core::ptr::null_mut(); + } + + if next.is_null() { + // We are after the last element, so we need a pointer to the last element, which is + // the same as `(*first).prev`. + next = first; + } + + // SAFETY: `next` can't be null, because then `first` must also be null, but in that case + // we would have exited at the `next == first` check. Thus, `next` is an element in the + // list, so we can access its `prev` pointer. + unsafe { (*next).prev } } - /// Move the cursor to the next element. - pub fn next(self) -> Option> { - // SAFETY: The `current` field is always in a list. - let next = unsafe { (*self.current).next }; + /// Access the element after this cursor. + pub fn peek_next(&mut self) -> Option> { + if self.next.is_null() { + return None; + } + + // INVARIANT: + // * We just checked that `self.next` is non-null, so it must be in `self.list`. + // * `ptr` is equal to `self.next`. + Some(CursorPeek { + ptr: self.next, + cursor: self, + }) + } + + /// Access the element before this cursor. + pub fn peek_prev(&mut self) -> Option> { + let prev = self.prev_ptr(); + + if prev.is_null() { + return None; + } + + // INVARIANT: + // * We just checked that `prev` is non-null, so it must be in `self.list`. + // * `self.prev_ptr()` never returns `self.next`. + Some(CursorPeek { + ptr: prev, + cursor: self, + }) + } + + /// Move the cursor one element forward. + /// + /// If the cursor is after the last element, then this call does nothing. This call returns + /// `true` if the cursor's position was changed. + pub fn move_next(&mut self) -> bool { + if self.next.is_null() { + return false; + } + + // SAFETY: `self.next` is an element in the list and we borrow the list mutably, so we can + // access the `next` field. + let mut next = unsafe { (*self.next).next }; if next == self.list.first { - None - } else { - // INVARIANT: Since `self.current` is in the `list`, its `next` pointer is also in the - // `list`. - Some(Cursor { - current: next, - list: self.list, - }) + next = core::ptr::null_mut(); } + + // INVARIANT: `next` is either null or the next element after an element in the list. + self.next = next; + true } - /// Move the cursor to the previous element. - pub fn prev(self) -> Option> { - // SAFETY: The `current` field is always in a list. - let prev = unsafe { (*self.current).prev }; + /// Move the cursor one element backwards. + /// + /// If the cursor is before the first element, then this call does nothing. This call returns + /// `true` if the cursor's position was changed. + pub fn move_prev(&mut self) -> bool { + if self.next == self.list.first { + return false; + } + + // INVARIANT: `prev_ptr()` always returns a pointer that is null or in the list. + self.next = self.prev_ptr(); + true + } - if self.current == self.list.first { - None + /// Inserts an element where the cursor is pointing and get a pointer to the new element. + fn insert_inner(&mut self, item: ListArc) -> *mut ListLinksFields { + let ptr = if self.next.is_null() { + self.list.first } else { - // INVARIANT: Since `self.current` is in the `list`, its `prev` pointer is also in the - // `list`. - Some(Cursor { - current: prev, - list: self.list, - }) + self.next + }; + // SAFETY: + // * `ptr` is an element in the list or null. + // * if `ptr` is null, then `self.list.first` is null so the list is empty. + let item = unsafe { self.list.insert_inner(item, ptr) }; + if self.next == self.list.first { + // INVARIANT: We just inserted `item`, so it's a member of list. + self.list.first = item; } + item + } + + /// Insert an element at this cursor's location. + pub fn insert(mut self, item: ListArc) { + // This is identical to `insert_prev`, but consumes the cursor. This is helpful because it + // reduces confusion when the last operation on the cursor is an insertion; in that case, + // you just want to insert the element at the cursor, and it is confusing that the call + // involves the word prev or next. + self.insert_inner(item); + } + + /// Inserts an element after this cursor. + /// + /// After insertion, the new element will be after the cursor. + pub fn insert_next(&mut self, item: ListArc) { + self.next = self.insert_inner(item); } - /// Remove the current element from the list. + /// Inserts an element before this cursor. + /// + /// After insertion, the new element will be before the cursor. + pub fn insert_prev(&mut self, item: ListArc) { + self.insert_inner(item); + } + + /// Remove the next element from the list. + pub fn remove_next(&mut self) -> Option> { + self.peek_next().map(|v| v.remove()) + } + + /// Remove the previous element from the list. + pub fn remove_prev(&mut self) -> Option> { + self.peek_prev().map(|v| v.remove()) + } +} + +/// References the element in the list next to the cursor. +/// +/// # Invariants +/// +/// * `ptr` is an element in `self.cursor.list`. +/// * `ISNEXT == (self.ptr == self.cursor.next)`. +pub struct CursorPeek<'a, 'b, T: ?Sized + ListItem, const ISNEXT: bool, const ID: u64> { + cursor: &'a mut Cursor<'b, T, ID>, + ptr: *mut ListLinksFields, +} + +impl<'a, 'b, T: ?Sized + ListItem, const ISNEXT: bool, const ID: u64> + CursorPeek<'a, 'b, T, ISNEXT, ID> +{ + /// Remove the element from the list. pub fn remove(self) -> ListArc { - // SAFETY: The `current` pointer always points at a member of the list. - unsafe { self.list.remove_internal(self.current) } + if ISNEXT { + self.cursor.move_next(); + } + + // INVARIANT: `self.ptr` is not equal to `self.cursor.next` due to the above `move_next` + // call. + // SAFETY: By the type invariants of `Self`, `next` is not null, so `next` is an element of + // `self.cursor.list` by the type invariants of `Cursor`. + unsafe { self.cursor.list.remove_internal(self.ptr) } + } + + /// Access this value as an [`ArcBorrow`]. + pub fn arc(&self) -> ArcBorrow<'_, T> { + // SAFETY: `self.ptr` points at an element in `self.cursor.list`. + let me = unsafe { T::view_value(ListLinks::from_fields(self.ptr)) }; + // SAFETY: + // * All values in a list are stored in an `Arc`. + // * The value cannot be removed from the list for the duration of the lifetime annotated + // on the returned `ArcBorrow`, because removing it from the list would require mutable + // access to the `CursorPeek`, the `Cursor` or the `List`. However, the `ArcBorrow` holds + // an immutable borrow on the `CursorPeek`, which in turn holds a mutable borrow on the + // `Cursor`, which in turn holds a mutable borrow on the `List`, so any such mutable + // access requires first releasing the immutable borrow on the `CursorPeek`. + // * Values in a list never have a `UniqueArc` reference, because the list has a `ListArc` + // reference, and `UniqueArc` references must be unique. + unsafe { ArcBorrow::from_raw(me) } + } +} + +impl<'a, 'b, T: ?Sized + ListItem, const ISNEXT: bool, const ID: u64> core::ops::Deref + for CursorPeek<'a, 'b, T, ISNEXT, ID> +{ + // If you change the `ptr` field to have type `ArcBorrow<'a, T>`, it might seem like you could + // get rid of the `CursorPeek::arc` method and change the deref target to `ArcBorrow<'a, T>`. + // However, that doesn't work because 'a is too long. You could obtain an `ArcBorrow<'a, T>` + // and then call `CursorPeek::remove` without giving up the `ArcBorrow<'a, T>`, which would be + // unsound. + type Target = T; + + fn deref(&self) -> &T { + // SAFETY: `self.ptr` points at an element in `self.cursor.list`. + let me = unsafe { T::view_value(ListLinks::from_fields(self.ptr)) }; + + // SAFETY: The value cannot be removed from the list for the duration of the lifetime + // annotated on the returned `&T`, because removing it from the list would require mutable + // access to the `CursorPeek`, the `Cursor` or the `List`. However, the `&T` holds an + // immutable borrow on the `CursorPeek`, which in turn holds a mutable borrow on the + // `Cursor`, which in turn holds a mutable borrow on the `List`, so any such mutable access + // requires first releasing the immutable borrow on the `CursorPeek`. + unsafe { &*me } } } -- 2.51.0