]> www.infradead.org Git - users/jedix/linux-maple.git/commitdiff
should be more than one commit.
authorLiam R. Howlett <Liam.Howlett@oracle.com>
Sun, 10 Oct 2021 01:10:24 +0000 (21:10 -0400)
committerLiam R. Howlett <Liam.Howlett@oracle.com>
Wed, 20 Oct 2021 19:23:10 +0000 (15:23 -0400)
Signed-off-by: Liam R. Howlett <Liam.Howlett@oracle.com>
lib/maple_tree.c
mm/mmap.c

index e46970ecd871304734655955e3809843eb9627c1..1592bb2212ab3a9289a2d5105098c5dfc8d8db93 100644 (file)
@@ -462,6 +462,20 @@ static inline struct maple_node *mte_parent(const struct maple_enode *enode)
                        (mte_to_node(enode)->parent) & ~MAPLE_NODE_MASK);
 }
 
+
+/*
+ * ma_dead_node() - check if the @enode is dead.
+ * @enode: The encoded maple node
+ *
+ * Return: true if dead, false otherwise.
+ */
+static inline bool ma_dead_node(const struct maple_node *node)
+{
+       struct maple_node *parent = (void *)((unsigned long)
+                                            node->parent & ~MAPLE_NODE_MASK);
+
+       return (parent == node);
+}
 /*
  * mte_dead_node() - check if the @enode is dead.
  * @enode: The encoded maple node
@@ -1014,7 +1028,7 @@ static int mas_ascend(struct ma_state *mas)
        a_type = mas_parent_enum(mas, mas->node);
        offset = mte_parent_slot(mas->node);
        a_enode = mt_mk_node(p_node, a_type);
-       if (unlikely(mte_dead_node(mas->node)))
+       if (unlikely(mas_mn(mas) == p_node))
                return 1;
 
        /* Check to make sure all parent information is still accurate */
@@ -1032,8 +1046,8 @@ static int mas_ascend(struct ma_state *mas)
 
        min = 0;
        max = ULONG_MAX;
-       p_enode = a_enode;
        do {
+               p_enode = a_enode;
                a_type = mas_parent_enum(mas, p_enode);
                a_node = mte_parent(p_enode);
                a_slot = mte_parent_slot(p_enode);
@@ -1050,13 +1064,12 @@ static int mas_ascend(struct ma_state *mas)
                        max = pivots[a_slot];
                }
 
-               if (unlikely(mte_dead_node(a_enode)))
+               if (unlikely(ma_dead_node(a_node)))
                        return 1;
 
                if (unlikely(ma_is_root(a_node)))
                        break;
 
-               p_enode = a_enode;
        } while (!set_min || !set_max);
 
        mas->max = max;
@@ -2245,7 +2258,8 @@ bool mast_sibling_rebalance_right(struct maple_subtree_state *mast, bool free)
 }
 
 static inline int mas_prev_node(struct ma_state *mas, unsigned long min);
-static inline int mas_next_node(struct ma_state *mas, unsigned long max);
+static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
+                               unsigned long max);
 /*
  * mast_cousin_rebalance_right() - Rebalance from nodes with different parents.
  * Check the right side, then the left.  Data is copied into the @mast->bn.
@@ -2260,7 +2274,7 @@ bool mast_cousin_rebalance_right(struct maple_subtree_state *mast, bool free)
        MA_STATE(tmp, mast->orig_r->tree, mast->orig_r->index, mast->orig_r->last);
 
        tmp = *mast->orig_r;
-       mas_next_node(mast->orig_r, ULONG_MAX);
+       mas_next_node(mast->orig_r, mas_mn(mast->orig_r), ULONG_MAX);
        if (!mas_is_none(mast->orig_r)) {
                mast_rebalance_next(mast, old_r, free);
                return true;
@@ -3494,13 +3508,16 @@ static bool mas_is_span_wr(struct ma_state *mas, unsigned long piv,
 static inline void mas_node_walk(struct ma_state *mas, enum maple_type type,
                unsigned long *range_min, unsigned long *range_max)
 {
-       unsigned long *pivots = ma_pivots(mas_mn(mas), type);
+       struct maple_node *node;
+       unsigned long *pivots;
        unsigned char offset, count;
        unsigned long min, max, index;
 
+       node = mas_mn(mas);
+       pivots = ma_pivots(node, type);
        if (unlikely(ma_is_dense(type))) {
                (*range_max) = (*range_min) = mas->index;
-               if (unlikely(mte_dead_node(mas->node)))
+               if (unlikely(ma_dead_node(node)))
                        return;
 
                mas->offset = mas->index = mas->min;
@@ -3512,7 +3529,7 @@ static inline void mas_node_walk(struct ma_state *mas, enum maple_type type,
        count = mt_pivots[type];
        index = mas->index;
        max = pivots[offset];
-       if (unlikely(mte_dead_node(mas->node)))
+       if (unlikely(ma_dead_node(node)))
                return;
 
        if (unlikely(offset == count))
@@ -3528,7 +3545,7 @@ static inline void mas_node_walk(struct ma_state *mas, enum maple_type type,
        min = max + 1;
        while (offset < count) {
                max = pivots[offset];
-               if (unlikely(mte_dead_node(mas->node)))
+               if (unlikely(ma_dead_node(node)))
                        return;
 
                if (index <= max)
@@ -3653,10 +3670,12 @@ static inline void mas_extend_null(struct ma_state *l_mas, struct ma_state *r_ma
 static inline bool __mas_walk(struct ma_state *mas, unsigned long *range_min,
                unsigned long *range_max)
 {
+       struct maple_node *node;
        struct maple_enode *next;
        enum maple_type type;
 
-       if (unlikely(mte_dead_node(mas->node)))
+       node = mas_mn(mas);
+       if (unlikely(ma_dead_node(node)))
                return false;
 
        while (true) {
@@ -3664,7 +3683,7 @@ static inline bool __mas_walk(struct ma_state *mas, unsigned long *range_min,
                mas->depth++;
                mas_node_walk(mas, type, range_min, range_max);
                next = mas_slot(mas, ma_slots(mas_mn(mas), type), mas->offset);
-               if (unlikely(mte_dead_node(mas->node)))
+               if (unlikely(ma_dead_node(node)))
                        return false;
 
                if (unlikely(ma_is_leaf(type)))
@@ -3675,6 +3694,7 @@ static inline bool __mas_walk(struct ma_state *mas, unsigned long *range_min,
 
                /* Descend. */
                mas->node = next;
+               node = mas_mn(mas);
                mas->max = *range_max;
                mas->min = *range_min;
                mas->offset = 0;
@@ -4212,7 +4232,8 @@ static inline int mas_prev_node(struct ma_state *mas, unsigned long min)
 
        level = 0;
        do {
-               if (mte_is_root(mas->node))
+               node = mas_mn(mas);
+               if (ma_is_root(node))
                        goto no_entry;
 
                /* Walk up. */
@@ -4230,7 +4251,7 @@ static inline int mas_prev_node(struct ma_state *mas, unsigned long min)
        mas->max = pivots[offset];
        if (offset)
                mas->min = pivots[offset - 1] + 1;
-       if (unlikely(mte_dead_node(mas->node)))
+       if (unlikely(ma_dead_node(node)))
                return 1;
 
        if (mas->max < min)
@@ -4239,7 +4260,7 @@ static inline int mas_prev_node(struct ma_state *mas, unsigned long min)
        while (level > 1) {
                level--;
                enode = mas_slot(mas, slots, offset);
-               if (unlikely(mte_dead_node(mas->node)))
+               if (unlikely(ma_dead_node(node)))
                        return 1;
 
                mas->node = enode;
@@ -4261,7 +4282,7 @@ static inline int mas_prev_node(struct ma_state *mas, unsigned long min)
        enode = mas_slot(mas, slots, offset);
        offset = mas_data_end(mas);
 
-       if (unlikely(mte_dead_node(mas->node)))
+       if (unlikely(ma_dead_node(node)))
                return 1;
 
        mas->node = enode;
@@ -4273,7 +4294,7 @@ no_entry_min:
        if (offset)
                mas->min = pivots[offset - 1] + 1;
 no_entry:
-       if (unlikely(mte_dead_node(mas->node)))
+       if (unlikely(ma_dead_node(node)))
                return 1;
 
        mas->node = MAS_NONE;
@@ -4288,11 +4309,11 @@ no_entry:
  * The next value will be mas->node[mas->offset] or MAS_NONE.
  * Return: 1 on dead node, 0 otherwise.
  */
-static inline int mas_next_node(struct ma_state *mas, unsigned long max)
+static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
+                               unsigned long max)
 {
        unsigned long min, pivot;
        unsigned long *pivots;
-       struct maple_node *node;
        struct maple_enode *enode;
        int level = 0;
        unsigned char offset;
@@ -4304,7 +4325,7 @@ static inline int mas_next_node(struct ma_state *mas, unsigned long max)
 
        level = 0;
        do {
-               if (mte_is_root(mas->node))
+               if (ma_is_root(node))
                        goto no_entry;
 
                min = mas->max + 1;
@@ -4316,9 +4337,9 @@ static inline int mas_next_node(struct ma_state *mas, unsigned long max)
 
                offset = mas->offset;
                level++;
+               node = mas_mn(mas);
        } while (unlikely(offset == mas_data_end(mas)));
 
-       node = mas_mn(mas);
        mt = mte_node_type(mas->node);
        slots = ma_slots(node, mt);
        pivots = ma_pivots(node, mt);
@@ -4326,7 +4347,7 @@ static inline int mas_next_node(struct ma_state *mas, unsigned long max)
        while (unlikely(level > 1)) {
                /* Descend, if necessary */
                enode = mas_slot(mas, slots, offset);
-               if (unlikely(mte_dead_node(mas->node)))
+               if (unlikely(ma_dead_node(node)))
                        return 1;
 
                mas->node = enode;
@@ -4340,7 +4361,7 @@ static inline int mas_next_node(struct ma_state *mas, unsigned long max)
        }
 
        enode = mas_slot(mas, slots, offset);
-       if (unlikely(mte_dead_node(mas->node)))
+       if (unlikely(ma_dead_node(node)))
                return 1;
 
        mas->node = enode;
@@ -4349,7 +4370,7 @@ static inline int mas_next_node(struct ma_state *mas, unsigned long max)
        return 0;
 
 no_entry:
-       if (unlikely(mte_dead_node(mas->node)))
+       if (unlikely(ma_dead_node(node)))
                return 1;
 
        mas->node = MAS_NONE;
@@ -4367,69 +4388,61 @@ no_entry:
  *
  * Return: The next entry, %NULL otherwise
  */
-static inline void *mas_next_nentry(struct ma_state *mas, unsigned long max,
-                                   enum maple_type type)
+static inline void *mas_next_nentry(struct ma_state *mas,
+           struct maple_node *node, unsigned long max, enum maple_type type)
 {
-       struct maple_node *node;
        unsigned long pivot = 0;
        unsigned char count;
        unsigned long *pivots;
        void __rcu **slots;
-       void *entry = NULL;
+       void *entry;
 
        if (mas->last == mas->max) {
                mas->index = mas->max;
                return NULL;
        }
 
-       node = mas_mn(mas);
-       type = mte_node_type(mas->node);
        pivots = ma_pivots(node, type);
        mas->index = mas_safe_min(mas, pivots, mas->offset);
        slots = ma_slots(node, type);
-
-       if (mte_dead_node(mas->node))
+       if (ma_dead_node(node))
                return NULL;
 
        if (mas->index > max)
-               goto no_entry;
+               return NULL;
 
        count = ma_data_end(node, type, pivots, mas->max);
        while (mas->offset < count) {
                pivot = pivots[mas->offset];
                entry = mas_slot(mas, slots, mas->offset);
-               if (mte_dead_node(mas->node))
+               if (ma_dead_node(node))
                        return NULL;
 
                if (entry)
                        goto found;
 
                if (pivot >= max)
-                       goto no_entry;
+                       return NULL;
 
                mas->index = pivot + 1;
                mas->offset++;
        }
 
-
        if (mas->index > mas->max) {
                mas->index = mas->last;
-               goto no_entry;
+               return NULL;
        }
 
        pivot = _mas_safe_pivot(mas, pivots, mas->offset, type);
        entry = mas_slot(mas, slots, mas->offset);
-       if (mte_dead_node(mas->node))
+       if (ma_dead_node(node))
                return NULL;
 
        if (!pivot)
-               goto no_entry;
-
-       if (entry)
-               goto found;
+               return NULL;
 
-no_entry:
-       return NULL;
+       if (!entry)
+               return NULL;
 
 found:
        mas->last = pivot;
@@ -4521,12 +4534,11 @@ static inline int mas_dead_node(struct ma_state *mas, unsigned long index)
  *
  * Return: The first entry or MAS_NONE.
  */
-static inline void *mas_first_entry(struct ma_state *mas,
+static inline void *mas_first_entry(struct ma_state *mas, struct maple_node *mn,
                unsigned long limit, enum maple_type mt)
 {
        unsigned long max;
        unsigned long *pivots;
-       struct maple_node *mn;
        void __rcu **slots;
        void *entry = NULL;
 
@@ -4537,22 +4549,21 @@ static inline void *mas_first_entry(struct ma_state *mas,
        max = mas->max;
        mas->offset = 0;
        while (likely(!ma_is_leaf(mt))) {
-               mn = mas_mn(mas);
                slots = ma_slots(mn, mt);
                pivots = ma_pivots(mn, mt);
                max = pivots[0];
                entry = mas_slot(mas, slots, 0);
-               if (unlikely(mte_dead_node(mas->node)))
+               if (unlikely(ma_dead_node(mn)))
                        return NULL;
                mas->node = entry;
+               mn = mas_mn(mas);
                mt = mte_node_type(mas->node);
        }
 
        mas->max = max;
-       mn = mas_mn(mas);
        slots = ma_slots(mn, mt);
        entry = mas_slot(mas, slots, 0);
-       if (unlikely(mte_dead_node(mas->node)))
+       if (unlikely(ma_dead_node(mn)))
                return NULL;
 
        /* Slot 0 or 1 must be set */
@@ -4566,7 +4577,7 @@ static inline void *mas_first_entry(struct ma_state *mas,
        mas->index = pivots[0] + 1;
        mas->offset = 1;
        entry = mas_slot(mas, slots, 1);
-       if (unlikely(mte_dead_node(mas->node)))
+       if (unlikely(ma_dead_node(mn)))
                return NULL;
 
        if (mas->index > limit)
@@ -4576,7 +4587,7 @@ static inline void *mas_first_entry(struct ma_state *mas,
                return entry;
 
 none:
-       if (likely(!mte_dead_node(mas->node)))
+       if (likely(!ma_dead_node(mn)))
                mas->node = MAS_NONE;
        return NULL;
 }
@@ -4597,6 +4608,7 @@ static inline void *_mas_next(struct ma_state *mas, unsigned long limit)
 {
        void *entry = NULL;
        struct maple_enode *prev_node;
+       struct maple_node *node;
        unsigned char offset;
        unsigned long last;
        enum maple_type mt;
@@ -4605,18 +4617,15 @@ static inline void *_mas_next(struct ma_state *mas, unsigned long limit)
 retry:
        offset = mas->offset;
        prev_node = mas->node;
+       node = mas_mn(mas);
        mt = mte_node_type(mas->node);
        mas->offset++;
        if (unlikely(mas->offset >= mt_slots[mt]))
                goto next_node;
 
        while (!mas_is_none(mas)) {
-               if (likely(ma_is_leaf(mt)))
-                       entry = mas_next_nentry(mas, limit, mt);
-               else
-                       entry = mas_first_entry(mas, limit, mt);
-
-               if (unlikely(mte_dead_node(mas->node))) {
+               entry = mas_next_nentry(mas, node, limit, mt);
+               if (unlikely(ma_dead_node(node))) {
                        mas_rewalk(mas, last);
                        goto retry;
                }
@@ -4630,12 +4639,12 @@ retry:
 next_node:
                prev_node = mas->node;
                offset = mas->offset;
-               if (unlikely(mas_next_node(mas, limit))) {
+               if (unlikely(mas_next_node(mas, node, limit))) {
                        mas_rewalk(mas, last);
                        goto retry;
                }
-
                mas->offset = 0;
+               node = mas_mn(mas);
                mt = mte_node_type(mas->node);
        }
 
@@ -4677,7 +4686,7 @@ retry:
        else
                pivot = pivots[offset];
 
-       if (mte_dead_node(mas->node)) {
+       if (ma_dead_node(mn)) {
                mas_rewalk(mas, index);
                goto retry;
        }
@@ -4687,7 +4696,7 @@ retry:
 
        min = mas_safe_min(mas, pivots, offset);
        entry = mas_slot(mas, slots, offset);
-       if (mte_dead_node(mas->node)) {
+       if (ma_dead_node(mn)) {
                mas_rewalk(mas, index);
                goto retry;
        }
@@ -4952,30 +4961,6 @@ void *mas_walk(struct ma_state *mas)
        return entry;
 }
 
-static inline bool mas_search_cont(struct ma_state *mas, unsigned long index,
-               unsigned long max, void *entry)
-{
-       if (index > max)
-               return false;
-
-       if (mas_is_start(mas))
-               return true;
-
-       if (index == max)
-               return false;
-
-       if (!mas_searchable(mas))
-               return false;
-
-       if (mas_is_err(mas))
-               return false;
-
-       if (entry)
-               return false;
-
-       return true;
-}
-
 static inline bool mas_rewind_node(struct ma_state *mas)
 {
        unsigned char slot;
@@ -5374,18 +5359,23 @@ void *_mt_find(struct maple_tree *mt, unsigned long *index, unsigned long max,
                entry = mas_get_slot(&mas, mas.offset);
 
        mas.last = range_end;
-       if (!entry || xa_is_zero(entry))
-               entry = NULL;
+       if (entry && !xa_is_zero(entry)) {
+               rcu_read_unlock();
+               goto done;
+       }
 
-       while (mas_search_cont(&mas, range_start, max, entry)) {
-               entry = _mas_find(&mas, max);
-               range_start = mas.index;
-               if (!entry || xa_is_zero(entry))
-                       entry = NULL;
+       mas.index = range_start;
+       while (mas_searchable(&mas) && (mas.index < max)) {
+               entry = _mas_next(&mas, max);
+               if (likely(entry && !xa_is_zero(entry)))
+                       break;
        }
        rcu_read_unlock();
 
-       if (entry){
+       if (unlikely(xa_is_zero(entry)))
+               entry = NULL;
+done:
+       if (likely(entry)) {
                *index = mas.last + 1;
 #ifdef CONFIG_DEBUG_MAPLE_TREE
                if ((*index) && (*index) <= copy)
@@ -6341,7 +6331,7 @@ static void mas_dfs_postorder(struct ma_state *mas, unsigned long max)
        struct maple_enode *p = MAS_NONE, *mn = mas->node;
        unsigned long p_min, p_max;
 
-       mas_next_node(mas, max);
+       mas_next_node(mas, mas_mn(mas), max);
        if (!mas_is_none(mas))
                return;
 
@@ -6757,7 +6747,7 @@ void mt_validate_nulls(struct maple_tree *mt)
                MT_BUG_ON(mt, !last && !entry);
                last = entry;
                if (offset == mas_data_end(&mas)) {
-                       mas_next_node(&mas, ULONG_MAX);
+                       mas_next_node(&mas, mas_mn(&mas), ULONG_MAX);
                        if (mas_is_none(&mas))
                                return;
                        offset = 0;
@@ -6783,7 +6773,7 @@ void mt_validate(struct maple_tree *mt)
        if (!mas_searchable(&mas))
                goto done;
 
-       mas_first_entry(&mas, ULONG_MAX, mte_node_type(mas.node));
+       mas_first_entry(&mas, mas_mn(&mas), ULONG_MAX, mte_node_type(mas.node));
        while (!mas_is_none(&mas)) {
                MT_BUG_ON(mas.tree, mte_dead_node(mas.node));
                if (!mte_is_root(mas.node)) {
index 3444625626b83b698a4a038ff773c1ea4ffe62e1..d5fa26a1883d62bf8c4791bd91580f479986d8ac 100644 (file)
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -188,10 +188,9 @@ static void remove_vma(struct vm_area_struct *vma)
 static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
                         unsigned long newbrk, unsigned long oldbrk,
                         struct list_head *uf);
-static int do_brk_flags(struct ma_state *mas, struct ma_state *ma_prev,
-                       struct vm_area_struct **brkvma,
-                       unsigned long addr, unsigned long request,
-                       unsigned long flags);
+static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma,
+                       struct ma_state *ma_prev, unsigned long addr,
+                       unsigned long request, unsigned long flags);
 SYSCALL_DEFINE1(brk, unsigned long, brk)
 {
        unsigned long newbrk, oldbrk, origbrk;
@@ -274,6 +273,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
        ma_neighbour = mas;
        mas_lock(&ma_neighbour);
        next = mas_next(&ma_neighbour, newbrk + PAGE_SIZE + stack_guard_gap);
+       brkvma = mas_prev(&ma_neighbour, mm->start_brk);
        mas_unlock(&ma_neighbour);
        /* Only check if the next VMA is within the stack_guard_gap of the
         * expansion area */
@@ -281,9 +281,6 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
        if (next && newbrk + PAGE_SIZE > vm_start_gap(next))
                goto out;
 
-       mas_lock(&ma_neighbour);
-       brkvma = mas_prev(&ma_neighbour, mm->start_brk);
-       mas_unlock(&ma_neighbour);
        if (brkvma) {
                if (brkvma->vm_start >= oldbrk)
                        goto out; // Trying to map over another vma.
@@ -293,7 +290,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
        }
 
        /* Ok, looks good - let it rip. */
-       if (do_brk_flags(&mas, &ma_neighbour, &brkvma, oldbrk,
+       if (do_brk_flags(&mas, brkvma, &ma_neighbour, oldbrk,
                         newbrk - oldbrk, 0) < 0)
                goto out;
 
@@ -521,35 +518,6 @@ void vma_mt_store(struct mm_struct *mm, struct vm_area_struct *vma)
                GFP_KERNEL);
 }
 
-/*
- * vma_mas_link() - Link a VMA into an mm
- * @mm: The mm struct
- * @vma: The VMA to link in
- * @mas: The maple state
- *
- * Must not hold the maple tree lock.
- */
-static void vma_mas_link(struct mm_struct *mm, struct vm_area_struct *vma,
-                        struct ma_state *mas)
-{
-       struct address_space *mapping = NULL;
-
-       if (vma->vm_file) {
-               mapping = vma->vm_file->f_mapping;
-               i_mmap_lock_write(mapping);
-       }
-
-       mas_lock(mas);
-       vma_mas_store(vma, mas);
-       mas_unlock(mas);
-       __vma_link_file(vma);
-
-       if (mapping)
-               i_mmap_unlock_write(mapping);
-
-       mm->map_count++;
-}
-
 static void vma_link(struct mm_struct *mm, struct vm_area_struct *vma)
 {
        struct address_space *mapping = NULL;
@@ -588,7 +556,7 @@ static void __insert_vm_struct(struct mm_struct *mm, struct vm_area_struct *vma)
  * @start: The start of the vma
  * @end: The exclusive end of the vma
  *
- * @mas must be locked
+ * @mas cannot be locked.
  */
 inline int vma_expand(struct ma_state *mas, struct vm_area_struct *vma,
                      unsigned long start, unsigned long end, pgoff_t pgoff,
@@ -619,16 +587,11 @@ inline int vma_expand(struct ma_state *mas, struct vm_area_struct *vma,
                root = &mapping->i_mmap;
                uprobe_munmap(vma, vma->vm_start, vma->vm_end);
                i_mmap_lock_write(mapping);
-       }
-
-       if (anon_vma) {
-               anon_vma_lock_write(anon_vma);
-               anon_vma_interval_tree_pre_update_vma(vma);
-       }
-
-       if (file) {
                flush_dcache_mmap_lock(mapping);
                vma_interval_tree_remove(vma, root);
+       } else if (anon_vma) {
+               anon_vma_lock_write(anon_vma);
+               anon_vma_interval_tree_pre_update_vma(vma);
        }
 
        vma->vm_start = start;
@@ -642,30 +605,24 @@ inline int vma_expand(struct ma_state *mas, struct vm_area_struct *vma,
        if (file) {
                vma_interval_tree_insert(vma, root);
                flush_dcache_mmap_unlock(mapping);
-       }
+               /* Expanding over the next vma */
+               if (remove_next)
+                       __remove_shared_vm_struct(next, file, mapping);
 
-       /* Expanding over the next vma */
-       if (remove_next && file) {
-               __remove_shared_vm_struct(next, file, mapping);
-       }
-
-       if (anon_vma) {
-               anon_vma_interval_tree_post_update_vma(vma);
-               anon_vma_unlock_write(anon_vma);
-       }
-
-       if (file) {
                i_mmap_unlock_write(mapping);
                uprobe_mmap(vma);
-       }
-
-       if (remove_next) {
-               if (file) {
+               if (remove_next) {
                        uprobe_munmap(next, next->vm_start, next->vm_end);
                        fput(file);
                }
-               if (next->anon_vma)
+       } else if (anon_vma) {
+               anon_vma_interval_tree_post_update_vma(vma);
+               anon_vma_unlock_write(anon_vma);
+               if (remove_next && next->anon_vma)
                        anon_vma_merge(vma, next);
+       }
+
+       if (remove_next) {
                mm->map_count--;
                mpol_put(vma_policy(next));
                vm_area_free(next);
@@ -2577,9 +2534,7 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
                 * requested mapping. Account for the pages it would unmap.
                 */
                nr_pages = count_vma_pages_range(&mas);
-
-               if (!may_expand_vm(mm, vm_flags,
-                                       (len >> PAGE_SHIFT) - nr_pages))
+               if (!may_expand_vm(mm, vm_flags, (len >> PAGE_SHIFT) - nr_pages))
                        return -ENOMEM;
        }
 
@@ -2589,6 +2544,9 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
                mas_unlock(&mas);
                return -ENOMEM;
        }
+       /* Get next and prev */
+       next = mas_next(&mas, ULONG_MAX);
+       prev = mas_prev(&mas, 0);
        mas_unlock(&mas);
 
        /*
@@ -2601,20 +2559,11 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
                vm_flags |= VM_ACCOUNT;
        }
 
-
-       if (vm_flags & VM_SPECIAL) {
-               rcu_read_lock();
-               prev = mas_prev(&mas, 0);
-               rcu_read_unlock();
+       if (vm_flags & VM_SPECIAL)
                goto cannot_expand;
-       }
 
        /* Attempt to expand an old mapping */
-
        /* Check next */
-       rcu_read_lock();
-       next = mas_next(&mas, ULONG_MAX);
-       rcu_read_unlock();
        if (next && next->vm_start == end && vma_policy(next) &&
            can_vma_merge_before(next, vm_flags, NULL, file, pgoff+pglen,
                                 NULL_VM_UFFD_CTX)) {
@@ -2624,9 +2573,6 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
        }
 
        /* Check prev */
-       rcu_read_lock();
-       prev = mas_prev(&mas, 0);
-       rcu_read_unlock();
        if (prev && prev->vm_end == addr && !vma_policy(prev) &&
            can_vma_merge_after(prev, vm_flags, NULL, file, pgoff,
                                NULL_VM_UFFD_CTX)) {
@@ -2635,7 +2581,6 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
                vm_pgoff = prev->vm_pgoff;
        }
 
-
        /* Actually expand, if possible */
        if (vma &&
            !vma_expand(&mas, vma, merge_start, merge_end, vm_pgoff, next)) {
@@ -2735,7 +2680,23 @@ cannot_expand:
                        goto free_vma;
        }
 
-       vma_mas_link(mm, vma, &mas);
+       if (vma->vm_file)
+               i_mmap_lock_write(vma->vm_file->f_mapping);
+
+       mas_lock(&mas);
+       vma_mas_store(vma, &mas);
+       mas_unlock(&mas);
+       mm->map_count++;
+       if (vma->vm_file) {
+               if (vma->vm_flags & VM_SHARED)
+                       mapping_allow_writable(vma->vm_file->f_mapping);
+
+               flush_dcache_mmap_lock(vma->vm_file->f_mapping);
+               vma_interval_tree_insert(vma, &vma->vm_file->f_mapping->i_mmap);
+               flush_dcache_mmap_unlock(vma->vm_file->f_mapping);
+               i_mmap_unlock_write(vma->vm_file->f_mapping);
+       }
+
        /* Once vma denies write, undo our temporary denial count */
 unmap_writable:
        if (file && vm_flags & VM_SHARED)
@@ -2949,7 +2910,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
        unmap.vm_start = newbrk;
        unmap.vm_end = oldbrk;
        unmap.vm_pgoff = newbrk >> PAGE_SHIFT;
-       if (vma->anon_vma)
+       if (likely(vma->anon_vma))
                vma_set_anonymous(&unmap);
 
        ret = userfaultfd_unmap_prep(&unmap, newbrk, oldbrk, uf);
@@ -2960,7 +2921,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
 
        // Change the oldbrk of vma to the newbrk of the munmap area
        vma_adjust_trans_huge(vma, vma->vm_start, newbrk, 0);
-       if (vma->anon_vma) {
+       if (likely(vma->anon_vma)) {
                anon_vma_lock_write(vma->anon_vma);
                anon_vma_interval_tree_pre_update_vma(vma);
        }
@@ -2968,10 +2929,11 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
        mas_lock(mas);
        if (vma_mas_remove(&unmap, mas))
                goto mas_store_fail;
+       next = mas_next(mas, ULONG_MAX);
        mas_unlock(mas);
 
        vma->vm_end = newbrk;
-       if (vma->anon_vma) {
+       if (likely(vma->anon_vma)) {
                anon_vma_interval_tree_post_update_vma(vma);
                anon_vma_unlock_write(vma->anon_vma);
        }
@@ -2982,9 +2944,6 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
                munlock_vma_pages_range(&unmap, newbrk, oldbrk);
        }
 
-       mas_lock(mas);
-       next = mas_next(mas, ULONG_MAX);
-       mas_unlock(mas);
        mmap_write_downgrade(mm);
        unmap_region(mm, mas->tree, &unmap, vma, next, newbrk, oldbrk);
        /* Statistics */
@@ -3018,13 +2977,11 @@ mas_store_fail:
  * do not match then create a new anonymous VMA.  Eventually we may be able to
  * do some brk-specific accounting here.
  */
-static int do_brk_flags(struct ma_state *mas, struct ma_state *ma_prev,
-                       struct vm_area_struct **brkvma,
-                       unsigned long addr, unsigned long len,
-                       unsigned long flags)
+static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma,
+                       struct ma_state *ma_prev, unsigned long addr,
+                       unsigned long len, unsigned long flags)
 {
        struct mm_struct *mm = current->mm;
-       struct vm_area_struct *vma;
        int error;
        unsigned long mapped_addr;
        validate_mm(mm);
@@ -3052,58 +3009,54 @@ static int do_brk_flags(struct ma_state *mas, struct ma_state *ma_prev,
        if (security_vm_enough_memory_mm(mm, len >> PAGE_SHIFT))
                return -ENOMEM;
 
-       if (*brkvma) {
-               vma = *brkvma;
-               /* Expand the existing vma if possible; almost never a singular
-                * list, so this will almost always fail. */
-
-               if ((!vma->anon_vma ||
-                    list_is_singular(&vma->anon_vma_chain)) &&
-                    ((vma->vm_flags & ~VM_SOFTDIRTY) == flags)){
-                       ma_prev->index = vma->vm_start;
-                       ma_prev->last = addr + len - 1;
-
-                       vma_adjust_trans_huge(vma, addr, addr + len, 0);
-                       if (vma->anon_vma) {
-                               anon_vma_lock_write(vma->anon_vma);
-                               anon_vma_interval_tree_pre_update_vma(vma);
-                       }
-                       vma->vm_end = addr + len;
-                       vma->vm_flags |= VM_SOFTDIRTY;
+       /* Expand the existing vma if possible; almost never a singular
+        * list, so this will almost always fail. */
+       if (vma &&
+           (unlikely((!vma->anon_vma ||
+                      list_is_singular(&vma->anon_vma_chain))) &&
+            ((vma->vm_flags & ~VM_SOFTDIRTY) == flags))) {
+               ma_prev->index = vma->vm_start;
+               ma_prev->last = addr + len - 1;
+               vma_adjust_trans_huge(vma, addr, addr + len, 0);
+               if (likely(vma->anon_vma)) {
+                       anon_vma_lock_write(vma->anon_vma);
+                       anon_vma_interval_tree_pre_update_vma(vma);
+               }
 
-                       mas_lock(ma_prev);
+               vma->vm_end = addr + len;
+               vma->vm_flags |= VM_SOFTDIRTY;
+               mas_lock(ma_prev);
 #if defined(CONFIG_DEBUG_MAPLE_TREE)
-                       /* Make sure no VMAs are about to be lost. */
-                       {
-                               MA_STATE(test, ma_prev->tree, vma->vm_start,
-                                        vma->vm_end - 1);
-                               struct vm_area_struct *vma_mas;
-                               int count = 0;
-
-                               mas_for_each(&test, vma_mas, vma->vm_end - 1)
-                                       count++;
-
-                               BUG_ON(count != 1);
-                               mas_set_range(ma_prev, vma->vm_start,
-                                             vma->vm_end - 1);
-                       }
+               /* Make sure no VMAs are about to be lost. */
+               {
+                       MA_STATE(test, ma_prev->tree, vma->vm_start,
+                                vma->vm_end - 1);
+                       struct vm_area_struct *vma_mas;
+                       int count = 0;
+
+                       mas_for_each(&test, vma_mas, vma->vm_end - 1)
+                               count++;
+
+                       BUG_ON(count != 1);
+                       mas_set_range(ma_prev, vma->vm_start,
+                                     vma->vm_end - 1);
+               }
 #endif
-                       if (mas_store_gfp(ma_prev, vma, GFP_KERNEL))
-                               goto mas_mod_fail;
-                       mas_unlock(ma_prev);
+               if (unlikely(mas_store_gfp(ma_prev, vma, GFP_KERNEL)))
+                       goto mas_mod_fail;
+               mas_unlock(ma_prev);
 
-                       if (vma->anon_vma) {
-                               anon_vma_interval_tree_post_update_vma(vma);
-                               anon_vma_unlock_write(vma->anon_vma);
-                       }
-                       khugepaged_enter_vma_merge(vma, flags);
-                       goto out;
+               if (likely(vma->anon_vma)) {
+                       anon_vma_interval_tree_post_update_vma(vma);
+                       anon_vma_unlock_write(vma->anon_vma);
                }
+               khugepaged_enter_vma_merge(vma, flags);
+               goto out;
        }
 
        /* create a vma struct for an anonymous mapping */
        vma = vm_area_alloc(mm);
-       if (!vma)
+       if (unlikely(!vma))
                goto vma_alloc_fail;
 
        vma_set_anonymous(vma);
@@ -3112,8 +3065,10 @@ static int do_brk_flags(struct ma_state *mas, struct ma_state *ma_prev,
        vma->vm_pgoff = addr >> PAGE_SHIFT;
        vma->vm_flags = flags;
        vma->vm_page_prot = vm_get_page_prot(flags);
-       vma_mas_link(mm, vma, mas);
-       *brkvma = vma;
+       mas_lock(mas);
+       vma_mas_store(vma, mas);
+       mas_unlock(mas);
+       mm->map_count++;
 out:
        perf_event_mmap(vma);
        mm->total_vm += len >> PAGE_SHIFT;
@@ -3143,7 +3098,6 @@ mas_mod_fail:
 int vm_brk_flags(unsigned long addr, unsigned long request, unsigned long flags)
 {
        struct mm_struct *mm = current->mm;
-       struct vm_area_struct *vma = NULL;
        unsigned long len;
        int ret;
        bool populate;
@@ -3158,7 +3112,7 @@ int vm_brk_flags(unsigned long addr, unsigned long request, unsigned long flags)
        if (mmap_write_lock_killable(mm))
                return -EINTR;
 
-       ret = do_brk_flags(&mas, &mas, &vma, addr, len, flags);
+       ret = do_brk_flags(&mas, NULL, &mas, addr, len, flags);
        populate = ((mm->def_flags & VM_LOCKED) != 0);
        mmap_write_unlock(mm);
        if (populate && !ret)