/* Number of bytes emit_patch() needs to generate instructions */
 #define X86_PATCH_SIZE         5
+/* Number of bytes that will be skipped on tailcall */
+#define X86_TAIL_CALL_OFFSET   11
 
-#define PROLOGUE_SIZE          25
+static void push_callee_regs(u8 **pprog, bool *callee_regs_used)
+{
+       u8 *prog = *pprog;
+       int cnt = 0;
+
+       if (callee_regs_used[0])
+               EMIT1(0x53);         /* push rbx */
+       if (callee_regs_used[1])
+               EMIT2(0x41, 0x55);   /* push r13 */
+       if (callee_regs_used[2])
+               EMIT2(0x41, 0x56);   /* push r14 */
+       if (callee_regs_used[3])
+               EMIT2(0x41, 0x57);   /* push r15 */
+       *pprog = prog;
+}
+
+static void pop_callee_regs(u8 **pprog, bool *callee_regs_used)
+{
+       u8 *prog = *pprog;
+       int cnt = 0;
+
+       if (callee_regs_used[3])
+               EMIT2(0x41, 0x5F);   /* pop r15 */
+       if (callee_regs_used[2])
+               EMIT2(0x41, 0x5E);   /* pop r14 */
+       if (callee_regs_used[1])
+               EMIT2(0x41, 0x5D);   /* pop r13 */
+       if (callee_regs_used[0])
+               EMIT1(0x5B);         /* pop rbx */
+       *pprog = prog;
+}
 
 /*
- * Emit x86-64 prologue code for BPF program and check its size.
- * bpf_tail_call helper will skip it while jumping into another program
+ * Emit x86-64 prologue code for BPF program.
+ * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes
+ * while jumping to another program
  */
-static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
+static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
+                         bool tail_call_reachable, bool is_subprog)
 {
        u8 *prog = *pprog;
        int cnt = X86_PATCH_SIZE;
         */
        memcpy(prog, ideal_nops[NOP_ATOMIC5], cnt);
        prog += cnt;
+       if (!ebpf_from_cbpf) {
+               if (tail_call_reachable && !is_subprog)
+                       EMIT2(0x31, 0xC0); /* xor eax, eax */
+               else
+                       EMIT2(0x66, 0x90); /* nop2 */
+       }
        EMIT1(0x55);             /* push rbp */
        EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
        /* sub rsp, rounded_stack_depth */
        EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
-       EMIT1(0x53);             /* push rbx */
-       EMIT2(0x41, 0x55);       /* push r13 */
-       EMIT2(0x41, 0x56);       /* push r14 */
-       EMIT2(0x41, 0x57);       /* push r15 */
-       if (!ebpf_from_cbpf) {
-               /* zero init tail_call_cnt */
-               EMIT2(0x6a, 0x00);
-               BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
-       }
+       if (tail_call_reachable)
+               EMIT1(0x50);         /* push rax */
        *pprog = prog;
 }
 
        mutex_lock(&text_mutex);
        if (memcmp(ip, old_insn, X86_PATCH_SIZE))
                goto out;
+       ret = 1;
        if (memcmp(ip, new_insn, X86_PATCH_SIZE)) {
                if (text_live)
                        text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL);
                else
                        memcpy(ip, new_insn, X86_PATCH_SIZE);
+               ret = 0;
        }
-       ret = 0;
 out:
        mutex_unlock(&text_mutex);
        return ret;
        return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true);
 }
 
+static int get_pop_bytes(bool *callee_regs_used)
+{
+       int bytes = 0;
+
+       if (callee_regs_used[3])
+               bytes += 2;
+       if (callee_regs_used[2])
+               bytes += 2;
+       if (callee_regs_used[1])
+               bytes += 2;
+       if (callee_regs_used[0])
+               bytes += 1;
+
+       return bytes;
+}
+
 /*
  * Generate the following code:
  *
  *   goto *(prog->bpf_func + prologue_size);
  * out:
  */
-static void emit_bpf_tail_call_indirect(u8 **pprog)
+static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
+                                       u32 stack_depth)
 {
+       int tcc_off = -4 - round_up(stack_depth, 8);
        u8 *prog = *pprog;
-       int label1, label2, label3;
+       int pop_bytes = 0;
+       int off1 = 49;
+       int off2 = 38;
+       int off3 = 16;
        int cnt = 0;
 
+       /* count the additional bytes used for popping callee regs from stack
+        * that need to be taken into account for each of the offsets that
+        * are used for bailing out of the tail call
+        */
+       pop_bytes = get_pop_bytes(callee_regs_used);
+       off1 += pop_bytes;
+       off2 += pop_bytes;
+       off3 += pop_bytes;
+
        /*
         * rdi - pointer to ctx
         * rsi - pointer to bpf_array
        EMIT2(0x89, 0xD2);                        /* mov edx, edx */
        EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
              offsetof(struct bpf_array, map.max_entries));
-#define OFFSET1 (41 + RETPOLINE_RCX_BPF_JIT_SIZE) /* Number of bytes to jump */
+#define OFFSET1 (off1 + RETPOLINE_RCX_BPF_JIT_SIZE) /* Number of bytes to jump */
        EMIT2(X86_JBE, OFFSET1);                  /* jbe out */
-       label1 = cnt;
 
        /*
         * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
         *      goto out;
         */
-       EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */
+       EMIT2_off32(0x8B, 0x85, tcc_off);         /* mov eax, dword ptr [rbp - tcc_off] */
        EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
-#define OFFSET2 (30 + RETPOLINE_RCX_BPF_JIT_SIZE)
+#define OFFSET2 (off2 + RETPOLINE_RCX_BPF_JIT_SIZE)
        EMIT2(X86_JA, OFFSET2);                   /* ja out */
-       label2 = cnt;
        EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
-       EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */
+       EMIT2_off32(0x89, 0x85, tcc_off);         /* mov dword ptr [rbp - tcc_off], eax */
 
        /* prog = array->ptrs[index]; */
        EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6,       /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */
         * if (prog == NULL)
         *      goto out;
         */
-       EMIT3(0x48, 0x85, 0xC9);                  /* test rcx,rcx */
-#define OFFSET3 (8 + RETPOLINE_RCX_BPF_JIT_SIZE)
+       EMIT3(0x48, 0x85, 0xC9);                  /* test rcx,rcx */
+#define OFFSET3 (off3 + RETPOLINE_RCX_BPF_JIT_SIZE)
        EMIT2(X86_JE, OFFSET3);                   /* je out */
-       label3 = cnt;
 
-       /* goto *(prog->bpf_func + prologue_size); */
+       *pprog = prog;
+       pop_callee_regs(pprog, callee_regs_used);
+       prog = *pprog;
+
+       EMIT1(0x58);                              /* pop rax */
+       EMIT3_off32(0x48, 0x81, 0xC4,             /* add rsp, sd */
+                   round_up(stack_depth, 8));
+
+       /* goto *(prog->bpf_func + X86_TAIL_CALL_OFFSET); */
        EMIT4(0x48, 0x8B, 0x49,                   /* mov rcx, qword ptr [rcx + 32] */
              offsetof(struct bpf_prog, bpf_func));
-       EMIT4(0x48, 0x83, 0xC1, PROLOGUE_SIZE);   /* add rcx, prologue_size */
-
+       EMIT4(0x48, 0x83, 0xC1,                   /* add rcx, X86_TAIL_CALL_OFFSET */
+             X86_TAIL_CALL_OFFSET);
        /*
         * Now we're ready to jump into next BPF program
         * rdi == ctx (1st arg)
-        * rcx == prog->bpf_func + prologue_size
+        * rcx == prog->bpf_func + X86_TAIL_CALL_OFFSET
         */
        RETPOLINE_RCX_BPF_JIT();
 
        /* out: */
-       BUILD_BUG_ON(cnt - label1 != OFFSET1);
-       BUILD_BUG_ON(cnt - label2 != OFFSET2);
-       BUILD_BUG_ON(cnt - label3 != OFFSET3);
        *pprog = prog;
 }
 
 static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
-                                     u8 **pprog, int addr, u8 *image)
+                                     u8 **pprog, int addr, u8 *image,
+                                     bool *callee_regs_used, u32 stack_depth)
 {
+       int tcc_off = -4 - round_up(stack_depth, 8);
        u8 *prog = *pprog;
+       int pop_bytes = 0;
+       int off1 = 27;
+       int poke_off;
        int cnt = 0;
 
+       /* count the additional bytes used for popping callee regs to stack
+        * that need to be taken into account for jump offset that is used for
+        * bailing out from of the tail call when limit is reached
+        */
+       pop_bytes = get_pop_bytes(callee_regs_used);
+       off1 += pop_bytes;
+
+       /*
+        * total bytes for:
+        * - nop5/ jmpq $off
+        * - pop callee regs
+        * - sub rsp, $val
+        * - pop rax
+        */
+       poke_off = X86_PATCH_SIZE + pop_bytes + 7 + 1;
+
        /*
         * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
         *      goto out;
         */
-       EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */
+       EMIT2_off32(0x8B, 0x85, tcc_off);             /* mov eax, dword ptr [rbp - tcc_off] */
        EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);         /* cmp eax, MAX_TAIL_CALL_CNT */
-       EMIT2(X86_JA, 14);                            /* ja out */
+       EMIT2(X86_JA, off1);                          /* ja out */
        EMIT3(0x83, 0xC0, 0x01);                      /* add eax, 1 */
-       EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */
+       EMIT2_off32(0x89, 0x85, tcc_off);             /* mov dword ptr [rbp - tcc_off], eax */
 
+       poke->tailcall_bypass = image + (addr - poke_off - X86_PATCH_SIZE);
+       poke->adj_off = X86_TAIL_CALL_OFFSET;
        poke->tailcall_target = image + (addr - X86_PATCH_SIZE);
-       poke->adj_off = PROLOGUE_SIZE;
+       poke->bypass_addr = (u8 *)poke->tailcall_target + X86_PATCH_SIZE;
+
+       emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE,
+                 poke->tailcall_bypass);
+
+       *pprog = prog;
+       pop_callee_regs(pprog, callee_regs_used);
+       prog = *pprog;
+       EMIT1(0x58);                                  /* pop rax */
+       EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8));
 
        memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE);
        prog += X86_PATCH_SIZE;
                                                   (u8 *)target->bpf_func +
                                                   poke->adj_off, false);
                        BUG_ON(ret < 0);
+                       ret = __bpf_arch_text_poke(poke->tailcall_bypass,
+                                                  BPF_MOD_JUMP,
+                                                  (u8 *)poke->tailcall_target +
+                                                  X86_PATCH_SIZE, NULL, false);
+                       BUG_ON(ret < 0);
                }
                WRITE_ONCE(poke->tailcall_target_stable, true);
                mutex_unlock(&array->aux->poke_mutex);
        return true;
 }
 
+static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt,
+                            bool *regs_used, bool *tail_call_seen)
+{
+       int i;
+
+       for (i = 1; i <= insn_cnt; i++, insn++) {
+               if (insn->code == (BPF_JMP | BPF_TAIL_CALL))
+                       *tail_call_seen = true;
+               if (insn->dst_reg == BPF_REG_6 || insn->src_reg == BPF_REG_6)
+                       regs_used[0] = true;
+               if (insn->dst_reg == BPF_REG_7 || insn->src_reg == BPF_REG_7)
+                       regs_used[1] = true;
+               if (insn->dst_reg == BPF_REG_8 || insn->src_reg == BPF_REG_8)
+                       regs_used[2] = true;
+               if (insn->dst_reg == BPF_REG_9 || insn->src_reg == BPF_REG_9)
+                       regs_used[3] = true;
+       }
+}
+
 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                  int oldproglen, struct jit_context *ctx)
 {
+       bool tail_call_reachable = bpf_prog->aux->tail_call_reachable;
        struct bpf_insn *insn = bpf_prog->insnsi;
+       bool callee_regs_used[4] = {};
        int insn_cnt = bpf_prog->len;
+       bool tail_call_seen = false;
        bool seen_exit = false;
        u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
        int i, cnt = 0, excnt = 0;
        int proglen = 0;
        u8 *prog = temp;
 
+       detect_reg_usage(insn, insn_cnt, callee_regs_used,
+                        &tail_call_seen);
+
+       /* tail call's presence in current prog implies it is reachable */
+       tail_call_reachable |= tail_call_seen;
+
        emit_prologue(&prog, bpf_prog->aux->stack_depth,
-                     bpf_prog_was_classic(bpf_prog));
+                     bpf_prog_was_classic(bpf_prog), tail_call_reachable,
+                     bpf_prog->aux->func_idx != 0);
+       push_callee_regs(&prog, callee_regs_used);
        addrs[0] = prog - temp;
 
        for (i = 1; i <= insn_cnt; i++, insn++) {
                        /* call */
                case BPF_JMP | BPF_CALL:
                        func = (u8 *) __bpf_call_base + imm32;
-                       if (!imm32 || emit_call(&prog, func, image + addrs[i - 1]))
-                               return -EINVAL;
+                       if (tail_call_reachable) {
+                               EMIT3_off32(0x48, 0x8B, 0x85,
+                                           -(bpf_prog->aux->stack_depth + 8));
+                               if (!imm32 || emit_call(&prog, func, image + addrs[i - 1] + 7))
+                                       return -EINVAL;
+                       } else {
+                               if (!imm32 || emit_call(&prog, func, image + addrs[i - 1]))
+                                       return -EINVAL;
+                       }
                        break;
 
                case BPF_JMP | BPF_TAIL_CALL:
                        if (imm32)
                                emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1],
-                                                         &prog, addrs[i], image);
+                                                         &prog, addrs[i], image,
+                                                         callee_regs_used,
+                                                         bpf_prog->aux->stack_depth);
                        else
-                               emit_bpf_tail_call_indirect(&prog);
+                               emit_bpf_tail_call_indirect(&prog,
+                                                           callee_regs_used,
+                                                           bpf_prog->aux->stack_depth);
                        break;
 
                        /* cond jump */
                        seen_exit = true;
                        /* Update cleanup_addr */
                        ctx->cleanup_addr = proglen;
-                       if (!bpf_prog_was_classic(bpf_prog))
-                               EMIT1(0x5B); /* get rid of tail_call_cnt */
-                       EMIT2(0x41, 0x5F);   /* pop r15 */
-                       EMIT2(0x41, 0x5E);   /* pop r14 */
-                       EMIT2(0x41, 0x5D);   /* pop r13 */
-                       EMIT1(0x5B);         /* pop rbx */
+                       pop_callee_regs(&prog, callee_regs_used);
+                       if (tail_call_reachable)
+                               EMIT1(0x59); /* pop rcx, get rid of tail_call_cnt */
                        EMIT1(0xC9);         /* leave */
                        EMIT1(0xC3);         /* ret */
                        break;