*     Extend a radix tree so it can store key @index.
  */
 static int radix_tree_extend(struct radix_tree_root *root,
-                               unsigned long index, unsigned order)
+                               unsigned long index)
 {
        struct radix_tree_node *node;
        struct radix_tree_node *slot;
        while (index > radix_tree_maxindex(height))
                height++;
 
-       if ((root->rnode == NULL) && (order == 0)) {
+       if (root->rnode == NULL) {
                root->height = height;
                goto out;
        }
                node->count = 1;
                node->parent = NULL;
                slot = root->rnode;
-               if (radix_tree_is_indirect_ptr(slot) && newheight > 1) {
+               if (radix_tree_is_indirect_ptr(slot)) {
                        slot = indirect_to_ptr(slot);
                        slot->parent = node;
                        slot = ptr_to_indirect(slot);
                root->height = newheight;
        } while (height > root->height);
 out:
-       return 0;
+       return height * RADIX_TREE_MAP_SHIFT;
 }
 
 /**
                        void ***slotp)
 {
        struct radix_tree_node *node = NULL, *slot;
+       unsigned long maxindex;
        unsigned int height, shift, offset;
-       int error;
+       unsigned long max = index | ((1UL << order) - 1);
+
+       shift = radix_tree_load_root(root, &slot, &maxindex);
 
        /* Make sure the tree is high enough.  */
-       if (index > radix_tree_maxindex(root->height)) {
-               error = radix_tree_extend(root, index, order);
-               if (error)
+       if (max > maxindex) {
+               int error = radix_tree_extend(root, max);
+               if (error < 0)
                        return error;
+               shift = error;
+               slot = root->rnode;
+               if (order == shift) {
+                       shift += RADIX_TREE_MAP_SHIFT;
+                       root->height++;
+               }
        }
 
-       slot = root->rnode;
-
        height = root->height;
-       shift = height * RADIX_TREE_MAP_SHIFT;
 
        offset = 0;                     /* uninitialised var warning */
        while (shift > order) {