#define MAX_STAT_DEPTH 32
 
-#define KEYLENGTH (8*sizeof(t_key))
+#define KEYLENGTH      (8*sizeof(t_key))
+#define KEY_MAX                ((t_key)~0)
 
 typedef unsigned int t_key;
 
        union {
                /* The fields in this struct are valid if bits > 0 (TNODE) */
                struct {
-                       unsigned int full_children;  /* KEYLENGTH bits needed */
-                       unsigned int empty_children; /* KEYLENGTH bits needed */
+                       t_key empty_children; /* KEYLENGTH bits needed */
+                       t_key full_children;  /* KEYLENGTH bits needed */
                        struct tnode __rcu *child[0];
                };
                /* This list pointer if valid if bits == 0 (LEAF) */
                return vzalloc(size);
 }
 
+static inline void empty_child_inc(struct tnode *n)
+{
+       ++n->empty_children ? : ++n->full_children;
+}
+
+static inline void empty_child_dec(struct tnode *n)
+{
+       n->empty_children-- ? : n->full_children--;
+}
+
 static struct tnode *leaf_new(t_key key)
 {
        struct tnode *l = kmem_cache_alloc(trie_leaf_kmem, GFP_KERNEL);
 
 static struct tnode *tnode_new(t_key key, int pos, int bits)
 {
-       size_t sz = offsetof(struct tnode, child[1 << bits]);
+       size_t sz = offsetof(struct tnode, child[1ul << bits]);
        struct tnode *tn = tnode_alloc(sz);
        unsigned int shift = pos + bits;
 
                tn->pos = pos;
                tn->bits = bits;
                tn->key = (shift < KEYLENGTH) ? (key >> shift) << shift : 0;
-               tn->full_children = 0;
-               tn->empty_children = 1<<bits;
+               if (bits == KEYLENGTH)
+                       tn->full_children = 1;
+               else
+                       tn->empty_children = 1ul << bits;
        }
 
        pr_debug("AT %p s=%zu %zu\n", tn, sizeof(struct tnode),
 
        BUG_ON(i >= tnode_child_length(tn));
 
-       /* update emptyChildren */
+       /* update emptyChildren, overflow into fullChildren */
        if (n == NULL && chi != NULL)
-               tn->empty_children++;
-       else if (n != NULL && chi == NULL)
-               tn->empty_children--;
+               empty_child_inc(tn);
+       if (n != NULL && chi == NULL)
+               empty_child_dec(tn);
 
        /* update fullChildren */
        wasfull = tnode_full(tn, chi);
        return 0;
 }
 
+static void collapse(struct trie *t, struct tnode *oldtnode)
+{
+       struct tnode *n, *tp;
+       unsigned long i;
+
+       /* scan the tnode looking for that one child that might still exist */
+       for (n = NULL, i = tnode_child_length(oldtnode); !n && i;)
+               n = tnode_get_child(oldtnode, --i);
+
+       /* compress one level */
+       tp = node_parent(oldtnode);
+       put_child_root(tp, t, oldtnode->key, n);
+       node_set_parent(n, tp);
+
+       /* drop dead node */
+       node_free(oldtnode);
+}
+
 static unsigned char update_suffix(struct tnode *tn)
 {
        unsigned char slen = tn->pos;
 
        /* Keep root node larger */
        threshold *= tp ? inflate_threshold : inflate_threshold_root;
-       used += tn->full_children;
        used -= tn->empty_children;
+       used += tn->full_children;
 
-       return tn->pos && ((50 * used) >= threshold);
+       /* if bits == KEYLENGTH then pos = 0, and will fail below */
+
+       return (used > 1) && tn->pos && ((50 * used) >= threshold);
 }
 
 static bool should_halve(const struct tnode *tp, const struct tnode *tn)
        threshold *= tp ? halve_threshold : halve_threshold_root;
        used -= tn->empty_children;
 
-       return (tn->bits > 1) && ((100 * used) < threshold);
+       /* if bits == KEYLENGTH then used = 100% on wrap, and will fail below */
+
+       return (used > 1) && (tn->bits > 1) && ((100 * used) < threshold);
+}
+
+static bool should_collapse(const struct tnode *tn)
+{
+       unsigned long used = tnode_child_length(tn);
+
+       used -= tn->empty_children;
+
+       /* account for bits == KEYLENGTH case */
+       if ((tn->bits == KEYLENGTH) && tn->full_children)
+               used -= KEY_MAX;
+
+       /* One child or none, time to drop us from the trie */
+       return used < 2;
 }
 
 #define MAX_WORK 10
 static void resize(struct trie *t, struct tnode *tn)
 {
-       struct tnode *tp = node_parent(tn), *n = NULL;
+       struct tnode *tp = node_parent(tn);
        struct tnode __rcu **cptr;
        int max_work = MAX_WORK;
 
        cptr = tp ? &tp->child[get_index(tn->key, tp)] : &t->trie;
        BUG_ON(tn != rtnl_dereference(*cptr));
 
-       /* No children */
-       if (tn->empty_children > (tnode_child_length(tn) - 1))
-               goto no_children;
-
-       /* One child */
-       if (tn->empty_children == (tnode_child_length(tn) - 1))
-               goto one_child;
-
        /* Double as long as the resulting node has a number of
         * nonempty nodes that are above the threshold.
         */
        }
 
        /* Only one child remains */
-       if (tn->empty_children == (tnode_child_length(tn) - 1)) {
-               unsigned long i;
-one_child:
-               for (i = tnode_child_length(tn); !n && i;)
-                       n = tnode_get_child(tn, --i);
-no_children:
-               /* compress one level */
-               put_child_root(tp, t, tn->key, n);
-               node_set_parent(n, tp);
-
-               /* drop dead node */
-               tnode_free_init(tn);
-               tnode_free(tn);
+       if (should_collapse(tn)) {
+               collapse(t, tn);
                return;
        }