struct ib_udata *udata,
                                             int access_flags);
 void mlx5_ib_free_implicit_mr(struct mlx5_ib_mr *mr);
+void mlx5_ib_fence_odp_mr(struct mlx5_ib_mr *mr);
 int mlx5_ib_rereg_user_mr(struct ib_mr *ib_mr, int flags, u64 start,
                          u64 length, u64 virt_addr, int access_flags,
                          struct ib_pd *pd, struct ib_udata *udata);
 
 struct mlx5_ib_mr *mlx5_mr_cache_alloc(struct mlx5_ib_dev *dev, int entry);
 void mlx5_mr_cache_free(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr);
+int mlx5_mr_cache_invalidate(struct mlx5_ib_mr *mr);
+
 int mlx5_ib_check_mr_status(struct ib_mr *ibmr, u32 check_mask,
                            struct ib_mr_status *mr_status);
 struct ib_wq *mlx5_ib_create_wq(struct ib_pd *pd,
 
 static void clean_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr);
 static void dereg_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr);
 static int mr_cache_max_order(struct mlx5_ib_dev *dev);
-static int unreg_umr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr);
 
 static bool umr_can_use_indirect_mkey(struct mlx5_ib_dev *dev)
 {
        c = order2idx(dev, mr->order);
        WARN_ON(c < 0 || c >= MAX_MR_CACHE_ENTRIES);
 
-       if (unreg_umr(dev, mr)) {
+       if (mlx5_mr_cache_invalidate(mr)) {
                mr->allocated_from_cache = false;
                destroy_mkey(dev, mr);
                ent = &cache->ent[c];
        return ERR_PTR(err);
 }
 
-static int unreg_umr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr)
+/**
+ * mlx5_mr_cache_invalidate - Fence all DMA on the MR
+ * @mr: The MR to fence
+ *
+ * Upon return the NIC will not be doing any DMA to the pages under the MR,
+ * and any DMA inprogress will be completed. Failure of this function
+ * indicates the HW has failed catastrophically.
+ */
+int mlx5_mr_cache_invalidate(struct mlx5_ib_mr *mr)
 {
-       struct mlx5_core_dev *mdev = dev->mdev;
        struct mlx5_umr_wr umrwr = {};
 
-       if (mdev->state == MLX5_DEVICE_STATE_INTERNAL_ERROR)
+       if (mr->dev->mdev->state == MLX5_DEVICE_STATE_INTERNAL_ERROR)
                return 0;
 
        umrwr.wr.send_flags = MLX5_IB_SEND_UMR_DISABLE_MR |
                              MLX5_IB_SEND_UMR_UPDATE_PD_ACCESS;
        umrwr.wr.opcode = MLX5_IB_WR_UMR;
-       umrwr.pd = dev->umrc.pd;
+       umrwr.pd = mr->dev->umrc.pd;
        umrwr.mkey = mr->mmkey.key;
        umrwr.ignore_free_state = 1;
 
-       return mlx5_ib_post_send_wait(dev, &umrwr);
+       return mlx5_ib_post_send_wait(mr->dev, &umrwr);
 }
 
 static int rereg_umr(struct ib_pd *pd, struct mlx5_ib_mr *mr,
                 * UMR can't be used - MKey needs to be replaced.
                 */
                if (mr->allocated_from_cache)
-                       err = unreg_umr(dev, mr);
+                       err = mlx5_mr_cache_invalidate(mr);
                else
                        err = destroy_mkey(dev, mr);
                if (err)
        int npages = mr->npages;
        struct ib_umem *umem = mr->umem;
 
-       if (is_odp_mr(mr)) {
-               struct ib_umem_odp *umem_odp = to_ib_umem_odp(umem);
-
-               /* Prevent new page faults and
-                * prefetch requests from succeeding
-                */
-               xa_erase(&dev->odp_mkeys, mlx5_base_mkey(mr->mmkey.key));
-
-               /* Wait for all running page-fault handlers to finish. */
-               synchronize_srcu(&dev->odp_srcu);
-
-               /* dequeue pending prefetch requests for the mr */
-               if (atomic_read(&mr->num_deferred_work)) {
-                       flush_workqueue(system_unbound_wq);
-                       WARN_ON(atomic_read(&mr->num_deferred_work));
-               }
-
-               /* Destroy all page mappings */
-               mlx5_ib_invalidate_range(umem_odp, ib_umem_start(umem_odp),
-                                        ib_umem_end(umem_odp));
-
-               /*
-                * We kill the umem before the MR for ODP,
-                * so that there will not be any invalidations in
-                * flight, looking at the *mr struct.
-                */
-               ib_umem_odp_release(umem_odp);
-               atomic_sub(npages, &dev->mdev->priv.reg_pages);
-
-               /* Avoid double-freeing the umem. */
-               umem = NULL;
-       }
+       /* Stop all DMA */
+       if (is_odp_mr(mr))
+               mlx5_ib_fence_odp_mr(mr);
+       else
+               clean_mr(dev, mr);
 
-       clean_mr(dev, mr);
+       if (mr->allocated_from_cache)
+               mlx5_mr_cache_free(dev, mr);
+       else
+               kfree(mr);
 
-       /*
-        * We should unregister the DMA address from the HCA before
-        * remove the DMA mapping.
-        */
-       mlx5_mr_cache_free(dev, mr);
        ib_umem_release(umem);
-       if (umem)
-               atomic_sub(npages, &dev->mdev->priv.reg_pages);
+       atomic_sub(npages, &dev->mdev->priv.reg_pages);
 
-       if (!mr->allocated_from_cache)
-               kfree(mr);
 }
 
 int mlx5_ib_dereg_mr(struct ib_mr *ibmr, struct ib_udata *udata)
 
        }
 }
 
+static void dma_fence_odp_mr(struct mlx5_ib_mr *mr)
+{
+       struct ib_umem_odp *odp = to_ib_umem_odp(mr->umem);
+
+       /* Ensure mlx5_ib_invalidate_range() will not touch the MR any more */
+       mutex_lock(&odp->umem_mutex);
+       if (odp->npages) {
+               mlx5_mr_cache_invalidate(mr);
+               ib_umem_odp_unmap_dma_pages(odp, ib_umem_start(odp),
+                                           ib_umem_end(odp));
+               WARN_ON(odp->npages);
+       }
+       odp->private = NULL;
+       mutex_unlock(&odp->umem_mutex);
+
+       if (!mr->allocated_from_cache) {
+               mlx5_core_destroy_mkey(mr->dev->mdev, &mr->mmkey);
+               WARN_ON(mr->descs);
+       }
+}
+
 /*
  * This must be called after the mr has been removed from implicit_children
  * and the SRCU synchronized.  NOTE: The MR does not necessarily have to be
                srcu_read_unlock(&mr->dev->odp_srcu, srcu_key);
        }
 
+       dma_fence_odp_mr(mr);
+
        mr->parent = NULL;
        mlx5_mr_cache_free(mr->dev, mr);
        ib_umem_odp_release(odp);
        int in_block = 0;
        u64 addr;
 
-       if (!umem_odp) {
-               pr_err("invalidation called on NULL umem or non-ODP umem\n");
-               return;
-       }
-
+       mutex_lock(&umem_odp->umem_mutex);
+       /*
+        * If npages is zero then umem_odp->private may not be setup yet. This
+        * does not complete until after the first page is mapped for DMA.
+        */
+       if (!umem_odp->npages)
+               goto out;
        mr = umem_odp->private;
 
-       if (!mr || !mr->ibmr.pd)
-               return;
-
        start = max_t(u64, ib_umem_start(umem_odp), start);
        end = min_t(u64, ib_umem_end(umem_odp), end);
 
         * overwrite the same MTTs.  Concurent invalidations might race us,
         * but they will write 0s as well, so no difference in the end result.
         */
-       mutex_lock(&umem_odp->umem_mutex);
        for (addr = start; addr < end; addr += BIT(umem_odp->page_shift)) {
                idx = (addr - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
                /*
 
        if (unlikely(!umem_odp->npages && mr->parent))
                destroy_unused_implicit_child_mr(mr);
+out:
        mutex_unlock(&umem_odp->umem_mutex);
 }
 
                WARN_ON(atomic_read(&imr->num_deferred_work));
        }
 
+       /*
+        * Fence the imr before we destroy the children. This allows us to
+        * skip updating the XLT of the imr during destroy of the child mkey
+        * the imr points to.
+        */
+       mlx5_mr_cache_invalidate(imr);
+
        list_for_each_entry_safe (mtt, tmp, &destroy_list, odp_destroy.elm)
                free_implicit_child_mr(mtt, false);
 
        ib_umem_odp_release(odp_imr);
 }
 
+/**
+ * mlx5_ib_fence_odp_mr - Stop all access to the ODP MR
+ * @mr: to fence
+ *
+ * On return no parallel threads will be touching this MR and no DMA will be
+ * active.
+ */
+void mlx5_ib_fence_odp_mr(struct mlx5_ib_mr *mr)
+{
+       /* Prevent new page faults and prefetch requests from succeeding */
+       xa_erase(&mr->dev->odp_mkeys, mlx5_base_mkey(mr->mmkey.key));
+
+       /* Wait for all running page-fault handlers to finish. */
+       synchronize_srcu(&mr->dev->odp_srcu);
+
+       if (atomic_read(&mr->num_deferred_work)) {
+               flush_workqueue(system_unbound_wq);
+               WARN_ON(atomic_read(&mr->num_deferred_work));
+       }
+
+       dma_fence_odp_mr(mr);
+}
+
 #define MLX5_PF_FLAGS_DOWNGRADE BIT(1)
 static int pagefault_real_mr(struct mlx5_ib_mr *mr, struct ib_umem_odp *odp,
                             u64 user_va, size_t bcnt, u32 *bytes_mapped,