MLX5_IMR_MTT_ENTRIES,
                                 PAGE_SHIFT,
                                 MLX5_IB_UPD_XLT_ZAP |
-                                MLX5_IB_UPD_XLT_ENABLE |
-                                MLX5_IB_UPD_XLT_ATOMIC);
+                                MLX5_IB_UPD_XLT_ENABLE);
        if (err) {
                ret = ERR_PTR(err);
                goto out_release;
         * Once the store to either xarray completes any error unwind has to
         * use synchronize_srcu(). Avoid this with xa_reserve()
         */
-       err = xa_err(xa_store(&imr->implicit_children, idx, mr, GFP_KERNEL));
-       if (err) {
-               ret = ERR_PTR(err);
+       ret = xa_cmpxchg(&imr->implicit_children, idx, NULL, mr, GFP_KERNEL);
+       if (unlikely(ret)) {
+               if (xa_is_err(ret)) {
+                       ret = ERR_PTR(xa_err(ret));
+                       goto out_release;
+               }
+               /*
+                * Another thread beat us to creating the child mr, use
+                * theirs.
+                */
                goto out_release;
        }
 
        struct mlx5_ib_mr *result = NULL;
        int ret;
 
-       mutex_lock(&odp_imr->umem_mutex);
+       lockdep_assert_held(&imr->dev->odp_srcu);
+
        for (idx = idx; idx <= end_idx; idx++) {
                struct mlx5_ib_mr *mtt = xa_load(&imr->implicit_children, idx);
 
         */
 out:
        if (likely(!inv_len))
-               goto out_unlock;
+               return result;
 
+       /*
+        * Notice this is not strictly ordered right, the KSM is updated after
+        * the implicit_leaves is updated, so a parallel page fault could see
+        * a MR that is not yet visible in the KSM.  This is similar to a
+        * parallel page fault seeing a MR that is being concurrently removed
+        * from the KSM. Both of these improbable situations are resolved
+        * safely by resuming the HW and then taking another page fault. The
+        * next pagefault handler will see the new information.
+        */
+       mutex_lock(&odp_imr->umem_mutex);
        ret = mlx5_ib_update_xlt(imr, inv_start_idx, inv_len, 0,
                                 MLX5_IB_UPD_XLT_INDIRECT |
                                         MLX5_IB_UPD_XLT_ATOMIC);
+       mutex_unlock(&odp_imr->umem_mutex);
        if (ret) {
                mlx5_ib_err(to_mdev(imr->ibmr.pd->device),
                            "Failed to update PAS\n");
-               result = ERR_PTR(ret);
-               goto out_unlock;
+               return ERR_PTR(ret);
        }
-
-out_unlock:
-       mutex_unlock(&odp_imr->umem_mutex);
        return result;
 }