return ctx->idx;
 }
 
-int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image,
-                               void *image_end, const struct btf_func_model *m,
-                               u32 flags, struct bpf_tramp_links *tlinks,
-                               void *func_addr)
+static int btf_func_model_nregs(const struct btf_func_model *m)
 {
-       int i, ret;
        int nregs = m->nr_args;
-       int max_insns = ((long)image_end - (long)image) / AARCH64_INSN_SIZE;
-       struct jit_ctx ctx = {
-               .image = NULL,
-               .idx = 0,
-       };
+       int i;
 
        /* extra registers needed for struct argument */
        for (i = 0; i < MAX_BPF_FUNC_ARGS; i++) {
                        nregs += (m->arg_size[i] + 7) / 8 - 1;
        }
 
+       return nregs;
+}
+
+int arch_bpf_trampoline_size(const struct btf_func_model *m, u32 flags,
+                            struct bpf_tramp_links *tlinks, void *func_addr)
+{
+       struct jit_ctx ctx = {
+               .image = NULL,
+               .idx = 0,
+       };
+       struct bpf_tramp_image im;
+       int nregs, ret;
+
+       nregs = btf_func_model_nregs(m);
        /* the first 8 registers are used for arguments */
        if (nregs > 8)
                return -ENOTSUPP;
 
-       ret = prepare_trampoline(&ctx, im, tlinks, func_addr, nregs, flags);
+       ret = prepare_trampoline(&ctx, &im, tlinks, func_addr, nregs, flags);
        if (ret < 0)
                return ret;
 
-       if (ret > max_insns)
-               return -EFBIG;
+       return ret < 0 ? ret : ret * AARCH64_INSN_SIZE;
+}
 
-       ctx.image = image;
-       ctx.idx = 0;
+int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image,
+                               void *image_end, const struct btf_func_model *m,
+                               u32 flags, struct bpf_tramp_links *tlinks,
+                               void *func_addr)
+{
+       int ret, nregs;
+       struct jit_ctx ctx = {
+               .image = image,
+               .idx = 0,
+       };
+
+       nregs = btf_func_model_nregs(m);
+       /* the first 8 registers are used for arguments */
+       if (nregs > 8)
+               return -ENOTSUPP;
+
+       ret = arch_bpf_trampoline_size(m, flags, tlinks, func_addr);
+       if (ret < 0)
+               return ret;
+
+       if (ret > ((long)image_end - (long)image))
+               return -EFBIG;
 
        jit_fill_hole(image, (unsigned int)(image_end - image));
        ret = prepare_trampoline(&ctx, im, tlinks, func_addr, nregs, flags);
 
        return ret;
 }
 
+int arch_bpf_trampoline_size(const struct btf_func_model *m, u32 flags,
+                            struct bpf_tramp_links *tlinks, void *func_addr)
+{
+       struct bpf_tramp_image im;
+       struct rv_jit_context ctx;
+       int ret;
+
+       ctx.ninsns = 0;
+       ctx.insns = NULL;
+       ctx.ro_insns = NULL;
+       ret = __arch_prepare_bpf_trampoline(&im, m, tlinks, func_addr, flags, &ctx);
+
+       return ret < 0 ? ret : ninsns_rvoff(ctx.ninsns);
+}
+
 int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image,
                                void *image_end, const struct btf_func_model *m,
                                u32 flags, struct bpf_tramp_links *tlinks,
        int ret;
        struct rv_jit_context ctx;
 
-       ctx.ninsns = 0;
-       ctx.insns = NULL;
-       ctx.ro_insns = NULL;
-       ret = __arch_prepare_bpf_trampoline(im, m, tlinks, func_addr, flags, &ctx);
+       ret = arch_bpf_trampoline_size(im, m, flags, tlinks, func_addr);
        if (ret < 0)
                return ret;
 
-       if (ninsns_rvoff(ret) > (long)image_end - (long)image)
+       if (ret > (long)image_end - (long)image)
                return -EFBIG;
 
        ctx.ninsns = 0;
 
        return 0;
 }
 
+int arch_bpf_trampoline_size(const struct btf_func_model *m, u32 flags,
+                            struct bpf_tramp_links *tlinks, void *orig_call)
+{
+       struct bpf_tramp_image im;
+       struct bpf_tramp_jit tjit;
+       int ret;
+
+       memset(&tjit, 0, sizeof(tjit));
+
+       ret = __arch_prepare_bpf_trampoline(&im, &tjit, m, flags,
+                                           tlinks, orig_call);
+
+       return ret < 0 ? ret : tjit.common.prg;
+}
+
 int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image,
                                void *image_end, const struct btf_func_model *m,
                                u32 flags, struct bpf_tramp_links *tlinks,
 {
        struct bpf_tramp_jit tjit;
        int ret;
-       int i;
 
-       for (i = 0; i < 2; i++) {
-               if (i == 0) {
-                       /* Compute offsets, check whether the code fits. */
-                       memset(&tjit, 0, sizeof(tjit));
-               } else {
-                       /* Generate the code. */
-                       tjit.common.prg = 0;
-                       tjit.common.prg_buf = image;
-               }
-               ret = __arch_prepare_bpf_trampoline(im, &tjit, m, flags,
-                                                   tlinks, func_addr);
-               if (ret < 0)
-                       return ret;
-               if (tjit.common.prg > (char *)image_end - (char *)image)
-                       /*
-                        * Use the same error code as for exceeding
-                        * BPF_MAX_TRAMP_LINKS.
-                        */
-                       return -E2BIG;
-       }
+       /* Compute offsets, check whether the code fits. */
+       memset(&tjit, 0, sizeof(tjit));
+       ret = __arch_prepare_bpf_trampoline(im, &tjit, m, flags,
+                                           tlinks, func_addr);
+
+       if (ret < 0)
+               return ret;
+       if (tjit.common.prg > (char *)image_end - (char *)image)
+               /*
+                * Use the same error code as for exceeding
+                * BPF_MAX_TRAMP_LINKS.
+                */
+               return -E2BIG;
+
+       tjit.common.prg = 0;
+       tjit.common.prg_buf = image;
+       ret = __arch_prepare_bpf_trampoline(im, &tjit, m, flags,
+                                           tlinks, func_addr);
 
-       return tjit.common.prg;
+       return ret < 0 ? ret : tjit.common.prg;
 }
 
 bool bpf_jit_supports_subprog_tailcalls(void)
 
  * add rsp, 8                      // skip eth_type_trans's frame
  * ret                             // return to its caller
  */
-int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *image_end,
-                               const struct btf_func_model *m, u32 flags,
-                               struct bpf_tramp_links *tlinks,
-                               void *func_addr)
+static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *image_end,
+                                        const struct btf_func_model *m, u32 flags,
+                                        struct bpf_tramp_links *tlinks,
+                                        void *func_addr)
 {
        int i, ret, nr_regs = m->nr_args, stack_size = 0;
        int regs_off, nregs_off, ip_off, run_ctx_off, arg_stack_off, rbx_off;
        return ret;
 }
 
+int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *image_end,
+                               const struct btf_func_model *m, u32 flags,
+                               struct bpf_tramp_links *tlinks,
+                               void *func_addr)
+{
+       return __arch_prepare_bpf_trampoline(im, image, image_end, m, flags, tlinks, func_addr);
+}
+
+int arch_bpf_trampoline_size(const struct btf_func_model *m, u32 flags,
+                            struct bpf_tramp_links *tlinks, void *func_addr)
+{
+       struct bpf_tramp_image im;
+       void *image;
+       int ret;
+
+       /* Allocate a temporary buffer for __arch_prepare_bpf_trampoline().
+        * This will NOT cause fragmentation in direct map, as we do not
+        * call set_memory_*() on this buffer.
+        *
+        * We cannot use kvmalloc here, because we need image to be in
+        * module memory range.
+        */
+       image = bpf_jit_alloc_exec(PAGE_SIZE);
+       if (!image)
+               return -ENOMEM;
+
+       ret = __arch_prepare_bpf_trampoline(&im, image, image + PAGE_SIZE, m, flags,
+                                           tlinks, func_addr);
+       bpf_jit_free_exec(image);
+       return ret;
+}
+
 static int emit_bpf_dispatcher(u8 **pprog, int a, int b, s64 *progs, u8 *image, u8 *buf)
 {
        u8 *jg_reloc, *prog = *pprog;
 
 void arch_free_bpf_trampoline(void *image, unsigned int size);
 void arch_protect_bpf_trampoline(void *image, unsigned int size);
 void arch_unprotect_bpf_trampoline(void *image, unsigned int size);
+int arch_bpf_trampoline_size(const struct btf_func_model *m, u32 flags,
+                            struct bpf_tramp_links *tlinks, void *func_addr);
 
 u64 notrace __bpf_prog_enter_sleepable_recur(struct bpf_prog *prog,
                                             struct bpf_tramp_run_ctx *run_ctx);
 
        set_memory_rw((long)image, 1);
 }
 
+int __weak arch_bpf_trampoline_size(const struct btf_func_model *m, u32 flags,
+                                   struct bpf_tramp_links *tlinks, void *func_addr)
+{
+       return -ENOTSUPP;
+}
+
 static int __init init_trampolines(void)
 {
        int i;