#define SOCK_CREATE_FLAG_MASK \
        (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
 
-struct bpf_stab {
-       struct bpf_map map;
-       struct sock **sock_map;
+struct bpf_sock_progs {
        struct bpf_prog *bpf_tx_msg;
        struct bpf_prog *bpf_parse;
        struct bpf_prog *bpf_verdict;
 };
 
+struct bpf_stab {
+       struct bpf_map map;
+       struct sock **sock_map;
+       struct bpf_sock_progs progs;
+};
+
 enum smap_psock_state {
        SMAP_TX_RUNNING,
 };
 static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
 {
        return ((_rc == SK_PASS) ?
-              (md->map ? __SK_REDIRECT : __SK_PASS) :
+              (md->sk_redir ? __SK_REDIRECT : __SK_PASS) :
               __SK_DROP);
 }
 
         * when we orphan the skb so that we don't have the possibility
         * to reference a stale map.
         */
-       TCP_SKB_CB(skb)->bpf.map = NULL;
+       TCP_SKB_CB(skb)->bpf.sk_redir = NULL;
        skb->sk = psock->sock;
        bpf_compute_data_pointers(skb);
        preempt_disable();
 
        /* Moving return codes from UAPI namespace into internal namespace */
        return rc == SK_PASS ?
-               (TCP_SKB_CB(skb)->bpf.map ? __SK_REDIRECT : __SK_PASS) :
+               (TCP_SKB_CB(skb)->bpf.sk_redir ? __SK_REDIRECT : __SK_PASS) :
                __SK_DROP;
 }
 
 }
 
 static void smap_init_progs(struct smap_psock *psock,
-                           struct bpf_stab *stab,
                            struct bpf_prog *verdict,
                            struct bpf_prog *parse)
 {
        kfree(psock);
 }
 
-static struct smap_psock *smap_init_psock(struct sock *sock,
-                                         struct bpf_stab *stab)
+static struct smap_psock *smap_init_psock(struct sock *sock, int node)
 {
        struct smap_psock *psock;
 
        psock = kzalloc_node(sizeof(struct smap_psock),
                             GFP_ATOMIC | __GFP_NOWARN,
-                            stab->map.numa_node);
+                            node);
        if (!psock)
                return ERR_PTR(-ENOMEM);
 
  *  - sock_map must use READ_ONCE and (cmp)xchg operations
  *  - BPF verdict/parse programs must use READ_ONCE and xchg operations
  */
-static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
-                                   struct bpf_map *map,
-                                   void *key, u64 flags)
+
+static int __sock_map_ctx_update_elem(struct bpf_map *map,
+                                     struct bpf_sock_progs *progs,
+                                     struct sock *sock,
+                                     struct sock **map_link,
+                                     void *key)
 {
-       struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
-       struct smap_psock_map_entry *e = NULL;
        struct bpf_prog *verdict, *parse, *tx_msg;
-       struct sock *osock, *sock;
+       struct smap_psock_map_entry *e = NULL;
        struct smap_psock *psock;
-       u32 i = *(u32 *)key;
        bool new = false;
        int err;
 
-       if (unlikely(flags > BPF_EXIST))
-               return -EINVAL;
-
-       if (unlikely(i >= stab->map.max_entries))
-               return -E2BIG;
-
-       sock = READ_ONCE(stab->sock_map[i]);
-       if (flags == BPF_EXIST && !sock)
-               return -ENOENT;
-       else if (flags == BPF_NOEXIST && sock)
-               return -EEXIST;
-
-       sock = skops->sk;
-
        /* 1. If sock map has BPF programs those will be inherited by the
         * sock being added. If the sock is already attached to BPF programs
         * this results in an error.
         */
-       verdict = READ_ONCE(stab->bpf_verdict);
-       parse = READ_ONCE(stab->bpf_parse);
-       tx_msg = READ_ONCE(stab->bpf_tx_msg);
+       verdict = READ_ONCE(progs->bpf_verdict);
+       parse = READ_ONCE(progs->bpf_parse);
+       tx_msg = READ_ONCE(progs->bpf_tx_msg);
 
        if (parse && verdict) {
                /* bpf prog refcnt may be zero if a concurrent attach operation
                 * we increment the refcnt. If this is the case abort with an
                 * error.
                 */
-               verdict = bpf_prog_inc_not_zero(stab->bpf_verdict);
+               verdict = bpf_prog_inc_not_zero(progs->bpf_verdict);
                if (IS_ERR(verdict))
                        return PTR_ERR(verdict);
 
-               parse = bpf_prog_inc_not_zero(stab->bpf_parse);
+               parse = bpf_prog_inc_not_zero(progs->bpf_parse);
                if (IS_ERR(parse)) {
                        bpf_prog_put(verdict);
                        return PTR_ERR(parse);
        }
 
        if (tx_msg) {
-               tx_msg = bpf_prog_inc_not_zero(stab->bpf_tx_msg);
+               tx_msg = bpf_prog_inc_not_zero(progs->bpf_tx_msg);
                if (IS_ERR(tx_msg)) {
                        if (verdict)
                                bpf_prog_put(verdict);
                        goto out_progs;
                }
        } else {
-               psock = smap_init_psock(sock, stab);
+               psock = smap_init_psock(sock, map->numa_node);
                if (IS_ERR(psock)) {
                        err = PTR_ERR(psock);
                        goto out_progs;
                err = -ENOMEM;
                goto out_progs;
        }
-       e->entry = &stab->sock_map[i];
 
        /* 3. At this point we have a reference to a valid psock that is
         * running. Attach any BPF programs needed.
                err = smap_init_sock(psock, sock);
                if (err)
                        goto out_free;
-               smap_init_progs(psock, stab, verdict, parse);
+               smap_init_progs(psock, verdict, parse);
                smap_start_sock(psock, sock);
        }
 
         * it with. Because we can only have a single set of programs if
         * old_sock has a strp we can stop it.
         */
-       list_add_tail(&e->list, &psock->maps);
-       write_unlock_bh(&sock->sk_callback_lock);
-
-       osock = xchg(&stab->sock_map[i], sock);
-       if (osock) {
-               struct smap_psock *opsock = smap_psock_sk(osock);
-
-               write_lock_bh(&osock->sk_callback_lock);
-               smap_list_remove(opsock, &stab->sock_map[i]);
-               smap_release_sock(opsock, osock);
-               write_unlock_bh(&osock->sk_callback_lock);
+       if (map_link) {
+               e->entry = map_link;
+               list_add_tail(&e->list, &psock->maps);
        }
-       return 0;
+       write_unlock_bh(&sock->sk_callback_lock);
+       return err;
 out_free:
        smap_release_sock(psock, sock);
 out_progs:
        return err;
 }
 
-int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
+static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
+                                   struct bpf_map *map,
+                                   void *key, u64 flags)
 {
        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+       struct bpf_sock_progs *progs = &stab->progs;
+       struct sock *osock, *sock;
+       u32 i = *(u32 *)key;
+       int err;
+
+       if (unlikely(flags > BPF_EXIST))
+               return -EINVAL;
+
+       if (unlikely(i >= stab->map.max_entries))
+               return -E2BIG;
+
+       sock = READ_ONCE(stab->sock_map[i]);
+       if (flags == BPF_EXIST && !sock)
+               return -ENOENT;
+       else if (flags == BPF_NOEXIST && sock)
+               return -EEXIST;
+
+       sock = skops->sk;
+       err = __sock_map_ctx_update_elem(map, progs, sock, &stab->sock_map[i],
+                                        key);
+       if (err)
+               goto out;
+
+       osock = xchg(&stab->sock_map[i], sock);
+       if (osock) {
+               struct smap_psock *opsock = smap_psock_sk(osock);
+
+               write_lock_bh(&osock->sk_callback_lock);
+               smap_list_remove(opsock, &stab->sock_map[i]);
+               smap_release_sock(opsock, osock);
+               write_unlock_bh(&osock->sk_callback_lock);
+       }
+out:
+       return 0;
+}
+
+int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
+{
+       struct bpf_sock_progs *progs;
        struct bpf_prog *orig;
 
-       if (unlikely(map->map_type != BPF_MAP_TYPE_SOCKMAP))
+       if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
+               struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+
+               progs = &stab->progs;
+       } else {
                return -EINVAL;
+       }
 
        switch (type) {
        case BPF_SK_MSG_VERDICT:
-               orig = xchg(&stab->bpf_tx_msg, prog);
+               orig = xchg(&progs->bpf_tx_msg, prog);
                break;
        case BPF_SK_SKB_STREAM_PARSER:
-               orig = xchg(&stab->bpf_parse, prog);
+               orig = xchg(&progs->bpf_parse, prog);
                break;
        case BPF_SK_SKB_STREAM_VERDICT:
-               orig = xchg(&stab->bpf_verdict, prog);
+               orig = xchg(&progs->bpf_verdict, prog);
                break;
        default:
                return -EOPNOTSUPP;
 static void sock_map_release(struct bpf_map *map)
 {
        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+       struct bpf_sock_progs *progs;
        struct bpf_prog *orig;
 
-       orig = xchg(&stab->bpf_parse, NULL);
+       progs = &stab->progs;
+       orig = xchg(&progs->bpf_parse, NULL);
        if (orig)
                bpf_prog_put(orig);
-       orig = xchg(&stab->bpf_verdict, NULL);
+       orig = xchg(&progs->bpf_verdict, NULL);
        if (orig)
                bpf_prog_put(orig);
 
-       orig = xchg(&stab->bpf_tx_msg, NULL);
+       orig = xchg(&progs->bpf_tx_msg, NULL);
        if (orig)
                bpf_prog_put(orig);
 }