Commit 
b121b341e598 ("bpf: Add PTR_TO_BTF_ID_OR_NULL
support") adds a field btf_id_or_null_non0_off to
bpf_prog->aux structure to indicate that the
first ctx argument is PTR_TO_BTF_ID reg_type and
all others are PTR_TO_BTF_ID_OR_NULL.
This approach does not really scale if we have
other different reg types in the future, e.g.,
a pointer to a buffer.
This patch enables bpf_iter targets registering ctx argument
reg types which may be different from the default one.
For example, for pointers to structures, the default reg_type
is PTR_TO_BTF_ID for tracing program. The target can register
a particular pointer type as PTR_TO_BTF_ID_OR_NULL which can
be used by the verifier to enforce accesses.
Signed-off-by: Yonghong Song <yhs@fb.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Acked-by: Andrii Nakryiko <andriin@fb.com>
Link: https://lore.kernel.org/bpf/20200513180221.2949882-1-yhs@fb.com
        u16 reason;
 };
 
+/* reg_type info for ctx arguments */
+struct bpf_ctx_arg_aux {
+       u32 offset;
+       enum bpf_reg_type reg_type;
+};
+
 struct bpf_prog_aux {
        atomic64_t refcnt;
        u32 used_map_cnt;
        u32 func_cnt; /* used by non-func prog as the number of func progs */
        u32 func_idx; /* 0 for non-func prog, the index in func array for func prog */
        u32 attach_btf_id; /* in-kernel BTF type id to attach to */
+       u32 ctx_arg_info_size;
+       const struct bpf_ctx_arg_aux *ctx_arg_info;
        struct bpf_prog *linked_prog;
        bool verifier_zext; /* Zero extensions has been inserted by verifier. */
        bool offload_requested;
        bool attach_btf_trace; /* true if attaching to BTF-enabled raw tp */
        bool func_proto_unreliable;
-       bool btf_id_or_null_non0_off;
        enum bpf_tramp_prog_type trampoline_prog_type;
        struct bpf_trampoline *trampoline;
        struct hlist_node tramp_hlist;
 typedef int (*bpf_iter_init_seq_priv_t)(void *private_data);
 typedef void (*bpf_iter_fini_seq_priv_t)(void *private_data);
 
+#define BPF_ITER_CTX_ARG_MAX 2
 struct bpf_iter_reg {
        const char *target;
        const struct seq_operations *seq_ops;
        bpf_iter_init_seq_priv_t init_seq_private;
        bpf_iter_fini_seq_priv_t fini_seq_private;
        u32 seq_priv_size;
+       u32 ctx_arg_info_size;
+       struct bpf_ctx_arg_aux ctx_arg_info[BPF_ITER_CTX_ARG_MAX];
 };
 
 struct bpf_iter_meta {
 
        return !!(f6i->fib6_metrics->metrics[RTAX_LOCK - 1] & (1 << metric));
 }
 
+#if IS_BUILTIN(CONFIG_IPV6) && defined(CONFIG_BPF_SYSCALL)
+struct bpf_iter__ipv6_route {
+       __bpf_md_ptr(struct bpf_iter_meta *, meta);
+       __bpf_md_ptr(struct fib6_info *, rt);
+};
+#endif
+
 #ifdef CONFIG_IPV6_MULTIPLE_TABLES
 static inline bool fib6_has_custom_rules(const struct net *net)
 {
 
        }
        mutex_unlock(&targets_mutex);
 
+       if (supported) {
+               prog->aux->ctx_arg_info_size = tinfo->reg_info->ctx_arg_info_size;
+               prog->aux->ctx_arg_info = tinfo->reg_info->ctx_arg_info;
+       }
+
        return supported;
 }
 
 
        struct bpf_verifier_log *log = info->log;
        const struct btf_param *args;
        u32 nr_args, arg;
-       int ret;
+       int i, ret;
 
        if (off % 8) {
                bpf_log(log, "func '%s' offset %d is not multiple of 8\n",
                return true;
 
        /* this is a pointer to another type */
-       if (off != 0 && prog->aux->btf_id_or_null_non0_off)
-               info->reg_type = PTR_TO_BTF_ID_OR_NULL;
-       else
-               info->reg_type = PTR_TO_BTF_ID;
+       info->reg_type = PTR_TO_BTF_ID;
+       for (i = 0; i < prog->aux->ctx_arg_info_size; i++) {
+               const struct bpf_ctx_arg_aux *ctx_arg_info = &prog->aux->ctx_arg_info[i];
+
+               if (ctx_arg_info->offset == off) {
+                       info->reg_type = ctx_arg_info->reg_type;
+                       break;
+               }
+       }
 
        if (tgt_prog) {
                ret = btf_translate_to_vmlinux(log, btf, t, tgt_prog->type, arg);
 
        .init_seq_private       = NULL,
        .fini_seq_private       = NULL,
        .seq_priv_size          = sizeof(struct bpf_iter_seq_map_info),
+       .ctx_arg_info_size      = 1,
+       .ctx_arg_info           = {
+               { offsetof(struct bpf_iter__bpf_map, map),
+                 PTR_TO_BTF_ID_OR_NULL },
+       },
 };
 
 static int __init bpf_map_iter_init(void)
 
        .init_seq_private       = init_seq_pidns,
        .fini_seq_private       = fini_seq_pidns,
        .seq_priv_size          = sizeof(struct bpf_iter_seq_task_info),
+       .ctx_arg_info_size      = 1,
+       .ctx_arg_info           = {
+               { offsetof(struct bpf_iter__task, task),
+                 PTR_TO_BTF_ID_OR_NULL },
+       },
 };
 
 static const struct bpf_iter_reg task_file_reg_info = {
        .init_seq_private       = init_seq_pidns,
        .fini_seq_private       = fini_seq_pidns,
        .seq_priv_size          = sizeof(struct bpf_iter_seq_task_file_info),
+       .ctx_arg_info_size      = 2,
+       .ctx_arg_info           = {
+               { offsetof(struct bpf_iter__task_file, task),
+                 PTR_TO_BTF_ID_OR_NULL },
+               { offsetof(struct bpf_iter__task_file, file),
+                 PTR_TO_BTF_ID_OR_NULL },
+       },
 };
 
 static int __init task_iter_init(void)
 
                prog->aux->attach_func_proto = t;
                if (!bpf_iter_prog_supported(prog))
                        return -EINVAL;
-               prog->aux->btf_id_or_null_non0_off = true;
                ret = btf_distill_func_proto(&env->log, btf, t,
                                             tname, &fmodel);
                return ret;
 
 }
 
 #if IS_BUILTIN(CONFIG_IPV6) && defined(CONFIG_BPF_SYSCALL)
-struct bpf_iter__ipv6_route {
-       __bpf_md_ptr(struct bpf_iter_meta *, meta);
-       __bpf_md_ptr(struct fib6_info *, rt);
-};
-
 static int ipv6_route_prog_seq_show(struct bpf_prog *prog,
                                    struct bpf_iter_meta *meta,
                                    void *v)
 
        .init_seq_private       = bpf_iter_init_seq_net,
        .fini_seq_private       = bpf_iter_fini_seq_net,
        .seq_priv_size          = sizeof(struct ipv6_route_iter),
+       .ctx_arg_info_size      = 1,
+       .ctx_arg_info           = {
+               { offsetof(struct bpf_iter__ipv6_route, rt),
+                 PTR_TO_BTF_ID_OR_NULL },
+       },
 };
 
 static int __init bpf_iter_register(void)
 
        .init_seq_private       = bpf_iter_init_seq_net,
        .fini_seq_private       = bpf_iter_fini_seq_net,
        .seq_priv_size          = sizeof(struct nl_seq_iter),
+       .ctx_arg_info_size      = 1,
+       .ctx_arg_info           = {
+               { offsetof(struct bpf_iter__netlink, sk),
+                 PTR_TO_BTF_ID_OR_NULL },
+       },
 };
 
 static int __init bpf_iter_register(void)