]> www.infradead.org Git - users/jedix/linux-maple.git/commitdiff
mm: WIP linked list process
authorLiam R. Howlett <Liam.Howlett@Oracle.com>
Tue, 15 Dec 2020 18:57:11 +0000 (13:57 -0500)
committerLiam R. Howlett <Liam.Howlett@Oracle.com>
Tue, 5 Jan 2021 17:33:39 +0000 (12:33 -0500)
Signed-off-by: Liam R. Howlett <Liam.Howlett@Oracle.com>
include/linux/mm.h
mm/internal.h
mm/memory.c
mm/mmap.c

index 6a5e4fcbadec36714ac9954c30009ed7e1fbb147..68eb6204fa38bb3c4b1e0dca7789543209a6467d 100644 (file)
@@ -1648,6 +1648,8 @@ void zap_page_range(struct vm_area_struct *vma, unsigned long address,
                    unsigned long size);
 void unmap_vmas(struct mmu_gather *tlb, struct vm_area_struct *start_vma,
                unsigned long start, unsigned long end);
+void unmap_vmas_mt(struct mmu_gather *tlb, struct vm_area_struct *start_vma,
+               struct ma_state *mas, unsigned long start, unsigned long end);
 
 struct mmu_notifier_range;
 
index c43ccdddb0f6e92e712dfd2dafe144018e7a4440..4cfb9b23ddef8372515eefb9335d17e6ee8a32c3 100644 (file)
@@ -39,6 +39,10 @@ vm_fault_t do_swap_page(struct vm_fault *vmf);
 void free_pgtables(struct mmu_gather *tlb, struct vm_area_struct *start_vma,
                unsigned long floor, unsigned long ceiling);
 
+void free_mt_pgtables(struct mmu_gather *tlb, struct vm_area_struct *start_vma,
+                     struct ma_state *mas,
+               unsigned long floor, unsigned long ceiling);
+
 static inline bool can_madv_lru_vma(struct vm_area_struct *vma)
 {
        return !(vma->vm_flags & (VM_LOCKED|VM_HUGETLB|VM_PFNMAP));
index 3b0fe38f967da7832aa9affaa1ce7e907351c270..4316611a6681018f2b054fdac4f8ebc975b60312 100644 (file)
@@ -387,6 +387,43 @@ void free_pgd_range(struct mmu_gather *tlb,
        } while (pgd++, addr = next, addr != end);
 }
 
+void free_mt_pgtables(struct mmu_gather *tlb, struct vm_area_struct *vma,
+       struct ma_state *mas, unsigned long floor, unsigned long ceiling)
+{
+       struct ma_state ma_next = *mas;
+
+       mas_find(&ma_next, ceiling - 1);
+       mas_for_each(mas, vma, ceiling - 1) {
+               struct vm_area_struct *next = mas_find(&ma_next, ceiling - 1);
+               unsigned long addr = vma->vm_start;
+
+               /*
+                * Hide vma from rmap and truncate_pagecache before freeing
+                * pgtables
+                */
+               unlink_anon_vmas(vma);
+               unlink_file_vma(vma);
+
+               if (is_vm_hugetlb_page(vma)) {
+                       hugetlb_free_pgd_range(tlb, addr, vma->vm_end,
+                               floor, next ? next->vm_start : ceiling);
+               } else {
+                       /*
+                        * Optimization: gather nearby vmas into one call down
+                        */
+                       while (next && next->vm_start <= vma->vm_end + PMD_SIZE
+                              && !is_vm_hugetlb_page(next)) {
+                               next = mas_find(&ma_next, ceiling - 1);
+                               vma = mas_find(mas, ceiling - 1);
+                               unlink_anon_vmas(vma);
+                               unlink_file_vma(vma);
+                       }
+                       free_pgd_range(tlb, addr, vma->vm_end,
+                               floor, next ? next->vm_start : ceiling);
+               }
+       }
+}
+
 void free_pgtables(struct mmu_gather *tlb, struct vm_area_struct *vma,
                unsigned long floor, unsigned long ceiling)
 {
@@ -1474,6 +1511,19 @@ static void unmap_single_vma(struct mmu_gather *tlb,
        }
 }
 
+void unmap_vmas_mt(struct mmu_gather *tlb,
+               struct vm_area_struct *vma, struct ma_state *mas,
+               unsigned long start_addr, unsigned long end_addr)
+{
+       struct mmu_notifier_range range;
+
+       mmu_notifier_range_init(&range, MMU_NOTIFY_UNMAP, 0, vma, vma->vm_mm,
+                               start_addr, end_addr);
+       mmu_notifier_invalidate_range_start(&range);
+       mas_for_each(mas, vma, end_addr - 1)
+               unmap_single_vma(tlb, vma, start_addr, end_addr, NULL);
+       mmu_notifier_invalidate_range_end(&range);
+}
 /**
  * unmap_vmas - unmap a range of memory covered by a list of vma's
  * @tlb: address of the caller's struct mmu_gather
index 0b4680ef19c57f7e1be4dab1130a3a6db9457f37..b17ad5b6581864d9e597f44fa140a6de5b6c70fe 100644 (file)
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -73,10 +73,6 @@ int mmap_rnd_compat_bits __read_mostly = CONFIG_ARCH_MMAP_RND_COMPAT_BITS;
 static bool ignore_rlimit_data;
 core_param(ignore_rlimit_data, ignore_rlimit_data, bool, 0644);
 
-static void unmap_region(struct mm_struct *mm,
-               struct vm_area_struct *vma, struct vm_area_struct *prev,
-               unsigned long start, unsigned long end);
-
 /* description of effects of mapping type and prot in current implementation.
  * this is due to the limited x86 page protection hardware.  The expected
  * behavior is in parens:
@@ -168,10 +164,8 @@ void unlink_file_vma(struct vm_area_struct *vma)
 /*
  * Close a vm structure and free it, returning the next.
  */
-static struct vm_area_struct *remove_vma(struct vm_area_struct *vma)
+static void remove_vma(struct vm_area_struct *vma)
 {
-       struct vm_area_struct *next = vma->vm_next;
-
        might_sleep();
        if (vma->vm_ops && vma->vm_ops->close)
                vma->vm_ops->close(vma);
@@ -179,12 +173,11 @@ static struct vm_area_struct *remove_vma(struct vm_area_struct *vma)
                fput(vma->vm_file);
        mpol_put(vma_policy(vma));
        vm_area_free(vma);
-       return next;
 }
 
 static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
                         unsigned long newbrk, unsigned long oldbrk,
-                        struct list_head *uf);
+                        struct list_head *uf, unsigned long max);
 static int do_brk_flags(struct ma_state *mas, struct vm_area_struct **brkvma,
                        unsigned long addr, unsigned long request,
                        unsigned long flags);
@@ -198,6 +191,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
        bool downgraded = false;
        LIST_HEAD(uf);
        MA_STATE(mas, &mm->mm_mt, 0, 0);
+       struct ma_state ma_next;
 
        if (mmap_write_lock_killable(mm))
                return -EINTR;
@@ -239,6 +233,8 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
 
        mas_set(&mas, newbrk);
        brkvma = mas_walk(&mas);
+       ma_next = mas;
+       next = mas_next(&mas, newbrk + PAGE_SIZE + stack_guard_gap);
        if (brkvma) { // munmap necessary, there is something at newbrk.
                /*
                 * Always allow shrinking brk.
@@ -255,7 +251,8 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
                 */
                mm->brk = brk;
                mas.last = oldbrk - 1;
-               ret = do_brk_munmap(&mas, brkvma, newbrk, oldbrk, &uf);
+               ret = do_brk_munmap(&mas, brkvma, newbrk, oldbrk, &uf,
+                           next ? next->vm_start : USER_PGTABLES_CEILING);
                if (ret == 1)  {
                        downgraded = true;
                        goto success;
@@ -267,7 +264,6 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
        }
        /* Only check if the next VMA is within the stack_guard_gap of the
         * expansion area */
-       next = mas_next(&mas, newbrk + PAGE_SIZE + stack_guard_gap);
        /* Check against existing mmap mappings. */
        if (next && newbrk + PAGE_SIZE > vm_start_gap(next))
                goto out;
@@ -303,83 +299,16 @@ extern void mt_validate(struct maple_tree *mt);
 extern void mt_dump(const struct maple_tree *mt);
 
 /* Validate the maple tree */
-static void validate_mm_mt(struct mm_struct *mm)
-{
-       struct maple_tree *mt = &mm->mm_mt;
-       struct vm_area_struct *vma_mt, *vma = mm->mmap;
-
-       MA_STATE(mas, mt, 0, 0);
-       rcu_read_lock();
-       mas_for_each(&mas, vma_mt, ULONG_MAX) {
-               if (xa_is_zero(vma_mt))
-                       continue;
-
-               if (!vma)
-                       break;
-
-               if ((vma != vma_mt) ||
-                   (vma->vm_start != vma_mt->vm_start) ||
-                   (vma->vm_end != vma_mt->vm_end) ||
-                   (vma->vm_start != mas.index) ||
-                   (vma->vm_end - 1 != mas.last)) {
-                       pr_emerg("issue in %s\n", current->comm);
-                       dump_stack();
-#ifdef CONFIG_DEBUG_VM
-                       dump_vma(vma_mt);
-                       pr_emerg("and vm_next\n");
-                       dump_vma(vma->vm_next);
-#endif // CONFIG_DEBUG_VM
-                       pr_emerg("mt piv: %px %lu - %lu\n", vma_mt,
-                                mas.index, mas.last);
-                       pr_emerg("mt vma: %px %lu - %lu\n", vma_mt,
-                                vma_mt->vm_start, vma_mt->vm_end);
-                       if (vma->vm_prev) {
-                               pr_emerg("ll prev: %px %lu - %lu\n",
-                                        vma->vm_prev, vma->vm_prev->vm_start,
-                                        vma->vm_prev->vm_end);
-                       }
-                       pr_emerg("ll vma: %px %lu - %lu\n", vma,
-                                vma->vm_start, vma->vm_end);
-                       if (vma->vm_next) {
-                               pr_emerg("ll next: %px %lu - %lu\n",
-                                        vma->vm_next, vma->vm_next->vm_start,
-                                        vma->vm_next->vm_end);
-                       }
-
-                       mt_dump(mas.tree);
-                       if (vma_mt->vm_end != mas.last + 1) {
-                               pr_err("vma: %px vma_mt %lu-%lu\tmt %lu-%lu\n",
-                                               mm, vma_mt->vm_start, vma_mt->vm_end,
-                                               mas.index, mas.last);
-                               mt_dump(mas.tree);
-                       }
-                       VM_BUG_ON_MM(vma_mt->vm_end != mas.last + 1, mm);
-                       if (vma_mt->vm_start != mas.index) {
-                               pr_err("vma: %px vma_mt %px %lu - %lu doesn't match\n",
-                                               mm, vma_mt, vma_mt->vm_start, vma_mt->vm_end);
-                               mt_dump(mas.tree);
-                       }
-                       VM_BUG_ON_MM(vma_mt->vm_start != mas.index, mm);
-               }
-               VM_BUG_ON(vma != vma_mt);
-               vma = vma->vm_next;
-
-       }
-       VM_BUG_ON(vma);
-
-       rcu_read_unlock();
-       mt_validate(&mm->mm_mt);
-}
 static void validate_mm(struct mm_struct *mm)
 {
        int bug = 0;
        int i = 0;
        unsigned long highest_address = 0;
-       struct vm_area_struct *vma = mm->mmap;
+       struct vm_area_struct *vma;
+       MA_STATE(mas, &mm->mm_mt, 0, 0);
 
-       validate_mm_mt(mm);
 
-       while (vma) {
+       mas_for_each(mas, vma, ULONG_MAX) {
 #ifdef CONFIG_DEBUG_VM_RB
                struct anon_vma *anon_vma = vma->anon_vma;
                struct anon_vma_chain *avc;
@@ -391,11 +320,10 @@ static void validate_mm(struct mm_struct *mm)
                }
 #endif
                highest_address = vm_end_gap(vma);
-               vma = vma->vm_next;
                i++;
        }
        if (i != mm->map_count) {
-               pr_emerg("map_count %d vm_next %d\n", mm->map_count, i);
+               pr_emerg("map_count %d mas_for_each %d\n", mm->map_count, i);
                bug = 1;
        }
        if (highest_address != mm->highest_vm_end) {
@@ -406,7 +334,6 @@ static void validate_mm(struct mm_struct *mm)
        VM_BUG_ON_MM(bug, mm);
 }
 #else // !CONFIG_DEBUG_MAPLE_TREE
-#define validate_mm_mt(root) do { } while (0)
 #define validate_mm(mm) do { } while (0)
 #endif // CONFIG_DEBUG_MAPLE_TREE
 
@@ -724,7 +651,7 @@ int __vma_adjust(struct vm_area_struct *vma, unsigned long start,
        struct vm_area_struct *expand)
 {
        struct mm_struct *mm = vma->vm_mm;
-       struct vm_area_struct *next = vma->vm_next, *orig_vma = vma;
+       struct vm_area_struct *next = vma_next(mm, vma), *orig_vma = vma;
        struct address_space *mapping = NULL;
        struct rb_root_cached *root = NULL;
        struct anon_vma *anon_vma = NULL;
@@ -765,7 +692,7 @@ int __vma_adjust(struct vm_area_struct *vma, unsigned long start,
                                 */
                                remove_next = 1 + (end > next->vm_end);
                                VM_WARN_ON(remove_next == 2 &&
-                                          end != next->vm_next->vm_end);
+                                          end != vma_next(mm, next)->vm_end);
                                /* trim end to next, for case 6 first pass */
                                end = next->vm_end;
                        }
@@ -778,7 +705,7 @@ int __vma_adjust(struct vm_area_struct *vma, unsigned long start,
                         * next, if the vma overlaps with it.
                         */
                        if (remove_next == 2 && !next->anon_vma)
-                               exporter = next->vm_next;
+                               exporter = vma_next(mm, next);
 
                } else if (end > next->vm_start) {
                        /*
@@ -941,7 +868,7 @@ again:
                         * "next->vm_prev->vm_end" changed and the
                         * "vma->vm_next" gap must be updated.
                         */
-                       next = vma->vm_next;
+                       next = vma_next(mm, vma);
                } else {
                        /*
                         * For the scope of the comment "next" and
@@ -1160,7 +1087,7 @@ struct vm_area_struct *vma_merge(struct mm_struct *mm,
        next = vma_next_wrap(mm, prev);
        area = next;
        if (area && area->vm_end == end)                /* cases 6, 7, 8 */
-               next = next->vm_next;
+               next = vma_next(mm, next);
 
        /* verify some invariant that must be enforced by the caller */
        VM_WARN_ON(prev && addr <= prev->vm_start);
@@ -1295,17 +1222,20 @@ static struct anon_vma *reusable_anon_vma(struct vm_area_struct *old, struct vm_
 struct anon_vma *find_mergeable_anon_vma(struct vm_area_struct *vma)
 {
        struct anon_vma *anon_vma = NULL;
+       struct vm_area_struct *next, *prev;
 
        /* Try next first. */
-       if (vma->vm_next) {
-               anon_vma = reusable_anon_vma(vma->vm_next, vma, vma->vm_next);
+       next = vma_next(vma->vm_mm, vma);
+       if (next) {
+               anon_vma = reusable_anon_vma(next, vma, next);
                if (anon_vma)
                        return anon_vma;
        }
 
        /* Try prev next. */
-       if (vma->vm_prev)
-               anon_vma = reusable_anon_vma(vma->vm_prev, vma->vm_prev, vma);
+       prev = vma_prev(vma->vm_mm, vma);
+       if (prev)
+               anon_vma = reusable_anon_vma(prev, prev, vma);
 
        /*
         * We might reach here with anon_vma == NULL if we can't find
@@ -2076,7 +2006,7 @@ int expand_upwards(struct vm_area_struct *vma, unsigned long address)
        if (gap_addr < address || gap_addr > TASK_SIZE)
                gap_addr = TASK_SIZE;
 
-       next = vma->vm_next;
+       next = vma_next(mm, vma);
        if (next && next->vm_start < gap_addr && vma_is_accessible(next)) {
                if (!(next->vm_flags & VM_GROWSUP))
                        return -ENOMEM;
@@ -2122,7 +2052,7 @@ int expand_upwards(struct vm_area_struct *vma, unsigned long address)
                                vma->vm_end = address;
                                vma_mt_store(mm, vma);
                                anon_vma_interval_tree_post_update_vma(vma);
-                               if (!vma->vm_next)
+                               if (!vma_next(mm, vma))
                                        mm->highest_vm_end = vm_end_gap(vma);
                                spin_unlock(&mm->page_table_lock);
 
@@ -2288,20 +2218,20 @@ EXPORT_SYMBOL_GPL(find_extend_vma);
  *
  * Called with the mm semaphore held.
  */
-static inline void remove_vma_list(struct mm_struct *mm,
-                                  struct vm_area_struct *vma)
+static inline void remove_mt(struct mm_struct *mm, struct ma_state *mas)
 {
+       struct vm_area_struct *vma;
        unsigned long nr_accounted = 0;
 
        /* Update high watermark before we lower total_vm */
        update_hiwater_vm(mm);
-       do {
+       mas_for_each(mas, vma, ULONG_MAX) {
                long nrpages = vma_pages(vma);
 
                if (vma->vm_flags & VM_ACCOUNT)
                        nr_accounted += nrpages;
                vm_stat_account(mm, vma->vm_flags, -nrpages);
-               vma = remove_vma(vma);
+               remove_vma(vma);
        } while (vma);
        vm_unacct_memory(nr_accounted);
        validate_mm(mm);
@@ -2313,21 +2243,19 @@ static inline void remove_vma_list(struct mm_struct *mm,
  * Called with the mm semaphore held.
  */
 static void unmap_region(struct mm_struct *mm,
-               struct vm_area_struct *vma, struct vm_area_struct *prev,
-               unsigned long start, unsigned long end)
+                    struct vm_area_struct *vma, struct ma_state *mas,
+                    unsigned long start, unsigned long end,
+                    struct vm_area_struct *prev, unsigned long max)
 {
-       struct vm_area_struct *next = vma_next(mm, prev);
        struct mmu_gather tlb;
 
        lru_add_drain();
        tlb_gather_mmu(&tlb, mm, start, end);
        update_hiwater_rss(mm);
-       unmap_vmas(&tlb, vma, start, end);
-       free_pgtables(&tlb, vma, prev ? prev->vm_end : FIRST_USER_ADDRESS,
-                                next ? next->vm_start : USER_PGTABLES_CEILING);
+       unmap_vmas_mt(&tlb, vma, mas, start, end);
+       free_pgtables(&tlb, vma, prev ? prev->vm_end : FIRST_USER_ADDRESS, max);
        tlb_finish_mmu(&tlb, start, end);
 }
-
 /*
  * __split_vma() bypasses sysctl_max_map_count checking.  We use this where it
  * has already been checked or doesn't make sense to fail.
@@ -2337,7 +2265,6 @@ int __split_vma(struct mm_struct *mm, struct vm_area_struct *vma,
 {
        struct vm_area_struct *new;
        int err;
-       validate_mm_mt(mm);
 
        if (vma->vm_ops && vma->vm_ops->split) {
                err = vma->vm_ops->split(vma, addr);
@@ -2390,7 +2317,6 @@ int __split_vma(struct mm_struct *mm, struct vm_area_struct *vma,
        mpol_put(vma_policy(new));
  out_free_vma:
        vm_area_free(new);
-       validate_mm_mt(mm);
        return err;
 }
 
@@ -2407,25 +2333,44 @@ int split_vma(struct mm_struct *mm, struct vm_area_struct *vma,
        return __split_vma(mm, vma, addr, new_below);
 }
 
-static inline int unlock_range(struct vm_area_struct *start,
-                              struct vm_area_struct **tail, unsigned long limit)
+static inline unsigned long detach_range(struct mm_struct *mm,
+                struct ma_state *src, struct ma_state *dst,
+                struct vm_area_struct *prev, struct vm_area_struct **last)
 {
-       struct mm_struct *mm = start->vm_mm;
-       struct vm_area_struct *tmp = start;
+       struct vm_area_struct *tmp;
        int count = 0;
+       struct ma_state mas;
 
-       while (tmp && tmp->vm_start < limit) {
-               *tail = tmp;
+       /*
+        * unlock any mlock()ed ranges before detaching vmas, count the number
+        * of VMAs to be dropped, and return the tail entry of the affected
+        * area.
+        */
+       mas = *src;
+       mas_set(&mas, src->index);
+       mas_for_each(&mas, tmp, src->last) {
+               *last = tmp;
                count++;
                if (tmp->vm_flags & VM_LOCKED) {
                        mm->locked_vm -= vma_pages(tmp);
                        munlock_vma_pages_all(tmp);
                }
+               vma_mas_store(tmp, dst);
+       }
 
-               tmp = tmp->vm_next;
+       /* Decrement map_count */
+       mm->map_count -= count;
+       /* Find the one after the series before overwrite */
+       tmp = mas_find(&mas, ULONG_MAX);
+       /* Drop removed area from the tree */
+       mas_store_gfp(src, NULL, GFP_KERNEL);
+       /* Set the upper limit */
+       if (!tmp) {
+               mm->highest_vm_end = prev ? vm_end_gap(prev) : 0;
+               return mm->highest_vm_end;
        }
 
-       return count;
+       return tmp->vm_start;
 }
 
 /* do_mas_align_munmap() - munmap the aligned region from @start to @end.
@@ -2445,6 +2390,9 @@ int do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma,
                        unsigned long end, struct list_head *uf, bool downgrade)
 {
        struct vm_area_struct *prev, *last;
+       struct maple_tree mt_detach = MTREE_INIT(mt_detach, MAPLE_ALLOC_RANGE);
+       unsigned long max;
+       MA_STATE(dst, &mt_detach, start, start);
        /* we have start < vma->vm_end  */
 
        /*
@@ -2506,15 +2454,10 @@ int do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma,
                        return error;
        }
 
-       /*
-        * unlock any mlock()ed ranges before detaching vmas, count the number
-        * of VMAs to be dropped, and return the tail entry of the affected
-        * area.
-        */
-       mm->map_count -= unlock_range(vma, &last, end);
-       /* Drop removed area from the tree */
-       mas_store_gfp(mas, NULL, GFP_KERNEL);
+       /* Point of no return */
+       max = detach_range(mm, mas, &dst, prev, &last);
 
+#if 1
        /* Detach vmas from the MM linked list */
        vma->vm_prev = NULL;
        if (prev)
@@ -2525,9 +2468,8 @@ int do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma,
        if (last->vm_next) {
                last->vm_next->vm_prev = prev;
                last->vm_next = NULL;
-       } else
-               mm->highest_vm_end = prev ? vm_end_gap(prev) : 0;
-
+       }
+#endif
        /*
         * Do not downgrade mmap_lock if we are next to VM_GROWSDOWN or
         * VM_GROWSUP VMA. Such VMAs can change their size under
@@ -2542,10 +2484,14 @@ int do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma,
                        mmap_write_downgrade(mm);
        }
 
-       unmap_region(mm, vma, prev, start, end);
+       mas_reset(&dst);
+       mas_set(&dst, start);
+       unmap_region(mm, vma, &dst, start, end, prev, max);
 
        /* Fix up all other VM information */
-       remove_vma_list(mm, vma);
+       mas_reset(&dst);
+       mas_set(&dst, start);
+       remove_mt(mm, &dst);
 
        return downgrade ? 1 : 0;
 }
@@ -2613,6 +2559,7 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
        unsigned long charged = 0;
        unsigned long end = addr + len;
        unsigned long merge_start = addr, merge_end = end;
+       unsigned long max = USER_PGTABLES_CEILING;
        pgoff_t vm_pgoff;
        int error;
        MA_STATE(mas, &mm->mm_mt, addr, end - 1);
@@ -2656,12 +2603,16 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
 
        /* Check next */
        next = mas_next(&mas, ULONG_MAX);
-       if (next && next->vm_start == end && vma_policy(next) &&
-           can_vma_merge_before(next, vm_flags, NULL, file, pgoff + pglen,
-                                NULL_VM_UFFD_CTX)) {
-               merge_end = next->vm_end;
-               vma = next;
-               vm_pgoff = next->vm_pgoff - pglen;
+       if (next) {
+               max = next->vm_start;
+
+               if (next->vm_start == end && vma_policy(next) &&
+                   can_vma_merge_before(next, vm_flags, NULL, file,
+                                        pgoff + pglen, NULL_VM_UFFD_CTX)) {
+                       merge_end = next->vm_end;
+                       vma = next;
+                       vm_pgoff = next->vm_pgoff - pglen;
+               }
        }
 
        /* Check prev */
@@ -2827,8 +2778,10 @@ unmap_and_free_vma:
        vma->vm_file = NULL;
        fput(file);
 
+       mas.index = mas.last = addr;
+       mas_walk(&mas);
        /* Undo any partial mapping done by a device driver. */
-       unmap_region(mm, vma, prev, vma->vm_start, vma->vm_end);
+       unmap_region(mm, vma, &mas, vma->vm_start, vma->vm_end, prev, max);
        charged = 0;
        if (vm_flags & VM_SHARED)
                mapping_unmap_writable(file->f_mapping);
@@ -2894,15 +2847,17 @@ SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size,
        unsigned long populate = 0;
        unsigned long ret = -EINVAL;
        struct file *file;
+       struct ma_state ma_lock;
+       MA_STATE(mas, &mm->mm_mt, start, start);
 
        pr_warn_once("%s (%d) uses deprecated remap_file_pages() syscall. See Documentation/vm/remap_file_pages.rst.\n",
                     current->comm, current->pid);
 
        if (prot)
                return ret;
+
        start = start & PAGE_MASK;
        size = size & PAGE_MASK;
-
        if (start + size <= start)
                return ret;
 
@@ -2913,20 +2868,23 @@ SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size,
        if (mmap_write_lock_killable(mm))
                return -EINTR;
 
-       vma = find_vma(mm, start);
+       mas_set(&mas, start);
+       vma = mas_walk(&mas);
+       ma_lock = mas;
 
        if (!vma || !(vma->vm_flags & VM_SHARED))
                goto out;
 
-       if (start < vma->vm_start)
+       if (!vma->vm_file)
                goto out;
 
        if (start + size > vma->vm_end) {
-               struct vm_area_struct *next;
+               struct vm_area_struct *prev, *next;
 
-               for (next = vma->vm_next; next; next = next->vm_next) {
+               prev = vma;
+               mas_for_each(&mas, next, start + size) {
                        /* hole between vmas ? */
-                       if (next->vm_start != next->vm_prev->vm_end)
+                       if (next->vm_start != prev->vm_end)
                                goto out;
 
                        if (next->vm_file != vma->vm_file)
@@ -2937,6 +2895,8 @@ SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size,
 
                        if (start + size <= next->vm_end)
                                break;
+
+                       prev = next;
                }
 
                if (!next)
@@ -2949,24 +2909,6 @@ SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size,
 
        flags &= MAP_NONBLOCK;
        flags |= MAP_SHARED | MAP_FIXED | MAP_POPULATE;
-       if (vma->vm_flags & VM_LOCKED) {
-               struct vm_area_struct *tmp;
-               flags |= MAP_LOCKED;
-
-               /* drop PG_Mlocked flag for over-mapped range */
-               for (tmp = vma; tmp->vm_start >= start + size;
-                               tmp = tmp->vm_next) {
-                       /*
-                        * Split pmd and munlock page on the border
-                        * of the range.
-                        */
-                       vma_adjust_trans_huge(tmp, start, start + size, 0);
-
-                       munlock_vma_pages_range(tmp,
-                                       max(tmp->vm_start, start),
-                                       min(tmp->vm_end, start + size));
-               }
-       }
 
        file = get_file(vma->vm_file);
        ret = do_mmap(vma->vm_file, start, size,
@@ -2994,7 +2936,7 @@ out:
  */
 static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
                         unsigned long newbrk, unsigned long oldbrk,
-                        struct list_head *uf)
+                        struct list_head *uf, unsigned long max)
 {
        struct mm_struct *mm = vma->vm_mm;
        struct vm_area_struct unmap;
@@ -3039,14 +2981,13 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
        }
 
        mmap_write_downgrade(mm);
-       unmap_region(mm, &unmap, vma, newbrk, oldbrk);
+       unmap_region(mm, &unmap, mas, newbrk, oldbrk, vma, max);
        /* Statistics */
        vm_stat_account(mm, unmap.vm_flags, -unmap_pages);
        if (unmap.vm_flags & VM_ACCOUNT)
                vm_unacct_memory(unmap_pages);
 
 munmap_full_vma:
-       validate_mm_mt(mm);
        return ret;
 
 mas_store_fail:
@@ -3076,7 +3017,6 @@ static int do_brk_flags(struct ma_state *mas, struct vm_area_struct **brkvma,
        struct vm_area_struct *prev = NULL, *vma;
        int error;
        unsigned long mapped_addr;
-       validate_mm_mt(mm);
 
        /* Until we need other flags, refuse anything except VM_EXEC. */
        if ((flags & (~VM_EXEC)) != 0)
@@ -3160,7 +3100,6 @@ out:
        if (flags & VM_LOCKED)
                mm->locked_vm += (len >> PAGE_SHIFT);
        vma->vm_flags |= VM_SOFTDIRTY;
-       validate_mm_mt(mm);
        return 0;
 
 mas_store_fail:
@@ -3218,6 +3157,7 @@ void exit_mmap(struct mm_struct *mm)
        struct mmu_gather tlb;
        struct vm_area_struct *vma;
        unsigned long nr_accounted = 0;
+       MA_STATE(mas, &mm->mm_mt, 0, 0);
 
        /* mm's last user has gone, and its about to be pulled down */
        mmu_notifier_release(mm);
@@ -3246,8 +3186,16 @@ void exit_mmap(struct mm_struct *mm)
                mmap_write_unlock(mm);
        }
 
-       if (mm->locked_vm)
-               unlock_range(mm->mmap, &vma, ULONG_MAX);
+       if (mm->locked_vm) {
+               mas_for_each(&mas, vma, ULONG_MAX) {
+                       if (vma->vm_flags & VM_LOCKED) {
+                               mm->locked_vm -= vma_pages(vma);
+                               munlock_vma_pages_all(vma);
+                       }
+               }
+               mas_reset(&mas);
+               mas_set(&mas, 0);
+       }
 
        arch_exit_mmap(mm);
 
@@ -3268,10 +3216,12 @@ void exit_mmap(struct mm_struct *mm)
         * Walk the list again, actually closing and freeing it,
         * with preemption enabled, without holding any MM locks.
         */
-       while (vma) {
+       mas_reset(&mas);
+       mas_set(&mas, 0);
+       mas_for_each(&mas, vma, ULONG_MAX) {
                if (vma->vm_flags & VM_ACCOUNT)
                        nr_accounted += vma_pages(vma);
-               vma = remove_vma(vma);
+               remove_vma(vma);
                cond_resched();
        }
 
@@ -3330,7 +3280,6 @@ struct vm_area_struct *copy_vma(struct vm_area_struct **vmap,
        struct vm_area_struct *new_vma, *prev;
        bool faulted_in_anon_vma = true;
 
-       validate_mm_mt(mm);
        /*
         * If anonymous vma has not yet been faulted, update new pgoff
         * to match new location, to increase its chance of merging.
@@ -3386,7 +3335,6 @@ struct vm_area_struct *copy_vma(struct vm_area_struct **vmap,
                vma_link(mm, new_vma, prev);
                *need_rmap_locks = false;
        }
-       validate_mm_mt(mm);
        return new_vma;
 
 out_free_mempol:
@@ -3394,7 +3342,6 @@ out_free_mempol:
 out_free_vma:
        vm_area_free(new_vma);
 out:
-       validate_mm_mt(mm);
        return NULL;
 }
 
@@ -3519,7 +3466,6 @@ static struct vm_area_struct *__install_special_mapping(
        int ret;
        struct vm_area_struct *vma;
 
-       validate_mm_mt(mm);
        vma = vm_area_alloc(mm);
        if (unlikely(vma == NULL))
                return ERR_PTR(-ENOMEM);
@@ -3541,12 +3487,10 @@ static struct vm_area_struct *__install_special_mapping(
 
        perf_event_mmap(vma);
 
-       validate_mm_mt(mm);
        return vma;
 
 out:
        vm_area_free(vma);
-       validate_mm_mt(mm);
        return ERR_PTR(ret);
 }
 
@@ -3671,12 +3615,13 @@ int mm_take_all_locks(struct mm_struct *mm)
 {
        struct vm_area_struct *vma;
        struct anon_vma_chain *avc;
+       MA_STATE(mas, &mm->mm_mt, 0, 0);
 
        BUG_ON(mmap_read_trylock(mm));
 
        mutex_lock(&mm_all_locks_mutex);
 
-       for (vma = mm->mmap; vma; vma = vma->vm_next) {
+       mas_for_each(&mas, vma, ULONG_MAX) {
                if (signal_pending(current))
                        goto out_unlock;
                if (vma->vm_file && vma->vm_file->f_mapping &&
@@ -3684,7 +3629,8 @@ int mm_take_all_locks(struct mm_struct *mm)
                        vm_lock_mapping(mm, vma->vm_file->f_mapping);
        }
 
-       for (vma = mm->mmap; vma; vma = vma->vm_next) {
+       mas_reset(&mas);
+       mas_for_each(&mas, vma, ULONG_MAX) {
                if (signal_pending(current))
                        goto out_unlock;
                if (vma->vm_file && vma->vm_file->f_mapping &&
@@ -3692,7 +3638,8 @@ int mm_take_all_locks(struct mm_struct *mm)
                        vm_lock_mapping(mm, vma->vm_file->f_mapping);
        }
 
-       for (vma = mm->mmap; vma; vma = vma->vm_next) {
+       mas_reset(&mas);
+       mas_for_each(&mas, vma, ULONG_MAX) {
                if (signal_pending(current))
                        goto out_unlock;
                if (vma->anon_vma)
@@ -3751,11 +3698,12 @@ void mm_drop_all_locks(struct mm_struct *mm)
 {
        struct vm_area_struct *vma;
        struct anon_vma_chain *avc;
+       MA_STATE(mas, &mm->mm_mt, 0, 0);
 
        BUG_ON(mmap_read_trylock(mm));
        BUG_ON(!mutex_is_locked(&mm_all_locks_mutex));
 
-       for (vma = mm->mmap; vma; vma = vma->vm_next) {
+       mas_for_each(&mas, vma, ULONG_MAX) {
                if (vma->anon_vma)
                        list_for_each_entry(avc, &vma->anon_vma_chain, same_vma)
                                vm_unlock_anon_vma(avc->anon_vma);