btf_get_prog_ctx_type(struct bpf_verifier_log *log, const struct btf *btf,
                      const struct btf_type *t, enum bpf_prog_type prog_type,
                      int arg);
+int get_kern_ctx_btf_id(struct bpf_verifier_log *log, enum bpf_prog_type prog_type);
 bool btf_types_are_same(const struct btf *btf1, u32 id1,
                        const struct btf *btf2, u32 id2);
 #else
 {
        return NULL;
 }
+static inline int get_kern_ctx_btf_id(struct bpf_verifier_log *log,
+                                     enum bpf_prog_type prog_type) {
+       return -EINVAL;
+}
 static inline bool btf_types_are_same(const struct btf *btf1, u32 id1,
                                      const struct btf *btf2, u32 id2)
 {
 
        return kern_ctx_type->type;
 }
 
+int get_kern_ctx_btf_id(struct bpf_verifier_log *log, enum bpf_prog_type prog_type)
+{
+       const struct btf_member *kctx_member;
+       const struct btf_type *conv_struct;
+       const struct btf_type *kctx_type;
+       u32 kctx_type_id;
+
+       conv_struct = bpf_ctx_convert.t;
+       /* get member for kernel ctx type */
+       kctx_member = btf_type_member(conv_struct) + bpf_ctx_convert_map[prog_type] * 2 + 1;
+       kctx_type_id = kctx_member->type;
+       kctx_type = btf_type_by_id(btf_vmlinux, kctx_type_id);
+       if (!btf_type_is_struct(kctx_type)) {
+               bpf_log(log, "kern ctx type id %u is not a struct\n", kctx_type_id);
+               return -EINVAL;
+       }
+
+       return kctx_type_id;
+}
+
 BTF_ID_LIST(bpf_ctx_convert_btf_id)
 BTF_ID(struct, bpf_ctx_convert)
 
 
        u32 ref_obj_id;
        u8 release_regno;
        bool r0_rdonly;
+       u32 ret_btf_id;
        u64 r0_size;
        struct {
                u64 value;
        KF_bpf_list_push_back,
        KF_bpf_list_pop_front,
        KF_bpf_list_pop_back,
+       KF_bpf_cast_to_kern_ctx,
 };
 
 BTF_SET_START(special_kfunc_set)
 BTF_ID(func, bpf_list_push_back)
 BTF_ID(func, bpf_list_pop_front)
 BTF_ID(func, bpf_list_pop_back)
+BTF_ID(func, bpf_cast_to_kern_ctx)
 BTF_SET_END(special_kfunc_set)
 
 BTF_ID_LIST(special_kfunc_list)
 BTF_ID(func, bpf_list_push_back)
 BTF_ID(func, bpf_list_pop_front)
 BTF_ID(func, bpf_list_pop_back)
+BTF_ID(func, bpf_cast_to_kern_ctx)
 
 static enum kfunc_ptr_arg_type
 get_kfunc_ptr_arg_type(struct bpf_verifier_env *env,
        struct bpf_reg_state *reg = ®s[regno];
        bool arg_mem_size = false;
 
+       if (meta->func_id == special_kfunc_list[KF_bpf_cast_to_kern_ctx])
+               return KF_ARG_PTR_TO_CTX;
+
        /* In this function, we verify the kfunc's BTF as per the argument type,
         * leaving the rest of the verification with respect to the register
         * type to our caller. When a set of conditions hold in the BTF type of
                                verbose(env, "arg#%d expected pointer to ctx, but got %s\n", i, btf_type_str(t));
                                return -EINVAL;
                        }
+
+                       if (meta->func_id == special_kfunc_list[KF_bpf_cast_to_kern_ctx]) {
+                               ret = get_kern_ctx_btf_id(&env->log, resolve_prog_type(env->prog));
+                               if (ret < 0)
+                                       return -EINVAL;
+                               meta->ret_btf_id  = ret;
+                       }
                        break;
                case KF_ARG_PTR_TO_ALLOC_BTF_ID:
                        if (reg->type != (PTR_TO_BTF_ID | MEM_ALLOC)) {
                                regs[BPF_REG_0].btf = field->list_head.btf;
                                regs[BPF_REG_0].btf_id = field->list_head.value_btf_id;
                                regs[BPF_REG_0].off = field->list_head.node_offset;
+                       } else if (meta.func_id == special_kfunc_list[KF_bpf_cast_to_kern_ctx]) {
+                               mark_reg_known_zero(env, regs, BPF_REG_0);
+                               regs[BPF_REG_0].type = PTR_TO_BTF_ID | PTR_TRUSTED;
+                               regs[BPF_REG_0].btf = desc_btf;
+                               regs[BPF_REG_0].btf_id = meta.ret_btf_id;
                        } else {
                                verbose(env, "kernel function %s unhandled dynamic return type\n",
                                        meta.func_name);
                insn_buf[1] = addr[1];
                insn_buf[2] = *insn;
                *cnt = 3;
+       } else if (desc->func_id == special_kfunc_list[KF_bpf_cast_to_kern_ctx]) {
+               insn_buf[0] = BPF_MOV64_REG(BPF_REG_0, BPF_REG_1);
+               *cnt = 1;
        }
        return 0;
 }