#include <linux/export.h>
 #include <linux/hugetlb.h>
 #include <linux/slab.h>
+#include <linux/pagemap.h>
 #include <rdma/ib_umem_odp.h>
 
 #include "uverbs.h"
 
-
 static void __ib_umem_release(struct ib_device *dev, struct ib_umem *umem, int dirty)
 {
-       struct scatterlist *sg;
+       struct sg_page_iter sg_iter;
        struct page *page;
-       int i;
 
        if (umem->nmap > 0)
-               ib_dma_unmap_sg(dev, umem->sg_head.sgl,
-                               umem->npages,
+               ib_dma_unmap_sg(dev, umem->sg_head.sgl, umem->sg_nents,
                                DMA_BIDIRECTIONAL);
 
-       for_each_sg(umem->sg_head.sgl, sg, umem->npages, i) {
-
-               page = sg_page(sg);
+       for_each_sg_page(umem->sg_head.sgl, &sg_iter, umem->sg_nents, 0) {
+               page = sg_page_iter_page(&sg_iter);
                if (!PageDirty(page) && umem->writable && dirty)
                        set_page_dirty_lock(page);
                put_page(page);
        sg_free_table(&umem->sg_head);
 }
 
+/* ib_umem_add_sg_table - Add N contiguous pages to scatter table
+ *
+ * sg: current scatterlist entry
+ * page_list: array of npage struct page pointers
+ * npages: number of pages in page_list
+ * max_seg_sz: maximum segment size in bytes
+ * nents: [out] number of entries in the scatterlist
+ *
+ * Return new end of scatterlist
+ */
+static struct scatterlist *ib_umem_add_sg_table(struct scatterlist *sg,
+                                               struct page **page_list,
+                                               unsigned long npages,
+                                               unsigned int max_seg_sz,
+                                               int *nents)
+{
+       unsigned long first_pfn;
+       unsigned long i = 0;
+       bool update_cur_sg = false;
+       bool first = !sg_page(sg);
+
+       /* Check if new page_list is contiguous with end of previous page_list.
+        * sg->length here is a multiple of PAGE_SIZE and sg->offset is 0.
+        */
+       if (!first && (page_to_pfn(sg_page(sg)) + (sg->length >> PAGE_SHIFT) ==
+                      page_to_pfn(page_list[0])))
+               update_cur_sg = true;
+
+       while (i != npages) {
+               unsigned long len;
+               struct page *first_page = page_list[i];
+
+               first_pfn = page_to_pfn(first_page);
+
+               /* Compute the number of contiguous pages we have starting
+                * at i
+                */
+               for (len = 0; i != npages &&
+                             first_pfn + len == page_to_pfn(page_list[i]);
+                    len++)
+                       i++;
+
+               /* Squash N contiguous pages from page_list into current sge */
+               if (update_cur_sg &&
+                   ((max_seg_sz - sg->length) >= (len << PAGE_SHIFT))) {
+                       sg_set_page(sg, sg_page(sg),
+                                   sg->length + (len << PAGE_SHIFT), 0);
+                       update_cur_sg = false;
+                       continue;
+               }
+
+               /* Squash N contiguous pages into next sge or first sge */
+               if (!first)
+                       sg = sg_next(sg);
+
+               (*nents)++;
+               sg_set_page(sg, first_page, len << PAGE_SHIFT, 0);
+               first = false;
+       }
+
+       return sg;
+}
+
 /**
  * ib_umem_get - Pin and DMA map userspace memory.
  *
        int ret;
        int i;
        unsigned long dma_attrs = 0;
-       struct scatterlist *sg, *sg_list_start;
+       struct scatterlist *sg;
        unsigned int gup_flags = FOLL_WRITE;
 
        if (!udata)
        if (!umem->writable)
                gup_flags |= FOLL_FORCE;
 
-       sg_list_start = umem->sg_head.sgl;
+       sg = umem->sg_head.sgl;
 
        while (npages) {
                down_read(&mm->mmap_sem);
                        goto umem_release;
                }
 
-               umem->npages += ret;
                cur_base += ret * PAGE_SIZE;
                npages   -= ret;
 
+               sg = ib_umem_add_sg_table(sg, page_list, ret,
+                       dma_get_max_seg_size(context->device->dma_device),
+                       &umem->sg_nents);
+
                /* Continue to hold the mmap_sem as vma_list access
                 * needs to be protected.
                 */
-               for_each_sg(sg_list_start, sg, ret, i) {
+               for (i = 0; i < ret && umem->hugetlb; i++) {
                        if (vma_list && !is_vm_hugetlb_page(vma_list[i]))
                                umem->hugetlb = 0;
-
-                       sg_set_page(sg, page_list[i], PAGE_SIZE, 0);
                }
-               up_read(&mm->mmap_sem);
 
-               /* preparing for next loop */
-               sg_list_start = sg;
+               up_read(&mm->mmap_sem);
        }
 
+       sg_mark_end(sg);
+
        umem->nmap = ib_dma_map_sg_attrs(context->device,
                                  umem->sg_head.sgl,
-                                 umem->npages,
+                                 umem->sg_nents,
                                  DMA_BIDIRECTIONAL,
                                  dma_attrs);
 
                return -EINVAL;
        }
 
-       ret = sg_pcopy_to_buffer(umem->sg_head.sgl, umem->npages, dst, length,
-                                offset + ib_umem_offset(umem));
+       ret = sg_pcopy_to_buffer(umem->sg_head.sgl, ib_umem_num_pages(umem),
+                                dst, length, offset + ib_umem_offset(umem));
 
        if (ret < 0)
                return ret;
 
        union pvrdma_cmd_resp rsp;
        struct pvrdma_cmd_create_mr *cmd = &req.create_mr;
        struct pvrdma_cmd_create_mr_resp *resp = &rsp.create_mr_resp;
-       int ret;
+       int ret, npages;
 
        if (length == 0 || length > dev->dsr->caps.max_mr_size) {
                dev_warn(&dev->pdev->dev, "invalid mem region length\n");
                return ERR_CAST(umem);
        }
 
-       if (umem->npages < 0 || umem->npages > PVRDMA_PAGE_DIR_MAX_PAGES) {
+       npages = ib_umem_num_pages(umem);
+       if (npages < 0 || npages > PVRDMA_PAGE_DIR_MAX_PAGES) {
                dev_warn(&dev->pdev->dev, "overflow %d pages in mem region\n",
-                        umem->npages);
+                        npages);
                ret = -EINVAL;
                goto err_umem;
        }
        mr->mmr.size = length;
        mr->umem = umem;
 
-       ret = pvrdma_page_dir_init(dev, &mr->pdir, umem->npages, false);
+       ret = pvrdma_page_dir_init(dev, &mr->pdir, npages, false);
        if (ret) {
                dev_warn(&dev->pdev->dev,
                         "could not allocate page directory\n");
        cmd->length = length;
        cmd->pd_handle = to_vpd(pd)->pd_handle;
        cmd->access_flags = access_flags;
-       cmd->nchunks = umem->npages;
+       cmd->nchunks = npages;
        cmd->pdir_dma = mr->pdir.dir_dma;
 
        ret = pvrdma_cmd_post(dev, &req, &rsp, PVRDMA_CMD_CREATE_MR_RESP);