void mptcp_sock_graft(struct sock *sk, struct socket *parent);
 struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk);
 bool __mptcp_close(struct sock *sk, long timeout);
+void mptcp_cancel_work(struct sock *sk);
 
 bool mptcp_addresses_equal(const struct mptcp_addr_info *a,
                           const struct mptcp_addr_info *b, bool use_port);
 
        return !crypto_memneq(hmac, mp_opt->hmac, MPTCPOPT_HMAC_LEN);
 }
 
-static void mptcp_sock_destruct(struct sock *sk)
-{
-       /* if new mptcp socket isn't accepted, it is free'd
-        * from the tcp listener sockets request queue, linked
-        * from req->sk.  The tcp socket is released.
-        * This calls the ULP release function which will
-        * also remove the mptcp socket, via
-        * sock_put(ctx->conn).
-        *
-        * Problem is that the mptcp socket will be in
-        * ESTABLISHED state and will not have the SOCK_DEAD flag.
-        * Both result in warnings from inet_sock_destruct.
-        */
-       if ((1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_CLOSE_WAIT)) {
-               sk->sk_state = TCP_CLOSE;
-               WARN_ON_ONCE(sk->sk_socket);
-               sock_orphan(sk);
-       }
-
-       /* We don't need to clear msk->subflow, as it's still NULL at this point */
-       mptcp_destroy_common(mptcp_sk(sk), 0);
-       inet_sock_destruct(sk);
-}
-
 static void mptcp_force_close(struct sock *sk)
 {
        /* the msk is not yet exposed to user-space */
                        /* new mpc subflow takes ownership of the newly
                         * created mptcp socket
                         */
-                       new_msk->sk_destruct = mptcp_sock_destruct;
                        mptcp_sk(new_msk)->setsockopt_seq = ctx->setsockopt_seq;
                        mptcp_pm_new_connection(mptcp_sk(new_msk), child, 1);
                        mptcp_token_accept(subflow_req, mptcp_sk(new_msk));
 
        for (msk = head; msk; msk = next) {
                struct sock *sk = (struct sock *)msk;
-               bool slow;
+               bool slow, do_cancel_work;
 
+               sock_hold(sk);
                slow = lock_sock_fast_nested(sk);
                next = msk->dl_next;
                msk->first = NULL;
                msk->dl_next = NULL;
+
+               do_cancel_work = __mptcp_close(sk, 0);
                unlock_sock_fast(sk, slow);
+               if (do_cancel_work)
+                       mptcp_cancel_work(sk);
+               sock_put(sk);
        }
 
        /* we are still under the listener msk socket lock */