__sk_mem_schedule(sk, size, SK_MEM_SEND);
 }
 
-static inline bool sk_rmem_schedule(struct sock *sk, int size)
+static inline bool
+sk_rmem_schedule(struct sock *sk, struct sk_buff *skb, unsigned int size)
 {
        if (!sk_has_account(sk))
                return true;
-       return size <= sk->sk_forward_alloc ||
-               __sk_mem_schedule(sk, size, SK_MEM_RECV);
+       return size<= sk->sk_forward_alloc ||
+               __sk_mem_schedule(sk, size, SK_MEM_RECV) ||
+               skb_pfmemalloc(skb);
 }
 
 static inline void sk_mem_reclaim(struct sock *sk)
 
        sock_reset_flag(sk, SOCK_MEMALLOC);
        sk->sk_allocation &= ~__GFP_MEMALLOC;
        static_key_slow_dec(&memalloc_socks);
+
+       /*
+        * SOCK_MEMALLOC is allowed to ignore rmem limits to ensure forward
+        * progress of swapping. However, if SOCK_MEMALLOC is cleared while
+        * it has rmem allocations there is a risk that the user of the
+        * socket cannot make forward progress due to exceeding the rmem
+        * limits. By rights, sk_clear_memalloc() should only be called
+        * on sockets being torn down but warn and reset the accounting if
+        * that assumption breaks.
+        */
+       if (WARN_ON(sk->sk_forward_alloc))
+               sk_mem_reclaim(sk);
 }
 EXPORT_SYMBOL_GPL(sk_clear_memalloc);
 
        if (err)
                return err;
 
-       if (!sk_rmem_schedule(sk, skb->truesize)) {
+       if (!sk_rmem_schedule(sk, skb, skb->truesize)) {
                atomic_inc(&sk->sk_drops);
                return -ENOBUFS;
        }
 
 static bool tcp_prune_ofo_queue(struct sock *sk);
 static int tcp_prune_queue(struct sock *sk);
 
-static int tcp_try_rmem_schedule(struct sock *sk, unsigned int size)
+static int tcp_try_rmem_schedule(struct sock *sk, struct sk_buff *skb,
+                                unsigned int size)
 {
        if (atomic_read(&sk->sk_rmem_alloc) > sk->sk_rcvbuf ||
-           !sk_rmem_schedule(sk, size)) {
+           !sk_rmem_schedule(sk, skb, size)) {
 
                if (tcp_prune_queue(sk) < 0)
                        return -1;
 
-               if (!sk_rmem_schedule(sk, size)) {
+               if (!sk_rmem_schedule(sk, skb, size)) {
                        if (!tcp_prune_ofo_queue(sk))
                                return -1;
 
-                       if (!sk_rmem_schedule(sk, size))
+                       if (!sk_rmem_schedule(sk, skb, size))
                                return -1;
                }
        }
 
        TCP_ECN_check_ce(tp, skb);
 
-       if (unlikely(tcp_try_rmem_schedule(sk, skb->truesize))) {
+       if (unlikely(tcp_try_rmem_schedule(sk, skb, skb->truesize))) {
                NET_INC_STATS_BH(sock_net(sk), LINUX_MIB_TCPOFODROP);
                __kfree_skb(skb);
                return;
 
 int tcp_send_rcvq(struct sock *sk, struct msghdr *msg, size_t size)
 {
-       struct sk_buff *skb;
+       struct sk_buff *skb = NULL;
        struct tcphdr *th;
        bool fragstolen;
 
-       if (tcp_try_rmem_schedule(sk, size + sizeof(*th)))
-               goto err;
-
        skb = alloc_skb(size + sizeof(*th), sk->sk_allocation);
        if (!skb)
                goto err;
 
+       if (tcp_try_rmem_schedule(sk, skb, size + sizeof(*th)))
+               goto err_free;
+
        th = (struct tcphdr *)skb_put(skb, sizeof(*th));
        skb_reset_transport_header(skb);
        memset(th, 0, sizeof(*th));
                if (eaten <= 0) {
 queue_and_out:
                        if (eaten < 0 &&
-                           tcp_try_rmem_schedule(sk, skb->truesize))
+                           tcp_try_rmem_schedule(sk, skb, skb->truesize))
                                goto drop;
 
                        eaten = tcp_queue_rcv(sk, skb, 0, &fragstolen);