mm: WIP, trying to avoid splitting is a mess. maple3_do_munmap
authorLiam R. Howlett <Liam.Howlett@Oracle.com>
Tue, 24 Nov 2020 00:03:58 +0000 (19:03 -0500)
committerLiam R. Howlett <Liam.Howlett@Oracle.com>
Tue, 24 Nov 2020 00:03:58 +0000 (19:03 -0500)
If you split this VMA, then it is inserted into the tree.  Instead, we
try to create a dummy VMA and pass that along to handle freeing.  Unfortunately,
the hugetlb cases seem to require the file handle to be correctly processed
in the unmap_single_vma() call path.  There may be memory leaks in the cases
where we are not a hugetlb page, but I didn't see that.

The unfornate part is that we have to insert ourselves into the interval tree
when we try to do more with the VMA.

For now, let's just go back to something that is less of a mess

Signed-off-by: Liam R. Howlett <Liam.Howlett@Oracle.com>
include/linux/maple_tree.h
mm/memory.c
mm/mmap.c

index 06506c529fd71f38bd2ce8de42ce6aa7f4875f87..54be2dc2392ba149ff49533bcc132b85356ece4c 100644 (file)
@@ -12,7 +12,7 @@
 #include <linux/rcupdate.h>
 #include <linux/spinlock.h>
 #define CONFIG_MAPLE_RCU_DISABLED
-//#define CONFIG_DEBUG_MAPLE_TREE
+#define CONFIG_DEBUG_MAPLE_TREE
 //#define CONFIG_DEBUG_MAPLE_TREE_VERBOSE
 
 /*
index c48f8df6e502683432097861bc31b010e08f263d..2591713feb0419f38830e7c6734a1932c0e38b41 100644 (file)
@@ -361,18 +361,24 @@ void free_pgd_range(struct mmu_gather *tlb,
        addr &= PMD_MASK;
        if (addr < floor) {
                addr += PMD_SIZE;
-               if (!addr)
+               if (!addr) {
+                       printk("%s: %d\n", __func__, __LINE__);
                        return;
+               }
        }
        if (ceiling) {
                ceiling &= PMD_MASK;
-               if (!ceiling)
+               if (!ceiling) {
+                       printk("%s: %d\n", __func__, __LINE__);
                        return;
+               }
        }
        if (end - 1 > ceiling - 1)
                end -= PMD_SIZE;
-       if (addr > end - 1)
+       if (addr > end - 1) {
+                       printk("%s: %d\n", __func__, __LINE__);
                return;
+       }
        /*
         * We add page table cache pages with PAGE_SIZE,
         * (see pte_free_tlb()), flush the tlb if we need
@@ -401,7 +407,10 @@ void free_pgtables(struct mmu_gather *tlb, struct vm_area_struct *vma,
                unlink_anon_vmas(vma);
                unlink_file_vma(vma);
 
+               printk("vma %px\n", vma);
                if (is_vm_hugetlb_page(vma)) {
+                       printk("free %lu, end %lu, floor %lu to %lu\n",
+                              addr, vma->vm_end, floor, next ? next->vm_start : ceiling);
                        hugetlb_free_pgd_range(tlb, addr, vma->vm_end,
                                floor, next ? next->vm_start : ceiling);
                } else {
@@ -415,6 +424,8 @@ void free_pgtables(struct mmu_gather *tlb, struct vm_area_struct *vma,
                                unlink_anon_vmas(vma);
                                unlink_file_vma(vma);
                        }
+                       printk("free %lu, end %lu, floor %lu to %lu\n",
+                              addr, vma->vm_end, floor, next ? next->vm_start : ceiling);
                        free_pgd_range(tlb, addr, vma->vm_end,
                                floor, next ? next->vm_start : ceiling);
                }
index ba1c78bdc519e27c432acc52b0677687c84f1449..d781c3de27ebf6ad6c483f6973d210ff3dcbf565 100644 (file)
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -2553,10 +2553,19 @@ static void unmap_region(struct mm_struct *mm,
        struct vm_area_struct *next = vma_next(mm, prev);
        struct mmu_gather tlb;
 
+
+       if (prev)
+               printk("prev %px %lu-%lu\n", prev, prev->vm_start, prev->vm_end);
+       if (next)
+               printk("next %px %lu-%lu\n", next, next->vm_start, next->vm_end);
+
        lru_add_drain();
        tlb_gather_mmu(&tlb, mm, start, end);
        update_hiwater_rss(mm);
        unmap_vmas(&tlb, vma, start, end);
+       printk("free pgtables %px %lu-%lu\n",
+               vma, prev ? prev->vm_end : FIRST_USER_ADDRESS,
+                                next ? next->vm_start : USER_PGTABLES_CEILING);
        free_pgtables(&tlb, vma, prev ? prev->vm_end : FIRST_USER_ADDRESS,
                                 next ? next->vm_start : USER_PGTABLES_CEILING);
        tlb_finish_mmu(&tlb, start, end);
@@ -2656,38 +2665,62 @@ static inline void unlock_range(struct vm_area_struct *start, unsigned long limi
        }
 }
 
-void vma_shorten(struct vm_area_struct *vma, unsigned long start,
-                       unsigned long end, struct vm_area_struct *unmap)
+void vma_shorten(struct vm_area_struct *vma, unsigned long split,
+                       bool keep_front, struct vm_area_struct *unmap)
 {
        struct mm_struct *mm = vma->vm_mm;
-       unsigned long old_start = vma->vm_start;
-       unsigned long old_end = vma->vm_end;
        struct vm_area_struct *next = vma->vm_next;
        struct address_space *mapping = NULL;
        struct rb_root_cached *root = NULL;
        struct anon_vma *anon_vma = NULL;
        struct file *file = vma->vm_file;
+       unsigned long addr;
 
 
        vma_init(unmap, mm);
-       unmap->vm_pgoff = vma->vm_pgoff;
        unmap->vm_flags = vma->vm_flags;
-       if (end == old_end) {
-               /* Changing the start of the VMA. */
-               unmap->vm_start = old_start;
-               unmap->vm_end = start;
-               unmap->vm_next = vma;
-               unmap->vm_prev = vma->vm_prev;
-       } else {
-               /* unmap will contain the end section of the VMA to be removed  */
-               unmap->vm_start = end;
-               unmap->vm_end = old_end;
+       unmap->vm_file = vma->vm_file;
+       if (is_vm_hugetlb_page(vma) && !is_vm_hugetlb_page(unmap))
+               printk("Flags don't match: %lu %lu  !!!!!!!!\n", vma->vm_flags, unmap->vm_flags);
+#if 0
+       ASSERT_EXCLUSIVE_WRITER(vma->vm_flags);
+       ASSERT_EXCLUSIVE_WRITER(vma->vm_file);
+       *unmap = data_race(*vma);
+       INIT_LIST_HEAD(&unmap->anon_vma_chain);
+#endif
+       if (keep_front) {
+               unmap->vm_start = split;
+               unmap->vm_end = vma->vm_end;
                unmap->vm_next = vma->vm_next;
                unmap->vm_prev = vma;
-               unmap->vm_pgoff += ((end - old_start) >> PAGE_SHIFT);
+               unmap->vm_pgoff = 0;
+       } else {
+               unmap->vm_start = vma->vm_start;
+               unmap->vm_end = split;
+               unmap->vm_next = vma;
+               unmap->vm_prev = vma->vm_prev;
+               unmap->vm_pgoff = vma->vm_pgoff;
        }
 
-       vma_adjust_trans_huge(vma, start, end, 0);
+       if (vma->vm_ops && vma->vm_ops->split) {
+               int err = vma->vm_ops->split(vma, addr);
+               if (err)
+                       printk("error on split %d\n", err);
+       }
+#if 1
+       vma_dup_policy(vma, unmap);
+       anon_vma_clone(unmap, vma);
+       if (unmap->vm_file)
+               get_file(unmap->vm_file);
+
+       if (unmap->vm_ops && unmap->vm_ops->open)
+               unmap->vm_ops->open(unmap);
+#endif
+
+       if (keep_front)
+               vma_adjust_trans_huge(vma, vma->vm_start, split, 0);
+       else
+               vma_adjust_trans_huge(vma, split, vma->vm_end, 0);
 
        if (file) {
                mapping = file->f_mapping;
@@ -2695,6 +2728,7 @@ void vma_shorten(struct vm_area_struct *vma, unsigned long start,
                uprobe_munmap(vma, vma->vm_start, vma->vm_end);
 
                i_mmap_lock_write(mapping);
+               __vma_link_file(unmap);
        }
 
        anon_vma = vma->anon_vma;
@@ -2708,11 +2742,11 @@ void vma_shorten(struct vm_area_struct *vma, unsigned long start,
                vma_interval_tree_remove(vma, root);
        }
 
-       if (end == old_end) {
-               vma->vm_start = start;
-               vma->vm_pgoff += (start - old_start) >> PAGE_SHIFT;
+       if (keep_front) {
+               vma->vm_end = split;
        } else {
-               vma->vm_end = end;
+               vma->vm_pgoff += (split - vma->vm_start) >> PAGE_SHIFT;
+               vma->vm_start = split;
        }
 
        if (file) {
@@ -2731,6 +2765,7 @@ void vma_shorten(struct vm_area_struct *vma, unsigned long start,
        if (file) {
                i_mmap_unlock_write(mapping);
                uprobe_mmap(vma);
+               uprobe_mmap(unmap);
        }
 }
 /* Munmap is split into 2 main parts -- this part which finds
@@ -2747,7 +2782,6 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
        int map_count = 0;
        MA_STATE(mas, &mm->mm_mt, start, start);
 
-
        if ((offset_in_page(start)) || start > TASK_SIZE || len > TASK_SIZE-start)
                return -EINVAL;
 
@@ -2764,14 +2798,9 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
        if (!vma)
                return 0;
 
-       /* Check for userfaultfd now before altering the mm */
-       if (unlikely(uf)) {
-               int error = userfaultfd_unmap_prep(vma, start, end, uf);
-
-               if (error)
-                       return error;
-       }
 
+       printk("%s: %lu - %lu\n", __func__, start, end);
+       printk("  found vma %px %lu-%lu\n", vma, vma->vm_start, vma->vm_end);
        if (start > vma->vm_start) {
                if (unlikely(vma->vm_end > end)) {
                        /* Essentially, this means there is a hole punched in
@@ -2794,29 +2823,59 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
                        mas_set(&mas, start);
                        vma = mas_walk(&mas);
                } else {
-                       vma_shorten(vma, vma->vm_start, start, &start_split);
+                       printk("Shorten %lu-%lu to %lu - %lu\n",
+                              vma->vm_start, vma->vm_end, start, vma->vm_end);
+                       vma_shorten(vma, start, true, &start_split);
                        prev = vma;
                        vma = &start_split;
+                       printk("start_Split %lu-%lu\n", start_split.vm_start,
+                              start_split.vm_end);
                        map_count--;
                }
        } else {
                prev = vma->vm_prev;
        }
 
-       if (vma->vm_end >= end) // almost always the case
+       if (vma->vm_end >= end) // almost always the case
                last = vma;
-       else
+       } else {
                last = find_vma_intersection(mm, end - 1, end);
+       }
 
        /* Does it split the last one? */
        if (last && end < last->vm_end) {
-               vma_shorten(last, end, last->vm_end, &end_split);
+               printk("%d %lu - %lu to %lu - %lu\n", __LINE__,
+                      last->vm_start, last->vm_end, end, last->vm_end);
+#if 1
+               vma_shorten(last, end, false, &end_split);
+               printk("vma is %lu - %lu\n", vma->vm_start, vma->vm_end);
                if (last == vma)
                        vma = &end_split;
 
                // map_count will count the existing vma in this case
                map_count--;
                last = &end_split;
+#else
+               int error = __split_vma(mm, last, end, 1);
+               if (error)
+                       return error;
+               mas_reset(&mas);
+               mas_set(&mas, start);
+               vma = mas_walk(&mas);
+               last = find_vma_intersection(mm, end - 1, end);
+#endif
+               printk("last is %lu - %lu\n", last->vm_start, last->vm_end);
+               printk("next is %px and %px\n", last->vm_next, vma->vm_next);
+       }
+
+       /* Check for userfaultfd now before altering the mm */
+       if (unlikely(uf)) {
+               int error = userfaultfd_unmap_prep(vma, start, end, uf);
+
+               if (error) {
+                       printk("%d\n", __LINE__);
+                       return error;
+               }
        }
 
        /* unlock any mlock()ed VMAs and count the number of VMAs to be
@@ -2851,14 +2910,25 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
        last->vm_next = NULL;
 
        /* Detach VMAs from the maple tree */
+//     if (end - start >= 0x400000)
+//             mt_dump(mas.tree);
+//
+       /* Reset the maple range to write as the found range may be different */
        mas.index = start;
        mas.last = end - 1;
-       mas_store_gfp(&mas, NULL, GFP_KERNEL);
+//     printk("Store %lu - %lu\n", start, end - 1);
+       if(mas_store_gfp(&mas, NULL, GFP_KERNEL)) {
+               printk("UNDO!! UNDO!!!\n");
+       }
+//     if (end - start >= 0x400000)
+//             mt_dump(mas.tree);
 
        /* Update map_count */
+       printk("%d - %d\n", mm->map_count, map_count);
        mm->map_count -= map_count;
 
        /* Downgrade the lock, if possible */
+
        if (next && (next->vm_flags & VM_GROWSDOWN))
                downgrade = false;
        else if (prev && (prev->vm_flags & VM_GROWSUP))
@@ -2867,9 +2937,15 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
        if (downgrade)
                mmap_write_downgrade(mm);
 
+       printk("Start at %px (%lu) %lu to %lu\n", vma, vma->vm_start, start, end);
+       printk(" last at %px (%lu)\n", last, last->vm_end);
+       if (vma->vm_next)
+       printk(" next is %px (%lu - %lu)\n", vma->vm_next,
+              vma->vm_next->vm_start, vma->vm_next->vm_end);
        /* Actual unmap the region */
        unmap_region(mm, vma, prev, start, end);
 
+
        /* Take care of accounting for orphan VMAs, and remove from the list. */
        if (vma == &start_split) {
                if (vma->vm_flags & VM_ACCOUNT) {
@@ -2878,6 +2954,11 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
                        vm_stat_account(mm, vma->vm_flags, -nrpages);
                        vm_unacct_memory(nrpages);
                }
+               if (vma->vm_ops && vma->vm_ops->close)
+                       vma->vm_ops->close(vma);
+               if (vma->vm_file)
+                       fput(vma->vm_file);
+               mpol_put(vma_policy(vma));
                vma = vma->vm_next;
        }
 
@@ -2889,6 +2970,12 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
                        vm_unacct_memory(nrpages);
                }
 
+               if (last->vm_ops && last->vm_ops->close)
+                       last->vm_ops->close(last);
+               if (last->vm_file)
+                       fput(last->vm_file);
+               mpol_put(vma_policy(last));
+
                if (last->vm_prev)
                        last->vm_prev->vm_next = NULL;
                if (vma == last)