struct sk_psock *psock,
                                         struct proto *ops)
 {
-       psock->saved_unhash = sk->sk_prot->unhash;
-       psock->saved_close = sk->sk_prot->close;
-       psock->saved_write_space = sk->sk_write_space;
+       /* Initialize saved callbacks and original proto only once, since this
+        * function may be called multiple times for a psock, e.g. when
+        * psock->progs.msg_parser is updated.
+        *
+        * Since we've not installed the new proto, psock is not yet in use and
+        * we can initialize it without synchronization.
+        */
+       if (!psock->sk_proto) {
+               struct proto *orig = READ_ONCE(sk->sk_prot);
+
+               psock->saved_unhash = orig->unhash;
+               psock->saved_close = orig->close;
+               psock->saved_write_space = sk->sk_write_space;
+
+               psock->sk_proto = orig;
+       }
 
-       psock->sk_proto = sk->sk_prot;
        /* Pairs with lockless read in sk_clone_lock() */
        WRITE_ONCE(sk->sk_prot, ops);
 }
 
        sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
 }
 
-static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock)
-{
-       int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
-       int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
-
-       /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
-        * or added requiring sk_prot hook updates. We keep original saved
-        * hooks in this case.
-        *
-        * Pairs with lockless read in sk_clone_lock().
-        */
-       WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
-}
-
 static int tcp_bpf_assert_proto_ops(struct proto *ops)
 {
        /* In order to avoid retpoline, we make assumptions when we call
 
        rcu_read_lock();
        psock = sk_psock(sk);
-       tcp_bpf_reinit_sk_prot(sk, psock);
+       tcp_bpf_update_sk_prot(sk, psock);
        rcu_read_unlock();
 }