* memtype_lock protects the rbtree.
  */
 
-static void memtype_rb_augment_cb(struct rb_node *node);
-static struct rb_root memtype_rbroot = RB_AUGMENT_ROOT(&memtype_rb_augment_cb);
+static struct rb_root memtype_rbroot = RB_ROOT;
 
 static int is_node_overlap(struct memtype *node, u64 start, u64 end)
 {
 }
 
 /* Update 'subtree_max_end' for a node, based on node and its children */
-static void update_node_max_end(struct rb_node *node)
+static void memtype_rb_augment_cb(struct rb_node *node, void *__unused)
 {
        struct memtype *data;
        u64 max_end, child_max_end;
        data->subtree_max_end = max_end;
 }
 
-/* Update 'subtree_max_end' for a node and all its ancestors */
-static void update_path_max_end(struct rb_node *node)
-{
-       u64 old_max_end, new_max_end;
-
-       while (node) {
-               struct memtype *data = container_of(node, struct memtype, rb);
-
-               old_max_end = data->subtree_max_end;
-               update_node_max_end(node);
-               new_max_end = data->subtree_max_end;
-
-               if (new_max_end == old_max_end)
-                       break;
-
-               node = rb_parent(node);
-       }
-}
-
 /* Find the first (lowest start addr) overlapping range from rb tree */
 static struct memtype *memtype_rb_lowest_match(struct rb_root *root,
                                u64 start, u64 end)
        return -EBUSY;
 }
 
-static void memtype_rb_augment_cb(struct rb_node *node)
-{
-       if (node)
-               update_path_max_end(node);
-}
-
 static void memtype_rb_insert(struct rb_root *root, struct memtype *newdata)
 {
        struct rb_node **node = &(root->rb_node);
 
        rb_link_node(&newdata->rb, parent, node);
        rb_insert_color(&newdata->rb, root);
+       rb_augment_insert(&newdata->rb, memtype_rb_augment_cb, NULL);
 }
 
 int rbt_memtype_check_insert(struct memtype *new, unsigned long *ret_type)
 
 struct memtype *rbt_memtype_erase(u64 start, u64 end)
 {
+       struct rb_node *deepest;
        struct memtype *data;
 
        data = memtype_rb_exact_match(&memtype_rbroot, start, end);
        if (!data)
                goto out;
 
+       deepest = rb_augment_erase_begin(&data->rb);
        rb_erase(&data->rb, &memtype_rbroot);
+       rb_augment_erase_end(deepest, memtype_rb_augment_cb, NULL);
 out:
        return data;
 }
 
 struct rb_root
 {
        struct rb_node *rb_node;
-       void (*augment_cb)(struct rb_node *node);
 };
 
 
        rb->rb_parent_color = (rb->rb_parent_color & ~1) | color;
 }
 
-#define RB_ROOT        (struct rb_root) { NULL, NULL, }
-#define RB_AUGMENT_ROOT(x)     (struct rb_root) { NULL, x}
-
+#define RB_ROOT        (struct rb_root) { NULL, }
 #define        rb_entry(ptr, type, member) container_of(ptr, type, member)
 
 #define RB_EMPTY_ROOT(root)    ((root)->rb_node == NULL)
 extern void rb_insert_color(struct rb_node *, struct rb_root *);
 extern void rb_erase(struct rb_node *, struct rb_root *);
 
+typedef void (*rb_augment_f)(struct rb_node *node, void *data);
+
+extern void rb_augment_insert(struct rb_node *node,
+                             rb_augment_f func, void *data);
+extern struct rb_node *rb_augment_erase_begin(struct rb_node *node);
+extern void rb_augment_erase_end(struct rb_node *node,
+                                rb_augment_f func, void *data);
+
 /* Find logical next and previous nodes in a tree */
 extern struct rb_node *rb_next(const struct rb_node *);
 extern struct rb_node *rb_prev(const struct rb_node *);
 
        else
                root->rb_node = right;
        rb_set_parent(node, right);
-
-       if (root->augment_cb) {
-               root->augment_cb(node);
-               root->augment_cb(right);
-       }
 }
 
 static void __rb_rotate_right(struct rb_node *node, struct rb_root *root)
        else
                root->rb_node = left;
        rb_set_parent(node, left);
-
-       if (root->augment_cb) {
-               root->augment_cb(node);
-               root->augment_cb(left);
-       }
 }
 
 void rb_insert_color(struct rb_node *node, struct rb_root *root)
 {
        struct rb_node *parent, *gparent;
 
-       if (root->augment_cb)
-               root->augment_cb(node);
-
        while ((parent = rb_parent(node)) && rb_is_red(parent))
        {
                gparent = rb_parent(parent);
        else
        {
                struct rb_node *old = node, *left;
-               int old_parent_cb = 0;
-               int successor_parent_cb = 0;
 
                node = node->rb_right;
                while ((left = node->rb_left) != NULL)
                        node = left;
 
                if (rb_parent(old)) {
-                       old_parent_cb = 1;
                        if (rb_parent(old)->rb_left == old)
                                rb_parent(old)->rb_left = node;
                        else
                if (parent == old) {
                        parent = node;
                } else {
-                       successor_parent_cb = 1;
                        if (child)
                                rb_set_parent(child, parent);
-
                        parent->rb_left = child;
 
                        node->rb_right = old->rb_right;
                node->rb_left = old->rb_left;
                rb_set_parent(old->rb_left, node);
 
-               if (root->augment_cb) {
-                       /*
-                        * Here, three different nodes can have new children.
-                        * The parent of the successor node that was selected
-                        * to replace the node to be erased.
-                        * The node that is getting erased and is now replaced
-                        * by its successor.
-                        * The parent of the node getting erased-replaced.
-                        */
-                       if (successor_parent_cb)
-                               root->augment_cb(parent);
-
-                       root->augment_cb(node);
-
-                       if (old_parent_cb)
-                               root->augment_cb(rb_parent(old));
-               }
-
                goto color;
        }
 
 
        if (child)
                rb_set_parent(child, parent);
-
-       if (parent) {
+       if (parent)
+       {
                if (parent->rb_left == node)
                        parent->rb_left = child;
                else
                        parent->rb_right = child;
-
-               if (root->augment_cb)
-                       root->augment_cb(parent);
-
-       } else {
-               root->rb_node = child;
        }
+       else
+               root->rb_node = child;
 
  color:
        if (color == RB_BLACK)
 }
 EXPORT_SYMBOL(rb_erase);
 
+static void rb_augment_path(struct rb_node *node, rb_augment_f func, void *data)
+{
+       struct rb_node *parent;
+
+up:
+       func(node, data);
+       parent = rb_parent(node);
+       if (!parent)
+               return;
+
+       if (node == parent->rb_left && parent->rb_right)
+               func(parent->rb_right, data);
+       else if (parent->rb_left)
+               func(parent->rb_left, data);
+
+       node = parent;
+       goto up;
+}
+
+/*
+ * after inserting @node into the tree, update the tree to account for
+ * both the new entry and any damage done by rebalance
+ */
+void rb_augment_insert(struct rb_node *node, rb_augment_f func, void *data)
+{
+       if (node->rb_left)
+               node = node->rb_left;
+       else if (node->rb_right)
+               node = node->rb_right;
+
+       rb_augment_path(node, func, data);
+}
+
+/*
+ * before removing the node, find the deepest node on the rebalance path
+ * that will still be there after @node gets removed
+ */
+struct rb_node *rb_augment_erase_begin(struct rb_node *node)
+{
+       struct rb_node *deepest;
+
+       if (!node->rb_right && !node->rb_left)
+               deepest = rb_parent(node);
+       else if (!node->rb_right)
+               deepest = node->rb_left;
+       else if (!node->rb_left)
+               deepest = node->rb_right;
+       else {
+               deepest = rb_next(node);
+               if (deepest->rb_right)
+                       deepest = deepest->rb_right;
+               else if (rb_parent(deepest) != node)
+                       deepest = rb_parent(deepest);
+       }
+
+       return deepest;
+}
+
+/*
+ * after removal, update the tree to account for the removed entry
+ * and any rebalance damage.
+ */
+void rb_augment_erase_end(struct rb_node *node, rb_augment_f func, void *data)
+{
+       if (node)
+               rb_augment_path(node, func, data);
+}
+
 /*
  * This function returns the first node (in sort order) of the tree.
  */