*pprog = prog;
 }
 
+static int emit_patch(u8 **pprog, void *func, void *ip, u8 opcode)
+{
+       u8 *prog = *pprog;
+       int cnt = 0;
+       s64 offset;
+
+       offset = func - (ip + X86_PATCH_SIZE);
+       if (!is_simm32(offset)) {
+               pr_err("Target call %p is out of range\n", func);
+               return -ERANGE;
+       }
+       EMIT1_off32(opcode, offset);
+       *pprog = prog;
+       return 0;
+}
+
+static int emit_call(u8 **pprog, void *func, void *ip)
+{
+       return emit_patch(pprog, func, ip, 0xE8);
+}
+
+static int emit_jump(u8 **pprog, void *func, void *ip)
+{
+       return emit_patch(pprog, func, ip, 0xE9);
+}
+
+static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
+                               void *old_addr, void *new_addr,
+                               const bool text_live)
+{
+       int (*emit_patch_fn)(u8 **pprog, void *func, void *ip);
+       const u8 *nop_insn = ideal_nops[NOP_ATOMIC5];
+       u8 old_insn[X86_PATCH_SIZE] = {};
+       u8 new_insn[X86_PATCH_SIZE] = {};
+       u8 *prog;
+       int ret;
+
+       switch (t) {
+       case BPF_MOD_NOP_TO_CALL ... BPF_MOD_CALL_TO_NOP:
+               emit_patch_fn = emit_call;
+               break;
+       case BPF_MOD_NOP_TO_JUMP ... BPF_MOD_JUMP_TO_NOP:
+               emit_patch_fn = emit_jump;
+               break;
+       default:
+               return -ENOTSUPP;
+       }
+
+       switch (t) {
+       case BPF_MOD_NOP_TO_CALL:
+       case BPF_MOD_NOP_TO_JUMP:
+               if (!old_addr && new_addr) {
+                       memcpy(old_insn, nop_insn, X86_PATCH_SIZE);
+
+                       prog = new_insn;
+                       ret = emit_patch_fn(&prog, new_addr, ip);
+                       if (ret)
+                               return ret;
+                       break;
+               }
+               return -ENXIO;
+       case BPF_MOD_CALL_TO_CALL:
+       case BPF_MOD_JUMP_TO_JUMP:
+               if (old_addr && new_addr) {
+                       prog = old_insn;
+                       ret = emit_patch_fn(&prog, old_addr, ip);
+                       if (ret)
+                               return ret;
+
+                       prog = new_insn;
+                       ret = emit_patch_fn(&prog, new_addr, ip);
+                       if (ret)
+                               return ret;
+                       break;
+               }
+               return -ENXIO;
+       case BPF_MOD_CALL_TO_NOP:
+       case BPF_MOD_JUMP_TO_NOP:
+               if (old_addr && !new_addr) {
+                       memcpy(new_insn, nop_insn, X86_PATCH_SIZE);
+
+                       prog = old_insn;
+                       ret = emit_patch_fn(&prog, old_addr, ip);
+                       if (ret)
+                               return ret;
+                       break;
+               }
+               return -ENXIO;
+       default:
+               return -ENOTSUPP;
+       }
+
+       ret = -EBUSY;
+       mutex_lock(&text_mutex);
+       if (memcmp(ip, old_insn, X86_PATCH_SIZE))
+               goto out;
+       if (text_live)
+               text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL);
+       else
+               memcpy(ip, new_insn, X86_PATCH_SIZE);
+       ret = 0;
+out:
+       mutex_unlock(&text_mutex);
+       return ret;
+}
+
+int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
+                      void *old_addr, void *new_addr)
+{
+       if (!is_kernel_text((long)ip) &&
+           !is_bpf_text_address((long)ip))
+               /* BPF poking in modules is not supported */
+               return -EINVAL;
+
+       return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true);
+}
+
 /*
  * Generate the following code:
  *
  *   goto *(prog->bpf_func + prologue_size);
  * out:
  */
-static void emit_bpf_tail_call(u8 **pprog)
+static void emit_bpf_tail_call_indirect(u8 **pprog)
 {
        u8 *prog = *pprog;
        int label1, label2, label3;
        *pprog = prog;
 }
 
+static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
+                                     u8 **pprog, int addr, u8 *image)
+{
+       u8 *prog = *pprog;
+       int cnt = 0;
+
+       /*
+        * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
+        *      goto out;
+        */
+       EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */
+       EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);         /* cmp eax, MAX_TAIL_CALL_CNT */
+       EMIT2(X86_JA, 14);                            /* ja out */
+       EMIT3(0x83, 0xC0, 0x01);                      /* add eax, 1 */
+       EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */
+
+       poke->ip = image + (addr - X86_PATCH_SIZE);
+       poke->adj_off = PROLOGUE_SIZE;
+
+       memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE);
+       prog += X86_PATCH_SIZE;
+       /* out: */
+
+       *pprog = prog;
+}
+
+static void bpf_tail_call_direct_fixup(struct bpf_prog *prog)
+{
+       static const enum bpf_text_poke_type type = BPF_MOD_NOP_TO_JUMP;
+       struct bpf_jit_poke_descriptor *poke;
+       struct bpf_array *array;
+       struct bpf_prog *target;
+       int i, ret;
+
+       for (i = 0; i < prog->aux->size_poke_tab; i++) {
+               poke = &prog->aux->poke_tab[i];
+               WARN_ON_ONCE(READ_ONCE(poke->ip_stable));
+
+               if (poke->reason != BPF_POKE_REASON_TAIL_CALL)
+                       continue;
+
+               array = container_of(poke->tail_call.map, struct bpf_array, map);
+               mutex_lock(&array->aux->poke_mutex);
+               target = array->ptrs[poke->tail_call.key];
+               if (target) {
+                       /* Plain memcpy is used when image is not live yet
+                        * and still not locked as read-only. Once poke
+                        * location is active (poke->ip_stable), any parallel
+                        * bpf_arch_text_poke() might occur still on the
+                        * read-write image until we finally locked it as
+                        * read-only. Both modifications on the given image
+                        * are under text_mutex to avoid interference.
+                        */
+                       ret = __bpf_arch_text_poke(poke->ip, type, NULL,
+                                                  (u8 *)target->bpf_func +
+                                                  poke->adj_off, false);
+                       BUG_ON(ret < 0);
+               }
+               WRITE_ONCE(poke->ip_stable, true);
+               mutex_unlock(&array->aux->poke_mutex);
+       }
+}
+
 static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
                           u32 dst_reg, const u32 imm32)
 {
        *pprog = prog;
 }
 
-static int emit_patch(u8 **pprog, void *func, void *ip, u8 opcode)
-{
-       u8 *prog = *pprog;
-       int cnt = 0;
-       s64 offset;
-
-       offset = func - (ip + X86_PATCH_SIZE);
-       if (!is_simm32(offset)) {
-               pr_err("Target call %p is out of range\n", func);
-               return -EINVAL;
-       }
-       EMIT1_off32(opcode, offset);
-       *pprog = prog;
-       return 0;
-}
-
-static int emit_call(u8 **pprog, void *func, void *ip)
-{
-       return emit_patch(pprog, func, ip, 0xE8);
-}
-
-static int emit_jump(u8 **pprog, void *func, void *ip)
-{
-       return emit_patch(pprog, func, ip, 0xE9);
-}
-
-int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
-                      void *old_addr, void *new_addr)
-{
-       int (*emit_patch_fn)(u8 **pprog, void *func, void *ip);
-       u8 old_insn[X86_PATCH_SIZE] = {};
-       u8 new_insn[X86_PATCH_SIZE] = {};
-       u8 *prog;
-       int ret;
-
-       if (!is_kernel_text((long)ip) &&
-           !is_bpf_text_address((long)ip))
-               /* BPF poking in modules is not supported */
-               return -EINVAL;
-
-       switch (t) {
-       case BPF_MOD_NOP_TO_CALL ... BPF_MOD_CALL_TO_NOP:
-               emit_patch_fn = emit_call;
-               break;
-       case BPF_MOD_NOP_TO_JUMP ... BPF_MOD_JUMP_TO_NOP:
-               emit_patch_fn = emit_jump;
-               break;
-       default:
-               return -ENOTSUPP;
-       }
-
-       if (old_addr) {
-               prog = old_insn;
-               ret = emit_patch_fn(&prog, old_addr, (void *)ip);
-               if (ret)
-                       return ret;
-       }
-       if (new_addr) {
-               prog = new_insn;
-               ret = emit_patch_fn(&prog, new_addr, (void *)ip);
-               if (ret)
-                       return ret;
-       }
-
-       ret = -EBUSY;
-       mutex_lock(&text_mutex);
-       switch (t) {
-       case BPF_MOD_NOP_TO_CALL:
-       case BPF_MOD_NOP_TO_JUMP:
-               if (memcmp(ip, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE))
-                       goto out;
-               text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL);
-               break;
-       case BPF_MOD_CALL_TO_CALL:
-       case BPF_MOD_JUMP_TO_JUMP:
-               if (memcmp(ip, old_insn, X86_PATCH_SIZE))
-                       goto out;
-               text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL);
-               break;
-       case BPF_MOD_CALL_TO_NOP:
-       case BPF_MOD_JUMP_TO_NOP:
-               if (memcmp(ip, old_insn, X86_PATCH_SIZE))
-                       goto out;
-               text_poke_bp(ip, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE,
-                            NULL);
-               break;
-       }
-       ret = 0;
-out:
-       mutex_unlock(&text_mutex);
-       return ret;
-}
-
 static bool ex_handler_bpf(const struct exception_table_entry *x,
                           struct pt_regs *regs, int trapnr,
                           unsigned long error_code, unsigned long fault_addr)
                        break;
 
                case BPF_JMP | BPF_TAIL_CALL:
-                       emit_bpf_tail_call(&prog);
+                       if (imm32)
+                               emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1],
+                                                         &prog, addrs[i], image);
+                       else
+                               emit_bpf_tail_call_indirect(&prog);
                        break;
 
                        /* cond jump */
 
        if (image) {
                if (!prog->is_func || extra_pass) {
+                       bpf_tail_call_direct_fixup(prog);
                        bpf_jit_binary_lock_ro(header);
                } else {
                        jit_data->addrs = addrs;