*/
 
 #include <linux/vmalloc.h>
+#include <linux/count_zeros.h>
 #include <rdma/ib_umem.h>
 #include <linux/math.h>
 #include "hns_roce_device.h"
        buf_attr.user_access = mr->access;
        /* fast MR's buffer is alloced before mapping, not at creation */
        buf_attr.mtt_only = is_fast;
+       buf_attr.iova = mr->iova;
+       /* pagesize and hopnum is fixed for fast MR */
+       buf_attr.adaptive = !is_fast;
+       buf_attr.type = MTR_PBL;
 
        err = hns_roce_mtr_create(hr_dev, &mr->pbl_mtr, &buf_attr,
                                  hr_dev->caps.pbl_ba_pg_sz + PAGE_SHIFT,
        return page_cnt;
 }
 
+static bool need_split_huge_page(struct hns_roce_mtr *mtr)
+{
+       /* When HEM buffer uses 0-level addressing, the page size is
+        * equal to the whole buffer size. If the current MTR has multiple
+        * regions, we split the buffer into small pages(4k, required by hns
+        * ROCEE). These pages will be used in multiple regions.
+        */
+       return mtr->hem_cfg.is_direct && mtr->hem_cfg.region_count > 1;
+}
+
 static int mtr_map_bufs(struct hns_roce_dev *hr_dev, struct hns_roce_mtr *mtr)
 {
        struct ib_device *ibdev = &hr_dev->ib_dev;
        int npage;
        int ret;
 
-       /* When HEM buffer uses 0-level addressing, the page size is
-        * equal to the whole buffer size, and we split the buffer into
-        * small pages which is used to check whether the adjacent
-        * units are in the continuous space and its size is fixed to
-        * 4K based on hns ROCEE's requirement.
-        */
-       page_shift = mtr->hem_cfg.is_direct ? HNS_HW_PAGE_SHIFT :
-                                             mtr->hem_cfg.buf_pg_shift;
+       page_shift = need_split_huge_page(mtr) ? HNS_HW_PAGE_SHIFT :
+                                                mtr->hem_cfg.buf_pg_shift;
        /* alloc a tmp array to store buffer's dma address */
        pages = kvcalloc(page_count, sizeof(dma_addr_t), GFP_KERNEL);
        if (!pages)
                goto err_alloc_list;
        }
 
-       if (mtr->hem_cfg.is_direct && npage > 1) {
+       if (need_split_huge_page(mtr) && npage > 1) {
                ret = mtr_check_direct_pages(pages, npage, page_shift);
                if (ret) {
                        ibdev_err(ibdev, "failed to check %s page: %d / %d.\n",
        return ret;
 }
 
-static int mtr_init_buf_cfg(struct hns_roce_dev *hr_dev,
-                           struct hns_roce_buf_attr *attr,
-                           struct hns_roce_hem_cfg *cfg, u64 unalinged_size)
+static int get_best_page_shift(struct hns_roce_dev *hr_dev,
+                              struct hns_roce_mtr *mtr,
+                              struct hns_roce_buf_attr *buf_attr)
+{
+       unsigned int page_sz;
+
+       if (!buf_attr->adaptive || buf_attr->type != MTR_PBL || !mtr->umem)
+               return 0;
+
+       page_sz = ib_umem_find_best_pgsz(mtr->umem,
+                                        hr_dev->caps.page_size_cap,
+                                        buf_attr->iova);
+       if (!page_sz)
+               return -EINVAL;
+
+       buf_attr->page_shift = order_base_2(page_sz);
+       return 0;
+}
+
+static bool is_buf_attr_valid(struct hns_roce_dev *hr_dev,
+                             struct hns_roce_buf_attr *attr)
 {
        struct ib_device *ibdev = &hr_dev->ib_dev;
+
+       if (attr->region_count > ARRAY_SIZE(attr->region) ||
+           attr->region_count < 1 || attr->page_shift < HNS_HW_PAGE_SHIFT) {
+               ibdev_err(ibdev,
+                         "invalid buf attr, region count %d, page shift %u.\n",
+                         attr->region_count, attr->page_shift);
+               return false;
+       }
+
+       return true;
+}
+
+static int mtr_init_buf_cfg(struct hns_roce_dev *hr_dev,
+                           struct hns_roce_mtr *mtr,
+                           struct hns_roce_buf_attr *attr)
+{
+       struct hns_roce_hem_cfg *cfg = &mtr->hem_cfg;
        struct hns_roce_buf_region *r;
-       u64 first_region_padding;
-       int page_cnt, region_cnt;
        size_t buf_pg_sz;
        size_t buf_size;
+       int page_cnt, i;
+       u64 pgoff = 0;
+
+       if (!is_buf_attr_valid(hr_dev, attr))
+               return -EINVAL;
 
        /* If mtt is disabled, all pages must be within a continuous range */
        cfg->is_direct = !mtr_has_mtt(attr);
+       cfg->region_count = attr->region_count;
        buf_size = mtr_bufs_size(attr);
-       if (cfg->is_direct) {
+       if (need_split_huge_page(mtr)) {
                buf_pg_sz = HNS_HW_PAGE_SIZE;
                cfg->buf_pg_count = 1;
                /* The ROCEE requires the page size to be 4K * 2 ^ N. */
                cfg->buf_pg_shift = HNS_HW_PAGE_SHIFT +
                        order_base_2(DIV_ROUND_UP(buf_size, HNS_HW_PAGE_SIZE));
-               first_region_padding = 0;
        } else {
-               cfg->buf_pg_count = DIV_ROUND_UP(buf_size + unalinged_size,
-                                                1 << attr->page_shift);
+               buf_pg_sz = 1 << attr->page_shift;
+               cfg->buf_pg_count = mtr->umem ?
+                       ib_umem_num_dma_blocks(mtr->umem, buf_pg_sz) :
+                       DIV_ROUND_UP(buf_size, buf_pg_sz);
                cfg->buf_pg_shift = attr->page_shift;
-               buf_pg_sz = 1 << cfg->buf_pg_shift;
-               first_region_padding = unalinged_size;
+               pgoff = mtr->umem ? mtr->umem->address & ~PAGE_MASK : 0;
        }
 
        /* Convert buffer size to page index and page count for each region and
         * the buffer's offset needs to be appended to the first region.
         */
-       for (page_cnt = 0, region_cnt = 0; region_cnt < attr->region_count &&
-            region_cnt < ARRAY_SIZE(cfg->region); region_cnt++) {
-               r = &cfg->region[region_cnt];
+       for (page_cnt = 0, i = 0; i < attr->region_count; i++) {
+               r = &cfg->region[i];
                r->offset = page_cnt;
-               buf_size = hr_hw_page_align(attr->region[region_cnt].size +
-                                           first_region_padding);
-               r->count = DIV_ROUND_UP(buf_size, buf_pg_sz);
-               first_region_padding = 0;
-               page_cnt += r->count;
-               r->hopnum = to_hr_hem_hopnum(attr->region[region_cnt].hopnum,
-                                            r->count);
-       }
+               buf_size = hr_hw_page_align(attr->region[i].size + pgoff);
+               if (attr->type == MTR_PBL && mtr->umem)
+                       r->count = ib_umem_num_dma_blocks(mtr->umem, buf_pg_sz);
+               else
+                       r->count = DIV_ROUND_UP(buf_size, buf_pg_sz);
 
-       cfg->region_count = region_cnt;
-       if (cfg->region_count < 1 || cfg->buf_pg_shift < HNS_HW_PAGE_SHIFT) {
-               ibdev_err(ibdev, "failed to init mtr cfg, count %d shift %u.\n",
-                         cfg->region_count, cfg->buf_pg_shift);
-               return -EINVAL;
+               pgoff = 0;
+               page_cnt += r->count;
+               r->hopnum = to_hr_hem_hopnum(attr->region[i].hopnum, r->count);
        }
 
        return 0;
                        unsigned int ba_page_shift, struct ib_udata *udata,
                        unsigned long user_addr)
 {
-       u64 pgoff = udata ? user_addr & ~PAGE_MASK : 0;
        struct ib_device *ibdev = &hr_dev->ib_dev;
        int ret;
 
                                  "failed to alloc mtr bufs, ret = %d.\n", ret);
                        return ret;
                }
+
+               ret = get_best_page_shift(hr_dev, mtr, buf_attr);
+               if (ret)
+                       goto err_init_buf;
        }
 
-       ret = mtr_init_buf_cfg(hr_dev, buf_attr, &mtr->hem_cfg, pgoff);
+       ret = mtr_init_buf_cfg(hr_dev, mtr, buf_attr);
        if (ret)
                goto err_init_buf;