]> www.infradead.org Git - users/willy/xarray.git/commitdiff
bpf/tests: Add tail call limit test with external function call
authorJohan Almbladh <johan.almbladh@anyfinetworks.com>
Tue, 14 Sep 2021 09:18:42 +0000 (11:18 +0200)
committerDaniel Borkmann <daniel@iogearbox.net>
Tue, 28 Sep 2021 07:26:29 +0000 (09:26 +0200)
This patch adds a tail call limit test where the program also emits
a BPF_CALL to an external function prior to the tail call. Mainly
testing that JITed programs preserve its internal register state, for
example tail call count, across such external calls.

Signed-off-by: Johan Almbladh <johan.almbladh@anyfinetworks.com>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
Link: https://lore.kernel.org/bpf/20210914091842.4186267-15-johan.almbladh@anyfinetworks.com
lib/test_bpf.c

index a94ab634f9475d8da8131b51077a2257dd5a8d2b..08f438e6fe9e26fdf4db363370ed212b6570ac7b 100644 (file)
@@ -12208,6 +12208,30 @@ struct tail_call_test {
                     offset, TAIL_CALL_MARKER),        \
        BPF_JMP_IMM(BPF_TAIL_CALL, 0, 0, 0)
 
+/*
+ * A test function to be called from a BPF program, clobbering a lot of
+ * CPU registers in the process. A JITed BPF program calling this function
+ * must save and restore any caller-saved registers it uses for internal
+ * state, for example the current tail call count.
+ */
+BPF_CALL_1(bpf_test_func, u64, arg)
+{
+       char buf[64];
+       long a = 0;
+       long b = 1;
+       long c = 2;
+       long d = 3;
+       long e = 4;
+       long f = 5;
+       long g = 6;
+       long h = 7;
+
+       return snprintf(buf, sizeof(buf),
+                       "%ld %lu %lx %ld %lu %lx %ld %lu %x",
+                       a, b, c, d, e, f, g, h, (int)arg);
+}
+#define BPF_FUNC_test_func __BPF_FUNC_MAX_ID
+
 /*
  * Tail call tests. Each test case may call any other test in the table,
  * including itself, specified as a relative index offset from the calling
@@ -12267,6 +12291,28 @@ static struct tail_call_test tail_call_tests[] = {
                .flags = FLAG_NEED_STATE | FLAG_RESULT_IN_STATE,
                .result = (MAX_TAIL_CALL_CNT + 1 + 1) * MAX_TESTRUNS,
        },
+       {
+               "Tail call count preserved across function calls",
+               .insns = {
+                       BPF_LDX_MEM(BPF_W, R2, R1, 0),
+                       BPF_ALU64_IMM(BPF_ADD, R2, 1),
+                       BPF_STX_MEM(BPF_W, R1, R2, 0),
+                       BPF_STX_MEM(BPF_DW, R10, R1, -8),
+                       BPF_CALL_REL(BPF_FUNC_get_numa_node_id),
+                       BPF_CALL_REL(BPF_FUNC_ktime_get_ns),
+                       BPF_CALL_REL(BPF_FUNC_ktime_get_boot_ns),
+                       BPF_CALL_REL(BPF_FUNC_ktime_get_coarse_ns),
+                       BPF_CALL_REL(BPF_FUNC_jiffies64),
+                       BPF_CALL_REL(BPF_FUNC_test_func),
+                       BPF_LDX_MEM(BPF_DW, R1, R10, -8),
+                       BPF_ALU32_REG(BPF_MOV, R0, R1),
+                       TAIL_CALL(0),
+                       BPF_EXIT_INSN(),
+               },
+               .stack_depth = 8,
+               .flags = FLAG_NEED_STATE | FLAG_RESULT_IN_STATE,
+               .result = (MAX_TAIL_CALL_CNT + 1 + 1) * MAX_TESTRUNS,
+       },
        {
                "Tail call error path, NULL target",
                .insns = {
@@ -12345,17 +12391,19 @@ static __init int prepare_tail_call_tests(struct bpf_array **pprogs)
                /* Relocate runtime tail call offsets and addresses */
                for (i = 0; i < len; i++) {
                        struct bpf_insn *insn = &fp->insnsi[i];
-
-                       if (insn->imm != TAIL_CALL_MARKER)
-                               continue;
+                       long addr = 0;
 
                        switch (insn->code) {
                        case BPF_LD | BPF_DW | BPF_IMM:
+                               if (insn->imm != TAIL_CALL_MARKER)
+                                       break;
                                insn[0].imm = (u32)(long)progs;
                                insn[1].imm = ((u64)(long)progs) >> 32;
                                break;
 
                        case BPF_ALU | BPF_MOV | BPF_K:
+                               if (insn->imm != TAIL_CALL_MARKER)
+                                       break;
                                if (insn->off == TAIL_CALL_NULL)
                                        insn->imm = ntests;
                                else if (insn->off == TAIL_CALL_INVALID)
@@ -12363,6 +12411,38 @@ static __init int prepare_tail_call_tests(struct bpf_array **pprogs)
                                else
                                        insn->imm = which + insn->off;
                                insn->off = 0;
+                               break;
+
+                       case BPF_JMP | BPF_CALL:
+                               if (insn->src_reg != BPF_PSEUDO_CALL)
+                                       break;
+                               switch (insn->imm) {
+                               case BPF_FUNC_get_numa_node_id:
+                                       addr = (long)&numa_node_id;
+                                       break;
+                               case BPF_FUNC_ktime_get_ns:
+                                       addr = (long)&ktime_get_ns;
+                                       break;
+                               case BPF_FUNC_ktime_get_boot_ns:
+                                       addr = (long)&ktime_get_boot_fast_ns;
+                                       break;
+                               case BPF_FUNC_ktime_get_coarse_ns:
+                                       addr = (long)&ktime_get_coarse_ns;
+                                       break;
+                               case BPF_FUNC_jiffies64:
+                                       addr = (long)&get_jiffies_64;
+                                       break;
+                               case BPF_FUNC_test_func:
+                                       addr = (long)&bpf_test_func;
+                                       break;
+                               default:
+                                       err = -EFAULT;
+                                       goto out_err;
+                               }
+                               *insn = BPF_EMIT_CALL(BPF_CAST_CALL(addr));
+                               if ((long)__bpf_call_base + insn->imm != addr)
+                                       *insn = BPF_JMP_A(0); /* Skip: NOP */
+                               break;
                        }
                }