#include <asm/patch.h>
 #include "bpf_jit.h"
 
+#define RV_FENTRY_NINSNS 2
+
 #define RV_REG_TCC RV_REG_A6
 #define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
 
        if (!is_tail_call)
                emit_mv(RV_REG_A0, RV_REG_A5, ctx);
        emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
-                 is_tail_call ? 20 : 0, /* skip reserved nops and TCC init */
+                 is_tail_call ? (RV_FENTRY_NINSNS + 1) * 4 : 0, /* skip reserved nops and TCC init */
                  ctx);
 }
 
        return 0;
 }
 
-static int gen_call_or_nops(void *target, void *ip, u32 *insns)
-{
-       s64 rvoff;
-       int i, ret;
-       struct rv_jit_context ctx;
-
-       ctx.ninsns = 0;
-       ctx.insns = (u16 *)insns;
-
-       if (!target) {
-               for (i = 0; i < 4; i++)
-                       emit(rv_nop(), &ctx);
-               return 0;
-       }
-
-       rvoff = (s64)(target - (ip + 4));
-       emit(rv_sd(RV_REG_SP, -8, RV_REG_RA), &ctx);
-       ret = emit_jump_and_link(RV_REG_RA, rvoff, false, &ctx);
-       if (ret)
-               return ret;
-       emit(rv_ld(RV_REG_RA, -8, RV_REG_SP), &ctx);
-
-       return 0;
-}
-
-static int gen_jump_or_nops(void *target, void *ip, u32 *insns)
+static int gen_jump_or_nops(void *target, void *ip, u32 *insns, bool is_call)
 {
        s64 rvoff;
        struct rv_jit_context ctx;
        }
 
        rvoff = (s64)(target - ip);
-       return emit_jump_and_link(RV_REG_ZERO, rvoff, false, &ctx);
+       return emit_jump_and_link(is_call ? RV_REG_T0 : RV_REG_ZERO, rvoff, false, &ctx);
 }
 
 int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
                       void *old_addr, void *new_addr)
 {
-       u32 old_insns[4], new_insns[4];
+       u32 old_insns[RV_FENTRY_NINSNS], new_insns[RV_FENTRY_NINSNS];
        bool is_call = poke_type == BPF_MOD_CALL;
-       int (*gen_insns)(void *target, void *ip, u32 *insns);
-       int ninsns = is_call ? 4 : 2;
        int ret;
 
-       if (!is_bpf_text_address((unsigned long)ip))
+       if (!is_kernel_text((unsigned long)ip) &&
+           !is_bpf_text_address((unsigned long)ip))
                return -ENOTSUPP;
 
-       gen_insns = is_call ? gen_call_or_nops : gen_jump_or_nops;
-
-       ret = gen_insns(old_addr, ip, old_insns);
+       ret = gen_jump_or_nops(old_addr, ip, old_insns, is_call);
        if (ret)
                return ret;
 
-       if (memcmp(ip, old_insns, ninsns * 4))
+       if (memcmp(ip, old_insns, RV_FENTRY_NINSNS * 4))
                return -EFAULT;
 
-       ret = gen_insns(new_addr, ip, new_insns);
+       ret = gen_jump_or_nops(new_addr, ip, new_insns, is_call);
        if (ret)
                return ret;
 
        cpus_read_lock();
        mutex_lock(&text_mutex);
-       if (memcmp(ip, new_insns, ninsns * 4))
-               ret = patch_text(ip, new_insns, ninsns);
+       if (memcmp(ip, new_insns, RV_FENTRY_NINSNS * 4))
+               ret = patch_text(ip, new_insns, RV_FENTRY_NINSNS);
        mutex_unlock(&text_mutex);
        cpus_read_unlock();
 
        int i, ret, offset;
        int *branches_off = NULL;
        int stack_size = 0, nregs = m->nr_args;
-       int retaddr_off, fp_off, retval_off, args_off;
-       int nregs_off, ip_off, run_ctx_off, sreg_off;
+       int retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off;
        struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
        struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
        struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
        bool save_ret;
        u32 insn;
 
-       /* Generated trampoline stack layout:
+       /* Two types of generated trampoline stack layout:
+        *
+        * 1. trampoline called from function entry
+        * --------------------------------------
+        * FP + 8           [ RA to parent func ] return address to parent
+        *                                        function
+        * FP + 0           [ FP of parent func ] frame pointer of parent
+        *                                        function
+        * FP - 8           [ T0 to traced func ] return address of traced
+        *                                        function
+        * FP - 16          [ FP of traced func ] frame pointer of traced
+        *                                        function
+        * --------------------------------------
         *
-        * FP - 8           [ RA of parent func ] return address of parent
+        * 2. trampoline called directly
+        * --------------------------------------
+        * FP - 8           [ RA to caller func ] return address to caller
         *                                        function
-        * FP - retaddr_off [ RA of traced func ] return address of traced
+        * FP - 16          [ FP of caller func ] frame pointer of caller
         *                                        function
-        * FP - fp_off      [ FP of parent func ]
+        * --------------------------------------
         *
         * FP - retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
         *                                        BPF_TRAMP_F_RET_FENTRY_RET
        if (nregs > 8)
                return -ENOTSUPP;
 
-       /* room for parent function return address */
-       stack_size += 8;
-
-       stack_size += 8;
-       retaddr_off = stack_size;
-
-       stack_size += 8;
-       fp_off = stack_size;
+       /* room of trampoline frame to store return address and frame pointer */
+       stack_size += 16;
 
        save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
        if (save_ret) {
 
        stack_size = round_up(stack_size, 16);
 
-       emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
-
-       emit_sd(RV_REG_SP, stack_size - retaddr_off, RV_REG_RA, ctx);
-       emit_sd(RV_REG_SP, stack_size - fp_off, RV_REG_FP, ctx);
-
-       emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
+       if (func_addr) {
+               /* For the trampoline called from function entry,
+                * the frame of traced function and the frame of
+                * trampoline need to be considered.
+                */
+               emit_addi(RV_REG_SP, RV_REG_SP, -16, ctx);
+               emit_sd(RV_REG_SP, 8, RV_REG_RA, ctx);
+               emit_sd(RV_REG_SP, 0, RV_REG_FP, ctx);
+               emit_addi(RV_REG_FP, RV_REG_SP, 16, ctx);
+
+               emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
+               emit_sd(RV_REG_SP, stack_size - 8, RV_REG_T0, ctx);
+               emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
+               emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
+       } else {
+               /* For the trampoline called directly, just handle
+                * the frame of trampoline.
+                */
+               emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
+               emit_sd(RV_REG_SP, stack_size - 8, RV_REG_RA, ctx);
+               emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
+               emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
+       }
 
        /* callee saved register S1 to pass start time */
        emit_sd(RV_REG_FP, -sreg_off, RV_REG_S1, ctx);
 
        /* skip to actual body of traced function */
        if (flags & BPF_TRAMP_F_SKIP_FRAME)
-               orig_call += 16;
+               orig_call += RV_FENTRY_NINSNS * 4;
 
        if (flags & BPF_TRAMP_F_CALL_ORIG) {
                emit_imm(RV_REG_A0, (const s64)im, ctx);
 
        emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx);
 
-       if (flags & BPF_TRAMP_F_SKIP_FRAME)
-               /* return address of parent function */
-               emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
-       else
-               /* return address of traced function */
-               emit_ld(RV_REG_RA, stack_size - retaddr_off, RV_REG_SP, ctx);
+       if (func_addr) {
+               /* trampoline called from function entry */
+               emit_ld(RV_REG_T0, stack_size - 8, RV_REG_SP, ctx);
+               emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
+               emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
 
-       emit_ld(RV_REG_FP, stack_size - fp_off, RV_REG_SP, ctx);
-       emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
+               emit_ld(RV_REG_RA, 8, RV_REG_SP, ctx);
+               emit_ld(RV_REG_FP, 0, RV_REG_SP, ctx);
+               emit_addi(RV_REG_SP, RV_REG_SP, 16, ctx);
 
-       emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
+               if (flags & BPF_TRAMP_F_SKIP_FRAME)
+                       /* return to parent function */
+                       emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
+               else
+                       /* return to traced function */
+                       emit_jalr(RV_REG_ZERO, RV_REG_T0, 0, ctx);
+       } else {
+               /* trampoline called directly */
+               emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
+               emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
+               emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
+
+               emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
+       }
 
        ret = ctx->ninsns;
 out:
 
        store_offset = stack_adjust - 8;
 
-       /* reserve 4 nop insns */
-       for (i = 0; i < 4; i++)
+       /* nops reserved for auipc+jalr pair */
+       for (i = 0; i < RV_FENTRY_NINSNS; i++)
                emit(rv_nop(), ctx);
 
        /* First instruction is always setting the tail-call-counter