* xarray would be protected by the umem_mutex, however that is not
         * possible. Instead this uses a weaker update-then-lock pattern:
         *
-        *  srcu_read_lock()
         *    xa_store()
         *    mutex_lock(umem_mutex)
         *     mlx5_ib_update_xlt()
         * before destroying.
         *
         * The umem_mutex provides the acquire/release semantic needed to make
-        * the xa_store() visible to a racing thread. While SRCU is not
-        * technically required, using it gives consistent use of the SRCU
-        * locking around the xarray.
+        * the xa_store() visible to a racing thread.
         */
        lockdep_assert_held(&to_ib_umem_odp(imr->umem)->umem_mutex);
-       lockdep_assert_held(&mr_to_mdev(imr)->odp_srcu);
 
        for (; pklm != end; pklm++, idx++) {
                struct mlx5_ib_mr *mtt = xa_load(&imr->implicit_children, idx);
 }
 
 /*
- * This must be called after the mr has been removed from implicit_children
- * and the SRCU synchronized.  NOTE: The MR does not necessarily have to be
+ * This must be called after the mr has been removed from implicit_children.
+ * NOTE: The MR does not necessarily have to be
  * empty here, parallel page faults could have raced with the free process and
  * added pages to it.
  */
        struct ib_umem_odp *odp_imr = to_ib_umem_odp(imr->umem);
        struct ib_umem_odp *odp = to_ib_umem_odp(mr->umem);
        unsigned long idx = ib_umem_start(odp) >> MLX5_IMR_MTT_SHIFT;
-       int srcu_key;
 
-       /* implicit_child_mr's are not allowed to have deferred work */
-       WARN_ON(atomic_read(&mr->num_deferred_work));
+       mlx5r_deref_wait_odp_mkey(&mr->mmkey);
 
        if (need_imr_xlt) {
-               srcu_key = srcu_read_lock(&mr_to_mdev(mr)->odp_srcu);
                mutex_lock(&odp_imr->umem_mutex);
                mlx5_ib_update_xlt(mr->parent, idx, 1, 0,
                                   MLX5_IB_UPD_XLT_INDIRECT |
                                   MLX5_IB_UPD_XLT_ATOMIC);
                mutex_unlock(&odp_imr->umem_mutex);
-               srcu_read_unlock(&mr_to_mdev(mr)->odp_srcu, srcu_key);
        }
 
        dma_fence_odp_mr(mr);
        mr->parent = NULL;
        mlx5_mr_cache_free(mr_to_mdev(mr), mr);
        ib_umem_odp_release(odp);
-       if (atomic_dec_and_test(&imr->num_deferred_work))
-               wake_up(&imr->q_deferred_work);
 }
 
 static void free_implicit_child_mr_work(struct work_struct *work)
 {
        struct mlx5_ib_mr *mr =
                container_of(work, struct mlx5_ib_mr, odp_destroy.work);
+       struct mlx5_ib_mr *imr = mr->parent;
 
        free_implicit_child_mr(mr, true);
-}
-
-static void free_implicit_child_mr_rcu(struct rcu_head *head)
-{
-       struct mlx5_ib_mr *mr =
-               container_of(head, struct mlx5_ib_mr, odp_destroy.rcu);
-
-       /* Freeing a MR is a sleeping operation, so bounce to a work queue */
-       INIT_WORK(&mr->odp_destroy.work, free_implicit_child_mr_work);
-       queue_work(system_unbound_wq, &mr->odp_destroy.work);
+       mlx5r_deref_odp_mkey(&imr->mmkey);
 }
 
 static void destroy_unused_implicit_child_mr(struct mlx5_ib_mr *mr)
        unsigned long idx = ib_umem_start(odp) >> MLX5_IMR_MTT_SHIFT;
        struct mlx5_ib_mr *imr = mr->parent;
 
-       xa_lock(&imr->implicit_children);
-       /*
-        * This can race with mlx5_ib_free_implicit_mr(), the first one to
-        * reach the xa lock wins the race and destroys the MR.
-        */
-       if (__xa_cmpxchg(&imr->implicit_children, idx, mr, NULL, GFP_ATOMIC) !=
-           mr)
-               goto out_unlock;
+       if (!refcount_inc_not_zero(&imr->mmkey.usecount))
+               return;
 
-       atomic_inc(&imr->num_deferred_work);
-       call_srcu(&mr_to_mdev(mr)->odp_srcu, &mr->odp_destroy.rcu,
-                 free_implicit_child_mr_rcu);
+       xa_erase(&imr->implicit_children, idx);
 
-out_unlock:
-       xa_unlock(&imr->implicit_children);
+       /* Freeing a MR is a sleeping operation, so bounce to a work queue */
+       INIT_WORK(&mr->odp_destroy.work, free_implicit_child_mr_work);
+       queue_work(system_unbound_wq, &mr->odp_destroy.work);
 }
 
 static bool mlx5_ib_invalidate_range(struct mmu_interval_notifier *mni,
        mr->parent = imr;
        odp->private = mr;
 
+       /*
+        * First refcount is owned by the xarray and second refconut
+        * is returned to the caller.
+        */
+       refcount_set(&mr->mmkey.usecount, 2);
+
        err = mlx5_ib_update_xlt(mr, 0,
                                 MLX5_IMR_MTT_ENTRIES,
                                 PAGE_SHIFT,
                goto out_mr;
        }
 
-       /*
-        * Once the store to either xarray completes any error unwind has to
-        * use synchronize_srcu(). Avoid this with xa_reserve()
-        */
-       ret = xa_cmpxchg(&imr->implicit_children, idx, NULL, mr,
-                        GFP_KERNEL);
+       xa_lock(&imr->implicit_children);
+       ret = __xa_cmpxchg(&imr->implicit_children, idx, NULL, mr,
+                          GFP_KERNEL);
        if (unlikely(ret)) {
                if (xa_is_err(ret)) {
                        ret = ERR_PTR(xa_err(ret));
-                       goto out_mr;
+                       goto out_lock;
                }
                /*
                 * Another thread beat us to creating the child mr, use
                 * theirs.
                 */
-               goto out_mr;
+               refcount_inc(&ret->mmkey.usecount);
+               goto out_lock;
        }
+       xa_unlock(&imr->implicit_children);
 
        mlx5_ib_dbg(mr_to_mdev(imr), "key %x mr %p\n", mr->mmkey.key, mr);
        return mr;
 
+out_lock:
+       xa_unlock(&imr->implicit_children);
 out_mr:
        mlx5_mr_cache_free(mr_to_mdev(imr), mr);
 out_umem:
        imr->ibmr.device = &dev->ib_dev;
        imr->umem = &umem_odp->umem;
        imr->is_odp_implicit = true;
-       atomic_set(&imr->num_deferred_work, 0);
-       init_waitqueue_head(&imr->q_deferred_work);
        xa_init(&imr->implicit_children);
 
        err = mlx5_ib_update_xlt(imr, 0,
        if (err)
                goto out_mr;
 
-       err = xa_err(xa_store(&dev->odp_mkeys, mlx5_base_mkey(imr->mmkey.key),
-                             &imr->mmkey, GFP_KERNEL));
+       err = mlx5r_store_odp_mkey(dev, &imr->mmkey);
        if (err)
                goto out_mr;
 
 {
        struct ib_umem_odp *odp_imr = to_ib_umem_odp(imr->umem);
        struct mlx5_ib_dev *dev = mr_to_mdev(imr);
-       struct list_head destroy_list;
        struct mlx5_ib_mr *mtt;
-       struct mlx5_ib_mr *tmp;
        unsigned long idx;
 
-       INIT_LIST_HEAD(&destroy_list);
-
        xa_erase(&dev->odp_mkeys, mlx5_base_mkey(imr->mmkey.key));
-       /*
-        * This stops the SRCU protected page fault path from touching either
-        * the imr or any children. The page fault path can only reach the
-        * children xarray via the imr.
-        */
-       synchronize_srcu(&dev->odp_srcu);
-
        /*
         * All work on the prefetch list must be completed, xa_erase() prevented
         * new work from being created.
         */
-       wait_event(imr->q_deferred_work, !atomic_read(&imr->num_deferred_work));
-
+       mlx5r_deref_wait_odp_mkey(&imr->mmkey);
        /*
         * At this point it is forbidden for any other thread to enter
         * pagefault_mr() on this imr. It is already forbidden to call
         * pagefault_mr() on an implicit child. Due to this additions to
         * implicit_children are prevented.
+        * In addition, any new call to destroy_unused_implicit_child_mr()
+        * may return immediately.
         */
 
-       /*
-        * Block destroy_unused_implicit_child_mr() from incrementing
-        * num_deferred_work.
-        */
-       xa_lock(&imr->implicit_children);
-       xa_for_each (&imr->implicit_children, idx, mtt) {
-               __xa_erase(&imr->implicit_children, idx);
-               list_add(&mtt->odp_destroy.elm, &destroy_list);
-       }
-       xa_unlock(&imr->implicit_children);
-
-       /*
-        * Wait for any concurrent destroy_unused_implicit_child_mr() to
-        * complete.
-        */
-       wait_event(imr->q_deferred_work, !atomic_read(&imr->num_deferred_work));
-
        /*
         * Fence the imr before we destroy the children. This allows us to
         * skip updating the XLT of the imr during destroy of the child mkey
         */
        mlx5_mr_cache_invalidate(imr);
 
-       list_for_each_entry_safe (mtt, tmp, &destroy_list, odp_destroy.elm)
+       xa_for_each(&imr->implicit_children, idx, mtt) {
+               xa_erase(&imr->implicit_children, idx);
                free_implicit_child_mr(mtt, false);
+       }
 
        mlx5_mr_cache_free(dev, imr);
        ib_umem_odp_release(odp_imr);
        xa_erase(&mr_to_mdev(mr)->odp_mkeys, mlx5_base_mkey(mr->mmkey.key));
 
        /* Wait for all running page-fault handlers to finish. */
-       synchronize_srcu(&mr_to_mdev(mr)->odp_srcu);
-
-       wait_event(mr->q_deferred_work, !atomic_read(&mr->num_deferred_work));
+       mlx5r_deref_wait_odp_mkey(&mr->mmkey);
 
        dma_fence_odp_mr(mr);
 }
        /* Prevent new page faults and prefetch requests from succeeding */
        xa_erase(&mr_to_mdev(mr)->odp_mkeys, mlx5_base_mkey(mr->mmkey.key));
 
-       /* Wait for all running page-fault handlers to finish. */
-       synchronize_srcu(&mr_to_mdev(mr)->odp_srcu);
-
-       wait_event(mr->q_deferred_work, !atomic_read(&mr->num_deferred_work));
+       mlx5r_deref_wait_odp_mkey(&mr->mmkey);
 
        dma_resv_lock(umem_dmabuf->attach->dmabuf->resv, NULL);
        mlx5_mr_cache_invalidate(mr);
                struct mlx5_ib_mr *mtt;
                u64 len;
 
+               xa_lock(&imr->implicit_children);
                mtt = xa_load(&imr->implicit_children, idx);
                if (unlikely(!mtt)) {
+                       xa_unlock(&imr->implicit_children);
                        mtt = implicit_get_child_mr(imr, idx);
                        if (IS_ERR(mtt)) {
                                ret = PTR_ERR(mtt);
                        }
                        upd_start_idx = min(upd_start_idx, idx);
                        upd_len = idx - upd_start_idx + 1;
+               } else {
+                       refcount_inc(&mtt->mmkey.usecount);
+                       xa_unlock(&imr->implicit_children);
                }
 
                umem_odp = to_ib_umem_odp(mtt->umem);
 
                ret = pagefault_real_mr(mtt, umem_odp, user_va, len,
                                        bytes_mapped, flags);
+
+               mlx5r_deref_odp_mkey(&mtt->mmkey);
+
                if (ret < 0)
                        goto out;
                user_va += len;
 {
        struct ib_umem_odp *odp = to_ib_umem_odp(mr->umem);
 
-       lockdep_assert_held(&mr_to_mdev(mr)->odp_srcu);
        if (unlikely(io_virt < mr->mmkey.iova))
                return -EFAULT;
 
                                         u32 *bytes_committed,
                                         u32 *bytes_mapped)
 {
-       int npages = 0, srcu_key, ret, i, outlen, cur_outlen = 0, depth = 0;
+       int npages = 0, ret, i, outlen, cur_outlen = 0, depth = 0;
        struct pf_frame *head = NULL, *frame;
        struct mlx5_core_mkey *mmkey;
        struct mlx5_ib_mr *mr;
        size_t offset;
        int ndescs;
 
-       srcu_key = srcu_read_lock(&dev->odp_srcu);
-
        io_virt += *bytes_committed;
        bcnt -= *bytes_committed;
 
 next_mr:
+       xa_lock(&dev->odp_mkeys);
        mmkey = xa_load(&dev->odp_mkeys, mlx5_base_mkey(key));
        if (!mmkey) {
+               xa_unlock(&dev->odp_mkeys);
                mlx5_ib_dbg(
                        dev,
                        "skipping non ODP MR (lkey=0x%06x) in page fault handler.\n",
                 * faulted.
                 */
                ret = 0;
-               goto srcu_unlock;
+               goto end;
        }
+       refcount_inc(&mmkey->usecount);
+       xa_unlock(&dev->odp_mkeys);
+
        if (!mkey_is_eq(mmkey, key)) {
                mlx5_ib_dbg(dev, "failed to find mkey %x\n", key);
                ret = -EFAULT;
-               goto srcu_unlock;
+               goto end;
        }
 
        switch (mmkey->type) {
 
                ret = pagefault_mr(mr, io_virt, bcnt, bytes_mapped, 0);
                if (ret < 0)
-                       goto srcu_unlock;
+                       goto end;
 
                mlx5_update_odp_stats(mr, faults, ret);
 
                if (depth >= MLX5_CAP_GEN(dev->mdev, max_indirection)) {
                        mlx5_ib_dbg(dev, "indirection level exceeded\n");
                        ret = -EFAULT;
-                       goto srcu_unlock;
+                       goto end;
                }
 
                outlen = MLX5_ST_SZ_BYTES(query_mkey_out) +
                        out = kzalloc(outlen, GFP_KERNEL);
                        if (!out) {
                                ret = -ENOMEM;
-                               goto srcu_unlock;
+                               goto end;
                        }
                        cur_outlen = outlen;
                }
 
                ret = mlx5_core_query_mkey(dev->mdev, mmkey, out, outlen);
                if (ret)
-                       goto srcu_unlock;
+                       goto end;
 
                offset = io_virt - MLX5_GET64(query_mkey_out, out,
                                              memory_key_mkey_entry.start_addr);
                        frame = kzalloc(sizeof(*frame), GFP_KERNEL);
                        if (!frame) {
                                ret = -ENOMEM;
-                               goto srcu_unlock;
+                               goto end;
                        }
 
                        frame->key = be32_to_cpu(pklm->key);
        default:
                mlx5_ib_dbg(dev, "wrong mkey type %d\n", mmkey->type);
                ret = -EFAULT;
-               goto srcu_unlock;
+               goto end;
        }
 
        if (head) {
                depth = frame->depth;
                kfree(frame);
 
+               mlx5r_deref_odp_mkey(mmkey);
                goto next_mr;
        }
 
-srcu_unlock:
+end:
+       if (mmkey)
+               mlx5r_deref_odp_mkey(mmkey);
        while (head) {
                frame = head;
                head = frame->next;
        }
        kfree(out);
 
-       srcu_read_unlock(&dev->odp_srcu, srcu_key);
        *bytes_committed = 0;
        return ret ? ret : npages;
 }
        u32 i;
 
        for (i = 0; i < work->num_sge; ++i)
-               if (atomic_dec_and_test(&work->frags[i].mr->num_deferred_work))
-                       wake_up(&work->frags[i].mr->q_deferred_work);
+               mlx5r_deref_odp_mkey(&work->frags[i].mr->mmkey);
+
        kvfree(work);
 }
 
 {
        struct mlx5_ib_dev *dev = to_mdev(pd->device);
        struct mlx5_core_mkey *mmkey;
-       struct mlx5_ib_mr *mr;
-
-       lockdep_assert_held(&dev->odp_srcu);
+       struct mlx5_ib_mr *mr = NULL;
 
+       xa_lock(&dev->odp_mkeys);
        mmkey = xa_load(&dev->odp_mkeys, mlx5_base_mkey(lkey));
        if (!mmkey || mmkey->key != lkey || mmkey->type != MLX5_MKEY_MR)
-               return NULL;
+               goto end;
 
        mr = container_of(mmkey, struct mlx5_ib_mr, mmkey);
 
-       if (mr->ibmr.pd != pd)
-               return NULL;
+       if (mr->ibmr.pd != pd) {
+               mr = NULL;
+               goto end;
+       }
 
        /* prefetch with write-access must be supported by the MR */
        if (advice == IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_WRITE &&
-           !mr->umem->writable)
-               return NULL;
+           !mr->umem->writable) {
+               mr = NULL;
+               goto end;
+       }
 
+       refcount_inc(&mmkey->usecount);
+end:
+       xa_unlock(&dev->odp_mkeys);
        return mr;
 }
 
 {
        struct prefetch_mr_work *work =
                container_of(w, struct prefetch_mr_work, work);
-       struct mlx5_ib_dev *dev;
        u32 bytes_mapped = 0;
-       int srcu_key;
        int ret;
        u32 i;
 
        /* We rely on IB/core that work is executed if we have num_sge != 0 only. */
        WARN_ON(!work->num_sge);
-       dev = mr_to_mdev(work->frags[0].mr);
-       /* SRCU should be held when calling to mlx5_odp_populate_xlt() */
-       srcu_key = srcu_read_lock(&dev->odp_srcu);
        for (i = 0; i < work->num_sge; ++i) {
                ret = pagefault_mr(work->frags[i].mr, work->frags[i].io_virt,
                                   work->frags[i].length, &bytes_mapped,
                        continue;
                mlx5_update_odp_stats(work->frags[i].mr, prefetch, ret);
        }
-       srcu_read_unlock(&dev->odp_srcu, srcu_key);
 
        destroy_prefetch_work(work);
 }
                        work->num_sge = i;
                        return false;
                }
-
-               /* Keep the MR pointer will valid outside the SRCU */
-               atomic_inc(&work->frags[i].mr->num_deferred_work);
        }
        work->num_sge = num_sge;
        return true;
                                    u32 pf_flags, struct ib_sge *sg_list,
                                    u32 num_sge)
 {
-       struct mlx5_ib_dev *dev = to_mdev(pd->device);
        u32 bytes_mapped = 0;
-       int srcu_key;
        int ret = 0;
        u32 i;
 
-       srcu_key = srcu_read_lock(&dev->odp_srcu);
        for (i = 0; i < num_sge; ++i) {
                struct mlx5_ib_mr *mr;
 
                mr = get_prefetchable_mr(pd, advice, sg_list[i].lkey);
-               if (!mr) {
-                       ret = -ENOENT;
-                       goto out;
-               }
+               if (!mr)
+                       return -ENOENT;
                ret = pagefault_mr(mr, sg_list[i].addr, sg_list[i].length,
                                   &bytes_mapped, pf_flags);
-               if (ret < 0)
-                       goto out;
+               if (ret < 0) {
+                       mlx5r_deref_odp_mkey(&mr->mmkey);
+                       return ret;
+               }
                mlx5_update_odp_stats(mr, prefetch, ret);
+               mlx5r_deref_odp_mkey(&mr->mmkey);
        }
-       ret = 0;
 
-out:
-       srcu_read_unlock(&dev->odp_srcu, srcu_key);
-       return ret;
+       return 0;
 }
 
 int mlx5_ib_advise_mr_prefetch(struct ib_pd *pd,
                               enum ib_uverbs_advise_mr_advice advice,
                               u32 flags, struct ib_sge *sg_list, u32 num_sge)
 {
-       struct mlx5_ib_dev *dev = to_mdev(pd->device);
        u32 pf_flags = 0;
        struct prefetch_mr_work *work;
-       int srcu_key;
 
        if (advice == IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH)
                pf_flags |= MLX5_PF_FLAGS_DOWNGRADE;
        if (!work)
                return -ENOMEM;
 
-       srcu_key = srcu_read_lock(&dev->odp_srcu);
        if (!init_prefetch_work(pd, advice, pf_flags, work, sg_list, num_sge)) {
-               srcu_read_unlock(&dev->odp_srcu, srcu_key);
                destroy_prefetch_work(work);
                return -EINVAL;
        }
        queue_work(system_unbound_wq, &work->work);
-       srcu_read_unlock(&dev->odp_srcu, srcu_key);
        return 0;
 }