mm: Replace vma on split_vma() calls. maple_v5.18-rc2_lowmem
authorLiam R. Howlett <Liam.Howlett@oracle.com>
Tue, 12 Apr 2022 15:12:06 +0000 (11:12 -0400)
committerLiam R. Howlett <Liam.Howlett@oracle.com>
Thu, 28 Apr 2022 16:12:27 +0000 (12:12 -0400)
When splitting a VMA, create two new VMAs to replace both parts of the
VMA.  Change the callers to pass in a pointer and update the pointer to
the new VMA based on the value of new_below.

do_mas_align_munmap() needed to update a local variable in the case of
splitting the end and only having one VMA to split.

mprotect_fixup() needed to change where it set the previous pointer to
after the split to avoid a use-after-free scenario.

Signed-off-by: Liam R. Howlett <Liam.Howlett@oracle.com>
fs/userfaultfd.c
include/linux/mm.h
mm/madvise.c
mm/mempolicy.c
mm/mlock.c
mm/mmap.c
mm/mprotect.c

index af29e5885ed2fe274a933cec163c7cb6b2a7736f..0aa9435cac91f730e575b0882fb365bcca0ba23c 100644 (file)
@@ -1458,14 +1458,14 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
                        goto next;
                }
                if (vma->vm_start < start) {
-                       ret = split_vma(mm, vma, start, 1);
+                       ret = split_vma(mm, &vma, start, 1);
                        if (ret)
                                break;
                        /* split_vma() invalidated the mas */
                        mas_pause(&mas);
                }
                if (vma->vm_end > end) {
-                       ret = split_vma(mm, vma, end, 0);
+                       ret = split_vma(mm, &vma, end, 0);
                        if (ret)
                                break;
                        /* split_vma() invalidated the mas */
@@ -1643,12 +1643,12 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
                        goto next;
                }
                if (vma->vm_start < start) {
-                       ret = split_vma(mm, vma, start, 1);
+                       ret = split_vma(mm, &vma, start, 1);
                        if (ret)
                                break;
                }
                if (vma->vm_end > end) {
-                       ret = split_vma(mm, vma, end, 0);
+                       ret = split_vma(mm, &vma, end, 0);
                        if (ret)
                                break;
                }
index 3200c954c385c47abdab51849a8a0d98c2d09bda..b28a2aaf410d9998ce624d2589d9e0b55f0182ff 100644 (file)
@@ -2661,9 +2661,9 @@ extern struct vm_area_struct *vma_merge(struct mm_struct *,
        unsigned long vm_flags, struct anon_vma *, struct file *, pgoff_t,
        struct mempolicy *, struct vm_userfaultfd_ctx, struct anon_vma_name *);
 extern struct anon_vma *find_mergeable_anon_vma(struct vm_area_struct *);
-extern int __split_vma(struct mm_struct *, struct vm_area_struct *,
+extern int __split_vma(struct mm_struct *, struct vm_area_struct **,
        unsigned long addr, int new_below);
-extern int split_vma(struct mm_struct *, struct vm_area_struct *,
+extern int split_vma(struct mm_struct *, struct vm_area_struct **,
        unsigned long addr, int new_below);
 extern int insert_vm_struct(struct mm_struct *, struct vm_area_struct *);
 extern void unlink_file_vma(struct vm_area_struct *);
index 1c943a7ddcb29fc1e8622e92d1658b5eee9eedfc..3e0b69071694abb32212ca725e98c01b310eb9e9 100644 (file)
@@ -164,7 +164,7 @@ static int madvise_update_vma(struct vm_area_struct *vma,
        if (start != vma->vm_start) {
                if (unlikely(mm->map_count >= sysctl_max_map_count))
                        return -ENOMEM;
-               error = __split_vma(mm, vma, start, 1);
+               error = __split_vma(mm, &vma, start, 1);
                if (error)
                        return error;
        }
@@ -172,7 +172,7 @@ static int madvise_update_vma(struct vm_area_struct *vma,
        if (end != vma->vm_end) {
                if (unlikely(mm->map_count >= sysctl_max_map_count))
                        return -ENOMEM;
-               error = __split_vma(mm, vma, end, 0);
+               error = __split_vma(mm, &vma, end, 0);
                if (error)
                        return error;
        }
index 78d706196d4a387c1e1f35b73a9556fb6c5017a6..2c50bfbe3c14f26c20978d3294eae186bf42c83e 100644 (file)
@@ -827,14 +827,14 @@ static int mbind_range(struct mm_struct *mm, unsigned long start,
                        goto replace;
                }
                if (vma->vm_start != vmstart) {
-                       err = split_vma(vma->vm_mm, vma, vmstart, 1);
+                       err = split_vma(vma->vm_mm, &vma, vmstart, 1);
                        if (err)
                                goto out;
                        /* split_vma() invalidated the mas */
                        mas_pause(&mas);
                }
                if (vma->vm_end != vmend) {
-                       err = split_vma(vma->vm_mm, vma, vmend, 0);
+                       err = split_vma(vma->vm_mm, &vma, vmend, 0);
                        if (err)
                                goto out;
                        /* split_vma() invalidated the mas */
index c41604ba5197d80c15716d8cd53d608d611359bc..2beb2b4213ddd33d78b14e3646eabc3cbb395ee3 100644 (file)
@@ -426,13 +426,13 @@ static int mlock_fixup(struct vm_area_struct *vma, struct vm_area_struct **prev,
        }
 
        if (start != vma->vm_start) {
-               ret = split_vma(mm, vma, start, 1);
+               ret = split_vma(mm, &vma, start, 1);
                if (ret)
                        goto out;
        }
 
        if (end != vma->vm_end) {
-               ret = split_vma(mm, vma, end, 0);
+               ret = split_vma(mm, &vma, end, 0);
                if (ret)
                        goto out;
        }
index b39f1b4f79db976df0e450c0bea8906b198938aa..9114c839797b323c176fa92d29907dee84d7628a 100644 (file)
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -2226,76 +2226,172 @@ static void unmap_region(struct mm_struct *mm, struct maple_tree *mt,
 }
 
 /*
- * __split_vma() bypasses sysctl_max_map_count checking.  We use this where it
- * has already been checked or doesn't make sense to fail.
+ * vma_replace() -  Replace one vma with two new VMAs.
+ * @vma: The vma to be replaced
+ * @start: The lower address VMA
+ * @end: The higher address VMA
+ *
+ * Currently does not support @start and @end leaving a portion of @vma.
+ */
+static inline int vma_replace(struct vm_area_struct *vma,
+                       struct vm_area_struct *start, struct vm_area_struct *end)
+{
+       struct mm_struct *mm = vma->vm_mm;
+       struct address_space *mapping = NULL;
+       struct anon_vma *anon_vma = vma->anon_vma;
+       struct file *file = vma->vm_file;
+       MA_STATE(mas, &vma->vm_mm->mm_mt, start->vm_start, start->vm_end - 1);
+
+       if (mas_preallocate(&mas, vma, GFP_KERNEL))
+               return -ENOMEM;
+
+       vma_adjust_trans_huge(vma, vma->vm_start, end->vm_start, 0);
+       if (file) {
+               mapping = file->f_mapping;
+               uprobe_munmap(vma, vma->vm_start, vma->vm_end);
+
+               i_mmap_lock_write(mapping);
+               /*
+                * Put into interval tree now, so instantiated pages are visible
+                * to arm/parisc __flush_dcache_page throughout; but we cannot
+                * insert into address space until vma vm_start or vm_end is
+                * updated.
+                */
+               __vma_link_file(start, start->vm_file->f_mapping);
+               __vma_link_file(end, end->vm_file->f_mapping);
+       }
+
+       if (anon_vma)
+               unlink_anon_vmas(vma);
+
+       mas_store(&mas, start);
+       mm->map_count++;
+       mas_set_range(&mas, end->vm_start, end->vm_end - 1);
+       mas_store_prealloc(&mas, end);
+       /* mmap_count is fine here since one vma was just overwritten */
+       BUG_ON(start->vm_end != end->vm_start);
+
+       if (file) {
+               __remove_shared_vm_struct(vma, file, mapping);
+               i_mmap_unlock_write(mapping);
+               uprobe_mmap(start);
+               uprobe_mmap(end);
+       }
+
+       remove_vma(vma);
+       validate_mm(mm);
+       return 0;
+}
+
+/*
+ * __split_vma() - Split one VMA into two new VMAs.
+ * @mm: The mm_struct
+ * @vma: Pointer to the vma to be split
+ * @addr: The address to split the vma
+ * @new_below: Put the new one at a lower address, sets the pointer @vma to the
+ * other VMA.
+ *
+ * Note: __split_vma() bypasses sysctl_max_map_count checking.  We use this
+ * where it has already been checked or doesn't make sense to fail.
  */
-int __split_vma(struct mm_struct *mm, struct vm_area_struct *vma,
+int __split_vma(struct mm_struct *mm, struct vm_area_struct **vma,
                unsigned long addr, int new_below)
 {
-       struct vm_area_struct *new;
+       struct vm_area_struct *start, *end;
        int err;
        validate_mm_mt(mm);
 
-       if (vma->vm_ops && vma->vm_ops->may_split) {
-               err = vma->vm_ops->may_split(vma, addr);
+       if ((*vma)->vm_ops && (*vma)->vm_ops->may_split) {
+               err = (*vma)->vm_ops->may_split(*vma, addr);
                if (err)
                        return err;
        }
 
-       new = vm_area_dup(vma);
-       if (!new)
-               return -ENOMEM;
+       err = -ENOMEM;
+       start = vm_area_dup(*vma);
+       if (!start)
+               goto no_start;
 
-       if (new_below)
-               new->vm_end = addr;
-       else {
-               new->vm_start = addr;
-               new->vm_pgoff += ((addr - vma->vm_start) >> PAGE_SHIFT);
-       }
+       end = vm_area_dup(*vma);
+       if (!end)
+               goto no_end;
+
+       start->vm_end = addr;
+       end->vm_start = addr;
+       end->vm_pgoff += ((addr - start->vm_start) >> PAGE_SHIFT);
 
-       err = vma_dup_policy(vma, new);
+       err = vma_dup_policy(*vma, start);
        if (err)
-               goto out_free_vma;
+               goto no_start_policy;
 
-       err = anon_vma_clone(new, vma);
+       err = vma_dup_policy(*vma, end);
        if (err)
-               goto out_free_mpol;
+               goto no_end_policy;
 
-       if (new->vm_file)
-               get_file(new->vm_file);
+       err = anon_vma_clone(start, *vma);
+       if (err)
+               goto no_start_anon;
 
-       if (new->vm_ops && new->vm_ops->open)
-               new->vm_ops->open(new);
+       err = anon_vma_clone(end, *vma);
+       if (err)
+               goto no_end_anon;
+
+       if (start->vm_file) {
+               get_file(start->vm_file);
+               get_file(end->vm_file);
+       }
+
+       if (start->vm_ops && start->vm_ops->open) {
+               start->vm_ops->open(start);
+               end->vm_ops->open(end);
+       }
+
+
+       if (vma_replace(*vma, start, end))
+               goto no_replace;
 
        if (new_below)
-               err = vma_adjust(vma, addr, vma->vm_end, vma->vm_pgoff +
-                       ((addr - new->vm_start) >> PAGE_SHIFT), new);
+               *vma = end;
        else
-               err = vma_adjust(vma, vma->vm_start, addr, vma->vm_pgoff, new);
-
-       /* Success. */
-       if (!err)
-               return 0;
+               *vma = start;
+       return 0;
 
-       /* Clean everything up if vma_adjust failed. */
-       if (new->vm_ops && new->vm_ops->close)
-               new->vm_ops->close(new);
-       if (new->vm_file)
-               fput(new->vm_file);
-       unlink_anon_vmas(new);
- out_free_mpol:
-       mpol_put(vma_policy(new));
- out_free_vma:
-       vm_area_free(new);
+no_replace:
+       /* Clean everything up if vma_replace failed. */
+       if (start->vm_ops && start->vm_ops->close) {
+               start->vm_ops->close(start);
+               end->vm_ops->close(end);
+       }
+       if (start->vm_file) {
+               fput(start->vm_file);
+               fput(end->vm_file);
+       }
+       unlink_anon_vmas(end);
+no_end_anon:
+       unlink_anon_vmas(start);
+no_start_anon:
+       mpol_put(vma_policy(end));
+no_end_policy:
+       mpol_put(vma_policy(start));
+no_start_policy:
+       vm_area_free(end);
+no_end:
+       vm_area_free(start);
+no_start:
        validate_mm_mt(mm);
        return err;
 }
 
-/*
- * Split a vma into two pieces at address 'addr', a new vma is allocated
- * either for the first part or the tail.
+/**
+ * split_vma() - Split one VMA and replace it by two new VMAs
+ * @mm: The mm_struct
+ * @vma: Pointer to the vma
+ * @addr: The address to split the VMA
+ * @new_below: Put the new one at a lower address, sets the pointer @vma to the
+ * other VMA.
+ * Return: 0 on success, errno otherwise.
  */
-int split_vma(struct mm_struct *mm, struct vm_area_struct *vma,
+int split_vma(struct mm_struct *mm, struct vm_area_struct **vma,
              unsigned long addr, int new_below)
 {
        if (mm->map_count >= sysctl_max_map_count)
@@ -2355,7 +2451,7 @@ do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma,
                 * mas_pause() is not needed since mas->index needs to be set
                 * differently than vma->vm_end anyways.
                 */
-               error = __split_vma(mm, vma, start, 1);
+               error = __split_vma(mm, &vma, start, 1);
                if (error)
                        return error;
 
@@ -2376,10 +2472,12 @@ do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma,
                if (next->vm_end > end) {
                        int error;
 
-                       error = __split_vma(mm, next, end, 0);
+                       error = __split_vma(mm, &next, end, 0);
                        if (error)
                                return error;
 
+                       if (next->vm_start == start)
+                               vma = next;
                        mas_set(mas, end);
                }
                count++;
index fbb248caf8aa0682e5f18152f3e870b2f423502a..01965ed8c12040df588765d9c9494742ca294fb6 100644 (file)
@@ -492,20 +492,20 @@ mprotect_fixup(struct mmu_gather *tlb, struct vm_area_struct *vma,
                goto success;
        }
 
-       *pprev = vma;
 
        if (start != vma->vm_start) {
-               error = split_vma(mm, vma, start, 1);
+               error = split_vma(mm, &vma, start, 1);
                if (error)
                        goto fail;
        }
 
        if (end != vma->vm_end) {
-               error = split_vma(mm, vma, end, 0);
+               error = split_vma(mm, &vma, end, 0);
                if (error)
                        goto fail;
        }
 
+       *pprev = vma;
 success:
        /*
         * vm_flags and vm_page_prot are protected by the mmap_lock