u32 jmp_history_cnt;
 };
 
-#define bpf_get_spilled_reg(slot, frame)                               \
+#define bpf_get_spilled_reg(slot, frame, mask)                         \
        (((slot < frame->allocated_stack / BPF_REG_SIZE) &&             \
-         (frame->stack[slot].slot_type[0] == STACK_SPILL))             \
+         ((1 << frame->stack[slot].slot_type[0]) & (mask))) \
         ? &frame->stack[slot].spilled_ptr : NULL)
 
 /* Iterate over 'frame', setting 'reg' to either NULL or a spilled register. */
-#define bpf_for_each_spilled_reg(iter, frame, reg)                     \
-       for (iter = 0, reg = bpf_get_spilled_reg(iter, frame);          \
+#define bpf_for_each_spilled_reg(iter, frame, reg, mask)                       \
+       for (iter = 0, reg = bpf_get_spilled_reg(iter, frame, mask);            \
             iter < frame->allocated_stack / BPF_REG_SIZE;              \
-            iter++, reg = bpf_get_spilled_reg(iter, frame))
+            iter++, reg = bpf_get_spilled_reg(iter, frame, mask))
 
-/* Invoke __expr over regsiters in __vst, setting __state and __reg */
-#define bpf_for_each_reg_in_vstate(__vst, __state, __reg, __expr)   \
+#define bpf_for_each_reg_in_vstate_mask(__vst, __state, __reg, __mask, __expr)   \
        ({                                                               \
                struct bpf_verifier_state *___vstate = __vst;            \
                int ___i, ___j;                                          \
                                __reg = &___regs[___j];                  \
                                (void)(__expr);                          \
                        }                                                \
-                       bpf_for_each_spilled_reg(___j, __state, __reg) { \
+                       bpf_for_each_spilled_reg(___j, __state, __reg, __mask) { \
                                if (!__reg)                              \
                                        continue;                        \
                                (void)(__expr);                          \
                }                                                        \
        })
 
+/* Invoke __expr over regsiters in __vst, setting __state and __reg */
+#define bpf_for_each_reg_in_vstate(__vst, __state, __reg, __expr) \
+       bpf_for_each_reg_in_vstate_mask(__vst, __state, __reg, 1 << STACK_SPILL, __expr)
+
 /* linked list of verifier states used to prune search */
 struct bpf_verifier_state_list {
        struct bpf_verifier_state state;
 
 BTF_ID_FLAGS(func, bpf_iter_css_task_new, KF_ITER_NEW | KF_TRUSTED_ARGS)
 BTF_ID_FLAGS(func, bpf_iter_css_task_next, KF_ITER_NEXT | KF_RET_NULL)
 BTF_ID_FLAGS(func, bpf_iter_css_task_destroy, KF_ITER_DESTROY)
-BTF_ID_FLAGS(func, bpf_iter_task_new, KF_ITER_NEW | KF_TRUSTED_ARGS)
+BTF_ID_FLAGS(func, bpf_iter_task_new, KF_ITER_NEW | KF_TRUSTED_ARGS | KF_RCU_PROTECTED)
 BTF_ID_FLAGS(func, bpf_iter_task_next, KF_ITER_NEXT | KF_RET_NULL)
 BTF_ID_FLAGS(func, bpf_iter_task_destroy, KF_ITER_DESTROY)
-BTF_ID_FLAGS(func, bpf_iter_css_new, KF_ITER_NEW | KF_TRUSTED_ARGS)
+BTF_ID_FLAGS(func, bpf_iter_css_new, KF_ITER_NEW | KF_TRUSTED_ARGS | KF_RCU_PROTECTED)
 BTF_ID_FLAGS(func, bpf_iter_css_next, KF_ITER_NEXT | KF_RET_NULL)
 BTF_ID_FLAGS(func, bpf_iter_css_destroy, KF_ITER_DESTROY)
 BTF_ID_FLAGS(func, bpf_dynptr_adjust)
 
 
 static void __mark_reg_known_zero(struct bpf_reg_state *reg);
 
+static bool in_rcu_cs(struct bpf_verifier_env *env);
+
+static bool is_kfunc_rcu_protected(struct bpf_kfunc_call_arg_meta *meta);
+
 static int mark_stack_slots_iter(struct bpf_verifier_env *env,
+                                struct bpf_kfunc_call_arg_meta *meta,
                                 struct bpf_reg_state *reg, int insn_idx,
                                 struct btf *btf, u32 btf_id, int nr_slots)
 {
 
                __mark_reg_known_zero(st);
                st->type = PTR_TO_STACK; /* we don't have dedicated reg type */
+               if (is_kfunc_rcu_protected(meta)) {
+                       if (in_rcu_cs(env))
+                               st->type |= MEM_RCU;
+                       else
+                               st->type |= PTR_UNTRUSTED;
+               }
                st->live |= REG_LIVE_WRITTEN;
                st->ref_obj_id = i == 0 ? id : 0;
                st->iter.btf = btf;
        return true;
 }
 
-static bool is_iter_reg_valid_init(struct bpf_verifier_env *env, struct bpf_reg_state *reg,
+static int is_iter_reg_valid_init(struct bpf_verifier_env *env, struct bpf_reg_state *reg,
                                   struct btf *btf, u32 btf_id, int nr_slots)
 {
        struct bpf_func_state *state = func(env, reg);
 
        spi = iter_get_spi(env, reg, nr_slots);
        if (spi < 0)
-               return false;
+               return -EINVAL;
 
        for (i = 0; i < nr_slots; i++) {
                struct bpf_stack_state *slot = &state->stack[spi - i];
                struct bpf_reg_state *st = &slot->spilled_ptr;
 
+               if (st->type & PTR_UNTRUSTED)
+                       return -EPROTO;
                /* only main (first) slot has ref_obj_id set */
                if (i == 0 && !st->ref_obj_id)
-                       return false;
+                       return -EINVAL;
                if (i != 0 && st->ref_obj_id)
-                       return false;
+                       return -EINVAL;
                if (st->iter.btf != btf || st->iter.btf_id != btf_id)
-                       return false;
+                       return -EINVAL;
 
                for (j = 0; j < BPF_REG_SIZE; j++)
                        if (slot->slot_type[j] != STACK_ITER)
-                               return false;
+                               return -EINVAL;
        }
 
-       return true;
+       return 0;
 }
 
 /* Check if given stack slot is "special":
                                return err;
                }
 
-               err = mark_stack_slots_iter(env, reg, insn_idx, meta->btf, btf_id, nr_slots);
+               err = mark_stack_slots_iter(env, meta, reg, insn_idx, meta->btf, btf_id, nr_slots);
                if (err)
                        return err;
        } else {
                /* iter_next() or iter_destroy() expect initialized iter state*/
-               if (!is_iter_reg_valid_init(env, reg, meta->btf, btf_id, nr_slots)) {
+               err = is_iter_reg_valid_init(env, reg, meta->btf, btf_id, nr_slots);
+               switch (err) {
+               case 0:
+                       break;
+               case -EINVAL:
                        verbose(env, "expected an initialized iter_%s as arg #%d\n",
                                iter_type_str(meta->btf, btf_id), regno);
-                       return -EINVAL;
+                       return err;
+               case -EPROTO:
+                       verbose(env, "expected an RCU CS when using %s\n", meta->func_name);
+                       return err;
+               default:
+                       return err;
                }
 
                spi = iter_get_spi(env, reg, nr_slots);
        return meta->kfunc_flags & KF_RCU;
 }
 
+static bool is_kfunc_rcu_protected(struct bpf_kfunc_call_arg_meta *meta)
+{
+       return meta->kfunc_flags & KF_RCU_PROTECTED;
+}
+
 static bool __kfunc_param_match_suffix(const struct btf *btf,
                                       const struct btf_param *arg,
                                       const char *suffix)
        if (env->cur_state->active_rcu_lock) {
                struct bpf_func_state *state;
                struct bpf_reg_state *reg;
+               u32 clear_mask = (1 << STACK_SPILL) | (1 << STACK_ITER);
 
                if (in_rbtree_lock_required_cb(env) && (rcu_lock || rcu_unlock)) {
                        verbose(env, "Calling bpf_rcu_read_{lock,unlock} in unnecessary rbtree callback\n");
                        verbose(env, "nested rcu read lock (kernel function %s)\n", func_name);
                        return -EINVAL;
                } else if (rcu_unlock) {
-                       bpf_for_each_reg_in_vstate(env->cur_state, state, reg, ({
+                       bpf_for_each_reg_in_vstate_mask(env->cur_state, state, reg, clear_mask, ({
                                if (reg->type & MEM_RCU) {
                                        reg->type &= ~(MEM_RCU | PTR_MAYBE_NULL);
                                        reg->type |= PTR_UNTRUSTED;