*/
 
        for (addr = start; addr < end; addr += BIT(umem->page_shift)) {
-               idx = (addr - ib_umem_start(umem)) / PAGE_SIZE;
+               idx = (addr - ib_umem_start(umem)) >> umem->page_shift;
                /*
                 * Strive to write the MTTs in chunks, but avoid overwriting
                 * non-existing MTTs. The huristic here can be improved to
 
                        if (in_block && umr_offset == 0) {
                                mlx5_ib_update_xlt(mr, blk_start_idx,
-                                                  idx - blk_start_idx,
-                                                  PAGE_SHIFT,
+                                                  idx - blk_start_idx, 0,
                                                   MLX5_IB_UPD_XLT_ZAP |
                                                   MLX5_IB_UPD_XLT_ATOMIC);
                                in_block = 0;
        }
        if (in_block)
                mlx5_ib_update_xlt(mr, blk_start_idx,
-                                  idx - blk_start_idx + 1,
-                                  PAGE_SHIFT,
+                                  idx - blk_start_idx + 1, 0,
                                   MLX5_IB_UPD_XLT_ZAP |
                                   MLX5_IB_UPD_XLT_ATOMIC);
        /*
 /*
  * Handle a single data segment in a page-fault WQE or RDMA region.
  *
- * Returns number of pages retrieved on success. The caller may continue to
+ * Returns number of OS pages retrieved on success. The caller may continue to
  * the next data segment.
  * Can return the following error codes:
  * -EAGAIN to designate a temporary error. The caller will abort handling the
 {
        int srcu_key;
        unsigned int current_seq = 0;
-       u64 start_idx;
+       u64 start_idx, page_mask;
        int npages = 0, ret = 0;
        struct mlx5_ib_mr *mr;
        u64 access_mask = ODP_READ_ALLOWED_BIT;
        struct ib_umem_odp *odp;
        int implicit = 0;
        size_t size;
+       int page_shift;
 
        srcu_key = srcu_read_lock(&dev->mr_srcu);
        mr = mlx5_ib_odp_find_mr_lkey(dev, key);
                odp = mr->umem->odp_data;
        }
 
+       page_shift = mr->umem->page_shift;
+       page_mask = ~(BIT(page_shift) - 1);
+
 next_mr:
        current_seq = READ_ONCE(odp->notifiers_seq);
        /*
        smp_rmb();
 
        size = min_t(size_t, bcnt, ib_umem_end(odp->umem) - io_virt);
-       start_idx = (io_virt - (mr->mmkey.iova & PAGE_MASK)) >> PAGE_SHIFT;
+       start_idx = (io_virt - (mr->mmkey.iova & page_mask)) >> page_shift;
 
        if (mr->umem->writable)
                access_mask |= ODP_WRITE_ALLOWED_BIT;
                         * checks this.
                         */
                        ret = mlx5_ib_update_xlt(mr, start_idx, np,
-                                                PAGE_SHIFT,
+                                                page_shift,
                                                 MLX5_IB_UPD_XLT_ATOMIC);
                } else {
                        ret = -EAGAIN;
                                mlx5_ib_err(dev, "Failed to update mkey page tables\n");
                        goto srcu_unlock;
                }
-
                if (bytes_mapped) {
-                       u32 new_mappings = np * PAGE_SIZE -
-                               (io_virt - round_down(io_virt, PAGE_SIZE));
+                       u32 new_mappings = (np << page_shift) -
+                               (io_virt - round_down(io_virt,
+                                                     1 << page_shift));
                        *bytes_mapped += min_t(u32, new_mappings, size);
                }
 
-               npages += np;
+               npages += np << (page_shift - PAGE_SHIFT);
        }
 
        bcnt -= size;