const struct bpf_func_state *func,
                              int off)
 {
-       u16 stack = env->subprog_stack_depth[func->subprogno], total = 0;
-       struct bpf_verifier_state *cur = env->cur_state;
-       int i;
+       u16 stack = env->subprog_stack_depth[func->subprogno];
 
        if (stack >= -off)
                return 0;
 
        /* update known max for given subprogram */
        env->subprog_stack_depth[func->subprogno] = -off;
+       return 0;
+}
 
-       /* compute the total for current call chain */
-       for (i = 0; i <= cur->curframe; i++) {
-               u32 depth = env->subprog_stack_depth[cur->frame[i]->subprogno];
-
-               /* round up to 32-bytes, since this is granularity
-                * of interpreter stack sizes
-                */
-               depth = round_up(depth, 32);
-               total += depth;
-       }
+/* starting from main bpf function walk all instructions of the function
+ * and recursively walk all callees that given function can call.
+ * Ignore jump and exit insns.
+ * Since recursion is prevented by check_cfg() this algorithm
+ * only needs a local stack of MAX_CALL_FRAMES to remember callsites
+ */
+static int check_max_stack_depth(struct bpf_verifier_env *env)
+{
+       int depth = 0, frame = 0, subprog = 0, i = 0, subprog_end;
+       struct bpf_insn *insn = env->prog->insnsi;
+       int insn_cnt = env->prog->len;
+       int ret_insn[MAX_CALL_FRAMES];
+       int ret_prog[MAX_CALL_FRAMES];
 
-       if (total > MAX_BPF_STACK) {
+process_func:
+       /* round up to 32-bytes, since this is granularity
+        * of interpreter stack size
+        */
+       depth += round_up(max_t(u32, env->subprog_stack_depth[subprog], 1), 32);
+       if (depth > MAX_BPF_STACK) {
                verbose(env, "combined stack size of %d calls is %d. Too large\n",
-                       cur->curframe, total);
+                       frame + 1, depth);
                return -EACCES;
        }
-       return 0;
+continue_func:
+       if (env->subprog_cnt == subprog)
+               subprog_end = insn_cnt;
+       else
+               subprog_end = env->subprog_starts[subprog];
+       for (; i < subprog_end; i++) {
+               if (insn[i].code != (BPF_JMP | BPF_CALL))
+                       continue;
+               if (insn[i].src_reg != BPF_PSEUDO_CALL)
+                       continue;
+               /* remember insn and function to return to */
+               ret_insn[frame] = i + 1;
+               ret_prog[frame] = subprog;
+
+               /* find the callee */
+               i = i + insn[i].imm + 1;
+               subprog = find_subprog(env, i);
+               if (subprog < 0) {
+                       WARN_ONCE(1, "verifier bug. No program starts at insn %d\n",
+                                 i);
+                       return -EFAULT;
+               }
+               subprog++;
+               frame++;
+               if (frame >= MAX_CALL_FRAMES) {
+                       WARN_ONCE(1, "verifier bug. Call stack is too deep\n");
+                       return -EFAULT;
+               }
+               goto process_func;
+       }
+       /* end of for() loop means the last insn of the 'subprog'
+        * was reached. Doesn't matter whether it was JA or EXIT
+        */
+       if (frame == 0)
+               return 0;
+       depth -= round_up(max_t(u32, env->subprog_stack_depth[subprog], 1), 32);
+       frame--;
+       i = ret_insn[frame];
+       subprog = ret_prog[frame];
+       goto continue_func;
 }
 
 static int get_callee_stack_depth(struct bpf_verifier_env *env,
        if (ret == 0)
                sanitize_dead_code(env);
 
+       if (ret == 0)
+               ret = check_max_stack_depth(env);
+
        if (ret == 0)
                /* program is valid, convert *(u32*)(ctx + off) accesses */
                ret = convert_ctx_accesses(env);