#include <linux/sizes.h>
 #include <linux/mmu_notifier.h>
 #include <linux/iomap.h>
+#include <linux/rmap.h>
 #include <asm/pgalloc.h>
 
 #define CREATE_TRACE_POINTS
        return entry;
 }
 
-static inline
-unsigned long pgoff_address(pgoff_t pgoff, struct vm_area_struct *vma)
-{
-       unsigned long address;
-
-       address = vma->vm_start + ((pgoff - vma->vm_pgoff) << PAGE_SHIFT);
-       VM_BUG_ON_VMA(address < vma->vm_start || address >= vma->vm_end, vma);
-       return address;
-}
-
-/* Walk all mappings of a given index of a file and writeprotect them */
-static void dax_entry_mkclean(struct address_space *mapping, pgoff_t index,
-               unsigned long pfn)
-{
-       struct vm_area_struct *vma;
-       pte_t pte, *ptep = NULL;
-       pmd_t *pmdp = NULL;
-       spinlock_t *ptl;
-
-       i_mmap_lock_read(mapping);
-       vma_interval_tree_foreach(vma, &mapping->i_mmap, index, index) {
-               struct mmu_notifier_range range;
-               unsigned long address;
-
-               cond_resched();
-
-               if (!(vma->vm_flags & VM_SHARED))
-                       continue;
-
-               address = pgoff_address(index, vma);
-
-               /*
-                * follow_invalidate_pte() will use the range to call
-                * mmu_notifier_invalidate_range_start() on our behalf before
-                * taking any lock.
-                */
-               if (follow_invalidate_pte(vma->vm_mm, address, &range, &ptep,
-                                         &pmdp, &ptl))
-                       continue;
-
-               /*
-                * No need to call mmu_notifier_invalidate_range() as we are
-                * downgrading page table protection not changing it to point
-                * to a new page.
-                *
-                * See Documentation/vm/mmu_notifier.rst
-                */
-               if (pmdp) {
-#ifdef CONFIG_FS_DAX_PMD
-                       pmd_t pmd;
-
-                       if (pfn != pmd_pfn(*pmdp))
-                               goto unlock_pmd;
-                       if (!pmd_dirty(*pmdp) && !pmd_write(*pmdp))
-                               goto unlock_pmd;
-
-                       flush_cache_range(vma, address,
-                                         address + HPAGE_PMD_SIZE);
-                       pmd = pmdp_invalidate(vma, address, pmdp);
-                       pmd = pmd_wrprotect(pmd);
-                       pmd = pmd_mkclean(pmd);
-                       set_pmd_at(vma->vm_mm, address, pmdp, pmd);
-unlock_pmd:
-#endif
-                       spin_unlock(ptl);
-               } else {
-                       if (pfn != pte_pfn(*ptep))
-                               goto unlock_pte;
-                       if (!pte_dirty(*ptep) && !pte_write(*ptep))
-                               goto unlock_pte;
-
-                       flush_cache_page(vma, address, pfn);
-                       pte = ptep_clear_flush(vma, address, ptep);
-                       pte = pte_wrprotect(pte);
-                       pte = pte_mkclean(pte);
-                       set_pte_at(vma->vm_mm, address, ptep, pte);
-unlock_pte:
-                       pte_unmap_unlock(ptep, ptl);
-               }
-
-               mmu_notifier_invalidate_range_end(&range);
-       }
-       i_mmap_unlock_read(mapping);
-}
-
 static int dax_writeback_one(struct xa_state *xas, struct dax_device *dax_dev,
                struct address_space *mapping, void *entry)
 {
-       unsigned long pfn, index, count;
+       unsigned long pfn, index, count, end;
        long ret = 0;
+       struct vm_area_struct *vma;
 
        /*
         * A page got tagged dirty in DAX mapping? Something is seriously
        pfn = dax_to_pfn(entry);
        count = 1UL << dax_entry_order(entry);
        index = xas->xa_index & ~(count - 1);
+       end = index + count - 1;
+
+       /* Walk all mappings of a given index of a file and writeprotect them */
+       i_mmap_lock_read(mapping);
+       vma_interval_tree_foreach(vma, &mapping->i_mmap, index, end) {
+               pfn_mkclean_range(pfn, count, index, vma);
+               cond_resched();
+       }
+       i_mmap_unlock_read(mapping);
 
-       dax_entry_mkclean(mapping, index, pfn);
        dax_flush(dax_dev, page_address(pfn_to_page(pfn)), count * PAGE_SIZE);
        /*
         * After we have flushed the cache, we can clear the dirty tag. There