struct jit_context {
        int cleanup_addr; /* Epilogue code offset */
+
+       /*
+        * Program specific offsets of labels in the code; these rely on the
+        * JIT doing at least 2 passes, recording the position on the first
+        * pass, only to generate the correct offset on the second pass.
+        */
+       int tail_call_direct_label;
+       int tail_call_indirect_label;
 };
 
 /* Maximum number of bytes emitted while JITing one eBPF insn */
        return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true);
 }
 
-static int get_pop_bytes(bool *callee_regs_used)
-{
-       int bytes = 0;
-
-       if (callee_regs_used[3])
-               bytes += 2;
-       if (callee_regs_used[2])
-               bytes += 2;
-       if (callee_regs_used[1])
-               bytes += 2;
-       if (callee_regs_used[0])
-               bytes += 1;
-
-       return bytes;
-}
-
 /*
  * Generate the following code:
  *
  * out:
  */
 static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
-                                       u32 stack_depth)
+                                       u32 stack_depth, u8 *ip,
+                                       struct jit_context *ctx)
 {
        int tcc_off = -4 - round_up(stack_depth, 8);
-       u8 *prog = *pprog;
-       int pop_bytes = 0;
-       int off1 = 42;
-       int off2 = 31;
-       int off3 = 9;
-
-       /* count the additional bytes used for popping callee regs from stack
-        * that need to be taken into account for each of the offsets that
-        * are used for bailing out of the tail call
-        */
-       pop_bytes = get_pop_bytes(callee_regs_used);
-       off1 += pop_bytes;
-       off2 += pop_bytes;
-       off3 += pop_bytes;
-
-       if (stack_depth) {
-               off1 += 7;
-               off2 += 7;
-               off3 += 7;
-       }
+       u8 *prog = *pprog, *start = *pprog;
+       int offset;
 
        /*
         * rdi - pointer to ctx
        EMIT2(0x89, 0xD2);                        /* mov edx, edx */
        EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
              offsetof(struct bpf_array, map.max_entries));
-#define OFFSET1 (off1 + RETPOLINE_RCX_BPF_JIT_SIZE) /* Number of bytes to jump */
-       EMIT2(X86_JBE, OFFSET1);                  /* jbe out */
+
+       offset = ctx->tail_call_indirect_label - (prog + 2 - start);
+       EMIT2(X86_JBE, offset);                   /* jbe out */
 
        /*
         * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
         */
        EMIT2_off32(0x8B, 0x85, tcc_off);         /* mov eax, dword ptr [rbp - tcc_off] */
        EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
-#define OFFSET2 (off2 + RETPOLINE_RCX_BPF_JIT_SIZE)
-       EMIT2(X86_JA, OFFSET2);                   /* ja out */
+
+       offset = ctx->tail_call_indirect_label - (prog + 2 - start);
+       EMIT2(X86_JA, offset);                    /* ja out */
        EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
        EMIT2_off32(0x89, 0x85, tcc_off);         /* mov dword ptr [rbp - tcc_off], eax */
 
         *      goto out;
         */
        EMIT3(0x48, 0x85, 0xC9);                  /* test rcx,rcx */
-#define OFFSET3 (off3 + RETPOLINE_RCX_BPF_JIT_SIZE)
-       EMIT2(X86_JE, OFFSET3);                   /* je out */
 
-       *pprog = prog;
-       pop_callee_regs(pprog, callee_regs_used);
-       prog = *pprog;
+       offset = ctx->tail_call_indirect_label - (prog + 2 - start);
+       EMIT2(X86_JE, offset);                    /* je out */
+
+       pop_callee_regs(&prog, callee_regs_used);
 
        EMIT1(0x58);                              /* pop rax */
        if (stack_depth)
        RETPOLINE_RCX_BPF_JIT();
 
        /* out: */
+       ctx->tail_call_indirect_label = prog - start;
        *pprog = prog;
 }
 
 static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
-                                     u8 **pprog, int addr, u8 *image,
-                                     bool *callee_regs_used, u32 stack_depth)
+                                     u8 **pprog, u8 *ip,
+                                     bool *callee_regs_used, u32 stack_depth,
+                                     struct jit_context *ctx)
 {
        int tcc_off = -4 - round_up(stack_depth, 8);
-       u8 *prog = *pprog;
-       int pop_bytes = 0;
-       int off1 = 20;
-       int poke_off;
-
-       /* count the additional bytes used for popping callee regs to stack
-        * that need to be taken into account for jump offset that is used for
-        * bailing out from of the tail call when limit is reached
-        */
-       pop_bytes = get_pop_bytes(callee_regs_used);
-       off1 += pop_bytes;
-
-       /*
-        * total bytes for:
-        * - nop5/ jmpq $off
-        * - pop callee regs
-        * - sub rsp, $val if depth > 0
-        * - pop rax
-        */
-       poke_off = X86_PATCH_SIZE + pop_bytes + 1;
-       if (stack_depth) {
-               poke_off += 7;
-               off1 += 7;
-       }
+       u8 *prog = *pprog, *start = *pprog;
+       int offset;
 
        /*
         * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
         */
        EMIT2_off32(0x8B, 0x85, tcc_off);             /* mov eax, dword ptr [rbp - tcc_off] */
        EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);         /* cmp eax, MAX_TAIL_CALL_CNT */
-       EMIT2(X86_JA, off1);                          /* ja out */
+
+       offset = ctx->tail_call_direct_label - (prog + 2 - start);
+       EMIT2(X86_JA, offset);                        /* ja out */
        EMIT3(0x83, 0xC0, 0x01);                      /* add eax, 1 */
        EMIT2_off32(0x89, 0x85, tcc_off);             /* mov dword ptr [rbp - tcc_off], eax */
 
-       poke->tailcall_bypass = image + (addr - poke_off - X86_PATCH_SIZE);
+       poke->tailcall_bypass = ip + (prog - start);
        poke->adj_off = X86_TAIL_CALL_OFFSET;
-       poke->tailcall_target = image + (addr - X86_PATCH_SIZE);
+       poke->tailcall_target = ip + ctx->tail_call_direct_label - X86_PATCH_SIZE;
        poke->bypass_addr = (u8 *)poke->tailcall_target + X86_PATCH_SIZE;
 
        emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE,
                  poke->tailcall_bypass);
 
-       *pprog = prog;
-       pop_callee_regs(pprog, callee_regs_used);
-       prog = *pprog;
+       pop_callee_regs(&prog, callee_regs_used);
        EMIT1(0x58);                                  /* pop rax */
        if (stack_depth)
                EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8));
 
        memcpy(prog, x86_nops[5], X86_PATCH_SIZE);
        prog += X86_PATCH_SIZE;
+
        /* out: */
+       ctx->tail_call_direct_label = prog - start;
 
        *pprog = prog;
 }
                case BPF_JMP | BPF_TAIL_CALL:
                        if (imm32)
                                emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1],
-                                                         &prog, addrs[i], image,
+                                                         &prog, image + addrs[i - 1],
                                                          callee_regs_used,
-                                                         bpf_prog->aux->stack_depth);
+                                                         bpf_prog->aux->stack_depth,
+                                                         ctx);
                        else
                                emit_bpf_tail_call_indirect(&prog,
                                                            callee_regs_used,
-                                                           bpf_prog->aux->stack_depth);
+                                                           bpf_prog->aux->stack_depth,
+                                                           image + addrs[i - 1],
+                                                           ctx);
                        break;
 
                        /* cond jump */