From: Liam R. Howlett Date: Fri, 2 Sep 2022 15:23:21 +0000 (-0400) Subject: maple_tree: Return error on mte_pivots() out of range X-Git-Url: https://www.infradead.org/git/?a=commitdiff_plain;h=c5d00729c58cfdfe52b2915b15727a054cfbbdb7;p=users%2Fjedix%2Flinux-maple.git maple_tree: Return error on mte_pivots() out of range Rename mte_pivots() to mas_pivots() and pass through the ma_state to set the error code to -EIO when the offset is out of range for the node type. Signed-off-by: Liam R. Howlett --- diff --git a/lib/maple_tree.c b/lib/maple_tree.c index 9401136d77dc..8645b9b86719 100644 --- a/lib/maple_tree.c +++ b/lib/maple_tree.c @@ -662,22 +662,21 @@ static inline unsigned long *ma_gaps(struct maple_node *node, } /* - * mte_pivot() - Get the pivot at @piv of the maple encoded node. - * @mn: The maple encoded node. + * mas_pivot() - Get the pivot at @piv of the maple encoded node. + * @mas: The maple state. * @piv: The pivot. * * Return: the pivot at @piv of @mn. */ -static inline unsigned long mte_pivot(const struct maple_enode *mn, - unsigned char piv) +static inline unsigned long mas_pivot(struct ma_state *mas, unsigned char piv) { - struct maple_node *node = mte_to_node(mn); + struct maple_node *node = mas_mn(mas); - if (piv >= mt_pivots[piv]) { - WARN_ON(1); + if (WARN_ON(piv >= mt_pivots[piv])) { + mas_set_err(mas, -EIO); return 0; } - switch (mte_node_type(mn)) { + switch (mte_node_type(mas->node)) { case maple_arange_64: return node->ma64.pivot[piv]; case maple_range_64: @@ -5357,8 +5356,8 @@ static inline int mas_alloc(struct ma_state *mas, void *entry, return xa_err(mas->node); if (!mas->index) - return mte_pivot(mas->node, 0); - return mte_pivot(mas->node, 1); + return mas_pivot(mas, 0); + return mas_pivot(mas, 1); } /* Must be walking a tree. */ @@ -5375,7 +5374,10 @@ static inline int mas_alloc(struct ma_state *mas, void *entry, */ min = mas->min; if (mas->offset) - min = mte_pivot(mas->node, mas->offset - 1) + 1; + min = mas_pivot(mas, mas->offset - 1) + 1; + + if (mas_is_err(mas)) + return xa_err(mas->node); if (mas->index < min) mas->index = min;