int btf_check_subprog_call(struct bpf_verifier_env *env, int subprog,
                           struct bpf_reg_state *regs);
 int btf_prepare_func_args(struct bpf_verifier_env *env, int subprog,
-                         struct bpf_reg_state *reg, bool is_ex_cb);
+                         struct bpf_reg_state *reg, u32 *nargs);
 int btf_check_type_match(struct bpf_verifier_log *log, const struct bpf_prog *prog,
                         struct btf *btf, const struct btf_type *t);
 const char *btf_find_decl_tag_value(const struct btf *btf, const struct btf_type *pt,
 
  * (either PTR_TO_CTX or SCALAR_VALUE).
  */
 int btf_prepare_func_args(struct bpf_verifier_env *env, int subprog,
-                         struct bpf_reg_state *regs, bool is_ex_cb)
+                         struct bpf_reg_state *regs, u32 *arg_cnt)
 {
        struct bpf_verifier_log *log = &env->log;
        struct bpf_prog *prog = env->prog;
                        tname, nargs, MAX_BPF_FUNC_REG_ARGS);
                return -EINVAL;
        }
+       *arg_cnt = nargs;
        /* check that function returns int, exception cb also requires this */
        t = btf_type_by_id(btf, t->type);
        while (btf_type_is_modifier(t))
                        i, btf_type_str(t), tname);
                return -EINVAL;
        }
-       /* We have already ensured that the callback returns an integer, just
-        * like all global subprogs. We need to determine it only has a single
-        * scalar argument.
-        */
-       if (is_ex_cb && (nargs != 1 || regs[BPF_REG_1].type != SCALAR_VALUE)) {
-               bpf_log(log, "exception cb only supports single integer argument\n");
-               return -EINVAL;
-       }
        return 0;
 }
 
 
        return &env->prog->aux->func_info_aux[subprog];
 }
 
+static struct bpf_subprog_info *subprog_info(struct bpf_verifier_env *env, int subprog)
+{
+       return &env->subprog_info[subprog];
+}
+
+static void mark_subprog_exc_cb(struct bpf_verifier_env *env, int subprog)
+{
+       struct bpf_subprog_info *info = subprog_info(env, subprog);
+
+       info->is_cb = true;
+       info->is_async_cb = true;
+       info->is_exception_cb = true;
+}
+
+static bool subprog_is_exc_cb(struct bpf_verifier_env *env, int subprog)
+{
+       return subprog_info(env, subprog)->is_exception_cb;
+}
+
 static bool reg_may_point_to_spin_lock(const struct bpf_reg_state *reg)
 {
        return btf_record_has_field(reg_btf_record(reg), BPF_SPIN_LOCK);
                        if (env->subprog_info[i].start != ex_cb_insn)
                                continue;
                        env->exception_callback_subprog = i;
+                       mark_subprog_exc_cb(env, i);
                        break;
                }
        }
 
                env->exception_callback_subprog = env->subprog_cnt - 1;
                /* Don't update insn_cnt, as add_hidden_subprog always appends insns */
-               env->subprog_info[env->exception_callback_subprog].is_cb = true;
-               env->subprog_info[env->exception_callback_subprog].is_async_cb = true;
-               env->subprog_info[env->exception_callback_subprog].is_exception_cb = true;
+               mark_subprog_exc_cb(env, env->exception_callback_subprog);
        }
 
        for (i = 0; i < insn_cnt; i++, insn++) {
        }
 }
 
-static int do_check_common(struct bpf_verifier_env *env, int subprog, bool is_ex_cb)
+static int do_check_common(struct bpf_verifier_env *env, int subprog)
 {
        bool pop_log = !(env->log.level & BPF_LOG_LEVEL2);
        struct bpf_verifier_state *state;
 
        regs = state->frame[state->curframe]->regs;
        if (subprog || env->prog->type == BPF_PROG_TYPE_EXT) {
-               ret = btf_prepare_func_args(env, subprog, regs, is_ex_cb);
+               u32 nargs;
+
+               ret = btf_prepare_func_args(env, subprog, regs, &nargs);
                if (ret)
                        goto out;
+               if (subprog_is_exc_cb(env, subprog)) {
+                       state->frame[0]->in_exception_callback_fn = true;
+                       /* We have already ensured that the callback returns an integer, just
+                        * like all global subprogs. We need to determine it only has a single
+                        * scalar argument.
+                        */
+                       if (nargs != 1 || regs[BPF_REG_1].type != SCALAR_VALUE) {
+                               verbose(env, "exception cb only supports single integer argument\n");
+                               ret = -EINVAL;
+                               goto out;
+                       }
+               }
                for (i = BPF_REG_1; i <= BPF_REG_5; i++) {
                        if (regs[i].type == PTR_TO_CTX)
                                mark_reg_known_zero(env, regs, i);
                                regs[i].id = ++env->id_gen;
                        }
                }
-               if (is_ex_cb) {
-                       state->frame[0]->in_exception_callback_fn = true;
-                       env->subprog_info[subprog].is_cb = true;
-                       env->subprog_info[subprog].is_async_cb = true;
-                       env->subprog_info[subprog].is_exception_cb = true;
-               }
        } else {
                /* 1st arg to a function */
                regs[BPF_REG_1].type = PTR_TO_CTX;
 
                env->insn_idx = env->subprog_info[i].start;
                WARN_ON_ONCE(env->insn_idx == 0);
-               ret = do_check_common(env, i, env->exception_callback_subprog == i);
+               ret = do_check_common(env, i);
                if (ret) {
                        return ret;
                } else if (env->log.level & BPF_LOG_LEVEL) {
        int ret;
 
        env->insn_idx = 0;
-       ret = do_check_common(env, 0, false);
+       ret = do_check_common(env, 0);
        if (!ret)
                env->prog->aux->stack_depth = env->subprog_info[0].stack_depth;
        return ret;