static u64 mlx5_imr_ksm_entries;
 
-static int check_parent(struct ib_umem_odp *odp,
-                              struct mlx5_ib_mr *parent)
+void mlx5_odp_populate_klm(struct mlx5_klm *pklm, size_t idx, size_t nentries,
+                          struct mlx5_ib_mr *imr, int flags)
 {
-       struct mlx5_ib_mr *mr = odp->private;
-
-       return mr && mr->parent == parent && !odp->dying;
-}
-
-static struct ib_ucontext_per_mm *mr_to_per_mm(struct mlx5_ib_mr *mr)
-{
-       if (WARN_ON(!mr || !is_odp_mr(mr)))
-               return NULL;
-
-       return to_ib_umem_odp(mr->umem)->per_mm;
-}
-
-static struct ib_umem_odp *odp_next(struct ib_umem_odp *odp)
-{
-       struct mlx5_ib_mr *mr = odp->private, *parent = mr->parent;
-       struct ib_ucontext_per_mm *per_mm = odp->per_mm;
-       struct rb_node *rb;
-
-       down_read(&per_mm->umem_rwsem);
-       while (1) {
-               rb = rb_next(&odp->interval_tree.rb);
-               if (!rb)
-                       goto not_found;
-               odp = rb_entry(rb, struct ib_umem_odp, interval_tree.rb);
-               if (check_parent(odp, parent))
-                       goto end;
-       }
-not_found:
-       odp = NULL;
-end:
-       up_read(&per_mm->umem_rwsem);
-       return odp;
-}
-
-static struct ib_umem_odp *odp_lookup(u64 start, u64 length,
-                                     struct mlx5_ib_mr *parent)
-{
-       struct ib_ucontext_per_mm *per_mm = mr_to_per_mm(parent);
-       struct ib_umem_odp *odp;
-       struct rb_node *rb;
-
-       down_read(&per_mm->umem_rwsem);
-       odp = rbt_ib_umem_lookup(&per_mm->umem_tree, start, length);
-       if (!odp)
-               goto end;
-
-       while (1) {
-               if (check_parent(odp, parent))
-                       goto end;
-               rb = rb_next(&odp->interval_tree.rb);
-               if (!rb)
-                       goto not_found;
-               odp = rb_entry(rb, struct ib_umem_odp, interval_tree.rb);
-               if (ib_umem_start(odp) > start + length)
-                       goto not_found;
-       }
-not_found:
-       odp = NULL;
-end:
-       up_read(&per_mm->umem_rwsem);
-       return odp;
-}
-
-void mlx5_odp_populate_klm(struct mlx5_klm *pklm, size_t offset,
-                          size_t nentries, struct mlx5_ib_mr *mr, int flags)
-{
-       struct ib_pd *pd = mr->ibmr.pd;
-       struct mlx5_ib_dev *dev = to_mdev(pd->device);
-       struct ib_umem_odp *odp;
-       unsigned long va;
-       int i;
+       struct mlx5_klm *end = pklm + nentries;
 
        if (flags & MLX5_IB_UPD_XLT_ZAP) {
-               for (i = 0; i < nentries; i++, pklm++) {
+               for (; pklm != end; pklm++, idx++) {
                        pklm->bcount = cpu_to_be32(MLX5_IMR_MTT_SIZE);
-                       pklm->key = cpu_to_be32(dev->null_mkey);
+                       pklm->key = cpu_to_be32(imr->dev->null_mkey);
                        pklm->va = 0;
                }
                return;
        }
 
        /*
-        * The locking here is pretty subtle. Ideally the implicit children
-        * list would be protected by the umem_mutex, however that is not
+        * The locking here is pretty subtle. Ideally the implicit_children
+        * xarray would be protected by the umem_mutex, however that is not
         * possible. Instead this uses a weaker update-then-lock pattern:
         *
         *  srcu_read_lock()
-        *    <change children list>
+        *    xa_store()
         *    mutex_lock(umem_mutex)
         *     mlx5_ib_update_xlt()
         *    mutex_unlock(umem_mutex)
         *    destroy lkey
         *
-        * ie any change the children list must be followed by the locked
-        * update_xlt before destroying.
+        * ie any change the xarray must be followed by the locked update_xlt
+        * before destroying.
         *
         * The umem_mutex provides the acquire/release semantic needed to make
-        * the children list visible to a racing thread. While SRCU is not
+        * the xa_store() visible to a racing thread. While SRCU is not
         * technically required, using it gives consistent use of the SRCU
-        * locking around the children list.
+        * locking around the xarray.
         */
-       lockdep_assert_held(&to_ib_umem_odp(mr->umem)->umem_mutex);
-       lockdep_assert_held(&mr->dev->odp_srcu);
+       lockdep_assert_held(&to_ib_umem_odp(imr->umem)->umem_mutex);
+       lockdep_assert_held(&imr->dev->odp_srcu);
 
-       odp = odp_lookup(offset * MLX5_IMR_MTT_SIZE,
-                        nentries * MLX5_IMR_MTT_SIZE, mr);
+       for (; pklm != end; pklm++, idx++) {
+               struct mlx5_ib_mr *mtt = xa_load(&imr->implicit_children, idx);
 
-       for (i = 0; i < nentries; i++, pklm++) {
                pklm->bcount = cpu_to_be32(MLX5_IMR_MTT_SIZE);
-               va = (offset + i) * MLX5_IMR_MTT_SIZE;
-               if (odp && ib_umem_start(odp) == va) {
-                       struct mlx5_ib_mr *mtt = odp->private;
-
+               if (mtt) {
                        pklm->key = cpu_to_be32(mtt->ibmr.lkey);
-                       pklm->va = cpu_to_be64(va);
-                       odp = odp_next(odp);
+                       pklm->va = cpu_to_be64(idx * MLX5_IMR_MTT_SIZE);
                } else {
-                       pklm->key = cpu_to_be32(dev->null_mkey);
+                       pklm->key = cpu_to_be32(imr->dev->null_mkey);
                        pklm->va = 0;
                }
-               mlx5_ib_dbg(dev, "[%d] va %lx key %x\n",
-                           i, va, be32_to_cpu(pklm->key));
        }
 }
 
 
        if (unlikely(!umem_odp->npages && mr->parent &&
                     !umem_odp->dying)) {
+               xa_erase(&mr->parent->implicit_children,
+                        ib_umem_start(umem_odp) >> MLX5_IMR_MTT_SHIFT);
                xa_erase(&mr->dev->odp_mkeys, mlx5_base_mkey(mr->mmkey.key));
                umem_odp->dying = 1;
                atomic_inc(&mr->parent->num_leaf_free);
                goto out_release;
        }
 
+       /*
+        * Once the store to either xarray completes any error unwind has to
+        * use synchronize_srcu(). Avoid this with xa_reserve()
+        */
+       err = xa_err(xa_store(&imr->implicit_children, idx, mr, GFP_KERNEL));
+       if (err) {
+               ret = ERR_PTR(err);
+               goto out_release;
+       }
+
        xa_store(&imr->dev->odp_mkeys, mlx5_base_mkey(mr->mmkey.key),
                 &mr->mmkey, GFP_ATOMIC);
 
        return ret;
 }
 
-static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *imr,
+static struct mlx5_ib_mr *implicit_mr_get_data(struct mlx5_ib_mr *imr,
                                                u64 io_virt, size_t bcnt)
 {
        struct ib_umem_odp *odp_imr = to_ib_umem_odp(imr->umem);
        unsigned long idx = io_virt >> MLX5_IMR_MTT_SHIFT;
        unsigned long inv_start_idx = end_idx + 1;
        unsigned long inv_len = 0;
-       struct ib_umem_odp *result = NULL;
-       struct ib_umem_odp *odp;
+       struct mlx5_ib_mr *result = NULL;
        int ret;
 
        mutex_lock(&odp_imr->umem_mutex);
-       odp = odp_lookup(idx * MLX5_IMR_MTT_SIZE, 1, imr);
        for (idx = idx; idx <= end_idx; idx++) {
-               if (unlikely(!odp)) {
-                       struct mlx5_ib_mr *mtt;
+               struct mlx5_ib_mr *mtt = xa_load(&imr->implicit_children, idx);
 
+               if (unlikely(!mtt)) {
                        mtt = implicit_get_child_mr(imr, idx);
                        if (IS_ERR(mtt)) {
-                               result = ERR_CAST(mtt);
+                               result = mtt;
                                goto out;
                        }
-                       odp = to_ib_umem_odp(mtt->umem);
                        inv_start_idx = min(inv_start_idx, idx);
                        inv_len = idx - inv_start_idx + 1;
                }
 
                /* Return first odp if region not covered by single one */
                if (likely(!result))
-                       result = odp;
-
-               odp = odp_next(odp);
-               if (odp && ib_umem_start(odp) != idx * MLX5_IMR_MTT_SIZE)
-                       odp = NULL;
+                       result = mtt;
        }
 
        /*
-        * Any time the children in the interval tree are changed we must
-        * perform an update of the xlt before exiting to ensure the HW and
-        * the tree remains synchronized.
+        * Any time the implicit_children are changed we must perform an
+        * update of the xlt before exiting to ensure the HW and the
+        * implicit_children remains synchronized.
         */
 out:
        if (likely(!inv_len))
        init_waitqueue_head(&imr->q_leaf_free);
        atomic_set(&imr->num_leaf_free, 0);
        atomic_set(&imr->num_pending_prefetch, 0);
+       xa_init(&imr->implicit_children);
 
        err = mlx5_ib_update_xlt(imr, 0,
                                 mlx5_imr_ksm_entries,
 
 void mlx5_ib_free_implicit_mr(struct mlx5_ib_mr *imr)
 {
-       struct ib_ucontext_per_mm *per_mm = mr_to_per_mm(imr);
-       struct rb_node *node;
+       struct ib_umem_odp *odp_imr = to_ib_umem_odp(imr->umem);
+       struct mlx5_ib_mr *mtt;
+       unsigned long idx;
 
-       down_read(&per_mm->umem_rwsem);
-       for (node = rb_first_cached(&per_mm->umem_tree); node;
-            node = rb_next(node)) {
-               struct ib_umem_odp *umem_odp =
-                       rb_entry(node, struct ib_umem_odp, interval_tree.rb);
-               struct mlx5_ib_mr *mr = umem_odp->private;
+       mutex_lock(&odp_imr->umem_mutex);
+       xa_for_each (&imr->implicit_children, idx, mtt) {
+               struct ib_umem_odp *umem_odp = to_ib_umem_odp(mtt->umem);
 
-               if (mr->parent != imr)
-                       continue;
+               xa_erase(&imr->implicit_children, idx);
 
                mutex_lock(&umem_odp->umem_mutex);
                ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp),
                schedule_work(&umem_odp->work);
                mutex_unlock(&umem_odp->umem_mutex);
        }
-       up_read(&per_mm->umem_rwsem);
+       mutex_unlock(&odp_imr->umem_mutex);
 
        wait_event(imr->q_leaf_free, !atomic_read(&imr->num_leaf_free));
+       WARN_ON(!xa_empty(&imr->implicit_children));
+       /* Remove any left over reserved elements */
+       xa_destroy(&imr->implicit_children);
 }
 
 #define MLX5_PF_FLAGS_DOWNGRADE BIT(1)
                        u32 *bytes_mapped, u32 flags)
 {
        struct ib_umem_odp *odp = to_ib_umem_odp(mr->umem);
-       struct ib_umem_odp *child;
+       struct mlx5_ib_mr *mtt;
        int npages = 0;
 
        if (!odp->is_implicit_odp) {
                     mlx5_imr_ksm_entries * MLX5_IMR_MTT_SIZE - io_virt < bcnt))
                return -EFAULT;
 
-       child = implicit_mr_get_data(mr, io_virt, bcnt);
-       if (IS_ERR(child))
-               return PTR_ERR(child);
+       mtt = implicit_mr_get_data(mr, io_virt, bcnt);
+       if (IS_ERR(mtt))
+               return PTR_ERR(mtt);
 
        /* Fault each child mr that intersects with our interval. */
        while (bcnt) {
-               u64 end = min_t(u64, io_virt + bcnt, ib_umem_end(child));
+               struct ib_umem_odp *umem_odp = to_ib_umem_odp(mtt->umem);
+               u64 end = min_t(u64, io_virt + bcnt, ib_umem_end(umem_odp));
                u64 len = end - io_virt;
                int ret;
 
-               ret = pagefault_real_mr(child->private, child, io_virt, len,
+               ret = pagefault_real_mr(mtt, umem_odp, io_virt, len,
                                        bytes_mapped, flags);
                if (ret < 0)
                        return ret;
                npages += ret;
 
                if (unlikely(bcnt)) {
-                       child = odp_next(child);
+                       mtt = xa_load(&mr->implicit_children,
+                                     io_virt >> MLX5_IMR_MTT_SHIFT);
+
                        /*
                         * implicit_mr_get_data sets up all the leaves, this
                         * means they got invalidated before we got to them.
                         */
-                       if (!child || ib_umem_start(child) != io_virt) {
+                       if (!mtt) {
                                mlx5_ib_dbg(
                                        mr->dev,
                                        "next implicit leaf removed at 0x%llx.\n",