static const struct mm_walk_ops subpage_walk_ops = {
        .pmd_entry      = subpage_walk_pmd_entry,
+       .walk_lock      = PGWALK_WRLOCK_VERIFY,
 };
 
 static void subpage_mark_vma_nohuge(struct mm_struct *mm, unsigned long addr,
 
        .pmd_entry = pageattr_pmd_entry,
        .pte_entry = pageattr_pte_entry,
        .pte_hole = pageattr_pte_hole,
+       .walk_lock = PGWALK_RDLOCK,
 };
 
 static int __set_memory(unsigned long addr, int numpages, pgprot_t set_mask,
 
 
 static const struct mm_walk_ops thp_split_walk_ops = {
        .pmd_entry      = thp_split_walk_pmd_entry,
+       .walk_lock      = PGWALK_WRLOCK_VERIFY,
 };
 
 static inline void thp_split_mm(struct mm_struct *mm)
 
 static const struct mm_walk_ops zap_zero_walk_ops = {
        .pmd_entry      = __zap_zero_pages,
+       .walk_lock      = PGWALK_WRLOCK,
 };
 
 /*
        .hugetlb_entry          = __s390_enable_skey_hugetlb,
        .pte_entry              = __s390_enable_skey_pte,
        .pmd_entry              = __s390_enable_skey_pmd,
+       .walk_lock              = PGWALK_WRLOCK,
 };
 
 int s390_enable_skey(void)
 
 static const struct mm_walk_ops reset_cmma_walk_ops = {
        .pte_entry              = __s390_reset_cmma,
+       .walk_lock              = PGWALK_WRLOCK,
 };
 
 void s390_reset_cmma(struct mm_struct *mm)
 
 static const struct mm_walk_ops gather_pages_ops = {
        .pte_entry = s390_gather_pages,
+       .walk_lock = PGWALK_RDLOCK,
 };
 
 /*
 
 static const struct mm_walk_ops smaps_walk_ops = {
        .pmd_entry              = smaps_pte_range,
        .hugetlb_entry          = smaps_hugetlb_range,
+       .walk_lock              = PGWALK_RDLOCK,
 };
 
 static const struct mm_walk_ops smaps_shmem_walk_ops = {
        .pmd_entry              = smaps_pte_range,
        .hugetlb_entry          = smaps_hugetlb_range,
        .pte_hole               = smaps_pte_hole,
+       .walk_lock              = PGWALK_RDLOCK,
 };
 
 /*
 static const struct mm_walk_ops clear_refs_walk_ops = {
        .pmd_entry              = clear_refs_pte_range,
        .test_walk              = clear_refs_test_walk,
+       .walk_lock              = PGWALK_WRLOCK,
 };
 
 static ssize_t clear_refs_write(struct file *file, const char __user *buf,
        .pmd_entry      = pagemap_pmd_range,
        .pte_hole       = pagemap_pte_hole,
        .hugetlb_entry  = pagemap_hugetlb_range,
+       .walk_lock      = PGWALK_RDLOCK,
 };
 
 /*
 static const struct mm_walk_ops show_numa_ops = {
        .hugetlb_entry = gather_hugetlb_stats,
        .pmd_entry = gather_pte_stats,
+       .walk_lock = PGWALK_RDLOCK,
 };
 
 /*
 
 
 struct mm_walk;
 
+/* Locking requirement during a page walk. */
+enum page_walk_lock {
+       /* mmap_lock should be locked for read to stabilize the vma tree */
+       PGWALK_RDLOCK = 0,
+       /* vma will be write-locked during the walk */
+       PGWALK_WRLOCK = 1,
+       /* vma is expected to be already write-locked during the walk */
+       PGWALK_WRLOCK_VERIFY = 2,
+};
+
 /**
  * struct mm_walk_ops - callbacks for walk_page_range
  * @pgd_entry:         if set, called for each non-empty PGD (top-level) entry
        int (*pre_vma)(unsigned long start, unsigned long end,
                       struct mm_walk *walk);
        void (*post_vma)(struct mm_walk *walk);
+       enum page_walk_lock walk_lock;
 };
 
 /*
 
 static const struct mm_walk_ops damon_mkold_ops = {
        .pmd_entry = damon_mkold_pmd_entry,
        .hugetlb_entry = damon_mkold_hugetlb_entry,
+       .walk_lock = PGWALK_RDLOCK,
 };
 
 static void damon_va_mkold(struct mm_struct *mm, unsigned long addr)
 static const struct mm_walk_ops damon_young_ops = {
        .pmd_entry = damon_young_pmd_entry,
        .hugetlb_entry = damon_young_hugetlb_entry,
+       .walk_lock = PGWALK_RDLOCK,
 };
 
 static bool damon_va_young(struct mm_struct *mm, unsigned long addr,
 
        .pte_hole       = hmm_vma_walk_hole,
        .hugetlb_entry  = hmm_vma_walk_hugetlb_entry,
        .test_walk      = hmm_vma_walk_test,
+       .walk_lock      = PGWALK_RDLOCK,
 };
 
 /**
 
 
 static const struct mm_walk_ops break_ksm_ops = {
        .pmd_entry = break_ksm_pmd_entry,
+       .walk_lock = PGWALK_RDLOCK,
+};
+
+static const struct mm_walk_ops break_ksm_lock_vma_ops = {
+       .pmd_entry = break_ksm_pmd_entry,
+       .walk_lock = PGWALK_WRLOCK,
 };
 
 /*
  * of the process that owns 'vma'.  We also do not want to enforce
  * protection keys here anyway.
  */
-static int break_ksm(struct vm_area_struct *vma, unsigned long addr)
+static int break_ksm(struct vm_area_struct *vma, unsigned long addr, bool lock_vma)
 {
        vm_fault_t ret = 0;
+       const struct mm_walk_ops *ops = lock_vma ?
+                               &break_ksm_lock_vma_ops : &break_ksm_ops;
 
        do {
                int ksm_page;
 
                cond_resched();
-               ksm_page = walk_page_range_vma(vma, addr, addr + 1,
-                                              &break_ksm_ops, NULL);
+               ksm_page = walk_page_range_vma(vma, addr, addr + 1, ops, NULL);
                if (WARN_ON_ONCE(ksm_page < 0))
                        return ksm_page;
                if (!ksm_page)
        mmap_read_lock(mm);
        vma = find_mergeable_vma(mm, addr);
        if (vma)
-               break_ksm(vma, addr);
+               break_ksm(vma, addr, false);
        mmap_read_unlock(mm);
 }
 
  * in cmp_and_merge_page on one of the rmap_items we would be removing.
  */
 static int unmerge_ksm_pages(struct vm_area_struct *vma,
-                            unsigned long start, unsigned long end)
+                            unsigned long start, unsigned long end, bool lock_vma)
 {
        unsigned long addr;
        int err = 0;
                if (signal_pending(current))
                        err = -ERESTARTSYS;
                else
-                       err = break_ksm(vma, addr);
+                       err = break_ksm(vma, addr, lock_vma);
        }
        return err;
 }
                        if (!(vma->vm_flags & VM_MERGEABLE) || !vma->anon_vma)
                                continue;
                        err = unmerge_ksm_pages(vma,
-                                               vma->vm_start, vma->vm_end);
+                                               vma->vm_start, vma->vm_end, false);
                        if (err)
                                goto error;
                }
                return 0;
 
        if (vma->anon_vma) {
-               err = unmerge_ksm_pages(vma, vma->vm_start, vma->vm_end);
+               err = unmerge_ksm_pages(vma, vma->vm_start, vma->vm_end, true);
                if (err)
                        return err;
        }
                        return 0;               /* just ignore the advice */
 
                if (vma->anon_vma) {
-                       err = unmerge_ksm_pages(vma, start, end);
+                       err = unmerge_ksm_pages(vma, start, end, true);
                        if (err)
                                return err;
                }
 
 
 static const struct mm_walk_ops swapin_walk_ops = {
        .pmd_entry              = swapin_walk_pmd_entry,
+       .walk_lock              = PGWALK_RDLOCK,
 };
 
 static void shmem_swapin_range(struct vm_area_struct *vma,
 
 static const struct mm_walk_ops cold_walk_ops = {
        .pmd_entry = madvise_cold_or_pageout_pte_range,
+       .walk_lock = PGWALK_RDLOCK,
 };
 
 static void madvise_cold_page_range(struct mmu_gather *tlb,
 
 static const struct mm_walk_ops madvise_free_walk_ops = {
        .pmd_entry              = madvise_free_pte_range,
+       .walk_lock              = PGWALK_RDLOCK,
 };
 
 static int madvise_free_single_vma(struct vm_area_struct *vma,
 
 
 static const struct mm_walk_ops precharge_walk_ops = {
        .pmd_entry      = mem_cgroup_count_precharge_pte_range,
+       .walk_lock      = PGWALK_RDLOCK,
 };
 
 static unsigned long mem_cgroup_count_precharge(struct mm_struct *mm)
 
 static const struct mm_walk_ops charge_walk_ops = {
        .pmd_entry      = mem_cgroup_move_charge_pte_range,
+       .walk_lock      = PGWALK_RDLOCK,
 };
 
 static void mem_cgroup_move_charge(void)
 
 static const struct mm_walk_ops hwp_walk_ops = {
        .pmd_entry = hwpoison_pte_range,
        .hugetlb_entry = hwpoison_hugetlb_range,
+       .walk_lock = PGWALK_RDLOCK,
 };
 
 /*
 
        .hugetlb_entry          = queue_folios_hugetlb,
        .pmd_entry              = queue_folios_pte_range,
        .test_walk              = queue_pages_test_walk,
+       .walk_lock              = PGWALK_RDLOCK,
+};
+
+static const struct mm_walk_ops queue_pages_lock_vma_walk_ops = {
+       .hugetlb_entry          = queue_folios_hugetlb,
+       .pmd_entry              = queue_folios_pte_range,
+       .test_walk              = queue_pages_test_walk,
+       .walk_lock              = PGWALK_WRLOCK,
 };
 
 /*
 static int
 queue_pages_range(struct mm_struct *mm, unsigned long start, unsigned long end,
                nodemask_t *nodes, unsigned long flags,
-               struct list_head *pagelist)
+               struct list_head *pagelist, bool lock_vma)
 {
        int err;
        struct queue_pages qp = {
                .end = end,
                .first = NULL,
        };
+       const struct mm_walk_ops *ops = lock_vma ?
+                       &queue_pages_lock_vma_walk_ops : &queue_pages_walk_ops;
 
-       err = walk_page_range(mm, start, end, &queue_pages_walk_ops, &qp);
+       err = walk_page_range(mm, start, end, ops, &qp);
 
        if (!qp.first)
                /* whole range in hole */
        vma = find_vma(mm, 0);
        VM_BUG_ON(!(flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)));
        queue_pages_range(mm, vma->vm_start, mm->task_size, &nmask,
-                       flags | MPOL_MF_DISCONTIG_OK, &pagelist);
+                       flags | MPOL_MF_DISCONTIG_OK, &pagelist, false);
 
        if (!list_empty(&pagelist)) {
                err = migrate_pages(&pagelist, alloc_migration_target, NULL,
         * Lock the VMAs before scanning for pages to migrate, to ensure we don't
         * miss a concurrently inserted page.
         */
-       vma_iter_init(&vmi, mm, start);
-       for_each_vma_range(vmi, vma, end)
-               vma_start_write(vma);
-
        ret = queue_pages_range(mm, start, end, nmask,
-                         flags | MPOL_MF_INVERT, &pagelist);
+                         flags | MPOL_MF_INVERT, &pagelist, true);
 
        if (ret < 0) {
                err = ret;
 
 static const struct mm_walk_ops migrate_vma_walk_ops = {
        .pmd_entry              = migrate_vma_collect_pmd,
        .pte_hole               = migrate_vma_collect_hole,
+       .walk_lock              = PGWALK_RDLOCK,
 };
 
 /*
 
        .pmd_entry              = mincore_pte_range,
        .pte_hole               = mincore_unmapped_range,
        .hugetlb_entry          = mincore_hugetlb,
+       .walk_lock              = PGWALK_RDLOCK,
 };
 
 /*
 
 {
        static const struct mm_walk_ops mlock_walk_ops = {
                .pmd_entry = mlock_pte_range,
+               .walk_lock = PGWALK_WRLOCK_VERIFY,
        };
 
        /*
 
        .pte_entry              = prot_none_pte_entry,
        .hugetlb_entry          = prot_none_hugetlb_entry,
        .test_walk              = prot_none_test,
+       .walk_lock              = PGWALK_WRLOCK,
 };
 
 int
 
        return err;
 }
 
+static inline void process_mm_walk_lock(struct mm_struct *mm,
+                                       enum page_walk_lock walk_lock)
+{
+       if (walk_lock == PGWALK_RDLOCK)
+               mmap_assert_locked(mm);
+       else
+               mmap_assert_write_locked(mm);
+}
+
+static inline void process_vma_walk_lock(struct vm_area_struct *vma,
+                                        enum page_walk_lock walk_lock)
+{
+#ifdef CONFIG_PER_VMA_LOCK
+       switch (walk_lock) {
+       case PGWALK_WRLOCK:
+               vma_start_write(vma);
+               break;
+       case PGWALK_WRLOCK_VERIFY:
+               vma_assert_write_locked(vma);
+               break;
+       case PGWALK_RDLOCK:
+               /* PGWALK_RDLOCK is handled by process_mm_walk_lock */
+               break;
+       }
+#endif
+}
+
 /**
  * walk_page_range - walk page table with caller specific callbacks
  * @mm:                mm_struct representing the target process of page table walk
        if (!walk.mm)
                return -EINVAL;
 
-       mmap_assert_locked(walk.mm);
+       process_mm_walk_lock(walk.mm, ops->walk_lock);
 
        vma = find_vma(walk.mm, start);
        do {
                        if (ops->pte_hole)
                                err = ops->pte_hole(start, next, -1, &walk);
                } else { /* inside vma */
+                       process_vma_walk_lock(vma, ops->walk_lock);
                        walk.vma = vma;
                        next = min(end, vma->vm_end);
                        vma = find_vma(mm, vma->vm_end);
        if (start < vma->vm_start || end > vma->vm_end)
                return -EINVAL;
 
-       mmap_assert_locked(walk.mm);
+       process_mm_walk_lock(walk.mm, ops->walk_lock);
+       process_vma_walk_lock(vma, ops->walk_lock);
        return __walk_page_range(start, end, &walk);
 }
 
        if (!walk.mm)
                return -EINVAL;
 
-       mmap_assert_locked(walk.mm);
+       process_mm_walk_lock(walk.mm, ops->walk_lock);
+       process_vma_walk_lock(vma, ops->walk_lock);
        return __walk_page_range(vma->vm_start, vma->vm_end, &walk);
 }
 
 
        static const struct mm_walk_ops mm_walk_ops = {
                .test_walk = should_skip_vma,
                .p4d_entry = walk_pud_range,
+               .walk_lock = PGWALK_RDLOCK,
        };
 
        int err;