#include "i915_drv.h"
 
-#define EXPECTED_FLAGS (VM_PFNMAP | VM_DONTEXPAND | VM_DONTDUMP)
+struct remap_pfn {
+       struct mm_struct *mm;
+       unsigned long pfn;
+       pgprot_t prot;
+
+       struct sgt_iter sgt;
+       resource_size_t iobase;
+};
 
 #define use_dma(io) ((io) != -1)
 
+static inline unsigned long sgt_pfn(const struct remap_pfn *r)
+{
+       if (use_dma(r->iobase))
+               return (r->sgt.dma + r->sgt.curr + r->iobase) >> PAGE_SHIFT;
+       else
+               return r->sgt.pfn + (r->sgt.curr >> PAGE_SHIFT);
+}
+
+static int remap_sg(pte_t *pte, unsigned long addr, void *data)
+{
+       struct remap_pfn *r = data;
+
+       if (GEM_WARN_ON(!r->sgt.sgp))
+               return -EINVAL;
+
+       /* Special PTE are not associated with any struct page */
+       set_pte_at(r->mm, addr, pte,
+                  pte_mkspecial(pfn_pte(sgt_pfn(r), r->prot)));
+       r->pfn++; /* track insertions in case we need to unwind later */
+
+       r->sgt.curr += PAGE_SIZE;
+       if (r->sgt.curr >= r->sgt.max)
+               r->sgt = __sgt_iter(__sg_next(r->sgt.sgp), use_dma(r->iobase));
+
+       return 0;
+}
+
+#define EXPECTED_FLAGS (VM_PFNMAP | VM_DONTEXPAND | VM_DONTDUMP)
+
 /**
  * remap_io_sg - remap an IO mapping to userspace
  * @vma: user vma to map to
                unsigned long addr, unsigned long size,
                struct scatterlist *sgl, resource_size_t iobase)
 {
-       unsigned long pfn, len, remapped = 0;
+       struct remap_pfn r = {
+               .mm = vma->vm_mm,
+               .prot = vma->vm_page_prot,
+               .sgt = __sgt_iter(sgl, use_dma(iobase)),
+               .iobase = iobase,
+       };
        int err;
 
        /* We rely on prevalidation of the io-mapping to skip track_pfn(). */
        if (!use_dma(iobase))
                flush_cache_range(vma, addr, size);
 
-       do {
-               if (use_dma(iobase)) {
-                       if (!sg_dma_len(sgl))
-                               break;
-                       pfn = (sg_dma_address(sgl) + iobase) >> PAGE_SHIFT;
-                       len = sg_dma_len(sgl);
-               } else {
-                       pfn = page_to_pfn(sg_page(sgl));
-                       len = sgl->length;
-               }
-
-               err = remap_pfn_range(vma, addr + remapped, pfn, len,
-                                     vma->vm_page_prot);
-               if (err)
-                       break;
-               remapped += len;
-       } while ((sgl = __sg_next(sgl)));
-
-       if (err)
-               zap_vma_ptes(vma, addr, remapped);
-       return err;
+       err = apply_to_page_range(r.mm, addr, size, remap_sg, &r);
+       if (unlikely(err)) {
+               zap_vma_ptes(vma, addr, r.pfn << PAGE_SHIFT);
+               return err;
+       }
+
+       return 0;
 }