#endif
 };
 
-static void tnode_put_child_reorg(struct tnode *tn, unsigned long i,
-                                 struct tnode *n, int wasfull);
-static struct tnode *resize(struct trie *t, struct tnode *tn);
+static void resize(struct trie *t, struct tnode *tn);
 /* tnodes to free after resize(); protected by RTNL */
 static struct callback_head *tnode_free_head;
 static size_t tnode_free_size;
        return n && ((n->pos + n->bits) == tn->pos) && IS_TNODE(n);
 }
 
-static inline void put_child(struct tnode *tn, unsigned long i,
-                            struct tnode *n)
-{
-       tnode_put_child_reorg(tn, i, n, -1);
-}
-
- /*
-  * Add a child at position i overwriting the old value.
-  * Update the value of full_children and empty_children.
-  */
-
-static void tnode_put_child_reorg(struct tnode *tn, unsigned long i,
-                                 struct tnode *n, int wasfull)
+/* Add a child at position i overwriting the old value.
+ * Update the value of full_children and empty_children.
+ */
+static void put_child(struct tnode *tn, unsigned long i, struct tnode *n)
 {
        struct tnode *chi = rtnl_dereference(tn->child[i]);
-       int isfull;
+       int isfull, wasfull;
 
        BUG_ON(i >= tnode_child_length(tn));
 
                tn->empty_children--;
 
        /* update fullChildren */
-       if (wasfull == -1)
-               wasfull = tnode_full(tn, chi);
-
+       wasfull = tnode_full(tn, chi);
        isfull = tnode_full(tn, n);
+
        if (wasfull && !isfull)
                tn->full_children--;
        else if (!wasfull && isfull)
        node_free(tn);
 }
 
-static struct tnode *inflate(struct trie *t, struct tnode *oldtnode)
+static int inflate(struct trie *t, struct tnode *oldtnode)
 {
        unsigned long olen = tnode_child_length(oldtnode);
+       struct tnode *tp = node_parent(oldtnode);
        struct tnode *tn;
        unsigned long i;
        t_key m;
        pr_debug("In inflate\n");
 
        tn = tnode_new(oldtnode->key, oldtnode->pos - 1, oldtnode->bits + 1);
-
        if (!tn)
-               return ERR_PTR(-ENOMEM);
+               return -ENOMEM;
 
        /*
         * Preallocate and store tnodes before the actual work so we
                        put_child(left, j, rtnl_dereference(inode->child[j]));
                        put_child(right, j, rtnl_dereference(inode->child[j + size]));
                }
-               put_child(tn, 2*i, resize(t, left));
-               put_child(tn, 2*i+1, resize(t, right));
+
+               put_child(tn, 2 * i, left);
+               put_child(tn, 2 * i + 1, right);
 
                tnode_free_safe(inode);
+
+               resize(t, left);
+               resize(t, right);
        }
+
+       put_child_root(tp, t, tn->key, tn);
        tnode_free_safe(oldtnode);
-       return tn;
+       return 0;
 nomem:
        tnode_clean_free(tn);
-       return ERR_PTR(-ENOMEM);
+       return -ENOMEM;
 }
 
-static struct tnode *halve(struct trie *t, struct tnode *oldtnode)
+static int halve(struct trie *t, struct tnode *oldtnode)
 {
        unsigned long olen = tnode_child_length(oldtnode);
+       struct tnode *tp = node_parent(oldtnode);
        struct tnode *tn, *left, *right;
        int i;
 
        pr_debug("In halve\n");
 
        tn = tnode_new(oldtnode->key, oldtnode->pos + 1, oldtnode->bits - 1);
-
        if (!tn)
-               return ERR_PTR(-ENOMEM);
+               return -ENOMEM;
 
        /*
         * Preallocate and store tnodes before the actual work so we
 
                        newn = tnode_new(left->key, oldtnode->pos, 1);
 
-                       if (!newn)
-                               goto nomem;
+                       if (!newn) {
+                               tnode_clean_free(tn);
+                               return -ENOMEM;
+                       }
 
                        put_child(tn, i/2, newn);
                }
 
                /* Two nonempty children */
                newBinNode = tnode_get_child(tn, i/2);
-               put_child(tn, i/2, NULL);
                put_child(newBinNode, 0, left);
                put_child(newBinNode, 1, right);
-               put_child(tn, i/2, resize(t, newBinNode));
+
+               put_child(tn, i / 2, newBinNode);
+
+               resize(t, newBinNode);
        }
+
+       put_child_root(tp, t, tn->key, tn);
        tnode_free_safe(oldtnode);
-       return tn;
-nomem:
-       tnode_clean_free(tn);
-       return ERR_PTR(-ENOMEM);
+
+       return 0;
 }
 
 /* From "Implementing a dynamic compressed trie" by Stefan Nilsson of
  *    tnode_child_length(tn)
  *
  */
-static bool should_inflate(const struct tnode *tn)
+static bool should_inflate(const struct tnode *tp, const struct tnode *tn)
 {
        unsigned long used = tnode_child_length(tn);
        unsigned long threshold = used;
 
        /* Keep root node larger */
-       threshold *= node_parent(tn) ? inflate_threshold :
-                                      inflate_threshold_root;
+       threshold *= tp ? inflate_threshold : inflate_threshold_root;
        used += tn->full_children;
        used -= tn->empty_children;
 
        return tn->pos && ((50 * used) >= threshold);
 }
 
-static bool should_halve(const struct tnode *tn)
+static bool should_halve(const struct tnode *tp, const struct tnode *tn)
 {
        unsigned long used = tnode_child_length(tn);
        unsigned long threshold = used;
 
        /* Keep root node larger */
-       threshold *= node_parent(tn) ? halve_threshold :
-                                      halve_threshold_root;
+       threshold *= tp ? halve_threshold : halve_threshold_root;
        used -= tn->empty_children;
 
        return (tn->bits > 1) && ((100 * used) < threshold);
 }
 
 #define MAX_WORK 10
-static struct tnode *resize(struct trie *t, struct tnode *tn)
+static void resize(struct trie *t, struct tnode *tn)
 {
-       struct tnode *old_tn, *n = NULL;
+       struct tnode *tp = node_parent(tn), *n = NULL;
+       struct tnode __rcu **cptr;
        int max_work;
 
-       if (!tn)
-               return NULL;
-
        pr_debug("In tnode_resize %p inflate_threshold=%d threshold=%d\n",
                 tn, inflate_threshold, halve_threshold);
 
+       /* track the tnode via the pointer from the parent instead of
+        * doing it ourselves.  This way we can let RCU fully do its
+        * thing without us interfering
+        */
+       cptr = tp ? &tp->child[get_index(tn->key, tp)] : &t->trie;
+       BUG_ON(tn != rtnl_dereference(*cptr));
+
        /* No children */
        if (tn->empty_children > (tnode_child_length(tn) - 1))
                goto no_children;
         * nonempty nodes that are above the threshold.
         */
        max_work = MAX_WORK;
-       while (should_inflate(tn) && max_work--) {
-               old_tn = tn;
-               tn = inflate(t, tn);
-
-               if (IS_ERR(tn)) {
-                       tn = old_tn;
+       while (should_inflate(tp, tn) && max_work--) {
+               if (inflate(t, tn)) {
 #ifdef CONFIG_IP_FIB_TRIE_STATS
                        this_cpu_inc(t->stats->resize_node_skipped);
 #endif
                        break;
                }
+
+               tn = rtnl_dereference(*cptr);
        }
 
        /* Return if at least one inflate is run */
        if (max_work != MAX_WORK)
-               return tn;
+               return;
 
        /* Halve as long as the number of empty children in this
         * node is above threshold.
         */
        max_work = MAX_WORK;
-       while (should_halve(tn) && max_work--) {
-               old_tn = tn;
-               tn = halve(t, tn);
-               if (IS_ERR(tn)) {
-                       tn = old_tn;
+       while (should_halve(tp, tn) && max_work--) {
+               if (halve(t, tn)) {
 #ifdef CONFIG_IP_FIB_TRIE_STATS
                        this_cpu_inc(t->stats->resize_node_skipped);
 #endif
                        break;
                }
-       }
 
+               tn = rtnl_dereference(*cptr);
+       }
 
        /* Only one child remains */
        if (tn->empty_children == (tnode_child_length(tn) - 1)) {
                        n = tnode_get_child(tn, --i);
 no_children:
                /* compress one level */
-               node_set_parent(n, NULL);
+               put_child_root(tp, t, tn->key, n);
+               node_set_parent(n, tp);
+
+               /* drop dead node */
                tnode_free_safe(tn);
-               return n;
        }
-       return tn;
 }
 
 /* readside must use rcu_read_lock currently dump routines
 
 static void trie_rebalance(struct trie *t, struct tnode *tn)
 {
-       int wasfull;
-       t_key cindex, key;
        struct tnode *tp;
 
-       key = tn->key;
-
-       while (tn != NULL && (tp = node_parent(tn)) != NULL) {
-               cindex = get_index(key, tp);
-               wasfull = tnode_full(tp, tnode_get_child(tp, cindex));
-               tn = resize(t, tn);
-
-               tnode_put_child_reorg(tp, cindex, tn, wasfull);
-
-               tp = node_parent(tn);
-               if (!tp)
-                       rcu_assign_pointer(t->trie, tn);
+       while ((tp = node_parent(tn)) != NULL) {
+               resize(t, tn);
 
                tnode_free_flush();
-               if (!tp)
-                       break;
                tn = tp;
        }
 
        /* Handle last (top) tnode */
        if (IS_TNODE(tn))
-               tn = resize(t, tn);
+               resize(t, tn);
 
-       rcu_assign_pointer(t->trie, tn);
        tnode_free_flush();
 }