From: Nandakumar Edamana Date: Tue, 26 Aug 2025 03:45:23 +0000 (+0530) Subject: bpf: Improve the general precision of tnum_mul X-Git-Url: https://www.infradead.org/git/?a=commitdiff_plain;h=1df7dad4d5c49335b72e26d833def960b2de76e3;p=users%2Fhch%2Fmisc.git bpf: Improve the general precision of tnum_mul Drop the value-mask decomposition technique and adopt straightforward long-multiplication with a twist: when LSB(a) is uncertain, find the two partial products (for LSB(a) = known 0 and LSB(a) = known 1) and take a union. Experiment shows that applying this technique in long multiplication improves the precision in a significant number of cases (at the cost of losing precision in a relatively lower number of cases). Signed-off-by: Nandakumar Edamana Signed-off-by: Andrii Nakryiko Tested-by: Harishankar Vishwanathan Reviewed-by: Harishankar Vishwanathan Acked-by: Eduard Zingerman Link: https://lore.kernel.org/bpf/20250826034524.2159515-1-nandakumar@nandakumar.co.in --- diff --git a/include/linux/tnum.h b/include/linux/tnum.h index 0ffb77ffe0e8..c52b862dad45 100644 --- a/include/linux/tnum.h +++ b/include/linux/tnum.h @@ -57,6 +57,9 @@ bool tnum_overlap(struct tnum a, struct tnum b); /* Return a tnum representing numbers satisfying both @a and @b */ struct tnum tnum_intersect(struct tnum a, struct tnum b); +/* Returns a tnum representing numbers satisfying either @a or @b */ +struct tnum tnum_union(struct tnum t1, struct tnum t2); + /* Return @a with all but the lowest @size bytes cleared */ struct tnum tnum_cast(struct tnum a, u8 size); diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c index d9328bbb3680..f8e70e9c3998 100644 --- a/kernel/bpf/tnum.c +++ b/kernel/bpf/tnum.c @@ -116,31 +116,47 @@ struct tnum tnum_xor(struct tnum a, struct tnum b) return TNUM(v & ~mu, mu); } -/* Generate partial products by multiplying each bit in the multiplier (tnum a) - * with the multiplicand (tnum b), and add the partial products after - * appropriately bit-shifting them. Instead of directly performing tnum addition - * on the generated partial products, equivalenty, decompose each partial - * product into two tnums, consisting of the value-sum (acc_v) and the - * mask-sum (acc_m) and then perform tnum addition on them. The following paper - * explains the algorithm in more detail: https://arxiv.org/abs/2105.05398. +/* Perform long multiplication, iterating through the bits in a using rshift: + * - if LSB(a) is a known 0, keep current accumulator + * - if LSB(a) is a known 1, add b to current accumulator + * - if LSB(a) is unknown, take a union of the above cases. + * + * For example: + * + * acc_0: acc_1: + * + * 11 * -> 11 * -> 11 * -> union(0011, 1001) == x0x1 + * x1 01 11 + * ------ ------ ------ + * 11 11 11 + * xx 00 11 + * ------ ------ ------ + * ???? 0011 1001 */ struct tnum tnum_mul(struct tnum a, struct tnum b) { - u64 acc_v = a.value * b.value; - struct tnum acc_m = TNUM(0, 0); + struct tnum acc = TNUM(0, 0); while (a.value || a.mask) { /* LSB of tnum a is a certain 1 */ if (a.value & 1) - acc_m = tnum_add(acc_m, TNUM(0, b.mask)); + acc = tnum_add(acc, b); /* LSB of tnum a is uncertain */ - else if (a.mask & 1) - acc_m = tnum_add(acc_m, TNUM(0, b.value | b.mask)); + else if (a.mask & 1) { + /* acc = tnum_union(acc_0, acc_1), where acc_0 and + * acc_1 are partial accumulators for cases + * LSB(a) = certain 0 and LSB(a) = certain 1. + * acc_0 = acc + 0 * b = acc. + * acc_1 = acc + 1 * b = tnum_add(acc, b). + */ + + acc = tnum_union(acc, tnum_add(acc, b)); + } /* Note: no case for LSB is certain 0 */ a = tnum_rshift(a, 1); b = tnum_lshift(b, 1); } - return tnum_add(TNUM(acc_v, 0), acc_m); + return acc; } bool tnum_overlap(struct tnum a, struct tnum b) @@ -163,6 +179,19 @@ struct tnum tnum_intersect(struct tnum a, struct tnum b) return TNUM(v & ~mu, mu); } +/* Returns a tnum with the uncertainty from both a and b, and in addition, new + * uncertainty at any position that a and b disagree. This represents a + * superset of the union of the concrete sets of both a and b. Despite the + * overapproximation, it is optimal. + */ +struct tnum tnum_union(struct tnum a, struct tnum b) +{ + u64 v = a.value & b.value; + u64 mu = (a.value ^ b.value) | a.mask | b.mask; + + return TNUM(v & ~mu, mu); +} + struct tnum tnum_cast(struct tnum a, u8 size) { a.value &= (1ULL << (size * 8)) - 1;