From bb2f53af9c9b4e049922e4795f4aedc3df83ac9c Mon Sep 17 00:00:00 2001
From: "Liam R. Howlett" <Liam.Howlett@Oracle.com>
Date: Mon, 14 Dec 2020 13:14:06 -0500
Subject: [PATCH] maple_tree: Stop using linked list in most cases

Signed-off-by: Liam R. Howlett <Liam.Howlett@Oracle.com>
---
 arch/arm64/kernel/vdso.c                   |  5 ++-
 arch/parisc/kernel/cache.c                 |  8 +++--
 arch/powerpc/mm/book3s32/tlb.c             |  3 +-
 arch/powerpc/mm/book3s64/subpage_prot.c    | 13 ++------
 arch/powerpc/oprofile/cell/spu_task_sync.c | 23 +++++++------
 arch/s390/mm/gmap.c                        |  6 ++--
 arch/um/kernel/tlb.c                       | 14 ++++----
 arch/x86/entry/vdso/vma.c                  |  7 ++--
 arch/xtensa/kernel/syscall.c               |  3 +-
 drivers/misc/cxl/fault.c                   |  3 +-
 drivers/oprofile/buffer_sync.c             | 14 +++-----
 drivers/tee/optee/call.c                   | 13 +++++---
 fs/binfmt_elf.c                            |  3 +-
 fs/coredump.c                              |  2 +-
 fs/exec.c                                  |  4 +--
 fs/proc/base.c                             |  5 ++-
 fs/proc/task_mmu.c                         | 22 +++++++------
 fs/userfaultfd.c                           | 24 ++++++++++----
 include/linux/mm.h                         | 33 +++++++++++++++++++
 ipc/shm.c                                  | 12 +++----
 kernel/acct.c                              |  2 +-
 kernel/events/core.c                       |  3 +-
 kernel/events/uprobes.c                    |  9 +++--
 kernel/sched/fair.c                        | 11 +++++--
 kernel/sys.c                               |  3 +-
 mm/gup.c                                   |  7 ++--
 mm/huge_memory.c                           |  6 ++--
 mm/khugepaged.c                            |  6 ++--
 mm/ksm.c                                   | 12 ++++---
 mm/madvise.c                               |  2 +-
 mm/memory.c                                |  4 +--
 mm/mempolicy.c                             | 28 +++++++++-------
 mm/mlock.c                                 | 16 +++++----
 mm/mmap.c                                  | 38 +++++++++++-----------
 mm/mprotect.c                              |  7 ++--
 mm/mremap.c                                |  6 ++--
 mm/msync.c                                 |  2 +-
 mm/nommu.c                                 | 10 +++---
 mm/oom_kill.c                              |  3 +-
 mm/pagewalk.c                              |  2 +-
 mm/swapfile.c                              |  3 +-
 41 files changed, 238 insertions(+), 159 deletions(-)

diff --git a/arch/arm64/kernel/vdso.c b/arch/arm64/kernel/vdso.c
index debb8995d57f..caa7d8f6b99d 100644
--- a/arch/arm64/kernel/vdso.c
+++ b/arch/arm64/kernel/vdso.c
@@ -144,10 +144,12 @@ int vdso_join_timens(struct task_struct *task, struct time_namespace *ns)
 {
 	struct mm_struct *mm = task->mm;
 	struct vm_area_struct *vma;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	mmap_read_lock(mm);
 
-	for (vma = mm->mmap; vma; vma = vma->vm_next) {
+	rcu_read_lock();
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		unsigned long size = vma->vm_end - vma->vm_start;
 
 		if (vma_is_special_mapping(vma, vdso_info[VDSO_ABI_AA64].dm))
@@ -157,6 +159,7 @@ int vdso_join_timens(struct task_struct *task, struct time_namespace *ns)
 			zap_page_range(vma, vma->vm_start, size);
 #endif
 	}
+	rcu_read_unlock();
 
 	mmap_read_unlock(mm);
 	return 0;
diff --git a/arch/parisc/kernel/cache.c b/arch/parisc/kernel/cache.c
index 86a1a63563fd..288a25e1b1c2 100644
--- a/arch/parisc/kernel/cache.c
+++ b/arch/parisc/kernel/cache.c
@@ -520,8 +520,9 @@ static inline unsigned long mm_total_size(struct mm_struct *mm)
 {
 	struct vm_area_struct *vma;
 	unsigned long usize = 0;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
-	for (vma = mm->mmap; vma; vma = vma->vm_next)
+	mas_for_each(&mas, vma, ULONG_MAX)
 		usize += vma->vm_end - vma->vm_start;
 	return usize;
 }
@@ -548,6 +549,7 @@ void flush_cache_mm(struct mm_struct *mm)
 {
 	struct vm_area_struct *vma;
 	pgd_t *pgd;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	/* Flushing the whole cache on each cpu takes forever on
 	   rp3440, etc.  So, avoid it if the mm isn't too big.  */
@@ -560,7 +562,7 @@ void flush_cache_mm(struct mm_struct *mm)
 	}
 
 	if (mm->context == mfsp(3)) {
-		for (vma = mm->mmap; vma; vma = vma->vm_next) {
+		mas_for_each(&mas, vma, ULONG_MAX) {
 			flush_user_dcache_range_asm(vma->vm_start, vma->vm_end);
 			if (vma->vm_flags & VM_EXEC)
 				flush_user_icache_range_asm(vma->vm_start, vma->vm_end);
@@ -570,7 +572,7 @@ void flush_cache_mm(struct mm_struct *mm)
 	}
 
 	pgd = mm->pgd;
-	for (vma = mm->mmap; vma; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		unsigned long addr;
 
 		for (addr = vma->vm_start; addr < vma->vm_end;
diff --git a/arch/powerpc/mm/book3s32/tlb.c b/arch/powerpc/mm/book3s32/tlb.c
index b6c7427daa6f..be595b36dc4c 100644
--- a/arch/powerpc/mm/book3s32/tlb.c
+++ b/arch/powerpc/mm/book3s32/tlb.c
@@ -121,6 +121,7 @@ EXPORT_SYMBOL(flush_tlb_kernel_range);
 void flush_tlb_mm(struct mm_struct *mm)
 {
 	struct vm_area_struct *mp;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	if (!Hash) {
 		_tlbia();
@@ -133,7 +134,7 @@ void flush_tlb_mm(struct mm_struct *mm)
 	 * unmap_region or exit_mmap, but not from vmtruncate on SMP -
 	 * but it seems dup_mmap is the only SMP case which gets here.
 	 */
-	for (mp = mm->mmap; mp != NULL; mp = mp->vm_next)
+	mas_for_each(&mas, mp, ULONG_MAX)
 		flush_range(mp->vm_mm, mp->vm_start, mp->vm_end);
 }
 EXPORT_SYMBOL(flush_tlb_mm);
diff --git a/arch/powerpc/mm/book3s64/subpage_prot.c b/arch/powerpc/mm/book3s64/subpage_prot.c
index 60c6ea16a972..ada1531f81be 100644
--- a/arch/powerpc/mm/book3s64/subpage_prot.c
+++ b/arch/powerpc/mm/book3s64/subpage_prot.c
@@ -149,24 +149,15 @@ static void subpage_mark_vma_nohuge(struct mm_struct *mm, unsigned long addr,
 				    unsigned long len)
 {
 	struct vm_area_struct *vma;
+	MA_STATE(mas, &mm->mm_mt, addr, addr);
 
 	/*
 	 * We don't try too hard, we just mark all the vma in that range
 	 * VM_NOHUGEPAGE and split them.
 	 */
-	vma = find_vma(mm, addr);
-	/*
-	 * If the range is in unmapped range, just return
-	 */
-	if (vma && ((addr + len) <= vma->vm_start))
-		return;
-
-	while (vma) {
-		if (vma->vm_start >= (addr + len))
-			break;
+	mas_for_each(&mas, vma, addr + len) {
 		vma->vm_flags |= VM_NOHUGEPAGE;
 		walk_page_vma(vma, &subpage_walk_ops, NULL);
-		vma = vma->vm_next;
 	}
 }
 #else
diff --git a/arch/powerpc/oprofile/cell/spu_task_sync.c b/arch/powerpc/oprofile/cell/spu_task_sync.c
index 489f993100d5..c1d59a255060 100644
--- a/arch/powerpc/oprofile/cell/spu_task_sync.c
+++ b/arch/powerpc/oprofile/cell/spu_task_sync.c
@@ -321,6 +321,7 @@ get_exec_dcookie_and_offset(struct spu *spu, unsigned int *offsetp,
 	struct vm_area_struct *vma;
 	struct file *exe_file;
 	struct mm_struct *mm = spu->mm;
+	MA_STATE(mas, &mm->mm_mt, spu_ref, spu_ref);
 
 	if (!mm)
 		goto out;
@@ -333,19 +334,17 @@ get_exec_dcookie_and_offset(struct spu *spu, unsigned int *offsetp,
 	}
 
 	mmap_read_lock(mm);
-	for (vma = mm->mmap; vma; vma = vma->vm_next) {
-		if (vma->vm_start > spu_ref || vma->vm_end <= spu_ref)
-			continue;
-		my_offset = spu_ref - vma->vm_start;
-		if (!vma->vm_file)
-			goto fail_no_image_cookie;
-
-		pr_debug("Found spu ELF at %X(object-id:%lx) for file %pD\n",
-			 my_offset, spu_ref, vma->vm_file);
-		*offsetp = my_offset;
-		break;
-	}
+	vma = mas_walk(&mas);
+	if (!vma)
+		goto fail_no_image_cookie;
+
+	my_offset = spu_ref - vma->vm_start;
+	if (!vma->vm_file)
+		goto fail_no_image_cookie;
 
+	pr_debug("Found spu ELF at %X(object-id:%lx) for file %pD\n",
+		 my_offset, spu_ref, vma->vm_file);
+	*offsetp = my_offset;
 	*spu_bin_dcookie = fast_get_dcookie(&vma->vm_file->f_path);
 	pr_debug("got dcookie for %pD\n", vma->vm_file);
 
diff --git a/arch/s390/mm/gmap.c b/arch/s390/mm/gmap.c
index 64795d034926..8f7ed79583b2 100644
--- a/arch/s390/mm/gmap.c
+++ b/arch/s390/mm/gmap.c
@@ -2502,8 +2502,9 @@ static const struct mm_walk_ops thp_split_walk_ops = {
 static inline void thp_split_mm(struct mm_struct *mm)
 {
 	struct vm_area_struct *vma;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
-	for (vma = mm->mmap; vma != NULL; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		vma->vm_flags &= ~VM_HUGEPAGE;
 		vma->vm_flags |= VM_NOHUGEPAGE;
 		walk_page_vma(vma, &thp_split_walk_ops, NULL);
@@ -2571,8 +2572,9 @@ int gmap_mark_unmergeable(void)
 	struct mm_struct *mm = current->mm;
 	struct vm_area_struct *vma;
 	int ret;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
-	for (vma = mm->mmap; vma; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		ret = ksm_madvise(vma, vma->vm_start, vma->vm_end,
 				  MADV_UNMERGEABLE, &vma->vm_flags);
 		if (ret)
diff --git a/arch/um/kernel/tlb.c b/arch/um/kernel/tlb.c
index 61776790cd67..e40dd6deb1d2 100644
--- a/arch/um/kernel/tlb.c
+++ b/arch/um/kernel/tlb.c
@@ -590,21 +590,19 @@ void flush_tlb_mm_range(struct mm_struct *mm, unsigned long start,
 
 void flush_tlb_mm(struct mm_struct *mm)
 {
-	struct vm_area_struct *vma = mm->mmap;
+	struct vm_area_struct *vma;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
-	while (vma != NULL) {
+	mas_for_each(&mas, vma, ULONG_MAX)
 		fix_range(mm, vma->vm_start, vma->vm_end, 0);
-		vma = vma->vm_next;
-	}
 }
 
 void force_flush_all(void)
 {
 	struct mm_struct *mm = current->mm;
-	struct vm_area_struct *vma = mm->mmap;
+	struct vm_area_struct *vma;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
-	while (vma != NULL) {
+	mas_for_each(&mas, vma, ULONG_MAX)
 		fix_range(mm, vma->vm_start, vma->vm_end, 1);
-		vma = vma->vm_next;
-	}
 }
diff --git a/arch/x86/entry/vdso/vma.c b/arch/x86/entry/vdso/vma.c
index 9185cb1d13b9..f531efb00ba3 100644
--- a/arch/x86/entry/vdso/vma.c
+++ b/arch/x86/entry/vdso/vma.c
@@ -144,9 +144,11 @@ int vdso_join_timens(struct task_struct *task, struct time_namespace *ns)
 	struct mm_struct *mm = task->mm;
 	struct vm_area_struct *vma;
 
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
+
 	mmap_read_lock(mm);
 
-	for (vma = mm->mmap; vma; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		unsigned long size = vma->vm_end - vma->vm_start;
 
 		if (vma_is_special_mapping(vma, &vvar_mapping))
@@ -371,6 +373,7 @@ int map_vdso_once(const struct vdso_image *image, unsigned long addr)
 {
 	struct mm_struct *mm = current->mm;
 	struct vm_area_struct *vma;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	mmap_write_lock(mm);
 	/*
@@ -380,7 +383,7 @@ int map_vdso_once(const struct vdso_image *image, unsigned long addr)
 	 * We could search vma near context.vdso, but it's a slowpath,
 	 * so let's explicitly check all VMAs to be completely sure.
 	 */
-	for (vma = mm->mmap; vma; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		if (vma_is_special_mapping(vma, &vdso_mapping) ||
 				vma_is_special_mapping(vma, &vvar_mapping)) {
 			mmap_write_unlock(mm);
diff --git a/arch/xtensa/kernel/syscall.c b/arch/xtensa/kernel/syscall.c
index 2c415fce6801..26ec2e67879a 100644
--- a/arch/xtensa/kernel/syscall.c
+++ b/arch/xtensa/kernel/syscall.c
@@ -62,6 +62,7 @@ unsigned long arch_get_unmapped_area(struct file *filp, unsigned long addr,
 		unsigned long len, unsigned long pgoff, unsigned long flags)
 {
 	struct vm_area_struct *vmm;
+	MA_STATE(mas, &mm->mm_mt, addr, addr);
 
 	if (flags & MAP_FIXED) {
 		/* We do not accept a shared mapping if it would violate
@@ -83,7 +84,7 @@ unsigned long arch_get_unmapped_area(struct file *filp, unsigned long addr,
 	else
 		addr = PAGE_ALIGN(addr);
 
-	for (vmm = find_vma(current->mm, addr); ; vmm = vmm->vm_next) {
+	mas_for_each(&mas, vmm, ULONG_MAX) {
 		/* At this point:  (!vmm || addr < vmm->vm_end). */
 		if (TASK_SIZE - len < addr)
 			return -ENOMEM;
diff --git a/drivers/misc/cxl/fault.c b/drivers/misc/cxl/fault.c
index 01153b74334a..47951b84f2cd 100644
--- a/drivers/misc/cxl/fault.c
+++ b/drivers/misc/cxl/fault.c
@@ -313,6 +313,7 @@ static void cxl_prefault_vma(struct cxl_context *ctx)
 	struct vm_area_struct *vma;
 	int rc;
 	struct mm_struct *mm;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	mm = get_mem_context(ctx);
 	if (mm == NULL) {
@@ -322,7 +323,7 @@ static void cxl_prefault_vma(struct cxl_context *ctx)
 	}
 
 	mmap_read_lock(mm);
-	for (vma = mm->mmap; vma; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		for (ea = vma->vm_start; ea < vma->vm_end;
 				ea = next_segment(ea, slb.vsid)) {
 			rc = copro_calculate_slb(mm, ea, &slb);
diff --git a/drivers/oprofile/buffer_sync.c b/drivers/oprofile/buffer_sync.c
index cc917865f13a..fbb43edad41b 100644
--- a/drivers/oprofile/buffer_sync.c
+++ b/drivers/oprofile/buffer_sync.c
@@ -257,11 +257,8 @@ lookup_dcookie(struct mm_struct *mm, unsigned long addr, off_t *offset)
 	struct vm_area_struct *vma;
 
 	mmap_read_lock(mm);
-	for (vma = find_vma(mm, addr); vma; vma = vma->vm_next) {
-
-		if (addr < vma->vm_start || addr >= vma->vm_end)
-			continue;
-
+	vma = find_vma_intersection(mm, addr, addr + 1);
+	if (vma) {
 		if (vma->vm_file) {
 			cookie = fast_get_dcookie(&vma->vm_file->f_path);
 			*offset = (vma->vm_pgoff << PAGE_SHIFT) + addr -
@@ -270,12 +267,9 @@ lookup_dcookie(struct mm_struct *mm, unsigned long addr, off_t *offset)
 			/* must be an anonymous map */
 			*offset = addr;
 		}
-
-		break;
-	}
-
-	if (!vma)
+	} else
 		cookie = INVALID_COOKIE;
+
 	mmap_read_unlock(mm);
 
 	return cookie;
diff --git a/drivers/tee/optee/call.c b/drivers/tee/optee/call.c
index c981757ba0d4..94acf379eaee 100644
--- a/drivers/tee/optee/call.c
+++ b/drivers/tee/optee/call.c
@@ -545,12 +545,17 @@ static bool is_normal_memory(pgprot_t p)
 
 static int __check_mem_type(struct vm_area_struct *vma, unsigned long end)
 {
-	while (vma && is_normal_memory(vma->vm_page_prot)) {
-		if (vma->vm_end >= end)
-			return 0;
-		vma = vma->vm_next;
+	MA_STATE(mas, &vma->vm_mm->mm_mt, vma->vm_start, vma->vm_start);
+
+
+	mas_for_each(&mas, vma, end) {
+		if (!is_normal_memory(vma->vm_page_prot))
+		    break;
 	}
 
+	if (!vma)
+		return 0;
+
 	return -EINVAL;
 }
 
diff --git a/fs/binfmt_elf.c b/fs/binfmt_elf.c
index fa50e8936f5f..92b67b914bf9 100644
--- a/fs/binfmt_elf.c
+++ b/fs/binfmt_elf.c
@@ -1609,6 +1609,7 @@ static int fill_files_note(struct memelfnote *note)
 	user_long_t *data;
 	user_long_t *start_end_ofs;
 	char *name_base, *name_curpos;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	/* *Estimated* file count and total data size needed */
 	count = mm->map_count;
@@ -1633,7 +1634,7 @@ static int fill_files_note(struct memelfnote *note)
 	name_base = name_curpos = ((char *)data) + names_ofs;
 	remaining = size - names_ofs;
 	count = 0;
-	for (vma = mm->mmap; vma != NULL; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		struct file *file;
 		const char *filename;
 
diff --git a/fs/coredump.c b/fs/coredump.c
index c6acfc694f65..0699ce6f6cc1 100644
--- a/fs/coredump.c
+++ b/fs/coredump.c
@@ -1059,7 +1059,7 @@ static struct vm_area_struct *next_vma(struct vm_area_struct *this_vma,
 {
 	struct vm_area_struct *ret;
 
-	ret = this_vma->vm_next;
+	ret = vma_next(this_vma->vm_mm, this_vma);
 	if (ret)
 		return ret;
 	if (this_vma == gate_vma)
diff --git a/fs/exec.c b/fs/exec.c
index aa466921d6a9..6eb585ea1b84 100644
--- a/fs/exec.c
+++ b/fs/exec.c
@@ -712,7 +712,7 @@ static int shift_arg_pages(struct vm_area_struct *vma, unsigned long shift)
 		 * when the old and new regions overlap clear from new_end.
 		 */
 		free_pgd_range(&tlb, new_end, old_end, new_end,
-			vma->vm_next ? vma->vm_next->vm_start : USER_PGTABLES_CEILING);
+			vma_next(mm, vma) ? vma_next(mm, vma)->vm_start : USER_PGTABLES_CEILING);
 	} else {
 		/*
 		 * otherwise, clean from old_start; this is done to not touch
@@ -721,7 +721,7 @@ static int shift_arg_pages(struct vm_area_struct *vma, unsigned long shift)
 		 * for the others its just a little faster.
 		 */
 		free_pgd_range(&tlb, old_start, old_end, new_end,
-			vma->vm_next ? vma->vm_next->vm_start : USER_PGTABLES_CEILING);
+			vma_next(mm, vma) ? vma_next(mm, vma)->vm_start : USER_PGTABLES_CEILING);
 	}
 	tlb_finish_mmu(&tlb, old_start, old_end);
 
diff --git a/fs/proc/base.c b/fs/proc/base.c
index b362523a9829..0a7d4d3d6ae9 100644
--- a/fs/proc/base.c
+++ b/fs/proc/base.c
@@ -2316,6 +2316,7 @@ proc_map_files_readdir(struct file *file, struct dir_context *ctx)
 	GENRADIX(struct map_files_info) fa;
 	struct map_files_info *p;
 	int ret;
+	MA_STATE(mas, NULL, 0, 0);
 
 	genradix_init(&fa);
 
@@ -2343,6 +2344,7 @@ proc_map_files_readdir(struct file *file, struct dir_context *ctx)
 	}
 
 	nr_files = 0;
+	mas.tree = &mm->mm_mt;
 
 	/*
 	 * We need two passes here:
@@ -2354,7 +2356,8 @@ proc_map_files_readdir(struct file *file, struct dir_context *ctx)
 	 * routine might require mmap_lock taken in might_fault().
 	 */
 
-	for (vma = mm->mmap, pos = 2; vma; vma = vma->vm_next) {
+	pos = 2;
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		if (!vma->vm_file)
 			continue;
 		if (++pos <= ctx->pos)
diff --git a/fs/proc/task_mmu.c b/fs/proc/task_mmu.c
index 1ec2dd34ebff..20a0fb414c4f 100644
--- a/fs/proc/task_mmu.c
+++ b/fs/proc/task_mmu.c
@@ -164,14 +164,13 @@ static void *m_start(struct seq_file *m, loff_t *ppos)
 static void *m_next(struct seq_file *m, void *v, loff_t *ppos)
 {
 	struct proc_maps_private *priv = m->private;
-	struct vm_area_struct *next, *vma = v;
+	struct vm_area_struct *next = NULL, *vma = v;
 
-	if (vma == priv->tail_vma)
-		next = NULL;
-	else if (vma->vm_next)
-		next = vma->vm_next;
-	else
-		next = priv->tail_vma;
+	if (vma != priv->tail_vma) {
+		next = vma_next(vma->vm_mm, vma);
+		if (!next)
+			next = priv->tail_vma;
+	}
 
 	*ppos = next ? next->vm_start : -1UL;
 
@@ -930,7 +929,7 @@ static int show_smaps_rollup(struct seq_file *m, void *v)
 				smap_gather_stats(vma, &mss, last_vma_end);
 		}
 		/* Case 2 above */
-		vma = vma->vm_next;
+		vma = vma_next(mm, vma);
 	}
 
 	show_vma_header_prefix(m, priv->mm->mmap->vm_start,
@@ -1209,6 +1208,7 @@ static ssize_t clear_refs_write(struct file *file, const char __user *buf,
 		return -ESRCH;
 	mm = get_task_mm(task);
 	if (mm) {
+		MA_STATE(mas, &mm->mm_mt, 0, 0);
 		struct mmu_notifier_range range;
 		struct clear_refs_private cp = {
 			.type = type,
@@ -1235,7 +1235,7 @@ static ssize_t clear_refs_write(struct file *file, const char __user *buf,
 		}
 		tlb_gather_mmu(&tlb, mm, 0, -1);
 		if (type == CLEAR_REFS_SOFT_DIRTY) {
-			for (vma = mm->mmap; vma; vma = vma->vm_next) {
+			mas_for_each(&mas, vma, ULONG_MAX) {
 				if (!(vma->vm_flags & VM_SOFTDIRTY))
 					continue;
 				mmap_read_unlock(mm);
@@ -1243,7 +1243,9 @@ static ssize_t clear_refs_write(struct file *file, const char __user *buf,
 					count = -EINTR;
 					goto out_mm;
 				}
-				for (vma = mm->mmap; vma; vma = vma->vm_next) {
+				mas_reset(&mas);
+				mas_set(&mas, 0);
+				mas_for_each(&mas, vma, ULONG_MAX) {
 					vma->vm_flags &= ~VM_SOFTDIRTY;
 					vma_set_page_prot(vma);
 				}
diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
index 000b457ad087..060e35f4bde9 100644
--- a/fs/userfaultfd.c
+++ b/fs/userfaultfd.c
@@ -598,14 +598,16 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx,
 	if (release_new_ctx) {
 		struct vm_area_struct *vma;
 		struct mm_struct *mm = release_new_ctx->mm;
+		MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 		/* the various vma->vm_userfaultfd_ctx still points to it */
 		mmap_write_lock(mm);
-		for (vma = mm->mmap; vma; vma = vma->vm_next)
+		mas_for_each(&mas, vma, ULONG_MAX) {
 			if (vma->vm_userfaultfd_ctx.ctx == release_new_ctx) {
 				vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
 				vma->vm_flags &= ~(VM_UFFD_WP | VM_UFFD_MISSING);
 			}
+		}
 		mmap_write_unlock(mm);
 
 		userfaultfd_ctx_put(release_new_ctx);
@@ -790,7 +792,9 @@ int userfaultfd_unmap_prep(struct vm_area_struct *vma,
 			   unsigned long start, unsigned long end,
 			   struct list_head *unmaps)
 {
-	for ( ; vma && vma->vm_start < end; vma = vma->vm_next) {
+	MA_STATE(mas, &mm->mm_mt, vma->vm_start, vma->vm_start);
+
+	mas_for_each(&mas, vma, end) {
 		struct userfaultfd_unmap_ctx *unmap_ctx;
 		struct userfaultfd_ctx *ctx = vma->vm_userfaultfd_ctx.ctx;
 
@@ -840,6 +844,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
 	/* len == 0 means wake all */
 	struct userfaultfd_wake_range range = { .len = 0, };
 	unsigned long new_flags;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	WRITE_ONCE(ctx->released, true);
 
@@ -856,7 +861,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
 	 */
 	mmap_write_lock(mm);
 	prev = NULL;
-	for (vma = mm->mmap; vma; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		cond_resched();
 		BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^
 		       !!(vma->vm_flags & (VM_UFFD_MISSING | VM_UFFD_WP)));
@@ -1270,6 +1275,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
 	bool found;
 	bool basic_ioctls;
 	unsigned long start, end, vma_end;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	user_uffdio_register = (struct uffdio_register __user *) arg;
 
@@ -1328,7 +1334,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
 	 */
 	found = false;
 	basic_ioctls = false;
-	for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) {
+	mas_set(&mas, vma->vm_start);
+	mas_for_each(&mas, curr, end) {
 		cond_resched();
 
 		BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
@@ -1444,7 +1451,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
 	skip:
 		prev = vma;
 		start = vma->vm_end;
-		vma = vma->vm_next;
+		vma = vma_next(vma);
 	} while (vma && vma->vm_start < end);
 out_unlock:
 	mmap_write_unlock(mm);
@@ -1485,6 +1492,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
 	bool found;
 	unsigned long start, end, vma_end;
 	const void __user *buf = (void __user *)arg;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	ret = -EFAULT;
 	if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister)))
@@ -1528,7 +1536,9 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
 	 */
 	found = false;
 	ret = -EINVAL;
-	for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) {
+	mas_set(&mas, vma->vm_start);
+
+	mas_for_each(&mas, cur, end) {
 		cond_resched();
 
 		BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
@@ -1614,7 +1624,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
 	skip:
 		prev = vma;
 		start = vma->vm_end;
-		vma = vma->vm_next;
+		vma = vma_next(vma);
 	} while (vma && vma->vm_start < end);
 out_unlock:
 	mmap_write_unlock(mm);
diff --git a/include/linux/mm.h b/include/linux/mm.h
index f039fefa3acd..6a5e4fcbadec 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -2654,6 +2654,24 @@ extern struct vm_area_struct * find_vma_prev(struct mm_struct * mm, unsigned lon
 extern struct vm_area_struct *find_vma_intersection(struct mm_struct *mm,
 		     unsigned long start_addr, unsigned long end_addr);
 
+static inline struct vm_area_struct *vma_next(struct mm_struct *mm,
+			struct vm_area_struct *vma)
+{
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
+
+	mas_set(&mas, vma->vm_end);
+	return mas_next(&mas, ULONG_MAX);
+}
+
+static inline struct vm_area_struct *vma_prev(struct mm_struct *mm,
+			struct vm_area_struct *vma)
+{
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
+
+	mas_set(&mas, vma->vm_start);
+	return mas_prev(&mas, 0);
+}
+
 static inline unsigned long vm_start_gap(struct vm_area_struct *vma)
 {
 	unsigned long vm_start = vma->vm_start;
@@ -2695,6 +2713,21 @@ static inline struct vm_area_struct *find_exact_vma(struct mm_struct *mm,
 	return vma;
 }
 
+static inline struct vm_area_struct *vma_mas_next(struct ma_state *mas)
+{
+	struct ma_state tmp;
+
+	memcpy(&tmp, mas, sizeof(tmp));
+	return mas_next(&tmp, ULONG_MAX);
+}
+
+static inline struct vm_area_struct *vma_mas_prev(struct ma_state *mas)
+{
+	struct ma_state tmp;
+
+	memcpy(&tmp, mas, sizeof(tmp));
+	return mas_prev(&tmp, 0);
+}
 static inline bool range_in_vma(struct vm_area_struct *vma,
 				unsigned long start, unsigned long end)
 {
diff --git a/ipc/shm.c b/ipc/shm.c
index e25c7c6106bc..02f4c9413690 100644
--- a/ipc/shm.c
+++ b/ipc/shm.c
@@ -1631,6 +1631,7 @@ long ksys_shmdt(char __user *shmaddr)
 	loff_t size = 0;
 	struct file *file;
 	struct vm_area_struct *next;
+	MA_STATE(mas, &mm->mm_mt, addr, addr);
 #endif
 
 	if (addr & ~PAGE_MASK)
@@ -1660,11 +1661,11 @@ long ksys_shmdt(char __user *shmaddr)
 	 * match the usual checks anyway. So assume all vma's are
 	 * above the starting address given.
 	 */
-	vma = find_vma(mm, addr);
 
 #ifdef CONFIG_MMU
+	vma = mas_find(&mas, ULONG_MAX);
 	while (vma) {
-		next = vma->vm_next;
+		next = mas_find(&mas, ULONG_MAX);
 
 		/*
 		 * Check if the starting address would match, i.e. it's
@@ -1702,22 +1703,21 @@ long ksys_shmdt(char __user *shmaddr)
 	 * prevent overflows and make comparisons vs. equal-width types.
 	 */
 	size = PAGE_ALIGN(size);
-	while (vma && (loff_t)(vma->vm_end - addr) <= size) {
-		next = vma->vm_next;
 
+	mas_for_each(&mas, vma, size) {
 		/* finding a matching vma now does not alter retval */
 		if ((vma->vm_ops == &shm_vm_ops) &&
 		    ((vma->vm_start - addr)/PAGE_SIZE == vma->vm_pgoff) &&
 		    (vma->vm_file == file))
 			do_munmap(mm, vma->vm_start, vma->vm_end - vma->vm_start, NULL);
-		vma = next;
 	}
 
 #else	/* CONFIG_MMU */
+	vma = mas_walk(&mas);
 	/* under NOMMU conditions, the exact address to be destroyed must be
 	 * given
 	 */
-	if (vma && vma->vm_start == addr && vma->vm_ops == &shm_vm_ops) {
+	if (vma && vma->vm_ops == &shm_vm_ops) {
 		do_munmap(mm, vma->vm_start, vma->vm_end - vma->vm_start, NULL);
 		retval = 0;
 	}
diff --git a/kernel/acct.c b/kernel/acct.c
index f175df8f6aa4..c1cb3d68948e 100644
--- a/kernel/acct.c
+++ b/kernel/acct.c
@@ -545,7 +545,7 @@ void acct_collect(long exitcode, int group_dead)
 		vma = current->mm->mmap;
 		while (vma) {
 			vsize += vma->vm_end - vma->vm_start;
-			vma = vma->vm_next;
+			vma = vma_next(vma->vm_mm, vma);
 		}
 		mmap_read_unlock(current->mm);
 	}
diff --git a/kernel/events/core.c b/kernel/events/core.c
index dc568ca295bd..767bec30c395 100644
--- a/kernel/events/core.c
+++ b/kernel/events/core.c
@@ -9888,8 +9888,9 @@ static void perf_addr_filter_apply(struct perf_addr_filter *filter,
 				   struct perf_addr_filter_range *fr)
 {
 	struct vm_area_struct *vma;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
-	for (vma = mm->mmap; vma; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		if (!vma->vm_file)
 			continue;
 
diff --git a/kernel/events/uprobes.c b/kernel/events/uprobes.c
index 00b0358739ab..b0aa05be2a61 100644
--- a/kernel/events/uprobes.c
+++ b/kernel/events/uprobes.c
@@ -356,8 +356,9 @@ static struct vm_area_struct *
 find_ref_ctr_vma(struct uprobe *uprobe, struct mm_struct *mm)
 {
 	struct vm_area_struct *tmp;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
-	for (tmp = mm->mmap; tmp; tmp = tmp->vm_next)
+	mas_for_each(&mas, tmp, ULONG_MAX)
 		if (valid_ref_ctr_vma(uprobe, tmp))
 			return tmp;
 
@@ -1239,9 +1240,10 @@ static int unapply_uprobe(struct uprobe *uprobe, struct mm_struct *mm)
 {
 	struct vm_area_struct *vma;
 	int err = 0;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	mmap_read_lock(mm);
-	for (vma = mm->mmap; vma; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		unsigned long vaddr;
 		loff_t offset;
 
@@ -1990,8 +1992,9 @@ bool uprobe_deny_signal(void)
 static void mmf_recalc_uprobes(struct mm_struct *mm)
 {
 	struct vm_area_struct *vma;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
-	for (vma = mm->mmap; vma; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		if (!valid_vma(vma, false))
 			continue;
 		/*
diff --git a/kernel/sched/fair.c b/kernel/sched/fair.c
index ae7ceba8fd4f..710727741538 100644
--- a/kernel/sched/fair.c
+++ b/kernel/sched/fair.c
@@ -2723,6 +2723,7 @@ static void task_numa_work(struct callback_head *work)
 	unsigned long start, end;
 	unsigned long nr_pte_updates = 0;
 	long pages, virtpages;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	SCHED_WARN_ON(p != container_of(work, struct task_struct, numa_work));
 
@@ -2775,13 +2776,17 @@ static void task_numa_work(struct callback_head *work)
 
 	if (!mmap_read_trylock(mm))
 		return;
-	vma = find_vma(mm, start);
+
+	mas_set(&mas, start);
+	vma = mas_walk(&mas);
 	if (!vma) {
 		reset_ptenuma_scan(p);
+		mas_reset(&mas);
 		start = 0;
-		vma = mm->mmap;
+		mas_set(&mas, start);
 	}
-	for (; vma; vma = vma->vm_next) {
+
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		if (!vma_migratable(vma) || !vma_policy_mof(vma) ||
 			is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_MIXEDMAP)) {
 			continue;
diff --git a/kernel/sys.c b/kernel/sys.c
index a730c03ee607..fe061c54a5a4 100644
--- a/kernel/sys.c
+++ b/kernel/sys.c
@@ -1858,9 +1858,10 @@ static int prctl_set_mm_exe_file(struct mm_struct *mm, unsigned int fd)
 	err = -EBUSY;
 	if (exe_file) {
 		struct vm_area_struct *vma;
+		MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 		mmap_read_lock(mm);
-		for (vma = mm->mmap; vma; vma = vma->vm_next) {
+		mas_for_each(&mas, vma, ULONG_MAX) {
 			if (!vma->vm_file)
 				continue;
 			if (path_equal(&vma->vm_file->f_path,
diff --git a/mm/gup.c b/mm/gup.c
index b06e2a6d9018..20392efb2d9b 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -1465,6 +1465,7 @@ int __mm_populate(unsigned long start, unsigned long len, int ignore_errors)
 	struct vm_area_struct *vma = NULL;
 	int locked = 0;
 	long ret = 0;
+	MA_STATE(mas, &mm->mm_mt, start, start);
 
 	end = start + len;
 
@@ -1476,10 +1477,10 @@ int __mm_populate(unsigned long start, unsigned long len, int ignore_errors)
 		if (!locked) {
 			locked = 1;
 			mmap_read_lock(mm);
-			vma = find_vma(mm, nstart);
+			vma = mas_find(&mas, end);
 		} else if (nstart >= vma->vm_end)
-			vma = vma->vm_next;
-		if (!vma || vma->vm_start >= end)
+			vma = mas_next(&mas, end);
+		if (!vma)
 			break;
 		/*
 		 * Set [nstart; nend) to intersection of desired address
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index ec2bb93f7431..c2627619d622 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -2304,12 +2304,12 @@ void vma_adjust_trans_huge(struct vm_area_struct *vma,
 		split_huge_pmd_address(vma, end, false, NULL);
 
 	/*
-	 * If we're also updating the vma->vm_next->vm_start, if the new
-	 * vm_next->vm_start isn't hpage aligned and it could previously
+	 * If we're also updating the vma_next(vma)->vm_start, if the new
+	 * vma_next()->vm_start isn't hpage aligned and it could previously
 	 * contain an hugepage: check if we need to split an huge pmd.
 	 */
 	if (adjust_next > 0) {
-		struct vm_area_struct *next = vma->vm_next;
+		struct vm_area_struct *next = vma_next(vma);
 		unsigned long nstart = next->vm_start;
 		nstart += adjust_next;
 		if (nstart & ~HPAGE_PMD_MASK &&
diff --git a/mm/khugepaged.c b/mm/khugepaged.c
index 4e3dff13eb70..b20228f10725 100644
--- a/mm/khugepaged.c
+++ b/mm/khugepaged.c
@@ -2049,6 +2049,7 @@ static unsigned int khugepaged_scan_mm_slot(unsigned int pages,
 	struct mm_struct *mm;
 	struct vm_area_struct *vma;
 	int progress = 0;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	VM_BUG_ON(!pages);
 	lockdep_assert_held(&khugepaged_mm_lock);
@@ -2070,13 +2071,14 @@ static unsigned int khugepaged_scan_mm_slot(unsigned int pages,
 	 * the next mm on the list.
 	 */
 	vma = NULL;
+	mas_set(&mas, khugepaged_scan.address);
 	if (unlikely(!mmap_read_trylock(mm)))
 		goto breakouterloop_mmap_lock;
 	if (likely(!khugepaged_test_exit(mm)))
-		vma = find_vma(mm, khugepaged_scan.address);
+		vma = mas_find(&mas, ULONG_MAX);
 
 	progress++;
-	for (; vma; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		unsigned long hstart, hend;
 
 		cond_resched();
diff --git a/mm/ksm.c b/mm/ksm.c
index 0960750bb316..2e9e685c588f 100644
--- a/mm/ksm.c
+++ b/mm/ksm.c
@@ -968,6 +968,7 @@ static int unmerge_and_remove_all_rmap_items(void)
 	struct mm_struct *mm;
 	struct vm_area_struct *vma;
 	int err = 0;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	spin_lock(&ksm_mmlist_lock);
 	ksm_scan.mm_slot = list_entry(ksm_mm_head.mm_list.next,
@@ -978,7 +979,7 @@ static int unmerge_and_remove_all_rmap_items(void)
 			mm_slot != &ksm_mm_head; mm_slot = ksm_scan.mm_slot) {
 		mm = mm_slot->mm;
 		mmap_read_lock(mm);
-		for (vma = mm->mmap; vma; vma = vma->vm_next) {
+		mas_for_each(&mas, vma, ULONG_MAX) {
 			if (ksm_test_exit(mm))
 				break;
 			if (!(vma->vm_flags & VM_MERGEABLE) || !vma->anon_vma)
@@ -2229,6 +2230,7 @@ static struct rmap_item *scan_get_next_rmap_item(struct page **page)
 	struct vm_area_struct *vma;
 	struct rmap_item *rmap_item;
 	int nid;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	if (list_empty(&ksm_mm_head.mm_list))
 		return NULL;
@@ -2289,10 +2291,12 @@ next_mm:
 	mmap_read_lock(mm);
 	if (ksm_test_exit(mm))
 		vma = NULL;
-	else
-		vma = find_vma(mm, ksm_scan.address);
+	else {
+		mas_set(&mas, ksm_scan.address);
+		vma = mas_find(&mas, ULONG_MAX);
+	}
 
-	for (; vma; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		if (!(vma->vm_flags & VM_MERGEABLE))
 			continue;
 		if (ksm_scan.address < vma->vm_start)
diff --git a/mm/madvise.c b/mm/madvise.c
index 13f5677b9322..0da3c9a5f5b8 100644
--- a/mm/madvise.c
+++ b/mm/madvise.c
@@ -1151,7 +1151,7 @@ int do_madvise(struct mm_struct *mm, unsigned long start, size_t len_in, int beh
 		if (start >= end)
 			goto out;
 		if (prev)
-			vma = prev->vm_next;
+			vma = vma_next(mm, prev);
 		else	/* madvise_remove dropped mmap_lock */
 			vma = find_vma(mm, start);
 	}
diff --git a/mm/memory.c b/mm/memory.c
index c48f8df6e502..3b0fe38f967d 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -4903,8 +4903,8 @@ int __access_remote_vm(struct task_struct *tsk, struct mm_struct *mm,
 			 * Check if this is a VM_IO | VM_PFNMAP VMA, which
 			 * we can access using slightly different code.
 			 */
-			vma = find_vma(mm, addr);
-			if (!vma || vma->vm_start > addr)
+			vma = find_vma_intersection(mm, addr, addr + 1);
+			if (!vma)
 				break;
 			if (vma->vm_ops && vma->vm_ops->access)
 				ret = vma->vm_ops->access(vma, addr, buf,
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index 3ca4898f3f24..f16ff55b10c8 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -404,9 +404,10 @@ void mpol_rebind_task(struct task_struct *tsk, const nodemask_t *new)
 void mpol_rebind_mm(struct mm_struct *mm, nodemask_t *new)
 {
 	struct vm_area_struct *vma;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	mmap_write_lock(mm);
-	for (vma = mm->mmap; vma; vma = vma->vm_next)
+	mas_for_each(&mas, vma, ULONG_MAX)
 		mpol_rebind_policy(vma->vm_policy, new);
 	mmap_write_unlock(mm);
 }
@@ -671,7 +672,7 @@ static unsigned long change_prot_numa(struct vm_area_struct *vma,
 static int queue_pages_test_walk(unsigned long start, unsigned long end,
 				struct mm_walk *walk)
 {
-	struct vm_area_struct *vma = walk->vma;
+	struct vm_area_struct *next, *vma = walk->vma;
 	struct queue_pages *qp = walk->private;
 	unsigned long endvma = vma->vm_end;
 	unsigned long flags = qp->flags;
@@ -686,9 +687,10 @@ static int queue_pages_test_walk(unsigned long start, unsigned long end,
 			/* hole at head side of range */
 			return -EFAULT;
 	}
+	next = vma_next(vma->vm_mm, vma);
 	if (!(flags & MPOL_MF_DISCONTIG_OK) &&
 		((vma->vm_end < qp->end) &&
-		(!vma->vm_next || vma->vm_end < vma->vm_next->vm_start)))
+		(!next || vma->vm_end < next->vm_start)))
 		/* hole at middle or tail of range */
 		return -EFAULT;
 
@@ -809,21 +811,22 @@ static int mbind_range(struct mm_struct *mm, unsigned long start,
 	pgoff_t pgoff;
 	unsigned long vmstart;
 	unsigned long vmend;
+	MA_STATE(mas, &mm->mm_mt, start, start);
 
-	vma = find_vma(mm, start);
+	vma = mas_find(&mas, ULONG_MAX);
 	VM_BUG_ON(!vma);
 
-	prev = vma->vm_prev;
+	prev = vma_mas_prev(&mas);
 	if (start > vma->vm_start)
 		prev = vma;
 
-	for (; vma && vma->vm_start < end; prev = vma, vma = next) {
-		next = vma->vm_next;
+	mas_for_each(&mas, vma, end) {
+		next = vma_next(mm, vma);
 		vmstart = max(start, vma->vm_start);
 		vmend   = min(end, vma->vm_end);
 
 		if (mpol_equal(vma_policy(vma), new_pol))
-			continue;
+			goto next;
 
 		pgoff = vma->vm_pgoff +
 			((vmstart - vma->vm_start) >> PAGE_SHIFT);
@@ -832,7 +835,7 @@ static int mbind_range(struct mm_struct *mm, unsigned long start,
 				 new_pol, vma->vm_userfaultfd_ctx);
 		if (prev) {
 			vma = prev;
-			next = vma->vm_next;
+			next = vma_next(mm, vma);
 			if (mpol_equal(vma_policy(vma), new_pol))
 				continue;
 			/* vma_merge() joined vma && vma->next, case 8 */
@@ -852,6 +855,8 @@ static int mbind_range(struct mm_struct *mm, unsigned long start,
 		err = vma_replace_policy(vma, new_pol);
 		if (err)
 			goto out;
+next:
+		prev = vma;
 	}
 
  out:
@@ -1217,13 +1222,12 @@ static struct page *new_page(struct page *page, unsigned long start)
 {
 	struct vm_area_struct *vma;
 	unsigned long address;
+	MA_STATE(mas, &current->mm->mm_mt, start, start);
 
-	vma = find_vma(current->mm, start);
-	while (vma) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		address = page_address_in_vma(page, vma);
 		if (address != -EFAULT)
 			break;
-		vma = vma->vm_next;
 	}
 
 	if (PageHuge(page)) {
diff --git a/mm/mlock.c b/mm/mlock.c
index 884b1216da6a..c5337fbf7139 100644
--- a/mm/mlock.c
+++ b/mm/mlock.c
@@ -591,6 +591,7 @@ static int apply_vma_lock_flags(unsigned long start, size_t len,
 	unsigned long nstart, end, tmp;
 	struct vm_area_struct * vma, * prev;
 	int error;
+	MA_STATE(mas, &current->mm->mm_mt, start, start);
 
 	VM_BUG_ON(offset_in_page(start));
 	VM_BUG_ON(len != PAGE_ALIGN(len));
@@ -599,11 +600,11 @@ static int apply_vma_lock_flags(unsigned long start, size_t len,
 		return -EINVAL;
 	if (end == start)
 		return 0;
-	vma = find_vma(current->mm, start);
-	if (!vma || vma->vm_start > start)
+	vma = mas_walk(&mas);
+	if (!vma)
 		return -ENOMEM;
 
-	prev = vma->vm_prev;
+	prev = mas_prev(&mas, 0);
 	if (start > vma->vm_start)
 		prev = vma;
 
@@ -625,7 +626,7 @@ static int apply_vma_lock_flags(unsigned long start, size_t len,
 		if (nstart >= end)
 			break;
 
-		vma = prev->vm_next;
+		vma = vma_next(prev->vm_mm, prev);
 		if (!vma || vma->vm_start != nstart) {
 			error = -ENOMEM;
 			break;
@@ -646,15 +647,16 @@ static unsigned long count_mm_mlocked_page_nr(struct mm_struct *mm,
 {
 	struct vm_area_struct *vma;
 	unsigned long count = 0;
+	MA_STATE(mas, &mm->mm_mt, start, start);
 
 	if (mm == NULL)
 		mm = current->mm;
 
-	vma = find_vma(mm, start);
+	vma = mas_find(&mas, ULONG_MAX);
 	if (vma == NULL)
 		vma = mm->mmap;
 
-	for (; vma ; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		if (start >= vma->vm_end)
 			continue;
 		if (start + len <=  vma->vm_start)
@@ -787,7 +789,7 @@ static int apply_mlockall_flags(int flags)
 			to_add |= VM_LOCKONFAULT;
 	}
 
-	for (vma = current->mm->mmap; vma ; vma = prev->vm_next) {
+	for (vma = current->mm->mmap; vma ; vma = vma_next(vma->vm_mm, prev)) {
 		vm_flags_t newflags;
 
 		newflags = vma->vm_flags & VM_LOCKED_CLEAR_MASK;
diff --git a/mm/mmap.c b/mm/mmap.c
index 3f7479e888e6..0b4680ef19c5 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -464,24 +464,6 @@ static bool range_has_overlap(struct mm_struct *mm, unsigned long start,
 	return existing ? true : false;
 }
 
-/*
- * vma_next() - Get the next VMA.
- * @mm: The mm_struct.
- * @vma: The current vma.
- *
- * If @vma is NULL, return the first vma in the mm.
- *
- * Returns: The next VMA after @vma.
- */
-static inline struct vm_area_struct *vma_next(struct mm_struct *mm,
-					 struct vm_area_struct *vma)
-{
-	if (!vma)
-		return mm->mmap;
-
-	return vma->vm_next;
-}
-
 static unsigned long count_vma_pages_range(struct mm_struct *mm,
 		unsigned long addr, unsigned long end)
 {
@@ -1096,6 +1078,24 @@ can_vma_merge_after(struct vm_area_struct *vma, unsigned long vm_flags,
 	return 0;
 }
 
+/*
+ * vma_next_wrap() - Get the next VMA of the first.
+ * @mm: The mm_struct.
+ * @vma: The current vma.
+ *
+ * If @vma is NULL, return the first vma in the mm.
+ *
+ * Returns: The next VMA after @vma.
+ */
+static inline struct vm_area_struct *vma_next_wrap(struct mm_struct *mm,
+                                        struct vm_area_struct *vma)
+{
+       if (!vma)
+               return mm->mmap;
+
+       return vma_next(mm, vma);
+}
+
 /*
  * Given a mapping request (addr,end,vm_flags,file,pgoff), figure out
  * whether that can be merged with its predecessor or its successor.
@@ -1157,7 +1157,7 @@ struct vm_area_struct *vma_merge(struct mm_struct *mm,
 	if (vm_flags & VM_SPECIAL)
 		return NULL;
 
-	next = vma_next(mm, prev);
+	next = vma_next_wrap(mm, prev);
 	area = next;
 	if (area && area->vm_end == end)		/* cases 6, 7, 8 */
 		next = next->vm_next;
diff --git a/mm/mprotect.c b/mm/mprotect.c
index 56c02beb6041..5492833af73b 100644
--- a/mm/mprotect.c
+++ b/mm/mprotect.c
@@ -518,6 +518,7 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
 	const int grows = prot & (PROT_GROWSDOWN|PROT_GROWSUP);
 	const bool rier = (current->personality & READ_IMPLIES_EXEC) &&
 				(prot & PROT_READ);
+	MA_STATE(mas, &current->mm->mm_mt, start, start);
 
 	start = untagged_addr(start);
 
@@ -549,11 +550,11 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
 	if ((pkey != -1) && !mm_pkey_is_allocated(current->mm, pkey))
 		goto out;
 
-	vma = find_vma(current->mm, start);
+	vma = mas_walk(&mas);
 	error = -ENOMEM;
 	if (!vma)
 		goto out;
-	prev = vma->vm_prev;
+	prev = vma_prev(vma->vm_mm, vma);
 	if (unlikely(grows & PROT_GROWSDOWN)) {
 		if (vma->vm_start >= end)
 			goto out;
@@ -626,7 +627,7 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
 		if (nstart >= end)
 			goto out;
 
-		vma = prev->vm_next;
+		vma = mas_next(&mas, ULONG_MAX);
 		if (!vma || vma->vm_start != nstart) {
 			error = -ENOMEM;
 			goto out;
diff --git a/mm/mremap.c b/mm/mremap.c
index a7526a8c1fe5..3b6e7f032463 100644
--- a/mm/mremap.c
+++ b/mm/mremap.c
@@ -465,7 +465,7 @@ out:
 	if (excess) {
 		vma->vm_flags |= VM_ACCOUNT;
 		if (split)
-			vma->vm_next->vm_flags |= VM_ACCOUNT;
+			vma_next(mm, vma)->vm_flags |= VM_ACCOUNT;
 	}
 
 	return new_addr;
@@ -638,9 +638,11 @@ out:
 static int vma_expandable(struct vm_area_struct *vma, unsigned long delta)
 {
 	unsigned long end = vma->vm_end + delta;
+	struct vm_area_struct *next;
 	if (end < vma->vm_end) /* overflow */
 		return 0;
-	if (vma->vm_next && vma->vm_next->vm_start < end) /* intersection */
+	next = vma_next(vma->vm_mm, vma);
+	if (next && next->vm_start < end) /* intersection */
 		return 0;
 	if (get_unmapped_area(NULL, vma->vm_start, end - vma->vm_start,
 			      0, MAP_FIXED) & ~PAGE_MASK)
diff --git a/mm/msync.c b/mm/msync.c
index 69c6d2029531..8100ad5b12eb 100644
--- a/mm/msync.c
+++ b/mm/msync.c
@@ -100,7 +100,7 @@ SYSCALL_DEFINE3(msync, unsigned long, start, size_t, len, int, flags)
 				error = 0;
 				goto out_unlock;
 			}
-			vma = vma->vm_next;
+			vma = vma_next(mm, vma);
 		}
 	}
 out_unlock:
diff --git a/mm/nommu.c b/mm/nommu.c
index 0faf39b32cdb..4e5cc63728b6 100644
--- a/mm/nommu.c
+++ b/mm/nommu.c
@@ -677,6 +677,7 @@ static void delete_vma(struct mm_struct *mm, struct vm_area_struct *vma)
 struct vm_area_struct *find_vma(struct mm_struct *mm, unsigned long addr)
 {
 	struct vm_area_struct *vma;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	/* check the cache first */
 	vma = vmacache_find(mm, addr);
@@ -685,7 +686,7 @@ struct vm_area_struct *find_vma(struct mm_struct *mm, unsigned long addr)
 
 	/* trawl the list (there may be multiple mappings in which addr
 	 * resides) */
-	for (vma = mm->mmap; vma; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		if (vma->vm_start > addr)
 			return NULL;
 		if (vma->vm_end > addr) {
@@ -726,6 +727,7 @@ static struct vm_area_struct *find_vma_exact(struct mm_struct *mm,
 {
 	struct vm_area_struct *vma;
 	unsigned long end = addr + len;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	/* check the cache first */
 	vma = vmacache_find_exact(mm, addr, end);
@@ -734,7 +736,7 @@ static struct vm_area_struct *find_vma_exact(struct mm_struct *mm,
 
 	/* trawl the list (there may be multiple mappings in which addr
 	 * resides) */
-	for (vma = mm->mmap; vma; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		if (vma->vm_start < addr)
 			continue;
 		if (vma->vm_start > addr)
@@ -1485,7 +1487,7 @@ int do_munmap(struct mm_struct *mm, unsigned long start, size_t len, struct list
 				return -EINVAL;
 			if (end == vma->vm_end)
 				goto erase_whole_vma;
-			vma = vma->vm_next;
+			vma = vma_next(vma);
 		} while (vma);
 		return -EINVAL;
 	} else {
@@ -1543,7 +1545,7 @@ void exit_mmap(struct mm_struct *mm)
 	mm->total_vm = 0;
 
 	while ((vma = mm->mmap)) {
-		mm->mmap = vma->vm_next;
+		mm->mmap = vma_next(vma);
 		delete_vma_from_mm(vma);
 		delete_vma(mm, vma);
 		cond_resched();
diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index 8b84661a6410..a5e2045ec276 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -514,6 +514,7 @@ bool __oom_reap_task_mm(struct mm_struct *mm)
 {
 	struct vm_area_struct *vma;
 	bool ret = true;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	/*
 	 * Tell all users of get_user/copy_from_user etc... that the content
@@ -523,7 +524,7 @@ bool __oom_reap_task_mm(struct mm_struct *mm)
 	 */
 	set_bit(MMF_UNSTABLE, &mm->flags);
 
-	for (vma = mm->mmap ; vma; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		if (!can_madv_lru_vma(vma))
 			continue;
 
diff --git a/mm/pagewalk.c b/mm/pagewalk.c
index e81640d9f177..20bd8d14d042 100644
--- a/mm/pagewalk.c
+++ b/mm/pagewalk.c
@@ -408,7 +408,7 @@ int walk_page_range(struct mm_struct *mm, unsigned long start,
 		} else { /* inside vma */
 			walk.vma = vma;
 			next = min(end, vma->vm_end);
-			vma = vma->vm_next;
+			vma = vma_next(mm, vma);;
 
 			err = walk_page_test(start, next, &walk);
 			if (err > 0) {
diff --git a/mm/swapfile.c b/mm/swapfile.c
index d58361109066..10e1d60ff6f8 100644
--- a/mm/swapfile.c
+++ b/mm/swapfile.c
@@ -2104,9 +2104,10 @@ static int unuse_mm(struct mm_struct *mm, unsigned int type,
 {
 	struct vm_area_struct *vma;
 	int ret = 0;
+	MA_STATE(mas, &mm->mm_mt, 0, 0);
 
 	mmap_read_lock(mm);
-	for (vma = mm->mmap; vma; vma = vma->vm_next) {
+	mas_for_each(&mas, vma, ULONG_MAX) {
 		if (vma->anon_vma) {
 			ret = unuse_vma(vma, type, frontswap,
 					fs_pages_to_unuse);
-- 
2.49.0