union bpf_attr;
 struct btf_show;
 struct btf_id_set;
+struct bpf_prog;
+
+typedef int (*btf_kfunc_filter_t)(const struct bpf_prog *prog, u32 kfunc_id);
 
 struct btf_kfunc_id_set {
        struct module *owner;
        struct btf_id_set8 *set;
+       btf_kfunc_filter_t filter;
 };
 
 struct btf_id_dtor_kfunc {
        return bsearch(&id, set->pairs, set->cnt, sizeof(set->pairs[0]), btf_id_cmp_func);
 }
 
-struct bpf_prog;
 struct bpf_verifier_log;
 
 #ifdef CONFIG_BPF_SYSCALL
 const char *btf_name_by_offset(const struct btf *btf, u32 offset);
 struct btf *btf_parse_vmlinux(void);
 struct btf *bpf_prog_get_target_btf(const struct bpf_prog *prog);
-u32 *btf_kfunc_id_set_contains(const struct btf *btf,
-                              enum bpf_prog_type prog_type,
-                              u32 kfunc_btf_id);
-u32 *btf_kfunc_is_modify_return(const struct btf *btf, u32 kfunc_btf_id);
+u32 *btf_kfunc_id_set_contains(const struct btf *btf, u32 kfunc_btf_id,
+                              const struct bpf_prog *prog);
+u32 *btf_kfunc_is_modify_return(const struct btf *btf, u32 kfunc_btf_id,
+                               const struct bpf_prog *prog);
 int register_btf_kfunc_id_set(enum bpf_prog_type prog_type,
                              const struct btf_kfunc_id_set *s);
 int register_btf_fmodret_id_set(const struct btf_kfunc_id_set *kset);
        return NULL;
 }
 static inline u32 *btf_kfunc_id_set_contains(const struct btf *btf,
-                                            enum bpf_prog_type prog_type,
-                                            u32 kfunc_btf_id)
+                                            u32 kfunc_btf_id,
+                                            struct bpf_prog *prog)
+
 {
        return NULL;
 }
 
 enum {
        BTF_KFUNC_SET_MAX_CNT = 256,
        BTF_DTOR_KFUNC_MAX_CNT = 256,
+       BTF_KFUNC_FILTER_MAX_CNT = 16,
+};
+
+struct btf_kfunc_hook_filter {
+       btf_kfunc_filter_t filters[BTF_KFUNC_FILTER_MAX_CNT];
+       u32 nr_filters;
 };
 
 struct btf_kfunc_set_tab {
        struct btf_id_set8 *sets[BTF_KFUNC_HOOK_MAX];
+       struct btf_kfunc_hook_filter hook_filters[BTF_KFUNC_HOOK_MAX];
 };
 
 struct btf_id_dtor_kfunc_tab {
 /* Kernel Function (kfunc) BTF ID set registration API */
 
 static int btf_populate_kfunc_set(struct btf *btf, enum btf_kfunc_hook hook,
-                                 struct btf_id_set8 *add_set)
+                                 const struct btf_kfunc_id_set *kset)
 {
+       struct btf_kfunc_hook_filter *hook_filter;
+       struct btf_id_set8 *add_set = kset->set;
        bool vmlinux_set = !btf_is_module(btf);
+       bool add_filter = !!kset->filter;
        struct btf_kfunc_set_tab *tab;
        struct btf_id_set8 *set;
        u32 set_cnt;
                return 0;
 
        tab = btf->kfunc_set_tab;
+
+       if (tab && add_filter) {
+               u32 i;
+
+               hook_filter = &tab->hook_filters[hook];
+               for (i = 0; i < hook_filter->nr_filters; i++) {
+                       if (hook_filter->filters[i] == kset->filter) {
+                               add_filter = false;
+                               break;
+                       }
+               }
+
+               if (add_filter && hook_filter->nr_filters == BTF_KFUNC_FILTER_MAX_CNT) {
+                       ret = -E2BIG;
+                       goto end;
+               }
+       }
+
        if (!tab) {
                tab = kzalloc(sizeof(*tab), GFP_KERNEL | __GFP_NOWARN);
                if (!tab)
         */
        if (!vmlinux_set) {
                tab->sets[hook] = add_set;
-               return 0;
+               goto do_add_filter;
        }
 
        /* In case of vmlinux sets, there may be more than one set being
 
        sort(set->pairs, set->cnt, sizeof(set->pairs[0]), btf_id_cmp_func, NULL);
 
+do_add_filter:
+       if (add_filter) {
+               hook_filter = &tab->hook_filters[hook];
+               hook_filter->filters[hook_filter->nr_filters++] = kset->filter;
+       }
        return 0;
 end:
        btf_free_kfunc_set_tab(btf);
 
 static u32 *__btf_kfunc_id_set_contains(const struct btf *btf,
                                        enum btf_kfunc_hook hook,
-                                       u32 kfunc_btf_id)
+                                       u32 kfunc_btf_id,
+                                       const struct bpf_prog *prog)
 {
+       struct btf_kfunc_hook_filter *hook_filter;
        struct btf_id_set8 *set;
-       u32 *id;
+       u32 *id, i;
 
        if (hook >= BTF_KFUNC_HOOK_MAX)
                return NULL;
        if (!btf->kfunc_set_tab)
                return NULL;
+       hook_filter = &btf->kfunc_set_tab->hook_filters[hook];
+       for (i = 0; i < hook_filter->nr_filters; i++) {
+               if (hook_filter->filters[i](prog, kfunc_btf_id))
+                       return NULL;
+       }
        set = btf->kfunc_set_tab->sets[hook];
        if (!set)
                return NULL;
  * protection for looking up a well-formed btf->kfunc_set_tab.
  */
 u32 *btf_kfunc_id_set_contains(const struct btf *btf,
-                              enum bpf_prog_type prog_type,
-                              u32 kfunc_btf_id)
+                              u32 kfunc_btf_id,
+                              const struct bpf_prog *prog)
 {
+       enum bpf_prog_type prog_type = resolve_prog_type(prog);
        enum btf_kfunc_hook hook;
        u32 *kfunc_flags;
 
-       kfunc_flags = __btf_kfunc_id_set_contains(btf, BTF_KFUNC_HOOK_COMMON, kfunc_btf_id);
+       kfunc_flags = __btf_kfunc_id_set_contains(btf, BTF_KFUNC_HOOK_COMMON, kfunc_btf_id, prog);
        if (kfunc_flags)
                return kfunc_flags;
 
        hook = bpf_prog_type_to_kfunc_hook(prog_type);
-       return __btf_kfunc_id_set_contains(btf, hook, kfunc_btf_id);
+       return __btf_kfunc_id_set_contains(btf, hook, kfunc_btf_id, prog);
 }
 
-u32 *btf_kfunc_is_modify_return(const struct btf *btf, u32 kfunc_btf_id)
+u32 *btf_kfunc_is_modify_return(const struct btf *btf, u32 kfunc_btf_id,
+                               const struct bpf_prog *prog)
 {
-       return __btf_kfunc_id_set_contains(btf, BTF_KFUNC_HOOK_FMODRET, kfunc_btf_id);
+       return __btf_kfunc_id_set_contains(btf, BTF_KFUNC_HOOK_FMODRET, kfunc_btf_id, prog);
 }
 
 static int __register_btf_kfunc_id_set(enum btf_kfunc_hook hook,
                        goto err_out;
        }
 
-       ret = btf_populate_kfunc_set(btf, hook, kset->set);
+       ret = btf_populate_kfunc_set(btf, hook, kset);
+
 err_out:
        btf_put(btf);
        return ret;
 
                *kfunc_name = func_name;
        func_proto = btf_type_by_id(desc_btf, func->type);
 
-       kfunc_flags = btf_kfunc_id_set_contains(desc_btf, resolve_prog_type(env->prog), func_id);
+       kfunc_flags = btf_kfunc_id_set_contains(desc_btf, func_id, env->prog);
        if (!kfunc_flags) {
                return -EACCES;
        }
                                 * in the fmodret id set with the KF_SLEEPABLE flag.
                                 */
                                else {
-                                       u32 *flags = btf_kfunc_is_modify_return(btf, btf_id);
+                                       u32 *flags = btf_kfunc_is_modify_return(btf, btf_id,
+                                                                               prog);
 
                                        if (flags && (*flags & KF_SLEEPABLE))
                                                ret = 0;
                                return -EINVAL;
                        }
                        ret = -EINVAL;
-                       if (btf_kfunc_is_modify_return(btf, btf_id) ||
+                       if (btf_kfunc_is_modify_return(btf, btf_id, prog) ||
                            !check_attach_modify_return(addr, tname))
                                ret = 0;
                        if (ret) {