return -EFAULT;
 }
 
-static int hmm_pfns_bad(unsigned long addr,
-                       unsigned long end,
-                       struct mm_walk *walk)
+static int hmm_pfns_fill(unsigned long addr, unsigned long end,
+               struct hmm_range *range, enum hmm_pfn_value_e value)
 {
-       struct hmm_vma_walk *hmm_vma_walk = walk->private;
-       struct hmm_range *range = hmm_vma_walk->range;
        uint64_t *pfns = range->pfns;
        unsigned long i;
 
        i = (addr - range->start) >> PAGE_SHIFT;
        for (; addr < end; addr += PAGE_SIZE, i++)
-               pfns[i] = range->values[HMM_PFN_ERROR];
+               pfns[i] = range->values[value];
 
        return 0;
 }
                }
                return 0;
        } else if (!pmd_present(pmd))
-               return hmm_pfns_bad(start, end, walk);
+               return hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
 
        if (pmd_devmap(pmd) || pmd_trans_huge(pmd)) {
                /*
         * recover.
         */
        if (pmd_bad(pmd))
-               return hmm_pfns_bad(start, end, walk);
+               return hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
 
        ptep = pte_offset_map(pmdp, addr);
        i = (addr - range->start) >> PAGE_SHIFT;
 #define hmm_vma_walk_hugetlb_entry NULL
 #endif /* CONFIG_HUGETLB_PAGE */
 
-static void hmm_pfns_clear(struct hmm_range *range,
-                          uint64_t *pfns,
-                          unsigned long addr,
-                          unsigned long end)
+static int hmm_vma_walk_test(unsigned long start, unsigned long end,
+                            struct mm_walk *walk)
 {
-       for (; addr < end; addr += PAGE_SIZE, pfns++)
-               *pfns = range->values[HMM_PFN_NONE];
+       struct hmm_vma_walk *hmm_vma_walk = walk->private;
+       struct hmm_range *range = hmm_vma_walk->range;
+       struct vm_area_struct *vma = walk->vma;
+
+       /*
+        * Skip vma ranges that don't have struct page backing them or
+        * map I/O devices directly.
+        */
+       if (vma->vm_flags & (VM_IO | VM_PFNMAP | VM_MIXEDMAP))
+               return -EFAULT;
+
+       /*
+        * If the vma does not allow read access, then assume that it does not
+        * allow write access either. HMM does not support architectures
+        * that allow write without read.
+        */
+       if (!(vma->vm_flags & VM_READ)) {
+               bool fault, write_fault;
+
+               /*
+                * Check to see if a fault is requested for any page in the
+                * range.
+                */
+               hmm_range_need_fault(hmm_vma_walk, range->pfns +
+                                       ((start - range->start) >> PAGE_SHIFT),
+                                       (end - start) >> PAGE_SHIFT,
+                                       0, &fault, &write_fault);
+               if (fault || write_fault)
+                       return -EFAULT;
+
+               hmm_pfns_fill(start, end, range, HMM_PFN_NONE);
+               hmm_vma_walk->last = end;
+
+               /* Skip this vma and continue processing the next vma. */
+               return 1;
+       }
+
+       return 0;
 }
 
 static const struct mm_walk_ops hmm_walk_ops = {
        .pmd_entry      = hmm_vma_walk_pmd,
        .pte_hole       = hmm_vma_walk_hole,
        .hugetlb_entry  = hmm_vma_walk_hugetlb_entry,
+       .test_walk      = hmm_vma_walk_test,
 };
 
 /**
  */
 long hmm_range_fault(struct hmm_range *range, unsigned int flags)
 {
-       const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP;
-       unsigned long start = range->start, end;
-       struct hmm_vma_walk hmm_vma_walk;
+       struct hmm_vma_walk hmm_vma_walk = {
+               .range = range,
+               .last = range->start,
+               .flags = flags,
+       };
        struct mm_struct *mm = range->notifier->mm;
-       struct vm_area_struct *vma;
        int ret;
 
        lockdep_assert_held(&mm->mmap_sem);
                if (mmu_interval_check_retry(range->notifier,
                                             range->notifier_seq))
                        return -EBUSY;
+               ret = walk_page_range(mm, hmm_vma_walk.last, range->end,
+                                     &hmm_walk_ops, &hmm_vma_walk);
+       } while (ret == -EBUSY);
 
-               vma = find_vma(mm, start);
-               if (vma == NULL || (vma->vm_flags & device_vma))
-                       return -EFAULT;
-
-               if (!(vma->vm_flags & VM_READ)) {
-                       /*
-                        * If vma do not allow read access, then assume that it
-                        * does not allow write access, either. HMM does not
-                        * support architecture that allow write without read.
-                        */
-                       hmm_pfns_clear(range, range->pfns,
-                               range->start, range->end);
-                       return -EPERM;
-               }
-
-               hmm_vma_walk.pgmap = NULL;
-               hmm_vma_walk.last = start;
-               hmm_vma_walk.flags = flags;
-               hmm_vma_walk.range = range;
-               end = min(range->end, vma->vm_end);
-
-               walk_page_range(vma->vm_mm, start, end, &hmm_walk_ops,
-                               &hmm_vma_walk);
-
-               do {
-                       ret = walk_page_range(vma->vm_mm, start, end,
-                                       &hmm_walk_ops, &hmm_vma_walk);
-                       start = hmm_vma_walk.last;
-
-                       /* Keep trying while the range is valid. */
-               } while (ret == -EBUSY &&
-                        !mmu_interval_check_retry(range->notifier,
-                                                  range->notifier_seq));
-
-               if (ret) {
-                       unsigned long i;
-
-                       i = (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
-                       hmm_pfns_clear(range, &range->pfns[i],
-                               hmm_vma_walk.last, range->end);
-                       return ret;
-               }
-               start = end;
-
-       } while (start < range->end);
-
+       if (ret)
+               return ret;
        return (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
 }
 EXPORT_SYMBOL(hmm_range_fault);