}
 
 static int invoke_bpf_prog(const struct btf_func_model *m, u8 **pprog,
-                          struct bpf_prog *p, int stack_size)
+                          struct bpf_prog *p, int stack_size, bool mod_ret)
 {
        u8 *prog = *pprog;
        int cnt = 0;
        if (emit_call(&prog, p->bpf_func, prog))
                return -EINVAL;
 
+       /* BPF_TRAMP_MODIFY_RETURN trampolines can modify the return
+        * of the previous call which is then passed on the stack to
+        * the next BPF program.
+        */
+       if (mod_ret)
+               emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
+
        /* arg1: mov rdi, progs[i] */
        emit_mov_imm64(&prog, BPF_REG_1, (long) p >> 32,
                       (u32) (long) p);
        return 0;
 }
 
+static int emit_mod_ret_check_imm8(u8 **pprog, int value)
+{
+       u8 *prog = *pprog;
+       int cnt = 0;
+
+       if (!is_imm8(value))
+               return -EINVAL;
+
+       if (value == 0)
+               EMIT2(0x85, add_2reg(0xC0, BPF_REG_0, BPF_REG_0));
+       else
+               EMIT3(0x83, add_1reg(0xF8, BPF_REG_0), value);
+
+       *pprog = prog;
+       return 0;
+}
+
 static int invoke_bpf(const struct btf_func_model *m, u8 **pprog,
                      struct bpf_tramp_progs *tp, int stack_size)
 {
        u8 *prog = *pprog;
 
        for (i = 0; i < tp->nr_progs; i++) {
-               if (invoke_bpf_prog(m, &prog, tp->progs[i], stack_size))
+               if (invoke_bpf_prog(m, &prog, tp->progs[i], stack_size, false))
+                       return -EINVAL;
+       }
+       *pprog = prog;
+       return 0;
+}
+
+static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
+                             struct bpf_tramp_progs *tp, int stack_size,
+                             u8 **branches)
+{
+       u8 *prog = *pprog;
+       int i;
+
+       /* The first fmod_ret program will receive a garbage return value.
+        * Set this to 0 to avoid confusing the program.
+        */
+       emit_mov_imm32(&prog, false, BPF_REG_0, 0);
+       emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
+       for (i = 0; i < tp->nr_progs; i++) {
+               if (invoke_bpf_prog(m, &prog, tp->progs[i], stack_size, true))
                        return -EINVAL;
+
+               /* Generate a branch:
+                *
+                * if (ret !=  0)
+                *      goto do_fexit;
+                *
+                * If needed this can be extended to any integer value which can
+                * be passed by user-space when the program is loaded.
+                */
+               if (emit_mod_ret_check_imm8(&prog, 0))
+                       return -EINVAL;
+
+               /* Save the location of the branch and Generate 6 nops
+                * (4 bytes for an offset and 2 bytes for the jump) These nops
+                * are replaced with a conditional jump once do_fexit (i.e. the
+                * start of the fexit invocation) is finalized.
+                */
+               branches[i] = prog;
+               emit_nops(&prog, 4 + 2);
        }
+
        *pprog = prog;
        return 0;
 }
                                struct bpf_tramp_progs *tprogs,
                                void *orig_call)
 {
-       int cnt = 0, nr_args = m->nr_args;
+       int ret, i, cnt = 0, nr_args = m->nr_args;
        int stack_size = nr_args * 8;
        struct bpf_tramp_progs *fentry = &tprogs[BPF_TRAMP_FENTRY];
        struct bpf_tramp_progs *fexit = &tprogs[BPF_TRAMP_FEXIT];
+       struct bpf_tramp_progs *fmod_ret = &tprogs[BPF_TRAMP_MODIFY_RETURN];
+       u8 **branches = NULL;
        u8 *prog;
 
        /* x86-64 supports up to 6 arguments. 7+ can be added in the future */
                if (invoke_bpf(m, &prog, fentry, stack_size))
                        return -EINVAL;
 
+       if (fmod_ret->nr_progs) {
+               branches = kcalloc(fmod_ret->nr_progs, sizeof(u8 *),
+                                  GFP_KERNEL);
+               if (!branches)
+                       return -ENOMEM;
+
+               if (invoke_bpf_mod_ret(m, &prog, fmod_ret, stack_size,
+                                      branches)) {
+                       ret = -EINVAL;
+                       goto cleanup;
+               }
+       }
+
        if (flags & BPF_TRAMP_F_CALL_ORIG) {
-               if (fentry->nr_progs)
+               if (fentry->nr_progs || fmod_ret->nr_progs)
                        restore_regs(m, &prog, nr_args, stack_size);
 
                /* call original function */
-               if (emit_call(&prog, orig_call, prog))
-                       return -EINVAL;
+               if (emit_call(&prog, orig_call, prog)) {
+                       ret = -EINVAL;
+                       goto cleanup;
+               }
                /* remember return value in a stack for bpf prog to access */
                emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
        }
 
+       if (fmod_ret->nr_progs) {
+               /* From Intel 64 and IA-32 Architectures Optimization
+                * Reference Manual, 3.4.1.4 Code Alignment, Assembly/Compiler
+                * Coding Rule 11: All branch targets should be 16-byte
+                * aligned.
+                */
+               emit_align(&prog, 16);
+               /* Update the branches saved in invoke_bpf_mod_ret with the
+                * aligned address of do_fexit.
+                */
+               for (i = 0; i < fmod_ret->nr_progs; i++)
+                       emit_cond_near_jump(&branches[i], prog, branches[i],
+                                           X86_JNE);
+       }
+
        if (fexit->nr_progs)
-               if (invoke_bpf(m, &prog, fexit, stack_size))
-                       return -EINVAL;
+               if (invoke_bpf(m, &prog, fexit, stack_size)) {
+                       ret = -EINVAL;
+                       goto cleanup;
+               }
 
        if (flags & BPF_TRAMP_F_RESTORE_REGS)
                restore_regs(m, &prog, nr_args, stack_size);
 
+       /* This needs to be done regardless. If there were fmod_ret programs,
+        * the return value is only updated on the stack and still needs to be
+        * restored to R0.
+        */
        if (flags & BPF_TRAMP_F_CALL_ORIG)
                /* restore original return value back into RAX */
                emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, -8);
                EMIT4(0x48, 0x83, 0xC4, 8); /* add rsp, 8 */
        EMIT1(0xC3); /* ret */
        /* Make sure the trampoline generation logic doesn't overflow */
-       if (WARN_ON_ONCE(prog > (u8 *)image_end - BPF_INSN_SAFETY))
-               return -EFAULT;
-       return prog - (u8 *)image;
+       if (WARN_ON_ONCE(prog > (u8 *)image_end - BPF_INSN_SAFETY)) {
+               ret = -EFAULT;
+               goto cleanup;
+       }
+       ret = prog - (u8 *)image;
+
+cleanup:
+       kfree(branches);
+       return ret;
 }
 
 static int emit_fallback_jump(u8 **pprog)