/* vhost zerocopy support fields below: */
        /* last used idx for outstanding DMA zerocopy buffers */
        int upend_idx;
-       /* first used idx for DMA done zerocopy buffers */
+       /* For TX, first used idx for DMA done zerocopy buffers
+        * For RX, number of batched heads
+        */
        int done_idx;
        /* an array of userspace buffers info */
        struct ubuf_info *ubuf_info;
        return skb_queue_empty(&sk->sk_receive_queue);
 }
 
+static void vhost_rx_signal_used(struct vhost_net_virtqueue *nvq)
+{
+       struct vhost_virtqueue *vq = &nvq->vq;
+       struct vhost_dev *dev = vq->dev;
+
+       if (!nvq->done_idx)
+               return;
+
+       vhost_add_used_and_signal_n(dev, vq, vq->heads, nvq->done_idx);
+       nvq->done_idx = 0;
+}
+
 static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk)
 {
        struct vhost_net_virtqueue *rvq = &net->vqs[VHOST_NET_VQ_RX];
        int len = peek_head_len(rvq, sk);
 
        if (!len && vq->busyloop_timeout) {
+               /* Flush batched heads first */
+               vhost_rx_signal_used(rvq);
                /* Both tx vq and rx socket were polled here */
                mutex_lock_nested(&vq->mutex, 1);
                vhost_disable_notify(&net->dev, vq);
        };
        size_t total_len = 0;
        int err, mergeable;
-       s16 headcount, nheads = 0;
+       s16 headcount;
        size_t vhost_hlen, sock_hlen;
        size_t vhost_len, sock_len;
        struct socket *sock;
        while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk))) {
                sock_len += sock_hlen;
                vhost_len = sock_len + vhost_hlen;
-               headcount = get_rx_bufs(vq, vq->heads + nheads, vhost_len,
-                                       &in, vq_log, &log,
+               headcount = get_rx_bufs(vq, vq->heads + nvq->done_idx,
+                                       vhost_len, &in, vq_log, &log,
                                        likely(mergeable) ? UIO_MAXIOV : 1);
                /* On error, stop handling until the next kick. */
                if (unlikely(headcount < 0))
                        vhost_discard_vq_desc(vq, headcount);
                        goto out;
                }
-               nheads += headcount;
-               if (nheads > VHOST_RX_BATCH) {
-                       vhost_add_used_and_signal_n(&net->dev, vq, vq->heads,
-                                                   nheads);
-                       nheads = 0;
-               }
+               nvq->done_idx += headcount;
+               if (nvq->done_idx > VHOST_RX_BATCH)
+                       vhost_rx_signal_used(nvq);
                if (unlikely(vq_log))
                        vhost_log_write(vq, vq_log, log, vhost_len);
                total_len += vhost_len;
        }
        vhost_net_enable_vq(net, vq);
 out:
-       if (nheads)
-               vhost_add_used_and_signal_n(&net->dev, vq, vq->heads,
-                                           nheads);
+       vhost_rx_signal_used(nvq);
        mutex_unlock(&vq->mutex);
 }