From: Liam R. Howlett Date: Thu, 24 Nov 2022 15:46:45 +0000 (-0500) Subject: maple_tree: Protect ma_pivots() from using dead nodes X-Git-Url: https://www.infradead.org/git/?a=commitdiff_plain;h=07ea44554c32fcc94fa95b3c339b8707a63822f7;p=users%2Fjedix%2Flinux-maple.git maple_tree: Protect ma_pivots() from using dead nodes Ensure the node type being used isn't going to return an invalid pointer. Signed-off-by: Liam R. Howlett --- diff --git a/lib/maple_tree.c b/lib/maple_tree.c index 27913ffd3795..6e441b2eec8a 100644 --- a/lib/maple_tree.c +++ b/lib/maple_tree.c @@ -527,6 +527,12 @@ static inline bool ma_dead_node(const struct maple_node *node) return (parent == node); } + +static inline bool ma_dead_pivots(const unsigned long *piv) +{ + return (piv == NULL); +} + /* * mte_dead_node() - check if the @enode is dead. * @enode: The encoded maple node @@ -613,6 +619,9 @@ static inline unsigned int mas_alloc_req(const struct ma_state *mas) static inline unsigned long *ma_pivots(struct maple_node *node, enum maple_type type) { + if (unlikely(ma_dead_node(node))) + return NULL; + switch (type) { case maple_arange_64: return node->ma64.pivot; @@ -1079,8 +1088,11 @@ static int mas_ascend(struct ma_state *mas) a_type = mas_parent_enum(mas, p_enode); a_node = mte_parent(p_enode); a_slot = mte_parent_slot(p_enode); - pivots = ma_pivots(a_node, a_type); a_enode = mt_mk_node(a_node, a_type); + pivots = ma_pivots(a_node, a_type); + + if (unlikely(ma_dead_pivots(pivots))) + return 1; if (!set_min && a_slot) { set_min = true; @@ -1421,6 +1433,9 @@ static inline unsigned char mas_data_end(struct ma_state *mas) return ma_meta_end(node, type); pivots = ma_pivots(node, type); + if (unlikely(ma_dead_pivots(pivots))) + return 0; + offset = mt_pivots[type] - 1; if (likely(!pivots[offset])) return ma_meta_end(node, type); @@ -4490,6 +4505,9 @@ static inline int mas_prev_node(struct ma_state *mas, unsigned long min) node = mas_mn(mas); slots = ma_slots(node, mt); pivots = ma_pivots(node, mt); + if (unlikely(ma_dead_pivots(pivots))) + return 1; + mas->max = pivots[offset]; if (offset) mas->min = pivots[offset - 1] + 1; @@ -4511,6 +4529,9 @@ static inline int mas_prev_node(struct ma_state *mas, unsigned long min) slots = ma_slots(node, mt); pivots = ma_pivots(node, mt); offset = ma_data_end(node, mt, pivots, mas->max); + if (unlikely(ma_dead_node(node))) + return 1; + if (offset) mas->min = pivots[offset - 1] + 1; @@ -4582,6 +4603,8 @@ static inline int mas_next_node(struct ma_state *mas, struct maple_node *node, node = mas_mn(mas); mt = mte_node_type(mas->node); pivots = ma_pivots(node, mt); + if (unlikely(ma_dead_pivots(pivots))) + return 1; } while (unlikely(offset == ma_data_end(node, mt, pivots, mas->max))); slots = ma_slots(node, mt); @@ -4598,6 +4621,9 @@ static inline int mas_next_node(struct ma_state *mas, struct maple_node *node, mt = mte_node_type(mas->node); slots = ma_slots(node, mt); pivots = ma_pivots(node, mt); + if (unlikely(ma_dead_pivots(pivots))) + return 1; + offset = 0; pivot = pivots[0]; } @@ -4644,13 +4670,15 @@ static inline void *mas_next_nentry(struct ma_state *mas, return NULL; } - pivots = ma_pivots(node, type); slots = ma_slots(node, type); - mas->index = mas_safe_min(mas, pivots, mas->offset); + pivots = ma_pivots(node, type); count = ma_data_end(node, type, pivots, mas->max); - if (ma_dead_node(node)) + if (unlikely(ma_dead_node(node))) return NULL; + mas->index = mas_safe_min(mas, pivots, mas->offset); + if (unlikely(ma_dead_node(node))) + return NULL; if (mas->index > max) return NULL; @@ -4806,6 +4834,11 @@ retry: slots = ma_slots(mn, mt); pivots = ma_pivots(mn, mt); + if (unlikely(ma_dead_pivots(pivots))) { + mas_rewalk(mas, index); + goto retry; + } + if (offset == mt_pivots[mt]) pivot = mas->max; else @@ -6609,11 +6642,11 @@ static inline void *mas_first_entry(struct ma_state *mas, struct maple_node *mn, while (likely(!ma_is_leaf(mt))) { MT_BUG_ON(mas->tree, mte_dead_node(mas->node)); slots = ma_slots(mn, mt); - pivots = ma_pivots(mn, mt); - max = pivots[0]; entry = mas_slot(mas, slots, 0); - if (unlikely(ma_dead_node(mn))) + pivots = ma_pivots(mn, mt); + if (unlikely(ma_dead_pivots(pivots))) return NULL; + max = pivots[0]; mas->node = entry; mn = mas_mn(mas); mt = mte_node_type(mas->node); @@ -6633,13 +6666,13 @@ static inline void *mas_first_entry(struct ma_state *mas, struct maple_node *mn, if (likely(entry)) return entry; - pivots = ma_pivots(mn, mt); - mas->index = pivots[0] + 1; mas->offset = 1; entry = mas_slot(mas, slots, 1); - if (unlikely(ma_dead_node(mn))) + pivots = ma_pivots(mn, mt); + if (unlikely(ma_dead_pivots(pivots))) return NULL; + mas->index = pivots[0] + 1; if (mas->index > limit) goto none;