#if defined(CONFIG_BPF_STREAM_PARSER)
 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, u32 which);
 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog);
+void sock_map_unhash(struct sock *sk);
+void sock_map_close(struct sock *sk, long timeout);
 #else
 static inline int sock_map_prog_update(struct bpf_map *map,
                                       struct bpf_prog *prog, u32 which)
 {
        return -EINVAL;
 }
-#endif
+#endif /* CONFIG_BPF_STREAM_PARSER */
 
 #if defined(CONFIG_INET) && defined(CONFIG_BPF_SYSCALL)
 void bpf_sk_reuseport_detach(struct sock *sk);
 
 }
 
 struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
-#if defined(CONFIG_BPF_STREAM_PARSER)
-void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link);
-#else
-static inline void sk_psock_unlink(struct sock *sk,
-                                  struct sk_psock_link *link)
-{
-}
-#endif
 
 void __sk_psock_purge_ingress_msg(struct sk_psock *psock);
 
        return test_bit(bit, &psock->state);
 }
 
-static inline struct sk_psock *sk_psock_get_checked(struct sock *sk)
-{
-       struct sk_psock *psock;
-
-       rcu_read_lock();
-       psock = sk_psock(sk);
-       if (psock) {
-               if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
-                       psock = ERR_PTR(-EBUSY);
-                       goto out;
-               }
-
-               if (!refcount_inc_not_zero(&psock->refcnt))
-                       psock = ERR_PTR(-EBUSY);
-       }
-out:
-       rcu_read_unlock();
-       return psock;
-}
-
 static inline struct sk_psock *sk_psock_get(struct sock *sk)
 {
        struct sk_psock *psock;
 
 struct sk_msg;
 struct sk_psock;
 
+#ifdef CONFIG_BPF_STREAM_PARSER
+struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
+void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
+#else
+static inline void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
+{
+}
+#endif /* CONFIG_BPF_STREAM_PARSER */
+
 #ifdef CONFIG_NET_SOCK_MSG
-int tcp_bpf_init(struct sock *sk);
 int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
                          int flags);
 int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
                    int nonblock, int flags, int *addr_len);
 int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
                      struct msghdr *msg, int len, int flags);
-void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
-#else
-static inline void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
-{
-}
 #endif /* CONFIG_NET_SOCK_MSG */
 
 /* Call BPF_SOCK_OPS program that returns an int. If the return value
 
        }
 }
 
+static int sock_map_init_proto(struct sock *sk)
+{
+       struct sk_psock *psock;
+       struct proto *prot;
+
+       sock_owned_by_me(sk);
+
+       rcu_read_lock();
+       psock = sk_psock(sk);
+       if (unlikely(!psock)) {
+               rcu_read_unlock();
+               return -EINVAL;
+       }
+
+       prot = tcp_bpf_get_proto(sk, psock);
+       if (IS_ERR(prot)) {
+               rcu_read_unlock();
+               return PTR_ERR(prot);
+       }
+
+       sk_psock_update_proto(sk, psock, prot);
+       rcu_read_unlock();
+       return 0;
+}
+
+static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
+{
+       struct sk_psock *psock;
+
+       rcu_read_lock();
+       psock = sk_psock(sk);
+       if (psock) {
+               if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
+                       psock = ERR_PTR(-EBUSY);
+                       goto out;
+               }
+
+               if (!refcount_inc_not_zero(&psock->refcnt))
+                       psock = ERR_PTR(-EBUSY);
+       }
+out:
+       rcu_read_unlock();
+       return psock;
+}
+
 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
                         struct sock *sk)
 {
                }
        }
 
-       psock = sk_psock_get_checked(sk);
+       psock = sock_map_psock_get_checked(sk);
        if (IS_ERR(psock)) {
                ret = PTR_ERR(psock);
                goto out_progs;
        if (msg_parser)
                psock_set_prog(&psock->progs.msg_parser, msg_parser);
 
-       ret = tcp_bpf_init(sk);
+       ret = sock_map_init_proto(sk);
        if (ret < 0)
                goto out_drop;
 
        struct sk_psock *psock;
        int ret;
 
-       psock = sk_psock_get_checked(sk);
+       psock = sock_map_psock_get_checked(sk);
        if (IS_ERR(psock))
                return PTR_ERR(psock);
 
                        return -ENOMEM;
        }
 
-       ret = tcp_bpf_init(sk);
+       ret = sock_map_init_proto(sk);
        if (ret < 0)
                sk_psock_put(sk, psock);
        return ret;
        return 0;
 }
 
-void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
+static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link)
 {
        switch (link->map->map_type) {
        case BPF_MAP_TYPE_SOCKMAP:
                break;
        }
 }
+
+static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock)
+{
+       struct sk_psock_link *link;
+
+       while ((link = sk_psock_link_pop(psock))) {
+               sock_map_unlink(sk, link);
+               sk_psock_free_link(link);
+       }
+}
+
+void sock_map_unhash(struct sock *sk)
+{
+       void (*saved_unhash)(struct sock *sk);
+       struct sk_psock *psock;
+
+       rcu_read_lock();
+       psock = sk_psock(sk);
+       if (unlikely(!psock)) {
+               rcu_read_unlock();
+               if (sk->sk_prot->unhash)
+                       sk->sk_prot->unhash(sk);
+               return;
+       }
+
+       saved_unhash = psock->saved_unhash;
+       sock_map_remove_links(sk, psock);
+       rcu_read_unlock();
+       saved_unhash(sk);
+}
+
+void sock_map_close(struct sock *sk, long timeout)
+{
+       void (*saved_close)(struct sock *sk, long timeout);
+       struct sk_psock *psock;
+
+       lock_sock(sk);
+       rcu_read_lock();
+       psock = sk_psock(sk);
+       if (unlikely(!psock)) {
+               rcu_read_unlock();
+               release_sock(sk);
+               return sk->sk_prot->close(sk, timeout);
+       }
+
+       saved_close = psock->saved_close;
+       sock_map_remove_links(sk, psock);
+       rcu_read_unlock();
+       release_sock(sk);
+       saved_close(sk, timeout);
+}
 
        return copied ? copied : err;
 }
 
-static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock)
-{
-       struct sk_psock_link *link;
-
-       while ((link = sk_psock_link_pop(psock))) {
-               sk_psock_unlink(sk, link);
-               sk_psock_free_link(link);
-       }
-}
-
-static void tcp_bpf_unhash(struct sock *sk)
-{
-       void (*saved_unhash)(struct sock *sk);
-       struct sk_psock *psock;
-
-       rcu_read_lock();
-       psock = sk_psock(sk);
-       if (unlikely(!psock)) {
-               rcu_read_unlock();
-               if (sk->sk_prot->unhash)
-                       sk->sk_prot->unhash(sk);
-               return;
-       }
-
-       saved_unhash = psock->saved_unhash;
-       tcp_bpf_remove(sk, psock);
-       rcu_read_unlock();
-       saved_unhash(sk);
-}
-
-static void tcp_bpf_close(struct sock *sk, long timeout)
-{
-       void (*saved_close)(struct sock *sk, long timeout);
-       struct sk_psock *psock;
-
-       lock_sock(sk);
-       rcu_read_lock();
-       psock = sk_psock(sk);
-       if (unlikely(!psock)) {
-               rcu_read_unlock();
-               release_sock(sk);
-               return sk->sk_prot->close(sk, timeout);
-       }
-
-       saved_close = psock->saved_close;
-       tcp_bpf_remove(sk, psock);
-       rcu_read_unlock();
-       release_sock(sk);
-       saved_close(sk, timeout);
-}
-
+#ifdef CONFIG_BPF_STREAM_PARSER
 enum {
        TCP_BPF_IPV4,
        TCP_BPF_IPV6,
                                   struct proto *base)
 {
        prot[TCP_BPF_BASE]                      = *base;
-       prot[TCP_BPF_BASE].unhash               = tcp_bpf_unhash;
-       prot[TCP_BPF_BASE].close                = tcp_bpf_close;
+       prot[TCP_BPF_BASE].unhash               = sock_map_unhash;
+       prot[TCP_BPF_BASE].close                = sock_map_close;
        prot[TCP_BPF_BASE].recvmsg              = tcp_bpf_recvmsg;
        prot[TCP_BPF_BASE].stream_memory_read   = tcp_bpf_stream_read;
 
               ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
 }
 
-static struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
+struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
 {
        int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
        int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
        return &tcp_bpf_prots[family][config];
 }
 
-int tcp_bpf_init(struct sock *sk)
-{
-       struct sk_psock *psock;
-       struct proto *prot;
-
-       sock_owned_by_me(sk);
-
-       rcu_read_lock();
-       psock = sk_psock(sk);
-       if (unlikely(!psock)) {
-               rcu_read_unlock();
-               return -EINVAL;
-       }
-
-       prot = tcp_bpf_get_proto(sk, psock);
-       if (IS_ERR(prot)) {
-               rcu_read_unlock();
-               return PTR_ERR(prot);
-       }
-
-       sk_psock_update_proto(sk, psock, prot);
-       rcu_read_unlock();
-       return 0;
-}
-
 /* If a child got cloned from a listening socket that had tcp_bpf
  * protocol callbacks installed, we need to restore the callbacks to
  * the default ones because the child does not inherit the psock state
        if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE])
                newsk->sk_prot = sk->sk_prot_creator;
 }
+#endif /* CONFIG_BPF_STREAM_PARSER */