Support for Augmented rbtrees
 -----------------------------
 
-Augmented rbtree is an rbtree with "some" additional data stored in each node.
-This data can be used to augment some new functionality to rbtree.
-Augmented rbtree is an optional feature built on top of basic rbtree
-infrastructure. An rbtree user who wants this feature will have to call the
-augmentation functions with the user provided augmentation callback
-when inserting and erasing nodes.
+Augmented rbtree is an rbtree with "some" additional data stored in
+each node, where the additional data for node N must be a function of
+the contents of all nodes in the subtree rooted at N. This data can
+be used to augment some new functionality to rbtree. Augmented rbtree
+is an optional feature built on top of basic rbtree infrastructure.
+An rbtree user who wants this feature will have to call the augmentation
+functions with the user provided augmentation callback when inserting
+and erasing nodes.
 
-On insertion, the user must call rb_augment_insert() once the new node is in
-place. This will cause the augmentation function callback to be called for
-each node between the new node and the root which has been affected by the
-insertion.
+On insertion, the user must update the augmented information on the path
+leading to the inserted node, then call rb_link_node() as usual and
+rb_augment_inserted() instead of the usual rb_insert_color() call.
+If rb_augment_inserted() rebalances the rbtree, it will callback into
+a user provided function to update the augmented information on the
+affected subtrees.
 
-When erasing a node, the user must call rb_augment_erase_begin() first to
-retrieve the deepest node on the rebalance path. Then, after erasing the
-original node, the user must call rb_augment_erase_end() with the deepest
-node found earlier. This will cause the augmentation function to be called
-for each affected node between the deepest node and the root.
+When erasing a node, the user must call rb_erase_augmented() instead of
+rb_erase(). rb_erase_augmented() calls back into user provided functions
+to updated the augmented information on affected subtrees.
 
+In both cases, the callbacks are provided through struct rb_augment_callbacks.
+3 callbacks must be defined:
+
+- A propagation callback, which updates the augmented value for a given
+  node and its ancestors, up to a given stop point (or NULL to update
+  all the way to the root).
+
+- A copy callback, which copies the augmented value for a given subtree
+  to a newly assigned subtree root.
+
+- A tree rotation callback, which copies the augmented value for a given
+  subtree to a newly assigned subtree root AND recomputes the augmented
+  information for the former subtree root.
+
+
+Sample usage:
 
 Interval tree is an example of augmented rb tree. Reference -
 "Introduction to Algorithms" by Cormen, Leiserson, Rivest and Stein.
 for lowest match (lowest start address among all possible matches)
 with something like:
 
-find_lowest_match(lo, hi, node)
+struct interval_tree_node *
+interval_tree_first_match(struct rb_root *root,
+                         unsigned long start, unsigned long last)
 {
-       lowest_match = NULL;
-       while (node) {
-               if (max_hi(node->left) > lo) {
-                       // Lowest overlap if any must be on left side
-                       node = node->left;
-               } else if (overlap(lo, hi, node)) {
-                       lowest_match = node;
-                       break;
-               } else if (lo > node->lo) {
-                       // Lowest overlap if any must be on right side
-                       node = node->right;
-               } else {
-                       break;
+       struct interval_tree_node *node;
+
+       if (!root->rb_node)
+               return NULL;
+       node = rb_entry(root->rb_node, struct interval_tree_node, rb);
+
+       while (true) {
+               if (node->rb.rb_left) {
+                       struct interval_tree_node *left =
+                               rb_entry(node->rb.rb_left,
+                                        struct interval_tree_node, rb);
+                       if (left->__subtree_last >= start) {
+                               /*
+                                * Some nodes in left subtree satisfy Cond2.
+                                * Iterate to find the leftmost such node N.
+                                * If it also satisfies Cond1, that's the match
+                                * we are looking for. Otherwise, there is no
+                                * matching interval as nodes to the right of N
+                                * can't satisfy Cond1 either.
+                                */
+                               node = left;
+                               continue;
+                       }
                }
+               if (node->start <= last) {              /* Cond1 */
+                       if (node->last >= start)        /* Cond2 */
+                               return node;    /* node is leftmost match */
+                       if (node->rb.rb_right) {
+                               node = rb_entry(node->rb.rb_right,
+                                       struct interval_tree_node, rb);
+                               if (node->__subtree_last >= start)
+                                       continue;
+                       }
+               }
+               return NULL;    /* No match */
+       }
+}
+
+Insertion/removal are defined using the following augmented callbacks:
+
+static inline unsigned long
+compute_subtree_last(struct interval_tree_node *node)
+{
+       unsigned long max = node->last, subtree_last;
+       if (node->rb.rb_left) {
+               subtree_last = rb_entry(node->rb.rb_left,
+                       struct interval_tree_node, rb)->__subtree_last;
+               if (max < subtree_last)
+                       max = subtree_last;
+       }
+       if (node->rb.rb_right) {
+               subtree_last = rb_entry(node->rb.rb_right,
+                       struct interval_tree_node, rb)->__subtree_last;
+               if (max < subtree_last)
+                       max = subtree_last;
+       }
+       return max;
+}
+
+static void augment_propagate(struct rb_node *rb, struct rb_node *stop)
+{
+       while (rb != stop) {
+               struct interval_tree_node *node =
+                       rb_entry(rb, struct interval_tree_node, rb);
+               unsigned long subtree_last = compute_subtree_last(node);
+               if (node->__subtree_last == subtree_last)
+                       break;
+               node->__subtree_last = subtree_last;
+               rb = rb_parent(&node->rb);
+       }
+}
+
+static void augment_copy(struct rb_node *rb_old, struct rb_node *rb_new)
+{
+       struct interval_tree_node *old =
+               rb_entry(rb_old, struct interval_tree_node, rb);
+       struct interval_tree_node *new =
+               rb_entry(rb_new, struct interval_tree_node, rb);
+
+       new->__subtree_last = old->__subtree_last;
+}
+
+static void augment_rotate(struct rb_node *rb_old, struct rb_node *rb_new)
+{
+       struct interval_tree_node *old =
+               rb_entry(rb_old, struct interval_tree_node, rb);
+       struct interval_tree_node *new =
+               rb_entry(rb_new, struct interval_tree_node, rb);
+
+       new->__subtree_last = old->__subtree_last;
+       old->__subtree_last = compute_subtree_last(old);
+}
+
+static const struct rb_augment_callbacks augment_callbacks = {
+       augment_propagate, augment_copy, augment_rotate
+};
+
+void interval_tree_insert(struct interval_tree_node *node,
+                         struct rb_root *root)
+{
+       struct rb_node **link = &root->rb_node, *rb_parent = NULL;
+       unsigned long start = node->start, last = node->last;
+       struct interval_tree_node *parent;
+
+       while (*link) {
+               rb_parent = *link;
+               parent = rb_entry(rb_parent, struct interval_tree_node, rb);
+               if (parent->__subtree_last < last)
+                       parent->__subtree_last = last;
+               if (start < parent->start)
+                       link = &parent->rb.rb_left;
+               else
+                       link = &parent->rb.rb_right;
        }
-       return lowest_match;
+
+       node->__subtree_last = last;
+       rb_link_node(&node->rb, rb_parent, link);
+       rb_insert_augmented(&node->rb, root, &augment_callbacks);
 }
 
-Finding exact match will be to first find lowest match and then to follow
-successor nodes looking for exact match, until the start of a node is beyond
-the hi value we are looking for.
+void interval_tree_remove(struct interval_tree_node *node,
+                         struct rb_root *root)
+{
+       rb_erase_augmented(&node->rb, root, &augment_callbacks);
+}
 
 extern void rb_insert_color(struct rb_node *, struct rb_root *);
 extern void rb_erase(struct rb_node *, struct rb_root *);
 
+
+struct rb_augment_callbacks {
+       void (*propagate)(struct rb_node *node, struct rb_node *stop);
+       void (*copy)(struct rb_node *old, struct rb_node *new);
+       void (*rotate)(struct rb_node *old, struct rb_node *new);
+};
+
+extern void __rb_insert_augmented(struct rb_node *node, struct rb_root *root,
+       void (*augment_rotate)(struct rb_node *old, struct rb_node *new));
+extern void rb_erase_augmented(struct rb_node *node, struct rb_root *root,
+                              const struct rb_augment_callbacks *augment);
+static inline void
+rb_insert_augmented(struct rb_node *node, struct rb_root *root,
+                   const struct rb_augment_callbacks *augment)
+{
+       __rb_insert_augmented(node, root, augment->rotate);
+}
+
+
 typedef void (*rb_augment_f)(struct rb_node *node, void *data);
 
 extern void rb_augment_insert(struct rb_node *node,
 
        __rb_change_child(old, new, parent, root);
 }
 
-void rb_insert_color(struct rb_node *node, struct rb_root *root)
+static __always_inline void
+__rb_insert(struct rb_node *node, struct rb_root *root,
+           void (*augment_rotate)(struct rb_node *old, struct rb_node *new))
 {
        struct rb_node *parent = rb_red_parent(node), *gparent, *tmp;
 
                                        rb_set_parent_color(tmp, parent,
                                                            RB_BLACK);
                                rb_set_parent_color(parent, node, RB_RED);
+                               augment_rotate(parent, node);
                                parent = node;
                                tmp = node->rb_right;
                        }
                        if (tmp)
                                rb_set_parent_color(tmp, gparent, RB_BLACK);
                        __rb_rotate_set_parents(gparent, parent, root, RB_RED);
+                       augment_rotate(gparent, parent);
                        break;
                } else {
                        tmp = gparent->rb_left;
                                        rb_set_parent_color(tmp, parent,
                                                            RB_BLACK);
                                rb_set_parent_color(parent, node, RB_RED);
+                               augment_rotate(parent, node);
                                parent = node;
                                tmp = node->rb_left;
                        }
                        if (tmp)
                                rb_set_parent_color(tmp, gparent, RB_BLACK);
                        __rb_rotate_set_parents(gparent, parent, root, RB_RED);
+                       augment_rotate(gparent, parent);
                        break;
                }
        }
 }
-EXPORT_SYMBOL(rb_insert_color);
 
-static void __rb_erase_color(struct rb_node *parent, struct rb_root *root)
+static __always_inline void
+__rb_erase_color(struct rb_node *parent, struct rb_root *root,
+                const struct rb_augment_callbacks *augment)
 {
        struct rb_node *node = NULL, *sibling, *tmp1, *tmp2;
 
                                rb_set_parent_color(tmp1, parent, RB_BLACK);
                                __rb_rotate_set_parents(parent, sibling, root,
                                                        RB_RED);
+                               augment->rotate(parent, sibling);
                                sibling = tmp1;
                        }
                        tmp1 = sibling->rb_right;
                                if (tmp1)
                                        rb_set_parent_color(tmp1, sibling,
                                                            RB_BLACK);
+                               augment->rotate(sibling, tmp2);
                                tmp1 = sibling;
                                sibling = tmp2;
                        }
                                rb_set_parent(tmp2, parent);
                        __rb_rotate_set_parents(parent, sibling, root,
                                                RB_BLACK);
+                       augment->rotate(parent, sibling);
                        break;
                } else {
                        sibling = parent->rb_left;
                                rb_set_parent_color(tmp1, parent, RB_BLACK);
                                __rb_rotate_set_parents(parent, sibling, root,
                                                        RB_RED);
+                               augment->rotate(parent, sibling);
                                sibling = tmp1;
                        }
                        tmp1 = sibling->rb_left;
                                if (tmp1)
                                        rb_set_parent_color(tmp1, sibling,
                                                            RB_BLACK);
+                               augment->rotate(sibling, tmp2);
                                tmp1 = sibling;
                                sibling = tmp2;
                        }
                                rb_set_parent(tmp2, parent);
                        __rb_rotate_set_parents(parent, sibling, root,
                                                RB_BLACK);
+                       augment->rotate(parent, sibling);
                        break;
                }
        }
 }
 
-void rb_erase(struct rb_node *node, struct rb_root *root)
+static __always_inline void
+__rb_erase(struct rb_node *node, struct rb_root *root,
+          const struct rb_augment_callbacks *augment)
 {
        struct rb_node *child = node->rb_right, *tmp = node->rb_left;
        struct rb_node *parent, *rebalance;
                        rebalance = NULL;
                } else
                        rebalance = __rb_is_black(pc) ? parent : NULL;
+               tmp = parent;
        } else if (!child) {
                /* Still case 1, but this time the child is node->rb_left */
                tmp->__rb_parent_color = pc = node->__rb_parent_color;
                parent = __rb_parent(pc);
                __rb_change_child(node, tmp, parent, root);
                rebalance = NULL;
+               tmp = parent;
        } else {
                struct rb_node *successor = child, *child2;
                tmp = child->rb_left;
                         *        \
                         *        (c)
                         */
-                       parent = child;
-                       child2 = child->rb_right;
+                       parent = successor;
+                       child2 = successor->rb_right;
+                       augment->copy(node, successor);
                } else {
                        /*
                         * Case 3: node's successor is leftmost under
                        parent->rb_left = child2 = successor->rb_right;
                        successor->rb_right = child;
                        rb_set_parent(child, successor);
+                       augment->copy(node, successor);
+                       augment->propagate(parent, successor);
                }
 
                successor->rb_left = tmp = node->rb_left;
                        successor->__rb_parent_color = pc;
                        rebalance = __rb_is_black(pc2) ? parent : NULL;
                }
+               tmp = successor;
        }
 
+       augment->propagate(tmp, NULL);
        if (rebalance)
-               __rb_erase_color(rebalance, root);
+               __rb_erase_color(rebalance, root, augment);
+}
+
+/*
+ * Non-augmented rbtree manipulation functions.
+ *
+ * We use dummy augmented callbacks here, and have the compiler optimize them
+ * out of the rb_insert_color() and rb_erase() function definitions.
+ */
+
+static inline void dummy_propagate(struct rb_node *node, struct rb_node *stop) {}
+static inline void dummy_copy(struct rb_node *old, struct rb_node *new) {}
+static inline void dummy_rotate(struct rb_node *old, struct rb_node *new) {}
+
+static const struct rb_augment_callbacks dummy_callbacks = {
+       dummy_propagate, dummy_copy, dummy_rotate
+};
+
+void rb_insert_color(struct rb_node *node, struct rb_root *root)
+{
+       __rb_insert(node, root, dummy_rotate);
+}
+EXPORT_SYMBOL(rb_insert_color);
+
+void rb_erase(struct rb_node *node, struct rb_root *root)
+{
+       __rb_erase(node, root, &dummy_callbacks);
 }
 EXPORT_SYMBOL(rb_erase);
 
+/*
+ * Augmented rbtree manipulation functions.
+ *
+ * This instantiates the same __always_inline functions as in the non-augmented
+ * case, but this time with user-defined callbacks.
+ */
+
+void __rb_insert_augmented(struct rb_node *node, struct rb_root *root,
+       void (*augment_rotate)(struct rb_node *old, struct rb_node *new))
+{
+       __rb_insert(node, root, augment_rotate);
+}
+EXPORT_SYMBOL(__rb_insert_augmented);
+
+void rb_erase_augmented(struct rb_node *node, struct rb_root *root,
+                       const struct rb_augment_callbacks *augment)
+{
+       __rb_erase(node, root, augment);
+}
+EXPORT_SYMBOL(rb_erase_augmented);
+
 static void rb_augment_path(struct rb_node *node, rb_augment_f func, void *data)
 {
        struct rb_node *parent;
 
        return max;
 }
 
-static void augment_callback(struct rb_node *rb, void *unused)
+static void augment_propagate(struct rb_node *rb, struct rb_node *stop)
 {
-       struct test_node *node = rb_entry(rb, struct test_node, rb);
-       node->augmented = augment_recompute(node);
+       while (rb != stop) {
+               struct test_node *node = rb_entry(rb, struct test_node, rb);
+               u32 augmented = augment_recompute(node);
+               if (node->augmented == augmented)
+                       break;
+               node->augmented = augmented;
+               rb = rb_parent(&node->rb);
+       }
+}
+
+static void augment_copy(struct rb_node *rb_old, struct rb_node *rb_new)
+{
+       struct test_node *old = rb_entry(rb_old, struct test_node, rb);
+       struct test_node *new = rb_entry(rb_new, struct test_node, rb);
+       new->augmented = old->augmented;
 }
 
+static void augment_rotate(struct rb_node *rb_old, struct rb_node *rb_new)
+{
+       struct test_node *old = rb_entry(rb_old, struct test_node, rb);
+       struct test_node *new = rb_entry(rb_new, struct test_node, rb);
+
+       /* Rotation doesn't change subtree's augmented value */
+       new->augmented = old->augmented;
+       old->augmented = augment_recompute(old);
+}
+
+static const struct rb_augment_callbacks augment_callbacks = {
+       augment_propagate, augment_copy, augment_rotate
+};
+
 static void insert_augmented(struct test_node *node, struct rb_root *root)
 {
-       struct rb_node **new = &root->rb_node, *parent = NULL;
+       struct rb_node **new = &root->rb_node, *rb_parent = NULL;
        u32 key = node->key;
+       u32 val = node->val;
+       struct test_node *parent;
 
        while (*new) {
-               parent = *new;
-               if (key < rb_entry(parent, struct test_node, rb)->key)
-                       new = &parent->rb_left;
+               rb_parent = *new;
+               parent = rb_entry(rb_parent, struct test_node, rb);
+               if (parent->augmented < val)
+                       parent->augmented = val;
+               if (key < parent->key)
+                       new = &parent->rb.rb_left;
                else
-                       new = &parent->rb_right;
+                       new = &parent->rb.rb_right;
        }
 
-       rb_link_node(&node->rb, parent, new);
-       rb_insert_color(&node->rb, root);
-       rb_augment_insert(&node->rb, augment_callback, NULL);
+       node->augmented = val;
+       rb_link_node(&node->rb, rb_parent, new);
+       rb_insert_augmented(&node->rb, root, &augment_callbacks);
 }
 
 static void erase_augmented(struct test_node *node, struct rb_root *root)
 {
-       struct rb_node *deepest = rb_augment_erase_begin(&node->rb);
-       rb_erase(&node->rb, root);
-       rb_augment_erase_end(deepest, augment_callback, NULL);
+       rb_erase_augmented(&node->rb, root, &augment_callbacks);
 }
 
 static void init(void)