} while (i != md->sg_end);
 }
 
-static void free_bytes_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
+static void free_bytes_sg(struct sock *sk, int bytes,
+                         struct sk_msg_buff *md, bool charge)
 {
        struct scatterlist *sg = md->sg_data;
        int i = md->sg_start, free;
                if (bytes < free) {
                        sg[i].length -= bytes;
                        sg[i].offset += bytes;
-                       sk_mem_uncharge(sk, bytes);
+                       if (charge)
+                               sk_mem_uncharge(sk, bytes);
                        break;
                }
 
-               sk_mem_uncharge(sk, sg[i].length);
+               if (charge)
+                       sk_mem_uncharge(sk, sg[i].length);
                put_page(sg_page(&sg[i]));
                bytes -= sg[i].length;
                sg[i].length = 0;
                if (i == MAX_SKB_FRAGS)
                        i = 0;
        }
+       md->sg_start = i;
 }
 
 static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
                                       struct sk_msg_buff *md,
                                       int flags)
 {
+       bool ingress = !!(md->flags & BPF_F_INGRESS);
        struct smap_psock *psock;
        struct scatterlist *sg;
-       int i, err, free = 0;
-       bool ingress = !!(md->flags & BPF_F_INGRESS);
+       int err = 0;
 
        sg = md->sg_data;
 
 out_rcu:
        rcu_read_unlock();
 out:
-       i = md->sg_start;
-       while (sg[i].length) {
-               free += sg[i].length;
-               put_page(sg_page(&sg[i]));
-               sg[i].length = 0;
-               i++;
-               if (i == MAX_SKB_FRAGS)
-                       i = 0;
-       }
-       return free;
+       free_bytes_sg(NULL, send, md, false);
+       return err;
 }
 
 static inline void bpf_md_init(struct smap_psock *psock)
                break;
        case __SK_DROP:
        default:
-               free_bytes_sg(sk, send, m);
+               free_bytes_sg(sk, send, m, true);
                apply_bytes_dec(psock, send);
                *copied -= send;
                psock->sg_size -= send;