df->data_seq + df->data_len == msk->write_seq;
 }
 
+static int mptcp_wmem_with_overhead(int size)
+{
+       return size + ((sizeof(struct mptcp_data_frag) * size) >> PAGE_SHIFT);
+}
+
+static void __mptcp_wmem_reserve(struct sock *sk, int size)
+{
+       int amount = mptcp_wmem_with_overhead(size);
+       struct mptcp_sock *msk = mptcp_sk(sk);
+
+       WARN_ON_ONCE(msk->wmem_reserved);
+       if (amount <= sk->sk_forward_alloc)
+               goto reserve;
+
+       /* under memory pressure try to reserve at most a single page
+        * otherwise try to reserve the full estimate and fallback
+        * to a single page before entering the error path
+        */
+       if ((tcp_under_memory_pressure(sk) && amount > PAGE_SIZE) ||
+           !sk_wmem_schedule(sk, amount)) {
+               if (amount <= PAGE_SIZE)
+                       goto nomem;
+
+               amount = PAGE_SIZE;
+               if (!sk_wmem_schedule(sk, amount))
+                       goto nomem;
+       }
+
+reserve:
+       msk->wmem_reserved = amount;
+       sk->sk_forward_alloc -= amount;
+       return;
+
+nomem:
+       /* we will wait for memory on next allocation */
+       msk->wmem_reserved = -1;
+}
+
+static void __mptcp_update_wmem(struct sock *sk)
+{
+       struct mptcp_sock *msk = mptcp_sk(sk);
+
+       if (!msk->wmem_reserved)
+               return;
+
+       if (msk->wmem_reserved < 0)
+               msk->wmem_reserved = 0;
+       if (msk->wmem_reserved > 0) {
+               sk->sk_forward_alloc += msk->wmem_reserved;
+               msk->wmem_reserved = 0;
+       }
+}
+
+static bool mptcp_wmem_alloc(struct sock *sk, int size)
+{
+       struct mptcp_sock *msk = mptcp_sk(sk);
+
+       /* check for pre-existing error condition */
+       if (msk->wmem_reserved < 0)
+               return false;
+
+       if (msk->wmem_reserved >= size)
+               goto account;
+
+       if (!sk_wmem_schedule(sk, size))
+               return false;
+
+       sk->sk_forward_alloc -= size;
+       msk->wmem_reserved += size;
+
+account:
+       msk->wmem_reserved -= size;
+       return true;
+}
+
 static void dfrag_uncharge(struct sock *sk, int len)
 {
        sk_mem_uncharge(sk, len);
        }
 
 out:
-       if (cleaned)
+       if (cleaned && tcp_under_memory_pressure(sk))
                sk_mem_reclaim_partial(sk);
 }
 
        if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
                return -EOPNOTSUPP;
 
-       lock_sock(sk);
+       mptcp_lock_sock(sk, __mptcp_wmem_reserve(sk, len));
 
        timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
 
                offset = dfrag->offset + dfrag->data_len;
                psize = pfrag->size - offset;
                psize = min_t(size_t, psize, msg_data_left(msg));
-               if (!sk_wmem_schedule(sk, psize + frag_truesize))
+               if (!mptcp_wmem_alloc(sk, psize + frag_truesize))
                        goto wait_for_memory;
 
                if (copy_page_from_iter(dfrag->page, offset, psize,
                                        &msg->msg_iter) != psize) {
+                       msk->wmem_reserved += psize + frag_truesize;
                        ret = -EFAULT;
                        goto out;
                }
                 * Note: we charge such data both to sk and ssk
                 */
                sk_wmem_queued_add(sk, frag_truesize);
-               sk->sk_forward_alloc -= frag_truesize;
                if (!dfrag_collapsed) {
                        get_page(dfrag->page);
                        list_add_tail(&dfrag->list, &msk->rtx_queue);
        INIT_WORK(&msk->work, mptcp_worker);
        msk->out_of_order_queue = RB_ROOT;
        msk->first_pending = NULL;
+       msk->wmem_reserved = 0;
 
        msk->ack_hint = NULL;
        msk->first = NULL;
 
        sk->sk_prot->destroy(sk);
 
+       WARN_ON_ONCE(msk->wmem_reserved);
        sk_stream_kill_queues(sk);
        xfrm_sk_free_policy(sk);
        sk_refcnt_debug_release(sk);
 
 #define MPTCP_DEFERRED_ALL (TCPF_WRITE_TIMER_DEFERRED)
 
-/* this is very alike tcp_release_cb() but we must handle differently a
- * different set of events
- */
+/* processes deferred events and flush wmem */
 static void mptcp_release_cb(struct sock *sk)
 {
        unsigned long flags, nflags;
 
+       /* clear any wmem reservation and errors */
+       __mptcp_update_wmem(sk);
+
        do {
                flags = sk->sk_tsq_flags;
                if (!(flags & MPTCP_DEFERRED_ALL))