/* objects protected by lock */
        struct mutex            lock;
-       struct rb_root          objects;
+       struct rb_root_cached   objects;
 };
 
 struct amdgpu_mn_node {
        mutex_lock(&adev->mn_lock);
        mutex_lock(&rmn->lock);
        hash_del(&rmn->node);
-       rbtree_postorder_for_each_entry_safe(node, next_node, &rmn->objects,
-                                            it.rb) {
+       rbtree_postorder_for_each_entry_safe(node, next_node,
+                                            &rmn->objects.rb_root, it.rb) {
                list_for_each_entry_safe(bo, next_bo, &node->bos, mn_list) {
                        bo->mn = NULL;
                        list_del_init(&bo->mn_list);
        rmn->mm = mm;
        rmn->mn.ops = &amdgpu_mn_ops;
        mutex_init(&rmn->lock);
-       rmn->objects = RB_ROOT;
+       rmn->objects = RB_ROOT_CACHED;
 
        r = __mmu_notifier_register(&rmn->mn, mm);
        if (r)
 
        u64 flags;
        uint64_t init_pde_value = 0;
 
-       vm->va = RB_ROOT;
+       vm->va = RB_ROOT_CACHED;
        vm->client_id = atomic64_inc_return(&adev->vm_manager.client_counter);
        for (i = 0; i < AMDGPU_MAX_VMHUBS; i++)
                vm->reserved_vmid[i] = NULL;
 
        amd_sched_entity_fini(vm->entity.sched, &vm->entity);
 
-       if (!RB_EMPTY_ROOT(&vm->va)) {
+       if (!RB_EMPTY_ROOT(&vm->va.rb_root)) {
                dev_err(adev->dev, "still active bo inside vm\n");
        }
-       rbtree_postorder_for_each_entry_safe(mapping, tmp, &vm->va, rb) {
+       rbtree_postorder_for_each_entry_safe(mapping, tmp,
+                                            &vm->va.rb_root, rb) {
                list_del(&mapping->list);
                amdgpu_vm_it_remove(mapping, &vm->va);
                kfree(mapping);
 
 
 struct amdgpu_vm {
        /* tree of virtual addresses mapped */
-       struct rb_root          va;
+       struct rb_root_cached   va;
 
        /* protecting invalidated */
        spinlock_t              status_lock;
 
 struct drm_mm_node *
 __drm_mm_interval_first(const struct drm_mm *mm, u64 start, u64 last)
 {
-       return drm_mm_interval_tree_iter_first((struct rb_root *)&mm->interval_tree,
+       return drm_mm_interval_tree_iter_first((struct rb_root_cached *)&mm->interval_tree,
                                               start, last) ?: (struct drm_mm_node *)&mm->head_node;
 }
 EXPORT_SYMBOL(__drm_mm_interval_first);
        struct drm_mm *mm = hole_node->mm;
        struct rb_node **link, *rb;
        struct drm_mm_node *parent;
+       bool leftmost = true;
 
        node->__subtree_last = LAST(node);
 
 
                rb = &hole_node->rb;
                link = &hole_node->rb.rb_right;
+               leftmost = false;
        } else {
                rb = NULL;
-               link = &mm->interval_tree.rb_node;
+               link = &mm->interval_tree.rb_root.rb_node;
        }
 
        while (*link) {
                        parent->__subtree_last = node->__subtree_last;
                if (node->start < parent->start)
                        link = &parent->rb.rb_left;
-               else
+               else {
                        link = &parent->rb.rb_right;
+                       leftmost = true;
+               }
        }
 
        rb_link_node(&node->rb, rb, link);
-       rb_insert_augmented(&node->rb,
-                           &mm->interval_tree,
-                           &drm_mm_interval_tree_augment);
+       rb_insert_augmented_cached(&node->rb, &mm->interval_tree, leftmost,
+                                  &drm_mm_interval_tree_augment);
 }
 
 #define RB_INSERT(root, member, expr) do { \
        *new = *old;
 
        list_replace(&old->node_list, &new->node_list);
-       rb_replace_node(&old->rb, &new->rb, &old->mm->interval_tree);
+       rb_replace_node(&old->rb, &new->rb, &old->mm->interval_tree.rb_root);
 
        if (drm_mm_hole_follows(old)) {
                list_replace(&old->hole_stack, &new->hole_stack);
        mm->color_adjust = NULL;
 
        INIT_LIST_HEAD(&mm->hole_stack);
-       mm->interval_tree = RB_ROOT;
+       mm->interval_tree = RB_ROOT_CACHED;
        mm->holes_size = RB_ROOT;
        mm->holes_addr = RB_ROOT;
 
 
        struct rb_node *iter;
        unsigned long offset;
 
-       iter = mgr->vm_addr_space_mm.interval_tree.rb_node;
+       iter = mgr->vm_addr_space_mm.interval_tree.rb_root.rb_node;
        best = NULL;
 
        while (likely(iter)) {
 
        spinlock_t lock;
        struct hlist_node node;
        struct mmu_notifier mn;
-       struct rb_root objects;
+       struct rb_root_cached objects;
        struct workqueue_struct *wq;
 };
 
        struct interval_tree_node *it;
        LIST_HEAD(cancelled);
 
-       if (RB_EMPTY_ROOT(&mn->objects))
+       if (RB_EMPTY_ROOT(&mn->objects.rb_root))
                return;
 
        /* interval ranges are inclusive, but invalidate range is exclusive */
 
        spin_lock_init(&mn->lock);
        mn->mn.ops = &i915_gem_userptr_notifier;
-       mn->objects = RB_ROOT;
+       mn->objects = RB_ROOT_CACHED;
        mn->wq = alloc_workqueue("i915-userptr-release", WQ_UNBOUND, 0);
        if (mn->wq == NULL) {
                kfree(mn);
 
 struct radeon_vm {
        struct mutex            mutex;
 
-       struct rb_root          va;
+       struct rb_root_cached   va;
 
        /* protecting invalidated and freed */
        spinlock_t              status_lock;
 
 
        /* objects protected by lock */
        struct mutex            lock;
-       struct rb_root          objects;
+       struct rb_root_cached   objects;
 };
 
 struct radeon_mn_node {
        mutex_lock(&rdev->mn_lock);
        mutex_lock(&rmn->lock);
        hash_del(&rmn->node);
-       rbtree_postorder_for_each_entry_safe(node, next_node, &rmn->objects,
-                                            it.rb) {
+       rbtree_postorder_for_each_entry_safe(node, next_node,
+                                            &rmn->objects.rb_root, it.rb) {
 
                interval_tree_remove(&node->it, &rmn->objects);
                list_for_each_entry_safe(bo, next_bo, &node->bos, mn_list) {
        rmn->mm = mm;
        rmn->mn.ops = &radeon_mn_ops;
        mutex_init(&rmn->lock);
-       rmn->objects = RB_ROOT;
+       rmn->objects = RB_ROOT_CACHED;
        
        r = __mmu_notifier_register(&rmn->mn, mm);
        if (r)
 
                vm->ids[i].last_id_use = NULL;
        }
        mutex_init(&vm->mutex);
-       vm->va = RB_ROOT;
+       vm->va = RB_ROOT_CACHED;
        spin_lock_init(&vm->status_lock);
        INIT_LIST_HEAD(&vm->invalidated);
        INIT_LIST_HEAD(&vm->freed);
        struct radeon_bo_va *bo_va, *tmp;
        int i, r;
 
-       if (!RB_EMPTY_ROOT(&vm->va)) {
+       if (!RB_EMPTY_ROOT(&vm->va.rb_root)) {
                dev_err(rdev->dev, "still active bo inside vm\n");
        }
-       rbtree_postorder_for_each_entry_safe(bo_va, tmp, &vm->va, it.rb) {
+       rbtree_postorder_for_each_entry_safe(bo_va, tmp,
+                                            &vm->va.rb_root, it.rb) {
                interval_tree_remove(&bo_va->it, &vm->va);
                r = radeon_bo_reserve(bo_va->bo, false);
                if (!r) {
 
 /* @last is not a part of the interval. See comment for function
  * node_last.
  */
-int rbt_ib_umem_for_each_in_range(struct rb_root *root,
+int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
                                  u64 start, u64 last,
                                  umem_call_back cb,
                                  void *cookie)
 }
 EXPORT_SYMBOL(rbt_ib_umem_for_each_in_range);
 
-struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root *root,
+struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root_cached *root,
                                       u64 addr, u64 length)
 {
        struct umem_odp_node *node;
 
        ucontext->closing = 0;
 
 #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-       ucontext->umem_tree = RB_ROOT;
+       ucontext->umem_tree = RB_ROOT_CACHED;
        init_rwsem(&ucontext->umem_rwsem);
        ucontext->odp_mrs_count = 0;
        INIT_LIST_HEAD(&ucontext->no_private_counters);
 
 
 struct mmu_rb_handler {
        struct mmu_notifier mn;
-       struct rb_root root;
+       struct rb_root_cached root;
        void *ops_arg;
        spinlock_t lock;        /* protect the RB tree */
        struct mmu_rb_ops *ops;
        if (!handlr)
                return -ENOMEM;
 
-       handlr->root = RB_ROOT;
+       handlr->root = RB_ROOT_CACHED;
        handlr->ops = ops;
        handlr->ops_arg = ops_arg;
        INIT_HLIST_NODE(&handlr->mn.hlist);
        INIT_LIST_HEAD(&del_list);
 
        spin_lock_irqsave(&handler->lock, flags);
-       while ((node = rb_first(&handler->root))) {
+       while ((node = rb_first_cached(&handler->root))) {
                rbnode = rb_entry(node, struct mmu_rb_node, node);
-               rb_erase(node, &handler->root);
+               rb_erase_cached(node, &handler->root);
                /* move from LRU list to delete list */
                list_move(&rbnode->list, &del_list);
        }
 {
        struct mmu_rb_handler *handler =
                container_of(mn, struct mmu_rb_handler, mn);
-       struct rb_root *root = &handler->root;
+       struct rb_root_cached *root = &handler->root;
        struct mmu_rb_node *node, *ptr = NULL;
        unsigned long flags;
        bool added = false;
 
        vpn_last = vpn_start + npages - 1;
 
        spin_lock(&pd->lock);
-       usnic_uiom_remove_interval(&pd->rb_root, vpn_start,
+       usnic_uiom_remove_interval(&pd->root, vpn_start,
                                        vpn_last, &rm_intervals);
        usnic_uiom_unmap_sorted_intervals(&rm_intervals, pd);
 
        err = usnic_uiom_get_intervals_diff(vpn_start, vpn_last,
                                                (writable) ? IOMMU_WRITE : 0,
                                                IOMMU_WRITE,
-                                               &pd->rb_root,
+                                               &pd->root,
                                                &sorted_diff_intervals);
        if (err) {
                usnic_err("Failed disjoint interval vpn [0x%lx,0x%lx] err %d\n",
 
        }
 
-       err = usnic_uiom_insert_interval(&pd->rb_root, vpn_start, vpn_last,
+       err = usnic_uiom_insert_interval(&pd->root, vpn_start, vpn_last,
                                        (writable) ? IOMMU_WRITE : 0);
        if (err) {
                usnic_err("Failed insert interval vpn [0x%lx,0x%lx] err %d\n",
 
 struct usnic_uiom_pd {
        struct iommu_domain             *domain;
        spinlock_t                      lock;
-       struct rb_root                  rb_root;
+       struct rb_root_cached           root;
        struct list_head                devs;
        int                             dev_cnt;
 };
 
 }
 
 static void
-find_intervals_intersection_sorted(struct rb_root *root, unsigned long start,
-                                       unsigned long last,
-                                       struct list_head *list)
+find_intervals_intersection_sorted(struct rb_root_cached *root,
+                                  unsigned long start, unsigned long last,
+                                  struct list_head *list)
 {
        struct usnic_uiom_interval_node *node;
 
 
 int usnic_uiom_get_intervals_diff(unsigned long start, unsigned long last,
                                        int flags, int flag_mask,
-                                       struct rb_root *root,
+                                       struct rb_root_cached *root,
                                        struct list_head *diff_set)
 {
        struct usnic_uiom_interval_node *interval, *tmp;
                kfree(interval);
 }
 
-int usnic_uiom_insert_interval(struct rb_root *root, unsigned long start,
+int usnic_uiom_insert_interval(struct rb_root_cached *root, unsigned long start,
                                unsigned long last, int flags)
 {
        struct usnic_uiom_interval_node *interval, *tmp;
        return err;
 }
 
-void usnic_uiom_remove_interval(struct rb_root *root, unsigned long start,
-                               unsigned long last, struct list_head *removed)
+void usnic_uiom_remove_interval(struct rb_root_cached *root,
+                               unsigned long start, unsigned long last,
+                               struct list_head *removed)
 {
        struct usnic_uiom_interval_node *interval;
 
 
 
 extern void
 usnic_uiom_interval_tree_insert(struct usnic_uiom_interval_node *node,
-                                       struct rb_root *root);
+                                       struct rb_root_cached *root);
 extern void
 usnic_uiom_interval_tree_remove(struct usnic_uiom_interval_node *node,
-                                       struct rb_root *root);
+                                       struct rb_root_cached *root);
 extern struct usnic_uiom_interval_node *
-usnic_uiom_interval_tree_iter_first(struct rb_root *root,
+usnic_uiom_interval_tree_iter_first(struct rb_root_cached *root,
                                        unsigned long start,
                                        unsigned long last);
 extern struct usnic_uiom_interval_node *
  * Inserts {start...last} into {root}.  If there are overlaps,
  * nodes will be broken up and merged
  */
-int usnic_uiom_insert_interval(struct rb_root *root,
+int usnic_uiom_insert_interval(struct rb_root_cached *root,
                                unsigned long start, unsigned long last,
                                int flags);
 /*
  * 'removed.' The caller is responsibile for freeing memory of nodes in
  * 'removed.'
  */
-void usnic_uiom_remove_interval(struct rb_root *root,
+void usnic_uiom_remove_interval(struct rb_root_cached *root,
                                unsigned long start, unsigned long last,
                                struct list_head *removed);
 /*
 int usnic_uiom_get_intervals_diff(unsigned long start,
                                        unsigned long last, int flags,
                                        int flag_mask,
-                                       struct rb_root *root,
+                                       struct rb_root_cached *root,
                                        struct list_head *diff_set);
 /* Call this to free diff_set returned by usnic_uiom_get_intervals_diff */
 void usnic_uiom_put_interval_set(struct list_head *intervals);
 
        if (!umem)
                return NULL;
 
-       umem->umem_tree = RB_ROOT;
+       umem->umem_tree = RB_ROOT_CACHED;
        umem->numem = 0;
        INIT_LIST_HEAD(&umem->umem_list);
 
 
 };
 
 struct vhost_umem {
-       struct rb_root umem_tree;
+       struct rb_root_cached umem_tree;
        struct list_head umem_list;
        int numem;
 };
 
 }
 
 static void
-hugetlb_vmdelete_list(struct rb_root *root, pgoff_t start, pgoff_t end)
+hugetlb_vmdelete_list(struct rb_root_cached *root, pgoff_t start, pgoff_t end)
 {
        struct vm_area_struct *vma;
 
 
        i_size_write(inode, offset);
        i_mmap_lock_write(mapping);
-       if (!RB_EMPTY_ROOT(&mapping->i_mmap))
+       if (!RB_EMPTY_ROOT(&mapping->i_mmap.rb_root))
                hugetlb_vmdelete_list(&mapping->i_mmap, pgoff, 0);
        i_mmap_unlock_write(mapping);
        remove_inode_hugepages(inode, offset, LLONG_MAX);
 
                inode_lock(inode);
                i_mmap_lock_write(mapping);
-               if (!RB_EMPTY_ROOT(&mapping->i_mmap))
+               if (!RB_EMPTY_ROOT(&mapping->i_mmap.rb_root))
                        hugetlb_vmdelete_list(&mapping->i_mmap,
                                                hole_start >> PAGE_SHIFT,
                                                hole_end  >> PAGE_SHIFT);
 
        init_rwsem(&mapping->i_mmap_rwsem);
        INIT_LIST_HEAD(&mapping->private_list);
        spin_lock_init(&mapping->private_lock);
-       mapping->i_mmap = RB_ROOT;
+       mapping->i_mmap = RB_ROOT_CACHED;
 }
 EXPORT_SYMBOL(address_space_init_once);
 
 
         * according to the (increasing) start address of the memory node. */
        struct drm_mm_node head_node;
        /* Keep an interval_tree for fast lookup of drm_mm_nodes by address. */
-       struct rb_root interval_tree;
+       struct rb_root_cached interval_tree;
        struct rb_root holes_size;
        struct rb_root holes_addr;
 
 
        struct radix_tree_root  page_tree;      /* radix tree of all pages */
        spinlock_t              tree_lock;      /* and lock protecting it */
        atomic_t                i_mmap_writable;/* count VM_SHARED mappings */
-       struct rb_root          i_mmap;         /* tree of private and shared mappings */
+       struct rb_root_cached   i_mmap;         /* tree of private and shared mappings */
        struct rw_semaphore     i_mmap_rwsem;   /* protect tree, count, list */
        /* Protected by tree_lock together with the radix tree */
        unsigned long           nrpages;        /* number of total pages */
  */
 static inline int mapping_mapped(struct address_space *mapping)
 {
-       return  !RB_EMPTY_ROOT(&mapping->i_mmap);
+       return  !RB_EMPTY_ROOT(&mapping->i_mmap.rb_root);
 }
 
 /*
 
 };
 
 extern void
-interval_tree_insert(struct interval_tree_node *node, struct rb_root *root);
+interval_tree_insert(struct interval_tree_node *node,
+                    struct rb_root_cached *root);
 
 extern void
-interval_tree_remove(struct interval_tree_node *node, struct rb_root *root);
+interval_tree_remove(struct interval_tree_node *node,
+                    struct rb_root_cached *root);
 
 extern struct interval_tree_node *
-interval_tree_iter_first(struct rb_root *root,
+interval_tree_iter_first(struct rb_root_cached *root,
                         unsigned long start, unsigned long last);
 
 extern struct interval_tree_node *
 
                                                                              \
 /* Insert / remove interval nodes from the tree */                           \
                                                                              \
-ITSTATIC void ITPREFIX ## _insert(ITSTRUCT *node, struct rb_root *root)              \
+ITSTATIC void ITPREFIX ## _insert(ITSTRUCT *node,                            \
+                                 struct rb_root_cached *root)                \
 {                                                                            \
-       struct rb_node **link = &root->rb_node, *rb_parent = NULL;            \
+       struct rb_node **link = &root->rb_root.rb_node, *rb_parent = NULL;    \
        ITTYPE start = ITSTART(node), last = ITLAST(node);                    \
        ITSTRUCT *parent;                                                     \
+       bool leftmost = true;                                                 \
                                                                              \
        while (*link) {                                                       \
                rb_parent = *link;                                            \
                        parent->ITSUBTREE = last;                             \
                if (start < ITSTART(parent))                                  \
                        link = &parent->ITRB.rb_left;                         \
-               else                                                          \
+               else {                                                        \
                        link = &parent->ITRB.rb_right;                        \
+                       leftmost = false;                                     \
+               }                                                             \
        }                                                                     \
                                                                              \
        node->ITSUBTREE = last;                                               \
        rb_link_node(&node->ITRB, rb_parent, link);                           \
-       rb_insert_augmented(&node->ITRB, root, &ITPREFIX ## _augment);        \
+       rb_insert_augmented_cached(&node->ITRB, root,                         \
+                                  leftmost, &ITPREFIX ## _augment);          \
 }                                                                            \
                                                                              \
-ITSTATIC void ITPREFIX ## _remove(ITSTRUCT *node, struct rb_root *root)              \
+ITSTATIC void ITPREFIX ## _remove(ITSTRUCT *node,                            \
+                                 struct rb_root_cached *root)                \
 {                                                                            \
-       rb_erase_augmented(&node->ITRB, root, &ITPREFIX ## _augment);         \
+       rb_erase_augmented_cached(&node->ITRB, root, &ITPREFIX ## _augment);  \
 }                                                                            \
                                                                              \
 /*                                                                           \
 }                                                                            \
                                                                              \
 ITSTATIC ITSTRUCT *                                                          \
-ITPREFIX ## _iter_first(struct rb_root *root, ITTYPE start, ITTYPE last)      \
+ITPREFIX ## _iter_first(struct rb_root_cached *root,                         \
+                       ITTYPE start, ITTYPE last)                            \
 {                                                                            \
-       ITSTRUCT *node;                                                       \
+       ITSTRUCT *node, *leftmost;                                            \
                                                                              \
-       if (!root->rb_node)                                                   \
+       if (!root->rb_root.rb_node)                                           \
                return NULL;                                                  \
-       node = rb_entry(root->rb_node, ITSTRUCT, ITRB);                       \
+                                                                             \
+       /*                                                                    \
+        * Fastpath range intersection/overlap between A: [a0, a1] and        \
+        * B: [b0, b1] is given by:                                           \
+        *                                                                    \
+        *         a0 <= b1 && b0 <= a1                                       \
+        *                                                                    \
+        *  ... where A holds the lock range and B holds the smallest         \
+        * 'start' and largest 'last' in the tree. For the later, we          \
+        * rely on the root node, which by augmented interval tree            \
+        * property, holds the largest value in its last-in-subtree.          \
+        * This allows mitigating some of the tree walk overhead for          \
+        * for non-intersecting ranges, maintained and consulted in O(1).     \
+        */                                                                   \
+       node = rb_entry(root->rb_root.rb_node, ITSTRUCT, ITRB);               \
        if (node->ITSUBTREE < start)                                          \
                return NULL;                                                  \
+                                                                             \
+       leftmost = rb_entry(root->rb_leftmost, ITSTRUCT, ITRB);               \
+       if (ITSTART(leftmost) > last)                                         \
+               return NULL;                                                  \
+                                                                             \
        return ITPREFIX ## _subtree_search(node, start, last);                \
 }                                                                            \
                                                                              \
 
 
 /* interval_tree.c */
 void vma_interval_tree_insert(struct vm_area_struct *node,
-                             struct rb_root *root);
+                             struct rb_root_cached *root);
 void vma_interval_tree_insert_after(struct vm_area_struct *node,
                                    struct vm_area_struct *prev,
-                                   struct rb_root *root);
+                                   struct rb_root_cached *root);
 void vma_interval_tree_remove(struct vm_area_struct *node,
-                             struct rb_root *root);
-struct vm_area_struct *vma_interval_tree_iter_first(struct rb_root *root,
+                             struct rb_root_cached *root);
+struct vm_area_struct *vma_interval_tree_iter_first(struct rb_root_cached *root,
                                unsigned long start, unsigned long last);
 struct vm_area_struct *vma_interval_tree_iter_next(struct vm_area_struct *node,
                                unsigned long start, unsigned long last);
             vma; vma = vma_interval_tree_iter_next(vma, start, last))
 
 void anon_vma_interval_tree_insert(struct anon_vma_chain *node,
-                                  struct rb_root *root);
+                                  struct rb_root_cached *root);
 void anon_vma_interval_tree_remove(struct anon_vma_chain *node,
-                                  struct rb_root *root);
-struct anon_vma_chain *anon_vma_interval_tree_iter_first(
-       struct rb_root *root, unsigned long start, unsigned long last);
+                                  struct rb_root_cached *root);
+struct anon_vma_chain *
+anon_vma_interval_tree_iter_first(struct rb_root_cached *root,
+                                 unsigned long start, unsigned long last);
 struct anon_vma_chain *anon_vma_interval_tree_iter_next(
        struct anon_vma_chain *node, unsigned long start, unsigned long last);
 #ifdef CONFIG_DEBUG_VM_RB
 
         * is serialized by a system wide lock only visible to
         * mm_take_all_locks() (mm_all_locks_mutex).
         */
-       struct rb_root rb_root; /* Interval tree of private "related" vmas */
+
+       /* Interval tree of private "related" vmas */
+       struct rb_root_cached rb_root;
 };
 
 /*
 
 void ib_umem_odp_unmap_dma_pages(struct ib_umem *umem, u64 start_offset,
                                 u64 bound);
 
-void rbt_ib_umem_insert(struct umem_odp_node *node, struct rb_root *root);
-void rbt_ib_umem_remove(struct umem_odp_node *node, struct rb_root *root);
+void rbt_ib_umem_insert(struct umem_odp_node *node,
+                       struct rb_root_cached *root);
+void rbt_ib_umem_remove(struct umem_odp_node *node,
+                       struct rb_root_cached *root);
 typedef int (*umem_call_back)(struct ib_umem *item, u64 start, u64 end,
                              void *cookie);
 /*
  * Call the callback on each ib_umem in the range. Returns the logical or of
  * the return values of the functions called.
  */
-int rbt_ib_umem_for_each_in_range(struct rb_root *root, u64 start, u64 end,
+int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
+                                 u64 start, u64 end,
                                  umem_call_back cb, void *cookie);
 
 /*
  * Find first region intersecting with address range.
  * Return NULL if not found
  */
-struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root *root,
+struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root_cached *root,
                                       u64 addr, u64 length);
 
 static inline int ib_umem_mmu_notifier_retry(struct ib_umem *item,
 
 
        struct pid             *tgid;
 #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-       struct rb_root      umem_tree;
+       struct rb_root_cached   umem_tree;
        /*
         * Protects .umem_rbroot and tree, as well as odp_mrs_count and
         * mmu notifiers registration.
 
 
 __param(uint, max_endpoint, ~0, "Largest value for the interval's endpoint");
 
-static struct rb_root root = RB_ROOT;
+static struct rb_root_cached root = RB_ROOT_CACHED;
 static struct interval_tree_node *nodes = NULL;
 static u32 *queries = NULL;
 
 static struct rnd_state rnd;
 
 static inline unsigned long
-search(struct rb_root *root, unsigned long start, unsigned long last)
+search(struct rb_root_cached *root, unsigned long start, unsigned long last)
 {
        struct interval_tree_node *node;
        unsigned long results = 0;
 
 /* Insert node immediately after prev in the interval tree */
 void vma_interval_tree_insert_after(struct vm_area_struct *node,
                                    struct vm_area_struct *prev,
-                                   struct rb_root *root)
+                                   struct rb_root_cached *root)
 {
        struct rb_node **link;
        struct vm_area_struct *parent;
 
        node->shared.rb_subtree_last = last;
        rb_link_node(&node->shared.rb, &parent->shared.rb, link);
-       rb_insert_augmented(&node->shared.rb, root,
+       rb_insert_augmented(&node->shared.rb, &root->rb_root,
                            &vma_interval_tree_augment);
 }
 
                     static inline, __anon_vma_interval_tree)
 
 void anon_vma_interval_tree_insert(struct anon_vma_chain *node,
-                                  struct rb_root *root)
+                                  struct rb_root_cached *root)
 {
 #ifdef CONFIG_DEBUG_VM_RB
        node->cached_vma_start = avc_start_pgoff(node);
 }
 
 void anon_vma_interval_tree_remove(struct anon_vma_chain *node,
-                                  struct rb_root *root)
+                                  struct rb_root_cached *root)
 {
        __anon_vma_interval_tree_remove(node, root);
 }
 
 struct anon_vma_chain *
-anon_vma_interval_tree_iter_first(struct rb_root *root,
+anon_vma_interval_tree_iter_first(struct rb_root_cached *root,
                                  unsigned long first, unsigned long last)
 {
        return __anon_vma_interval_tree_iter_first(root, first, last);
 
        zap_page_range_single(vma, start_addr, end_addr - start_addr, details);
 }
 
-static inline void unmap_mapping_range_tree(struct rb_root *root,
+static inline void unmap_mapping_range_tree(struct rb_root_cached *root,
                                            struct zap_details *details)
 {
        struct vm_area_struct *vma;
                details.last_index = ULONG_MAX;
 
        i_mmap_lock_write(mapping);
-       if (unlikely(!RB_EMPTY_ROOT(&mapping->i_mmap)))
+       if (unlikely(!RB_EMPTY_ROOT(&mapping->i_mmap.rb_root)))
                unmap_mapping_range_tree(&mapping->i_mmap, &details);
        i_mmap_unlock_write(mapping);
 }
 
        struct mm_struct *mm = vma->vm_mm;
        struct vm_area_struct *next = vma->vm_next, *orig_vma = vma;
        struct address_space *mapping = NULL;
-       struct rb_root *root = NULL;
+       struct rb_root_cached *root = NULL;
        struct anon_vma *anon_vma = NULL;
        struct file *file = vma->vm_file;
        bool start_changed = false, end_changed = false;
 
 static void vm_lock_anon_vma(struct mm_struct *mm, struct anon_vma *anon_vma)
 {
-       if (!test_bit(0, (unsigned long *) &anon_vma->root->rb_root.rb_node)) {
+       if (!test_bit(0, (unsigned long *) &anon_vma->root->rb_root.rb_root.rb_node)) {
                /*
                 * The LSB of head.next can't change from under us
                 * because we hold the mm_all_locks_mutex.
                 * anon_vma->root->rwsem.
                 */
                if (__test_and_set_bit(0, (unsigned long *)
-                                      &anon_vma->root->rb_root.rb_node))
+                                      &anon_vma->root->rb_root.rb_root.rb_node))
                        BUG();
        }
 }
 
 static void vm_unlock_anon_vma(struct anon_vma *anon_vma)
 {
-       if (test_bit(0, (unsigned long *) &anon_vma->root->rb_root.rb_node)) {
+       if (test_bit(0, (unsigned long *) &anon_vma->root->rb_root.rb_root.rb_node)) {
                /*
                 * The LSB of head.next can't change to 0 from under
                 * us because we hold the mm_all_locks_mutex.
                 * anon_vma->root->rwsem.
                 */
                if (!__test_and_clear_bit(0, (unsigned long *)
-                                         &anon_vma->root->rb_root.rb_node))
+                                         &anon_vma->root->rb_root.rb_root.rb_node))
                        BUG();
                anon_vma_unlock_write(anon_vma);
        }
 
                 * Leave empty anon_vmas on the list - we'll need
                 * to free them outside the lock.
                 */
-               if (RB_EMPTY_ROOT(&anon_vma->rb_root)) {
+               if (RB_EMPTY_ROOT(&anon_vma->rb_root.rb_root)) {
                        anon_vma->parent->degree--;
                        continue;
                }
 
        init_rwsem(&anon_vma->rwsem);
        atomic_set(&anon_vma->refcount, 0);
-       anon_vma->rb_root = RB_ROOT;
+       anon_vma->rb_root = RB_ROOT_CACHED;
 }
 
 void __init anon_vma_init(void)