if (!ingress) {
                if (!sock_writeable(psock->sk))
                        return -EAGAIN;
-               return skb_send_sock_locked(psock->sk, skb, off, len);
+               return skb_send_sock(psock->sk, skb, off, len);
        }
        return sk_psock_skb_ingress(psock, skb);
 }
        u32 len, off;
        int ret;
 
-       /* Lock sock to avoid losing sk_socket during loop. */
-       lock_sock(psock->sk);
+       mutex_lock(&psock->work_mutex);
        if (state->skb) {
                skb = state->skb;
                len = state->len;
                skb_bpf_redirect_clear(skb);
                do {
                        ret = -EIO;
-                       if (likely(psock->sk->sk_socket))
+                       if (!sock_flag(psock->sk, SOCK_DEAD))
                                ret = sk_psock_handle_skb(psock, skb, off,
                                                          len, ingress);
                        if (ret <= 0) {
                        kfree_skb(skb);
        }
 end:
-       release_sock(psock->sk);
+       mutex_unlock(&psock->work_mutex);
 }
 
 struct sk_psock *sk_psock_init(struct sock *sk, int node)
        spin_lock_init(&psock->link_lock);
 
        INIT_WORK(&psock->work, sk_psock_backlog);
+       mutex_init(&psock->work_mutex);
        INIT_LIST_HEAD(&psock->ingress_msg);
        spin_lock_init(&psock->ingress_lock);
        skb_queue_head_init(&psock->ingress_skb);
        }
 }
 
-static void sk_psock_zap_ingress(struct sk_psock *psock)
+static void __sk_psock_zap_ingress(struct sk_psock *psock)
 {
        struct sk_buff *skb;
 
                skb_bpf_redirect_clear(skb);
                kfree_skb(skb);
        }
-       spin_lock_bh(&psock->ingress_lock);
        __sk_psock_purge_ingress_msg(psock);
-       spin_unlock_bh(&psock->ingress_lock);
 }
 
 static void sk_psock_link_destroy(struct sk_psock *psock)
        }
 }
 
+void sk_psock_stop(struct sk_psock *psock, bool wait)
+{
+       spin_lock_bh(&psock->ingress_lock);
+       sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
+       sk_psock_cork_free(psock);
+       __sk_psock_zap_ingress(psock);
+       spin_unlock_bh(&psock->ingress_lock);
+
+       if (wait)
+               cancel_work_sync(&psock->work);
+}
+
 static void sk_psock_done_strp(struct sk_psock *psock);
 
 static void sk_psock_destroy_deferred(struct work_struct *gc)
        sk_psock_done_strp(psock);
 
        cancel_work_sync(&psock->work);
+       mutex_destroy(&psock->work_mutex);
 
        psock_progs_drop(&psock->progs);
 
        sk_psock_link_destroy(psock);
        sk_psock_cork_free(psock);
-       sk_psock_zap_ingress(psock);
 
        if (psock->sk_redir)
                sock_put(psock->sk_redir);
 
 void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
 {
-       sk_psock_cork_free(psock);
-       sk_psock_zap_ingress(psock);
+       sk_psock_stop(psock, false);
 
        write_lock_bh(&sk->sk_callback_lock);
        sk_psock_restore_proto(sk, psock);
        else if (psock->progs.stream_verdict)
                sk_psock_stop_verdict(sk, psock);
        write_unlock_bh(&sk->sk_callback_lock);
-       sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
 
        call_rcu(&psock->rcu, sk_psock_destroy);
 }
         * error that caused the pipe to break. We can't send a packet on
         * a socket that is in this state so we drop the skb.
         */
-       if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||
-           !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) {
+       if (!psock_other || sock_flag(sk_other, SOCK_DEAD)) {
+               kfree_skb(skb);
+               return;
+       }
+       spin_lock_bh(&psock_other->ingress_lock);
+       if (!sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) {
+               spin_unlock_bh(&psock_other->ingress_lock);
                kfree_skb(skb);
                return;
        }
 
        skb_queue_tail(&psock_other->ingress_skb, skb);
        schedule_work(&psock_other->work);
+       spin_unlock_bh(&psock_other->ingress_lock);
 }
 
 static void sk_psock_tls_verdict_apply(struct sk_buff *skb, struct sock *sk, int verdict)
                        err = sk_psock_skb_ingress_self(psock, skb);
                }
                if (err < 0) {
-                       skb_queue_tail(&psock->ingress_skb, skb);
-                       schedule_work(&psock->work);
+                       spin_lock_bh(&psock->ingress_lock);
+                       if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) {
+                               skb_queue_tail(&psock->ingress_skb, skb);
+                               schedule_work(&psock->work);
+                       }
+                       spin_unlock_bh(&psock->ingress_lock);
                }
                break;
        case __SK_REDIRECT: