#include <rdma/ib_umem.h>
 #include <rdma/ib_umem_odp.h>
 
+#include "uverbs.h"
+
 static void ib_umem_notifier_start_account(struct ib_umem_odp *umem_odp)
 {
        mutex_lock(&umem_odp->umem_mutex);
        return ret;
 }
 
-struct ib_umem_odp *ib_alloc_odp_umem(struct ib_umem_odp *root,
-                                     unsigned long addr, size_t size)
+/**
+ * ib_umem_odp_alloc_implicit - Allocate a parent implicit ODP umem
+ *
+ * Implicit ODP umems do not have a VA range and do not have any page lists.
+ * They exist only to hold the per_mm reference to help the driver create
+ * children umems.
+ *
+ * @udata: udata from the syscall being used to create the umem
+ * @access: ib_reg_mr access flags
+ */
+struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_udata *udata,
+                                              int access)
+{
+       struct ib_ucontext *context =
+               container_of(udata, struct uverbs_attr_bundle, driver_udata)
+                       ->context;
+       struct ib_umem *umem;
+       struct ib_umem_odp *umem_odp;
+       int ret;
+
+       if (access & IB_ACCESS_HUGETLB)
+               return ERR_PTR(-EINVAL);
+
+       if (!context)
+               return ERR_PTR(-EIO);
+       if (WARN_ON_ONCE(!context->invalidate_range))
+               return ERR_PTR(-EINVAL);
+
+       umem_odp = kzalloc(sizeof(*umem_odp), GFP_KERNEL);
+       if (!umem_odp)
+               return ERR_PTR(-ENOMEM);
+       umem = &umem_odp->umem;
+       umem->context = context;
+       umem->writable = ib_access_writable(access);
+       umem->owning_mm = current->mm;
+       umem_odp->is_implicit_odp = 1;
+       umem_odp->page_shift = PAGE_SHIFT;
+
+       ret = ib_init_umem_odp(umem_odp, NULL);
+       if (ret) {
+               kfree(umem_odp);
+               return ERR_PTR(ret);
+       }
+
+       mmgrab(umem->owning_mm);
+
+       return umem_odp;
+}
+EXPORT_SYMBOL(ib_umem_odp_alloc_implicit);
+
+/**
+ * ib_umem_odp_alloc_child - Allocate a child ODP umem under an implicit
+ *                           parent ODP umem
+ *
+ * @root: The parent umem enclosing the child. This must be allocated using
+ *        ib_alloc_implicit_odp_umem()
+ * @addr: The starting userspace VA
+ * @size: The length of the userspace VA
+ */
+struct ib_umem_odp *ib_umem_odp_alloc_child(struct ib_umem_odp *root,
+                                           unsigned long addr, size_t size)
 {
        /*
         * Caller must ensure that root cannot be freed during the call to
        struct ib_umem *umem;
        int ret;
 
+       if (WARN_ON(!root->is_implicit_odp))
+               return ERR_PTR(-EINVAL);
+
        odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
        if (!odp_data)
                return ERR_PTR(-ENOMEM);
 
        return odp_data;
 }
-EXPORT_SYMBOL(ib_alloc_odp_umem);
+EXPORT_SYMBOL(ib_umem_odp_alloc_child);
 
+/**
+ * ib_umem_odp_get - Complete ib_umem_get()
+ *
+ * @umem_odp: The partially configured umem from ib_umem_get()
+ * @addr: The starting userspace VA
+ * @access: ib_reg_mr access flags
+ */
 int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
 {
        /*
         */
        struct mm_struct *mm = umem_odp->umem.owning_mm;
 
-       if (umem_odp->umem.address == 0 && umem_odp->umem.length == 0)
-               umem_odp->is_implicit_odp = 1;
-
        umem_odp->page_shift = PAGE_SHIFT;
        if (access & IB_ACCESS_HUGETLB) {
                struct vm_area_struct *vma;
 
 }
 
 static struct mlx5_ib_mr *implicit_mr_alloc(struct ib_pd *pd,
-                                           struct ib_umem *umem,
+                                           struct ib_umem_odp *umem_odp,
                                            bool ksm, int access_flags)
 {
        struct mlx5_ib_dev *dev = to_mdev(pd->device);
        mr->dev = dev;
        mr->access_flags = access_flags;
        mr->mmkey.iova = 0;
-       mr->umem = umem;
+       mr->umem = &umem_odp->umem;
 
        if (ksm) {
                err = mlx5_ib_update_xlt(mr, 0,
                if (nentries)
                        nentries++;
        } else {
-               odp = ib_alloc_odp_umem(odp_mr, addr,
-                                       MLX5_IMR_MTT_SIZE);
+               odp = ib_umem_odp_alloc_child(odp_mr, addr, MLX5_IMR_MTT_SIZE);
                if (IS_ERR(odp)) {
                        mutex_unlock(&odp_mr->umem_mutex);
                        return ERR_CAST(odp);
                }
 
-               mtt = implicit_mr_alloc(mr->ibmr.pd, &odp->umem, 0,
+               mtt = implicit_mr_alloc(mr->ibmr.pd, odp, 0,
                                        mr->access_flags);
                if (IS_ERR(mtt)) {
                        mutex_unlock(&odp_mr->umem_mutex);
                                             int access_flags)
 {
        struct mlx5_ib_mr *imr;
-       struct ib_umem *umem;
+       struct ib_umem_odp *umem_odp;
 
-       umem = ib_umem_get(udata, 0, 0, access_flags, 0);
-       if (IS_ERR(umem))
-               return ERR_CAST(umem);
+       umem_odp = ib_umem_odp_alloc_implicit(udata, access_flags);
+       if (IS_ERR(umem_odp))
+               return ERR_CAST(umem_odp);
 
-       imr = implicit_mr_alloc(&pd->ibpd, umem, 1, access_flags);
+       imr = implicit_mr_alloc(&pd->ibpd, umem_odp, 1, access_flags);
        if (IS_ERR(imr)) {
-               ib_umem_release(umem);
+               ib_umem_release(&umem_odp->umem);
                return ERR_CAST(imr);
        }
 
-       imr->umem = umem;
+       imr->umem = &umem_odp->umem;
        init_waitqueue_head(&imr->q_leaf_free);
        atomic_set(&imr->num_leaf_free, 0);
        atomic_set(&imr->num_pending_prefetch, 0);
 
 };
 
 int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access);
-struct ib_umem_odp *ib_alloc_odp_umem(struct ib_umem_odp *root_umem,
-                                     unsigned long addr, size_t size);
+struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_udata *udata,
+                                              int access);
+struct ib_umem_odp *ib_umem_odp_alloc_child(struct ib_umem_odp *root_umem,
+                                           unsigned long addr, size_t size);
 void ib_umem_odp_release(struct ib_umem_odp *umem_odp);
 
 int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 start_offset,