return TNUM(v & ~mu, mu);
 }
 
-/* half-multiply add: acc += (unknown * mask * value).
- * An intermediate step in the multiply algorithm.
+/* Generate partial products by multiplying each bit in the multiplier (tnum a)
+ * with the multiplicand (tnum b), and add the partial products after
+ * appropriately bit-shifting them. Instead of directly performing tnum addition
+ * on the generated partial products, equivalenty, decompose each partial
+ * product into two tnums, consisting of the value-sum (acc_v) and the
+ * mask-sum (acc_m) and then perform tnum addition on them. The following paper
+ * explains the algorithm in more detail: https://arxiv.org/abs/2105.05398.
  */
-static struct tnum hma(struct tnum acc, u64 value, u64 mask)
-{
-       while (mask) {
-               if (mask & 1)
-                       acc = tnum_add(acc, TNUM(0, value));
-               mask >>= 1;
-               value <<= 1;
-       }
-       return acc;
-}
-
 struct tnum tnum_mul(struct tnum a, struct tnum b)
 {
-       struct tnum acc;
-       u64 pi;
-
-       pi = a.value * b.value;
-       acc = hma(TNUM(pi, 0), a.mask, b.mask | b.value);
-       return hma(acc, b.mask, a.value);
+       u64 acc_v = a.value * b.value;
+       struct tnum acc_m = TNUM(0, 0);
+
+       while (a.value || a.mask) {
+               /* LSB of tnum a is a certain 1 */
+               if (a.value & 1)
+                       acc_m = tnum_add(acc_m, TNUM(0, b.mask));
+               /* LSB of tnum a is uncertain */
+               else if (a.mask & 1)
+                       acc_m = tnum_add(acc_m, TNUM(0, b.value | b.mask));
+               /* Note: no case for LSB is certain 0 */
+               a = tnum_rshift(a, 1);
+               b = tnum_lshift(b, 1);
+       }
+       return tnum_add(TNUM(acc_v, 0), acc_m);
 }
 
 /* Note that if a and b disagree - i.e. one has a 'known 1' where the other has