]> www.infradead.org Git - users/jedix/linux-maple.git/commitdiff
maple_tree: Protect ma_pivots() from using dead nodes
authorLiam R. Howlett <Liam.Howlett@Oracle.com>
Thu, 24 Nov 2022 15:46:45 +0000 (10:46 -0500)
committerLiam R. Howlett <Liam.Howlett@oracle.com>
Tue, 13 Dec 2022 21:03:40 +0000 (16:03 -0500)
Ensure the node type being used isn't going to return an invalid
pointer.

Signed-off-by: Liam R. Howlett <Liam.Howlett@oracle.com>
lib/maple_tree.c

index 27913ffd379540e6e69418f247ff52f34ad9d097..6e441b2eec8ab3427c6cafc4d72f15ece6c07bb9 100644 (file)
@@ -527,6 +527,12 @@ static inline bool ma_dead_node(const struct maple_node *node)
 
        return (parent == node);
 }
+
+static inline bool ma_dead_pivots(const unsigned long *piv)
+{
+       return (piv == NULL);
+}
+
 /*
  * mte_dead_node() - check if the @enode is dead.
  * @enode: The encoded maple node
@@ -613,6 +619,9 @@ static inline unsigned int mas_alloc_req(const struct ma_state *mas)
 static inline unsigned long *ma_pivots(struct maple_node *node,
                                           enum maple_type type)
 {
+       if (unlikely(ma_dead_node(node)))
+               return NULL;
+
        switch (type) {
        case maple_arange_64:
                return node->ma64.pivot;
@@ -1079,8 +1088,11 @@ static int mas_ascend(struct ma_state *mas)
                a_type = mas_parent_enum(mas, p_enode);
                a_node = mte_parent(p_enode);
                a_slot = mte_parent_slot(p_enode);
-               pivots = ma_pivots(a_node, a_type);
                a_enode = mt_mk_node(a_node, a_type);
+               pivots = ma_pivots(a_node, a_type);
+
+               if (unlikely(ma_dead_pivots(pivots)))
+                       return 1;
 
                if (!set_min && a_slot) {
                        set_min = true;
@@ -1421,6 +1433,9 @@ static inline unsigned char mas_data_end(struct ma_state *mas)
                return ma_meta_end(node, type);
 
        pivots = ma_pivots(node, type);
+       if (unlikely(ma_dead_pivots(pivots)))
+               return 0;
+
        offset = mt_pivots[type] - 1;
        if (likely(!pivots[offset]))
                return ma_meta_end(node, type);
@@ -4490,6 +4505,9 @@ static inline int mas_prev_node(struct ma_state *mas, unsigned long min)
        node = mas_mn(mas);
        slots = ma_slots(node, mt);
        pivots = ma_pivots(node, mt);
+       if (unlikely(ma_dead_pivots(pivots)))
+               return 1;
+
        mas->max = pivots[offset];
        if (offset)
                mas->min = pivots[offset - 1] + 1;
@@ -4511,6 +4529,9 @@ static inline int mas_prev_node(struct ma_state *mas, unsigned long min)
                slots = ma_slots(node, mt);
                pivots = ma_pivots(node, mt);
                offset = ma_data_end(node, mt, pivots, mas->max);
+               if (unlikely(ma_dead_node(node)))
+                       return 1;
+
                if (offset)
                        mas->min = pivots[offset - 1] + 1;
 
@@ -4582,6 +4603,8 @@ static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
                node = mas_mn(mas);
                mt = mte_node_type(mas->node);
                pivots = ma_pivots(node, mt);
+               if (unlikely(ma_dead_pivots(pivots)))
+                   return 1;
        } while (unlikely(offset == ma_data_end(node, mt, pivots, mas->max)));
 
        slots = ma_slots(node, mt);
@@ -4598,6 +4621,9 @@ static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
                mt = mte_node_type(mas->node);
                slots = ma_slots(node, mt);
                pivots = ma_pivots(node, mt);
+               if (unlikely(ma_dead_pivots(pivots)))
+                   return 1;
+
                offset = 0;
                pivot = pivots[0];
        }
@@ -4644,13 +4670,15 @@ static inline void *mas_next_nentry(struct ma_state *mas,
                return NULL;
        }
 
-       pivots = ma_pivots(node, type);
        slots = ma_slots(node, type);
-       mas->index = mas_safe_min(mas, pivots, mas->offset);
+       pivots = ma_pivots(node, type);
        count = ma_data_end(node, type, pivots, mas->max);
-       if (ma_dead_node(node))
+       if (unlikely(ma_dead_node(node)))
                return NULL;
 
+       mas->index = mas_safe_min(mas, pivots, mas->offset);
+       if (unlikely(ma_dead_node(node)))
+               return NULL;
        if (mas->index > max)
                return NULL;
 
@@ -4806,6 +4834,11 @@ retry:
 
        slots = ma_slots(mn, mt);
        pivots = ma_pivots(mn, mt);
+       if (unlikely(ma_dead_pivots(pivots))) {
+               mas_rewalk(mas, index);
+               goto retry;
+       }
+
        if (offset == mt_pivots[mt])
                pivot = mas->max;
        else
@@ -6609,11 +6642,11 @@ static inline void *mas_first_entry(struct ma_state *mas, struct maple_node *mn,
        while (likely(!ma_is_leaf(mt))) {
                MT_BUG_ON(mas->tree, mte_dead_node(mas->node));
                slots = ma_slots(mn, mt);
-               pivots = ma_pivots(mn, mt);
-               max = pivots[0];
                entry = mas_slot(mas, slots, 0);
-               if (unlikely(ma_dead_node(mn)))
+               pivots = ma_pivots(mn, mt);
+               if (unlikely(ma_dead_pivots(pivots)))
                        return NULL;
+               max = pivots[0];
                mas->node = entry;
                mn = mas_mn(mas);
                mt = mte_node_type(mas->node);
@@ -6633,13 +6666,13 @@ static inline void *mas_first_entry(struct ma_state *mas, struct maple_node *mn,
        if (likely(entry))
                return entry;
 
-       pivots = ma_pivots(mn, mt);
-       mas->index = pivots[0] + 1;
        mas->offset = 1;
        entry = mas_slot(mas, slots, 1);
-       if (unlikely(ma_dead_node(mn)))
+       pivots = ma_pivots(mn, mt);
+       if (unlikely(ma_dead_pivots(pivots)))
                return NULL;
 
+       mas->index = pivots[0] + 1;
        if (mas->index > limit)
                goto none;