From: Haoran Jiang Date: Tue, 5 Aug 2025 11:00:22 +0000 (+0800) Subject: LoongArch: BPF: Fix the tailcall hierarchy X-Git-Url: https://www.infradead.org/git/?a=commitdiff_plain;h=c0fcc955ff827431b541b1aa6bcb82bdce4531f7;p=users%2Fwilly%2Fxarray.git LoongArch: BPF: Fix the tailcall hierarchy In specific use cases combining tailcalls and BPF-to-BPF calls, MAX_TAIL_CALL_CNT won't work because of missing tail_call_cnt back-propagation from callee to caller. This patch fixes this tailcall issue caused by abusing the tailcall in bpf2bpf feature on LoongArch like the way of "bpf, x64: Fix tailcall hierarchy". Push tail_call_cnt_ptr and tail_call_cnt into the stack, tail_call_cnt_ptr is passed between tailcall and bpf2bpf, uses tail_call_cnt_ptr to increment tail_call_cnt. Fixes: bb035ef0cc91 ("LoongArch: BPF: Support mixing bpf2bpf and tailcalls") Reviewed-by: Geliang Tang Reviewed-by: Hengqi Chen Signed-off-by: Haoran Jiang Signed-off-by: Huacai Chen --- diff --git a/arch/loongarch/net/bpf_jit.c b/arch/loongarch/net/bpf_jit.c index f4f12ed16d2f..4ea8ae4cf0ca 100644 --- a/arch/loongarch/net/bpf_jit.c +++ b/arch/loongarch/net/bpf_jit.c @@ -17,10 +17,7 @@ #define LOONGARCH_BPF_FENTRY_NBYTES (LOONGARCH_LONG_JUMP_NINSNS * 4) #define REG_TCC LOONGARCH_GPR_A6 -#define TCC_SAVED LOONGARCH_GPR_S5 - -#define SAVE_RA BIT(0) -#define SAVE_TCC BIT(1) +#define BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack) (round_up(stack, 16) - 80) static const int regmap[] = { /* return value from in-kernel function, and exit value for eBPF program */ @@ -42,32 +39,57 @@ static const int regmap[] = { [BPF_REG_AX] = LOONGARCH_GPR_T0, }; -static void mark_call(struct jit_ctx *ctx) +static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx, int *store_offset) { - ctx->flags |= SAVE_RA; -} + const struct bpf_prog *prog = ctx->prog; + const bool is_main_prog = !bpf_is_subprog(prog); -static void mark_tail_call(struct jit_ctx *ctx) -{ - ctx->flags |= SAVE_TCC; -} + if (is_main_prog) { + /* + * LOONGARCH_GPR_T3 = MAX_TAIL_CALL_CNT + * if (REG_TCC > T3 ) + * std REG_TCC -> LOONGARCH_GPR_SP + store_offset + * else + * std REG_TCC -> LOONGARCH_GPR_SP + store_offset + * REG_TCC = LOONGARCH_GPR_SP + store_offset + * + * std REG_TCC -> LOONGARCH_GPR_SP + store_offset + * + * The purpose of this code is to first push the TCC into stack, + * and then push the address of TCC into stack. + * In cases where bpf2bpf and tailcall are used in combination, + * the value in REG_TCC may be a count or an address, + * these two cases need to be judged and handled separately. + */ + emit_insn(ctx, addid, LOONGARCH_GPR_T3, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT); + *store_offset -= sizeof(long); -static bool seen_call(struct jit_ctx *ctx) -{ - return (ctx->flags & SAVE_RA); -} + emit_cond_jmp(ctx, BPF_JGT, REG_TCC, LOONGARCH_GPR_T3, 4); -static bool seen_tail_call(struct jit_ctx *ctx) -{ - return (ctx->flags & SAVE_TCC); -} + /* + * If REG_TCC < MAX_TAIL_CALL_CNT, the value in REG_TCC is a count, + * push tcc into stack + */ + emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset); -static u8 tail_call_reg(struct jit_ctx *ctx) -{ - if (seen_call(ctx)) - return TCC_SAVED; + /* Push the address of TCC into the REG_TCC */ + emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_SP, *store_offset); - return REG_TCC; + emit_uncond_jmp(ctx, 2); + + /* + * If REG_TCC > MAX_TAIL_CALL_CNT, the value in REG_TCC is an address, + * push tcc_ptr into stack + */ + emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset); + } else { + *store_offset -= sizeof(long); + emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset); + } + + /* Push tcc_ptr into stack */ + *store_offset -= sizeof(long); + emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset); } /* @@ -90,6 +112,10 @@ static u8 tail_call_reg(struct jit_ctx *ctx) * | $s4 | * +-------------------------+ * | $s5 | + * +-------------------------+ + * | tcc | + * +-------------------------+ + * | tcc_ptr | * +-------------------------+ <--BPF_REG_FP * | prog->aux->stack_depth | * | (optional) | @@ -99,12 +125,17 @@ static u8 tail_call_reg(struct jit_ctx *ctx) static void build_prologue(struct jit_ctx *ctx) { int i, stack_adjust = 0, store_offset, bpf_stack_adjust; + const struct bpf_prog *prog = ctx->prog; + const bool is_main_prog = !bpf_is_subprog(prog); bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16); - /* To store ra, fp, s0, s1, s2, s3, s4 and s5. */ + /* To store ra, fp, s0, s1, s2, s3, s4, s5 */ stack_adjust += sizeof(long) * 8; + /* To store tcc and tcc_ptr */ + stack_adjust += sizeof(long) * 2; + stack_adjust = round_up(stack_adjust, 16); stack_adjust += bpf_stack_adjust; @@ -113,11 +144,12 @@ static void build_prologue(struct jit_ctx *ctx) emit_insn(ctx, nop); /* - * First instruction initializes the tail call count (TCC). - * On tail call we skip this instruction, and the TCC is - * passed in REG_TCC from the caller. + * First instruction initializes the tail call count (TCC) + * register to zero. On tail call we skip this instruction, + * and the TCC is passed in REG_TCC from the caller. */ - emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT); + if (is_main_prog) + emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, 0); emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust); @@ -145,20 +177,13 @@ static void build_prologue(struct jit_ctx *ctx) store_offset -= sizeof(long); emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset); + prepare_bpf_tail_call_cnt(ctx, &store_offset); + emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust); if (bpf_stack_adjust) emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust); - /* - * Program contains calls and tail calls, so REG_TCC need - * to be saved across calls. - */ - if (seen_tail_call(ctx) && seen_call(ctx)) - move_reg(ctx, TCC_SAVED, REG_TCC); - else - emit_insn(ctx, nop); - ctx->stack_size = stack_adjust; } @@ -191,6 +216,16 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call) load_offset -= sizeof(long); emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset); + /* + * When push into the stack, follow the order of tcc then tcc_ptr. + * When pop from the stack, first pop tcc_ptr then followed by tcc. + */ + load_offset -= 2 * sizeof(long); + emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset); + + load_offset += sizeof(long); + emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset); + emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust); if (!is_tail_call) { @@ -203,7 +238,7 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call) * Call the next bpf prog and skip the first instruction * of TCC initialization. */ - emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_T3, 1); + emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_T3, 6); } } @@ -225,7 +260,7 @@ bool bpf_jit_supports_far_kfunc_call(void) static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn) { int off, tc_ninsn = 0; - u8 tcc = tail_call_reg(ctx); + int tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size); u8 a1 = LOONGARCH_GPR_A1; u8 a2 = LOONGARCH_GPR_A2; u8 t1 = LOONGARCH_GPR_T1; @@ -252,11 +287,15 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn) goto toofar; /* - * if (--TCC < 0) - * goto out; + * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT) + * goto out; */ - emit_insn(ctx, addid, REG_TCC, tcc, -1); - if (emit_tailcall_jmp(ctx, BPF_JSLT, REG_TCC, LOONGARCH_GPR_ZERO, jmp_offset) < 0) + emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off); + emit_insn(ctx, ldd, t3, REG_TCC, 0); + emit_insn(ctx, addid, t3, t3, 1); + emit_insn(ctx, std, t3, REG_TCC, 0); + emit_insn(ctx, addid, t2, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT); + if (emit_tailcall_jmp(ctx, BPF_JSGT, t3, t2, jmp_offset) < 0) goto toofar; /* @@ -467,7 +506,7 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext u64 func_addr; bool func_addr_fixed, sign_extend; int i = insn - ctx->prog->insnsi; - int ret, jmp_offset; + int ret, jmp_offset, tcc_ptr_off; const u8 code = insn->code; const u8 cond = BPF_OP(code); const u8 t1 = LOONGARCH_GPR_T1; @@ -903,12 +942,16 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext /* function call */ case BPF_JMP | BPF_CALL: - mark_call(ctx); ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass, &func_addr, &func_addr_fixed); if (ret < 0) return ret; + if (insn->src_reg == BPF_PSEUDO_CALL) { + tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size); + emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off); + } + move_addr(ctx, t1, func_addr); emit_insn(ctx, jirl, LOONGARCH_GPR_RA, t1, 0); @@ -919,7 +962,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext /* tail call */ case BPF_JMP | BPF_TAIL_CALL: - mark_tail_call(ctx); if (emit_bpf_tail_call(ctx, i) < 0) return -EINVAL; break; @@ -1412,7 +1454,7 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i { int i, ret, save_ret; int stack_size = 0, nargs = 0; - int retval_off, args_off, nargs_off, ip_off, run_ctx_off, sreg_off; + int retval_off, args_off, nargs_off, ip_off, run_ctx_off, sreg_off, tcc_ptr_off; bool is_struct_ops = flags & BPF_TRAMP_F_INDIRECT; void *orig_call = func_addr; struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY]; @@ -1447,6 +1489,7 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i * * FP - sreg_off [ callee saved reg ] * + * FP - tcc_ptr_off [ tail_call_cnt_ptr ] */ if (m->nr_args > LOONGARCH_MAX_REG_ARGS) @@ -1489,6 +1532,12 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i stack_size += 8; sreg_off = stack_size; + /* Room of trampoline frame to store tail_call_cnt_ptr */ + if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) { + stack_size += 8; + tcc_ptr_off = stack_size; + } + stack_size = round_up(stack_size, 16); if (is_struct_ops) { @@ -1519,6 +1568,9 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_size); } + if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) + emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off); + /* callee saved register S1 to pass start time */ emit_insn(ctx, std, LOONGARCH_GPR_S1, LOONGARCH_GPR_FP, -sreg_off); @@ -1565,6 +1617,10 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i if (flags & BPF_TRAMP_F_CALL_ORIG) { restore_args(ctx, m->nr_args, args_off); + + if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) + emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off); + ret = emit_call(ctx, (const u64)orig_call); if (ret) goto out; @@ -1605,6 +1661,9 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i emit_insn(ctx, ldd, LOONGARCH_GPR_S1, LOONGARCH_GPR_FP, -sreg_off); + if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) + emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off); + if (is_struct_ops) { /* trampoline called directly */ emit_insn(ctx, ldd, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, stack_size - 8);