insn->src_reg == BPF_PSEUDO_CALL;
 }
 
+static bool bpf_pseudo_func(const struct bpf_insn *insn)
+{
+       return insn->code == (BPF_LD | BPF_IMM | BPF_DW) &&
+              insn->src_reg == BPF_PSEUDO_FUNC;
+}
+
 struct bpf_call_arg_meta {
        struct bpf_map *map_ptr;
        bool raw_mode;
        u32 btf_id;
        struct btf *ret_btf;
        u32 ret_btf_id;
+       u32 subprogno;
 };
 
 struct btf *btf_vmlinux;
        return type == PTR_TO_SOCKET ||
                type == PTR_TO_TCP_SOCK ||
                type == PTR_TO_MAP_VALUE ||
+               type == PTR_TO_MAP_KEY ||
                type == PTR_TO_SOCK_COMMON;
 }
 
               type == ARG_PTR_TO_MEM_OR_NULL ||
               type == ARG_PTR_TO_CTX_OR_NULL ||
               type == ARG_PTR_TO_SOCKET_OR_NULL ||
-              type == ARG_PTR_TO_ALLOC_MEM_OR_NULL;
+              type == ARG_PTR_TO_ALLOC_MEM_OR_NULL ||
+              type == ARG_PTR_TO_STACK_OR_NULL;
 }
 
 /* Determine whether the function releases some resources allocated by another
        [PTR_TO_RDONLY_BUF_OR_NULL] = "rdonly_buf_or_null",
        [PTR_TO_RDWR_BUF]       = "rdwr_buf",
        [PTR_TO_RDWR_BUF_OR_NULL] = "rdwr_buf_or_null",
+       [PTR_TO_FUNC]           = "func",
+       [PTR_TO_MAP_KEY]        = "map_key",
 };
 
 static char slot_type_char[] = {
                        if (type_is_pkt_pointer(t))
                                verbose(env, ",r=%d", reg->range);
                        else if (t == CONST_PTR_TO_MAP ||
+                                t == PTR_TO_MAP_KEY ||
                                 t == PTR_TO_MAP_VALUE ||
                                 t == PTR_TO_MAP_VALUE_OR_NULL)
                                verbose(env, ",ks=%d,vs=%d",
 
        /* determine subprog starts. The end is one before the next starts */
        for (i = 0; i < insn_cnt; i++) {
+               if (bpf_pseudo_func(insn + i)) {
+                       if (!env->bpf_capable) {
+                               verbose(env,
+                                       "function pointers are allowed for CAP_BPF and CAP_SYS_ADMIN\n");
+                               return -EPERM;
+                       }
+                       ret = add_subprog(env, i + insn[i].imm + 1);
+                       if (ret < 0)
+                               return ret;
+                       /* remember subprog */
+                       insn[i + 1].imm = ret;
+                       continue;
+               }
                if (!bpf_pseudo_call(insn + i))
                        continue;
                if (!env->bpf_capable) {
        case PTR_TO_PERCPU_BTF_ID:
        case PTR_TO_MEM:
        case PTR_TO_MEM_OR_NULL:
+       case PTR_TO_FUNC:
+       case PTR_TO_MAP_KEY:
                return true;
        default:
                return false;
 
        reg = &cur_regs(env)[regno];
        switch (reg->type) {
+       case PTR_TO_MAP_KEY:
+               verbose(env, "invalid access to map key, key_size=%d off=%d size=%d\n",
+                       mem_size, off, size);
+               break;
        case PTR_TO_MAP_VALUE:
                verbose(env, "invalid access to map value, value_size=%d off=%d size=%d\n",
                        mem_size, off, size);
        case PTR_TO_FLOW_KEYS:
                pointer_desc = "flow keys ";
                break;
+       case PTR_TO_MAP_KEY:
+               pointer_desc = "key ";
+               break;
        case PTR_TO_MAP_VALUE:
                pointer_desc = "value ";
                break;
 continue_func:
        subprog_end = subprog[idx + 1].start;
        for (; i < subprog_end; i++) {
-               if (!bpf_pseudo_call(insn + i))
+               if (!bpf_pseudo_call(insn + i) && !bpf_pseudo_func(insn + i))
                        continue;
                /* remember insn and function to return to */
                ret_insn[frame] = i + 1;
        /* for access checks, reg->off is just part of off */
        off += reg->off;
 
-       if (reg->type == PTR_TO_MAP_VALUE) {
+       if (reg->type == PTR_TO_MAP_KEY) {
+               if (t == BPF_WRITE) {
+                       verbose(env, "write to change key R%d not allowed\n", regno);
+                       return -EACCES;
+               }
+
+               err = check_mem_region_access(env, regno, off, size,
+                                             reg->map_ptr->key_size, false);
+               if (err)
+                       return err;
+               if (value_regno >= 0)
+                       mark_reg_unknown(env, regs, value_regno);
+       } else if (reg->type == PTR_TO_MAP_VALUE) {
                if (t == BPF_WRITE && value_regno >= 0 &&
                    is_pointer_value(env, value_regno)) {
                        verbose(env, "R%d leaks addr into map\n", value_regno);
        case PTR_TO_PACKET_META:
                return check_packet_access(env, regno, reg->off, access_size,
                                           zero_size_allowed);
+       case PTR_TO_MAP_KEY:
+               return check_mem_region_access(env, regno, reg->off, access_size,
+                                              reg->map_ptr->key_size, false);
        case PTR_TO_MAP_VALUE:
                if (check_map_access_type(env, regno, reg->off, access_size,
                                          meta && meta->raw_mode ? BPF_WRITE :
                PTR_TO_STACK,
                PTR_TO_PACKET,
                PTR_TO_PACKET_META,
+               PTR_TO_MAP_KEY,
                PTR_TO_MAP_VALUE,
        },
 };
                PTR_TO_STACK,
                PTR_TO_PACKET,
                PTR_TO_PACKET_META,
+               PTR_TO_MAP_KEY,
                PTR_TO_MAP_VALUE,
                PTR_TO_MEM,
                PTR_TO_RDONLY_BUF,
                PTR_TO_STACK,
                PTR_TO_PACKET,
                PTR_TO_PACKET_META,
+               PTR_TO_MAP_KEY,
                PTR_TO_MAP_VALUE,
        },
 };
 static const struct bpf_reg_types btf_ptr_types = { .types = { PTR_TO_BTF_ID } };
 static const struct bpf_reg_types spin_lock_types = { .types = { PTR_TO_MAP_VALUE } };
 static const struct bpf_reg_types percpu_btf_ptr_types = { .types = { PTR_TO_PERCPU_BTF_ID } };
+static const struct bpf_reg_types func_ptr_types = { .types = { PTR_TO_FUNC } };
+static const struct bpf_reg_types stack_ptr_types = { .types = { PTR_TO_STACK } };
 
 static const struct bpf_reg_types *compatible_reg_types[__BPF_ARG_TYPE_MAX] = {
        [ARG_PTR_TO_MAP_KEY]            = &map_key_value_types,
        [ARG_PTR_TO_INT]                = &int_ptr_types,
        [ARG_PTR_TO_LONG]               = &int_ptr_types,
        [ARG_PTR_TO_PERCPU_BTF_ID]      = &percpu_btf_ptr_types,
+       [ARG_PTR_TO_FUNC]               = &func_ptr_types,
+       [ARG_PTR_TO_STACK_OR_NULL]      = &stack_ptr_types,
 };
 
 static int check_reg_type(struct bpf_verifier_env *env, u32 regno,
                        verbose(env, "verifier internal error\n");
                        return -EFAULT;
                }
+       } else if (arg_type == ARG_PTR_TO_FUNC) {
+               meta->subprogno = reg->subprogno;
        } else if (arg_type_is_mem_ptr(arg_type)) {
                /* The access to this pointer is only checked when we hit the
                 * next is_mem_size argument below.
        return __check_func_call(env, insn, insn_idx, subprog, set_callee_state);
 }
 
+static int set_map_elem_callback_state(struct bpf_verifier_env *env,
+                                      struct bpf_func_state *caller,
+                                      struct bpf_func_state *callee,
+                                      int insn_idx)
+{
+       struct bpf_insn_aux_data *insn_aux = &env->insn_aux_data[insn_idx];
+       struct bpf_map *map;
+       int err;
+
+       if (bpf_map_ptr_poisoned(insn_aux)) {
+               verbose(env, "tail_call abusing map_ptr\n");
+               return -EINVAL;
+       }
+
+       map = BPF_MAP_PTR(insn_aux->map_ptr_state);
+       if (!map->ops->map_set_for_each_callback_args ||
+           !map->ops->map_for_each_callback) {
+               verbose(env, "callback function not allowed for map\n");
+               return -ENOTSUPP;
+       }
+
+       err = map->ops->map_set_for_each_callback_args(env, caller, callee);
+       if (err)
+               return err;
+
+       callee->in_callback_fn = true;
+       return 0;
+}
+
 static int prepare_func_exit(struct bpf_verifier_env *env, int *insn_idx)
 {
        struct bpf_verifier_state *state = env->cur_state;
 
        state->curframe--;
        caller = state->frame[state->curframe];
-       /* return to the caller whatever r0 had in the callee */
-       caller->regs[BPF_REG_0] = *r0;
+       if (callee->in_callback_fn) {
+               /* enforce R0 return value range [0, 1]. */
+               struct tnum range = tnum_range(0, 1);
+
+               if (r0->type != SCALAR_VALUE) {
+                       verbose(env, "R0 not a scalar value\n");
+                       return -EACCES;
+               }
+               if (!tnum_in(range, r0->var_off)) {
+                       verbose_invalid_scalar(env, r0, &range, "callback return", "R0");
+                       return -EINVAL;
+               }
+       } else {
+               /* return to the caller whatever r0 had in the callee */
+               caller->regs[BPF_REG_0] = *r0;
+       }
 
        /* Transfer references to the caller */
        err = transfer_reference_state(caller, callee);
            func_id != BPF_FUNC_map_delete_elem &&
            func_id != BPF_FUNC_map_push_elem &&
            func_id != BPF_FUNC_map_pop_elem &&
-           func_id != BPF_FUNC_map_peek_elem)
+           func_id != BPF_FUNC_map_peek_elem &&
+           func_id != BPF_FUNC_for_each_map_elem)
                return 0;
 
        if (map == NULL) {
        return state->acquired_refs ? -EINVAL : 0;
 }
 
-static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn_idx)
+static int check_helper_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
+                            int *insn_idx_p)
 {
        const struct bpf_func_proto *fn = NULL;
        struct bpf_reg_state *regs;
        struct bpf_call_arg_meta meta;
+       int insn_idx = *insn_idx_p;
        bool changes_data;
-       int i, err;
+       int i, err, func_id;
 
        /* find function prototype */
+       func_id = insn->imm;
        if (func_id < 0 || func_id >= __BPF_FUNC_MAX_ID) {
                verbose(env, "invalid func %s#%d\n", func_id_name(func_id),
                        func_id);
                return -EINVAL;
        }
 
+       if (func_id == BPF_FUNC_for_each_map_elem) {
+               err = __check_func_call(env, insn, insn_idx_p, meta.subprogno,
+                                       set_map_elem_callback_state);
+               if (err < 0)
+                       return -EINVAL;
+       }
+
        /* reset caller saved regs */
        for (i = 0; i < CALLER_SAVED_REGS; i++) {
                mark_reg_not_init(env, regs, caller_saved[i]);
                else
                        *ptr_limit = -off;
                return 0;
+       case PTR_TO_MAP_KEY:
+               /* Currently, this code is not exercised as the only use
+                * is bpf_for_each_map_elem() helper which requires
+                * bpf_capble. The code has been tested manually for
+                * future use.
+                */
+               if (mask_to_left) {
+                       *ptr_limit = ptr_reg->umax_value + ptr_reg->off;
+               } else {
+                       off = ptr_reg->smin_value + ptr_reg->off;
+                       *ptr_limit = ptr_reg->map_ptr->key_size - off;
+               }
+               return 0;
        case PTR_TO_MAP_VALUE:
                if (mask_to_left) {
                        *ptr_limit = ptr_reg->umax_value + ptr_reg->off;
                verbose(env, "R%d pointer arithmetic on %s prohibited\n",
                        dst, reg_type_str[ptr_reg->type]);
                return -EACCES;
+       case PTR_TO_MAP_KEY:
        case PTR_TO_MAP_VALUE:
                if (!env->allow_ptr_leaks && !known && (smin_val < 0) != (smax_val < 0)) {
                        verbose(env, "R%d has unknown scalar with mixed signed bounds, pointer arithmetic with it prohibited for !root\n",
                return 0;
        }
 
+       if (insn->src_reg == BPF_PSEUDO_FUNC) {
+               struct bpf_prog_aux *aux = env->prog->aux;
+               u32 subprogno = insn[1].imm;
+
+               if (!aux->func_info) {
+                       verbose(env, "missing btf func_info\n");
+                       return -EINVAL;
+               }
+               if (aux->func_info_aux[subprogno].linkage != BTF_FUNC_STATIC) {
+                       verbose(env, "callback function not static\n");
+                       return -EINVAL;
+               }
+
+               dst_reg->type = PTR_TO_FUNC;
+               dst_reg->subprogno = subprogno;
+               return 0;
+       }
+
        map = env->used_maps[aux->map_index];
        mark_reg_known_zero(env, regs, insn->dst_reg);
        dst_reg->map_ptr = map;
        struct bpf_insn *insns = env->prog->insnsi;
        int ret;
 
+       if (bpf_pseudo_func(insns + t))
+               return visit_func_call_insn(t, insn_cnt, insns, env, true);
+
        /* All non-branch instructions have a single fall-through edge. */
        if (BPF_CLASS(insns[t].code) != BPF_JMP &&
            BPF_CLASS(insns[t].code) != BPF_JMP32)
                         */
                        return false;
                }
+       case PTR_TO_MAP_KEY:
        case PTR_TO_MAP_VALUE:
                /* If the new min/max/var_off satisfy the old ones and
                 * everything else matches, we are OK.
                                if (insn->src_reg == BPF_PSEUDO_CALL)
                                        err = check_func_call(env, insn, &env->insn_idx);
                                else
-                                       err = check_helper_call(env, insn->imm, env->insn_idx);
+                                       err = check_helper_call(env, insn, &env->insn_idx);
                                if (err)
                                        return err;
-
                        } else if (opcode == BPF_JA) {
                                if (BPF_SRC(insn->code) != BPF_K ||
                                    insn->imm != 0 ||
                                goto next_insn;
                        }
 
+                       if (insn[0].src_reg == BPF_PSEUDO_FUNC) {
+                               aux = &env->insn_aux_data[i];
+                               aux->ptr_type = PTR_TO_FUNC;
+                               goto next_insn;
+                       }
+
                        /* In final convert_pseudo_ld_imm64() step, this is
                         * converted into regular 64-bit imm load insn.
                         */
        int insn_cnt = env->prog->len;
        int i;
 
-       for (i = 0; i < insn_cnt; i++, insn++)
-               if (insn->code == (BPF_LD | BPF_IMM | BPF_DW))
-                       insn->src_reg = 0;
+       for (i = 0; i < insn_cnt; i++, insn++) {
+               if (insn->code != (BPF_LD | BPF_IMM | BPF_DW))
+                       continue;
+               if (insn->src_reg == BPF_PSEUDO_FUNC)
+                       continue;
+               insn->src_reg = 0;
+       }
 }
 
 /* single env->prog->insni[off] instruction was replaced with the range
                return 0;
 
        for (i = 0, insn = prog->insnsi; i < prog->len; i++, insn++) {
+               if (bpf_pseudo_func(insn)) {
+                       env->insn_aux_data[i].call_imm = insn->imm;
+                       /* subprog is encoded in insn[1].imm */
+                       continue;
+               }
+
                if (!bpf_pseudo_call(insn))
                        continue;
                /* Upon error here we cannot fall back to interpreter but
        for (i = 0; i < env->subprog_cnt; i++) {
                insn = func[i]->insnsi;
                for (j = 0; j < func[i]->len; j++, insn++) {
+                       if (bpf_pseudo_func(insn)) {
+                               subprog = insn[1].imm;
+                               insn[0].imm = (u32)(long)func[subprog]->bpf_func;
+                               insn[1].imm = ((u64)(long)func[subprog]->bpf_func) >> 32;
+                               continue;
+                       }
                        if (!bpf_pseudo_call(insn))
                                continue;
                        subprog = insn->off;
         * later look the same as if they were interpreted only.
         */
        for (i = 0, insn = prog->insnsi; i < prog->len; i++, insn++) {
+               if (bpf_pseudo_func(insn)) {
+                       insn[0].imm = env->insn_aux_data[i].call_imm;
+                       insn[1].imm = find_subprog(env, i + insn[0].imm + 1);
+                       continue;
+               }
                if (!bpf_pseudo_call(insn))
                        continue;
                insn->off = env->insn_aux_data[i].call_imm;
                return -EINVAL;
        }
        for (i = 0; i < prog->len; i++, insn++) {
+               if (bpf_pseudo_func(insn)) {
+                       /* When JIT fails the progs with callback calls
+                        * have to be rejected, since interpreter doesn't support them yet.
+                        */
+                       verbose(env, "callbacks are not allowed in non-JITed programs\n");
+                       return -EINVAL;
+               }
+
                if (!bpf_pseudo_call(insn))
                        continue;
                depth = get_callee_stack_depth(env, insn, i);