}
 
 #define MLX5_PF_FLAGS_DOWNGRADE BIT(1)
-static int pagefault_mr(struct mlx5_ib_mr *mr, u64 io_virt, size_t bcnt,
-                       u32 *bytes_mapped, u32 flags)
+static int pagefault_real_mr(struct mlx5_ib_mr *mr, struct ib_umem_odp *odp,
+                            u64 user_va, size_t bcnt, u32 *bytes_mapped,
+                            u32 flags)
 {
-       int npages = 0, current_seq, page_shift, ret, np;
-       struct ib_umem_odp *odp_mr = to_ib_umem_odp(mr->umem);
+       int current_seq, page_shift, ret, np;
        bool downgrade = flags & MLX5_PF_FLAGS_DOWNGRADE;
        u64 access_mask;
        u64 start_idx, page_mask;
-       struct ib_umem_odp *odp;
-       size_t size;
-
-       if (odp_mr->is_implicit_odp) {
-               odp = implicit_mr_get_data(mr, io_virt, bcnt);
-
-               if (IS_ERR(odp))
-                       return PTR_ERR(odp);
-               mr = odp->private;
-       } else {
-               odp = odp_mr;
-       }
-
-next_mr:
-       size = min_t(size_t, bcnt, ib_umem_end(odp) - io_virt);
 
        page_shift = odp->page_shift;
        page_mask = ~(BIT(page_shift) - 1);
-       start_idx = (io_virt - (mr->mmkey.iova & page_mask)) >> page_shift;
+       start_idx = (user_va - (mr->mmkey.iova & page_mask)) >> page_shift;
        access_mask = ODP_READ_ALLOWED_BIT;
 
        if (odp->umem.writable && !downgrade)
         */
        smp_rmb();
 
-       ret = ib_umem_odp_map_dma_pages(odp, io_virt, size, access_mask,
-                                       current_seq);
-
-       if (ret < 0)
-               goto out;
-
-       np = ret;
+       np = ib_umem_odp_map_dma_pages(odp, user_va, bcnt, access_mask,
+                                      current_seq);
+       if (np < 0)
+               return np;
 
        mutex_lock(&odp->umem_mutex);
        if (!ib_umem_mmu_notifier_retry(odp, current_seq)) {
 
        if (bytes_mapped) {
                u32 new_mappings = (np << page_shift) -
-                       (io_virt - round_down(io_virt, 1 << page_shift));
-               *bytes_mapped += min_t(u32, new_mappings, size);
-       }
-
-       npages += np << (page_shift - PAGE_SHIFT);
-       bcnt -= size;
+                       (user_va - round_down(user_va, 1 << page_shift));
 
-       if (unlikely(bcnt)) {
-               struct ib_umem_odp *next;
-
-               io_virt += size;
-               next = odp_next(odp);
-               if (unlikely(!next || ib_umem_start(next) != io_virt)) {
-                       mlx5_ib_dbg(
-                               mr->dev,
-                               "next implicit leaf removed at 0x%llx. got %p\n",
-                               io_virt, next);
-                       return -EAGAIN;
-               }
-               odp = next;
-               mr = odp->private;
-               goto next_mr;
+               *bytes_mapped += min_t(u32, new_mappings, bcnt);
        }
 
-       return npages;
+       return np << (page_shift - PAGE_SHIFT);
 
 out:
        if (ret == -EAGAIN) {
        return ret;
 }
 
+/*
+ * Returns:
+ *  -EFAULT: The io_virt->bcnt is not within the MR, it covers pages that are
+ *           not accessible, or the MR is no longer valid.
+ *  -EAGAIN/-ENOMEM: The operation should be retried
+ *
+ *  -EINVAL/others: General internal malfunction
+ *  >0: Number of pages mapped
+ */
+static int pagefault_mr(struct mlx5_ib_mr *mr, u64 io_virt, size_t bcnt,
+                       u32 *bytes_mapped, u32 flags)
+{
+       struct ib_umem_odp *odp = to_ib_umem_odp(mr->umem);
+       struct ib_umem_odp *child;
+       int npages = 0;
+
+       if (!odp->is_implicit_odp) {
+               if (unlikely(io_virt < ib_umem_start(odp) ||
+                            ib_umem_end(odp) - io_virt < bcnt))
+                       return -EFAULT;
+               return pagefault_real_mr(mr, odp, io_virt, bcnt, bytes_mapped,
+                                        flags);
+       }
+
+       if (unlikely(io_virt >= mlx5_imr_ksm_entries * MLX5_IMR_MTT_SIZE ||
+                    mlx5_imr_ksm_entries * MLX5_IMR_MTT_SIZE - io_virt < bcnt))
+               return -EFAULT;
+
+       child = implicit_mr_get_data(mr, io_virt, bcnt);
+       if (IS_ERR(child))
+               return PTR_ERR(child);
+
+       /* Fault each child mr that intersects with our interval. */
+       while (bcnt) {
+               u64 end = min_t(u64, io_virt + bcnt, ib_umem_end(child));
+               u64 len = end - io_virt;
+               int ret;
+
+               ret = pagefault_real_mr(child->private, child, io_virt, len,
+                                       bytes_mapped, flags);
+               if (ret < 0)
+                       return ret;
+               io_virt += len;
+               bcnt -= len;
+               npages += ret;
+
+               if (unlikely(bcnt)) {
+                       child = odp_next(child);
+                       /*
+                        * implicit_mr_get_data sets up all the leaves, this
+                        * means they got invalidated before we got to them.
+                        */
+                       if (!child || ib_umem_start(child) != io_virt) {
+                               mlx5_ib_dbg(
+                                       mr->dev,
+                                       "next implicit leaf removed at 0x%llx.\n",
+                                       io_virt);
+                               return -EAGAIN;
+                       }
+               }
+       }
+       return npages;
+}
+
 struct pf_frame {
        struct pf_frame *next;
        u32 key;