mm/mmap: Prepare detaching of VMAs for new split
authorLiam R. Howlett <Liam.Howlett@Oracle.com>
Fri, 5 Mar 2021 16:11:31 +0000 (11:11 -0500)
committerLiam R. Howlett <Liam.Howlett@Oracle.com>
Fri, 5 Mar 2021 16:11:31 +0000 (11:11 -0500)
Signed-off-by: Liam R. Howlett <Liam.Howlett@Oracle.com>
mm/mmap.c

index 64a2addb47e2e48b31b1524a13671410ac3fa2ea..b93fa2374ae41dffde97652c5098b16be839b8de 100644 (file)
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -2274,44 +2274,31 @@ int split_vma(struct mm_struct *mm, struct vm_area_struct *vma,
        return __split_vma(mm, vma, addr, new_below);
 }
 
-static inline unsigned long detach_range(struct mm_struct *mm,
-               struct ma_state *src, struct ma_state *dst,
-               struct vm_area_struct **vma, struct vm_area_struct *prev)
+
+static inline void detach_range(struct mm_struct *mm, struct ma_state *mas,
+                       struct ma_state *dst, struct vm_area_struct **vma)
 {
+       unsigned long start = dst->index;
+       unsigned long end = dst->last;
        int count = 0;
-       struct ma_state mas;
 
-       /*
-        * unlock any mlock()ed ranges before detaching vmas, count the number
-        * of VMAs to be dropped, and return the tail entry of the affected
-        * area.
-        */
-       mas = *src;
        do {
-               BUG_ON((*vma)->vm_start < src->index);
-               BUG_ON((*vma)->vm_end > (src->last + 1));
                count++;
+               *vma = mas_prev(mas, start);
+               BUG_ON((*vma)->vm_start < start);
+               BUG_ON((*vma)->vm_end > end + 1);
+               vma_mas_store(*vma, dst);
                if ((*vma)->vm_flags & VM_LOCKED) {
                        mm->locked_vm -= vma_pages(*vma);
                        munlock_vma_pages_all(*vma);
                }
-               vma_mas_store(*vma, dst);
-       } while ((*vma = mas_find(&mas, src->last)) != NULL);
+       } while ((*vma)->vm_start > start);
 
-       mas_set(&mas, src->last + 1);
        /* Drop removed area from the tree */
-       *vma = mas_find(&mas, ULONG_MAX);
-       mas_lock(src);
-       mas_store_gfp(src, NULL, GFP_KERNEL);
-       mas_unlock(src);
+       mas->last = end;
+       mas_store_gfp(mas, NULL, GFP_KERNEL);
        /* Decrement map_count */
        mm->map_count -= count;
-       /* Set the upper limit */
-       if (!(*vma))
-               return USER_PGTABLES_CEILING;
-
-       validate_mm(mm);
-       return (*vma)->vm_start;
 }
 
 /* do_mas_align_munmap() - munmap the aligned region from @start to @end.
@@ -2330,15 +2317,15 @@ static int do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma,
               struct mm_struct *mm, unsigned long start, unsigned long end,
               struct list_head *uf, bool downgrade)
 {
-       struct vm_area_struct *prev, *last;
+       struct vm_area_struct *prev, *last, *next = NULL;
        struct maple_tree mt_detach;
-       unsigned long max;
-       MA_STATE(dst, NULL, start, start);
+       unsigned long max = USER_PGTABLES_CEILING;
+       MA_STATE(dst, NULL, start, end - 1);
        struct ma_state tmp;
        /* we have start < vma->vm_end  */
 
        validate_mm(mm);
-        /* arch_unmap() might do unmaps itself.  */
+       /* arch_unmap() might do unmaps itself.  */
        arch_unmap(mm, start, end);
 
        /*
@@ -2362,32 +2349,31 @@ static int do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma,
                if (error)
                        return error;
                prev = vma;
-               // Split invalidated node, reset.
                mas_set_range(mas, start, end - 1);
+               vma = mas_walk(mas);
+
        } else {
                tmp = *mas;
                prev = mas_prev(&tmp, 0);
        }
 
-       if (vma->vm_end >= end)
+       if (end < vma->vm_end) {
                last = vma;
-       else {
-               tmp = *mas;
-               mas_set(&tmp, end - 1);
-               last = mas_walk(&tmp);
+       } else {
+               mas_set(mas, end - 1);
+               last = mas_walk(mas);
        }
 
        /* Does it split the last one? */
        if (last && end < last->vm_end) {
-               int error = __split_vma(mm, last, end, 1);
+               int error;
+               error = __split_vma(mm, last, end, 1);
                if (error)
                        return error;
-               // Split invalidated node, reset.
-               mas_set_range(mas, start, end - 1);
+               mas_set(mas, end - 1);
+               last = mas_walk(mas);
        }
-
-       if (mas->node == MAS_START)
-               vma = mas_find(mas, end - 1);
+       next = mas_next(mas, ULONG_MAX);
 
        if (unlikely(uf)) {
                /*
@@ -2406,10 +2392,14 @@ static int do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma,
        }
 
        /* Point of no return */
+       mas_lock(mas);
+       if (next)
+               max = next->vm_start;
+
        mtree_init(&mt_detach, MAPLE_ALLOC_RANGE);
        dst.tree = &mt_detach;
-       mas->last = end - 1;
-       max = detach_range(mm, mas, &dst, &vma, prev);
+       detach_range(mm, mas, &dst, &vma);
+       mas_unlock(mas);
 
        /*
         * Do not downgrade mmap_lock if we are next to VM_GROWSDOWN or
@@ -2417,7 +2407,7 @@ static int do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma,
         * down_read(mmap_lock) and collide with the VMA we are about to unmap.
         */
        if (downgrade) {
-               if (vma && (vma->vm_flags & VM_GROWSDOWN))
+               if (next && (next->vm_flags & VM_GROWSDOWN))
                        downgrade = false;
                else if (prev && (prev->vm_flags & VM_GROWSUP))
                        downgrade = false;
@@ -2426,13 +2416,10 @@ static int do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma,
        }
 
        /* Unmap the region */
-       mas_set(&dst, start);
-       tmp = dst;
-       vma = mas_find(&dst, end - 1); // head of list.
        unmap_region(mm, vma, &dst, start, end, prev, max);
 
        /* Statistics and freeing VMAs */
-       dst = tmp;
+       mas_set(&dst, start);
        remove_mt(mm, &dst);
 
        mtree_destroy(&mt_detach);
@@ -2488,8 +2475,11 @@ int do_mas_munmap(struct ma_state *mas, struct mm_struct *mm,
 int do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
              struct list_head *uf)
 {
+       int ret;
        MA_STATE(mas, &mm->mm_mt, start, start);
-       return do_mas_munmap(&mas, mm, start, len, uf, false);
+
+       ret = do_mas_munmap(&mas, mm, start, len, uf, false);
+       return ret;
 }
 
 unsigned long mmap_region(struct file *file, unsigned long addr,
@@ -2509,7 +2499,6 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
        struct ma_state ma_prev, tmp;
        MA_STATE(mas, &mm->mm_mt, addr, end - 1);
 
-       validate_mm(mm);
 
        /* Check against address space limit. */
        if (!may_expand_vm(mm, vm_flags, len >> PAGE_SHIFT)) {
@@ -2526,17 +2515,20 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
                        return -ENOMEM;
        }
 
+       validate_mm(mm);
        /* Unmap any existing mapping in the area */
-       if (do_mas_munmap(&mas, mm, addr, len, uf, false))
+       if (do_mas_munmap(&mas, mm, addr, len, uf, false)) {
                return -ENOMEM;
+       }
 
        /*
         * Private writable mapping: check memory availability
         */
        if (accountable_mapping(file, vm_flags)) {
                charged = len >> PAGE_SHIFT;
-               if (security_vm_enough_memory_mm(mm, charged))
+               if (security_vm_enough_memory_mm(mm, charged)) {
                        return -ENOMEM;
+               }
                vm_flags |= VM_ACCOUNT;
        }
 
@@ -2628,8 +2620,8 @@ cannot_expand:
                 * Answer: Yes, several device drivers can do it in their
                 *         f_op->mmap method. -DaveM
                 */
-               WARN_ON_ONCE(addr != vma->vm_start);
                if (addr != vma->vm_start) {
+                       WARN_ON_ONCE(addr != vma->vm_start);
                        addr = vma->vm_start;
                        mas_set_range(&mas, addr, end - 1);
                }
@@ -2718,8 +2710,8 @@ expanded:
        vma->vm_flags |= VM_SOFTDIRTY;
 
        vma_set_page_prot(vma);
-
        validate_mm(mm);
+
        return addr;
 
 unmap_and_free_vma:
@@ -2752,7 +2744,6 @@ static int __vm_munmap(unsigned long start, size_t len, bool downgrade)
 
        if (mmap_write_lock_killable(mm))
                return -EINTR;
-
        ret = do_mas_munmap(&mas, mm, start, len, &uf, downgrade);
        /*
         * Returning 1 indicates mmap_lock is downgraded.
@@ -3020,9 +3011,9 @@ static int do_brk_flags(struct ma_state *mas, struct ma_state *ma_prev,
                                anon_vma_lock_write(vma->anon_vma);
                                anon_vma_interval_tree_pre_update_vma(vma);
                        }
+                       mas_lock(ma_prev);
                        vma->vm_end = addr + len;
                        vma->vm_flags |= VM_SOFTDIRTY;
-                       mas_lock(ma_prev);
                        if (mas_store_gfp(ma_prev, vma, GFP_KERNEL)) {
                                mas_unlock(ma_prev);
                                goto mas_mod_fail;
@@ -3163,8 +3154,10 @@ void exit_mmap(struct mm_struct *mm)
        arch_exit_mmap(mm);
 
        vma = mas_find(&mas, ULONG_MAX);
-       if (!vma)       /* Can happen if dup_mmap() received an OOM */
+       if (!vma) { /* Can happen if dup_mmap() received an OOM */
+               rcu_read_unlock();
                return;
+       }
 
        lru_add_drain();
        flush_cache_mm(mm);
@@ -3595,6 +3588,7 @@ int mm_take_all_locks(struct mm_struct *mm)
        BUG_ON(mmap_read_trylock(mm));
 
        mutex_lock(&mm_all_locks_mutex);
+       rcu_read_lock();
 
        mas_for_each(&mas, vma, ULONG_MAX) {
                if (signal_pending(current))
@@ -3622,9 +3616,11 @@ int mm_take_all_locks(struct mm_struct *mm)
                                vm_lock_anon_vma(mm, avc->anon_vma);
        }
 
+       rcu_read_unlock();
        return 0;
 
 out_unlock:
+       rcu_read_unlock();
        mm_drop_all_locks(mm);
        return -EINTR;
 }
@@ -3675,6 +3671,8 @@ void mm_drop_all_locks(struct mm_struct *mm)
        BUG_ON(mmap_read_trylock(mm));
        BUG_ON(!mutex_is_locked(&mm_all_locks_mutex));
 
+
+       rcu_read_lock();
        mas_for_each(&mas, vma, ULONG_MAX) {
                if (vma->anon_vma)
                        list_for_each_entry(avc, &vma->anon_vma_chain, same_vma)
@@ -3682,6 +3680,7 @@ void mm_drop_all_locks(struct mm_struct *mm)
                if (vma->vm_file && vma->vm_file->f_mapping)
                        vm_unlock_mapping(vma->vm_file->f_mapping);
        }
+       rcu_read_unlock();
 
        mutex_unlock(&mm_all_locks_mutex);
 }