if (err)
                return err;
 
-       msk->first = ssock->sk;
+       WRITE_ONCE(msk->first, ssock->sk);
        WRITE_ONCE(msk->subflow, ssock);
        subflow = mptcp_subflow_ctx(ssock->sk);
        list_add(&subflow->node, &msk->conn_list);
        sock_put(ssk);
 
        if (ssk == msk->first)
-               msk->first = NULL;
+               WRITE_ONCE(msk->first, NULL);
 
 out:
        if (ssk == msk->last_snd)
        WRITE_ONCE(msk->rmem_released, 0);
        msk->timer_ival = TCP_RTO_MIN;
 
-       msk->first = NULL;
+       WRITE_ONCE(msk->first, NULL);
        inet_csk(sk)->icsk_sync_mss = mptcp_sync_mss;
        WRITE_ONCE(msk->csum_enabled, mptcp_is_checksum_enabled(sock_net(sk)));
        WRITE_ONCE(msk->allow_infinite_fallback, true);