* For RX, number of batched heads
         */
        int done_idx;
+       /* Number of XDP frames batched */
+       int batched_xdp;
        /* an array of userspace buffers info */
        struct ubuf_info *ubuf_info;
        /* Reference counting for outstanding ubufs.
        struct vhost_net_ubuf_ref *ubufs;
        struct ptr_ring *rx_ring;
        struct vhost_net_buf rxq;
+       /* Batched XDP buffs */
+       struct xdp_buff *xdp;
 };
 
 struct vhost_net {
                sock_flag(sock->sk, SOCK_ZEROCOPY);
 }
 
+static bool vhost_sock_xdp(struct socket *sock)
+{
+       return sock_flag(sock->sk, SOCK_XDP);
+}
+
 /* In case of DMA done not in order in lower device driver for some reason.
  * upend_idx is used to track end of used idx, done_idx is used to track head
  * of used idx. Once lower device DMA done contiguously, we will signal KVM
        nvq->done_idx = 0;
 }
 
+static void vhost_tx_batch(struct vhost_net *net,
+                          struct vhost_net_virtqueue *nvq,
+                          struct socket *sock,
+                          struct msghdr *msghdr)
+{
+       struct tun_msg_ctl ctl = {
+               .type = TUN_MSG_PTR,
+               .num = nvq->batched_xdp,
+               .ptr = nvq->xdp,
+       };
+       int err;
+
+       if (nvq->batched_xdp == 0)
+               goto signal_used;
+
+       msghdr->msg_control = &ctl;
+       err = sock->ops->sendmsg(sock, msghdr, 0);
+       if (unlikely(err < 0)) {
+               vq_err(&nvq->vq, "Fail to batch sending packets\n");
+               return;
+       }
+
+signal_used:
+       vhost_net_signal_used(nvq);
+       nvq->batched_xdp = 0;
+}
+
 static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
                                    struct vhost_net_virtqueue *nvq,
                                    unsigned int *out_num, unsigned int *in_num,
-                                   bool *busyloop_intr)
+                                   struct msghdr *msghdr, bool *busyloop_intr)
 {
        struct vhost_virtqueue *vq = &nvq->vq;
        unsigned long uninitialized_var(endtime);
                                  out_num, in_num, NULL, NULL);
 
        if (r == vq->num && vq->busyloop_timeout) {
+               /* Flush batched packets first */
                if (!vhost_sock_zcopy(vq->private_data))
-                       vhost_net_signal_used(nvq);
+                       vhost_tx_batch(net, nvq, vq->private_data, msghdr);
                preempt_disable();
                endtime = busy_clock() + vq->busyloop_timeout;
                while (vhost_can_busy_poll(endtime)) {
        struct vhost_virtqueue *vq = &nvq->vq;
        int ret;
 
-       ret = vhost_net_tx_get_vq_desc(net, nvq, out, in, busyloop_intr);
+       ret = vhost_net_tx_get_vq_desc(net, nvq, out, in, msg, busyloop_intr);
 
        if (ret < 0 || ret == vq->num)
                return ret;
               !vhost_vq_avail_empty(vq->dev, vq);
 }
 
+#define VHOST_NET_RX_PAD (NET_IP_ALIGN + NET_SKB_PAD)
+
+static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq,
+                              struct iov_iter *from)
+{
+       struct vhost_virtqueue *vq = &nvq->vq;
+       struct socket *sock = vq->private_data;
+       struct page_frag *alloc_frag = ¤t->task_frag;
+       struct virtio_net_hdr *gso;
+       struct xdp_buff *xdp = &nvq->xdp[nvq->batched_xdp];
+       struct tun_xdp_hdr *hdr;
+       size_t len = iov_iter_count(from);
+       int headroom = vhost_sock_xdp(sock) ? XDP_PACKET_HEADROOM : 0;
+       int buflen = SKB_DATA_ALIGN(sizeof(struct skb_shared_info));
+       int pad = SKB_DATA_ALIGN(VHOST_NET_RX_PAD + headroom + nvq->sock_hlen);
+       int sock_hlen = nvq->sock_hlen;
+       void *buf;
+       int copied;
+
+       if (unlikely(len < nvq->sock_hlen))
+               return -EFAULT;
+
+       if (SKB_DATA_ALIGN(len + pad) +
+           SKB_DATA_ALIGN(sizeof(struct skb_shared_info)) > PAGE_SIZE)
+               return -ENOSPC;
+
+       buflen += SKB_DATA_ALIGN(len + pad);
+       alloc_frag->offset = ALIGN((u64)alloc_frag->offset, SMP_CACHE_BYTES);
+       if (unlikely(!skb_page_frag_refill(buflen, alloc_frag, GFP_KERNEL)))
+               return -ENOMEM;
+
+       buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset;
+       copied = copy_page_from_iter(alloc_frag->page,
+                                    alloc_frag->offset +
+                                    offsetof(struct tun_xdp_hdr, gso),
+                                    sock_hlen, from);
+       if (copied != sock_hlen)
+               return -EFAULT;
+
+       hdr = buf;
+       gso = &hdr->gso;
+
+       if ((gso->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) &&
+           vhost16_to_cpu(vq, gso->csum_start) +
+           vhost16_to_cpu(vq, gso->csum_offset) + 2 >
+           vhost16_to_cpu(vq, gso->hdr_len)) {
+               gso->hdr_len = cpu_to_vhost16(vq,
+                              vhost16_to_cpu(vq, gso->csum_start) +
+                              vhost16_to_cpu(vq, gso->csum_offset) + 2);
+
+               if (vhost16_to_cpu(vq, gso->hdr_len) > len)
+                       return -EINVAL;
+       }
+
+       len -= sock_hlen;
+       copied = copy_page_from_iter(alloc_frag->page,
+                                    alloc_frag->offset + pad,
+                                    len, from);
+       if (copied != len)
+               return -EFAULT;
+
+       xdp->data_hard_start = buf;
+       xdp->data = buf + pad;
+       xdp->data_end = xdp->data + len;
+       hdr->buflen = buflen;
+
+       get_page(alloc_frag->page);
+       alloc_frag->offset += buflen;
+
+       ++nvq->batched_xdp;
+
+       return 0;
+}
+
 static void handle_tx_copy(struct vhost_net *net, struct socket *sock)
 {
        struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
        size_t len, total_len = 0;
        int err;
        int sent_pkts = 0;
+       bool sock_can_batch = (sock->sk->sk_sndbuf == INT_MAX);
 
        for (;;) {
                bool busyloop_intr = false;
 
+               if (nvq->done_idx == VHOST_NET_BATCH)
+                       vhost_tx_batch(net, nvq, sock, &msg);
+
                head = get_tx_bufs(net, nvq, &msg, &out, &in, &len,
                                   &busyloop_intr);
                /* On error, stop handling until the next kick. */
                        break;
                }
 
-               vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head);
-               vq->heads[nvq->done_idx].len = 0;
-
                total_len += len;
-               if (tx_can_batch(vq, total_len))
-                       msg.msg_flags |= MSG_MORE;
-               else
-                       msg.msg_flags &= ~MSG_MORE;
+
+               /* For simplicity, TX batching is only enabled if
+                * sndbuf is unlimited.
+                */
+               if (sock_can_batch) {
+                       err = vhost_net_build_xdp(nvq, &msg.msg_iter);
+                       if (!err) {
+                               goto done;
+                       } else if (unlikely(err != -ENOSPC)) {
+                               vhost_tx_batch(net, nvq, sock, &msg);
+                               vhost_discard_vq_desc(vq, 1);
+                               vhost_net_enable_vq(net, vq);
+                               break;
+                       }
+
+                       /* We can't build XDP buff, go for single
+                        * packet path but let's flush batched
+                        * packets.
+                        */
+                       vhost_tx_batch(net, nvq, sock, &msg);
+                       msg.msg_control = NULL;
+               } else {
+                       if (tx_can_batch(vq, total_len))
+                               msg.msg_flags |= MSG_MORE;
+                       else
+                               msg.msg_flags &= ~MSG_MORE;
+               }
 
                /* TODO: Check specific error and bomb out unless ENOBUFS? */
                err = sock->ops->sendmsg(sock, &msg, len);
                if (err != len)
                        pr_debug("Truncated TX packet: len %d != %zd\n",
                                 err, len);
-               if (++nvq->done_idx >= VHOST_NET_BATCH)
-                       vhost_net_signal_used(nvq);
+done:
+               vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head);
+               vq->heads[nvq->done_idx].len = 0;
+               ++nvq->done_idx;
                if (vhost_exceeds_weight(++sent_pkts, total_len)) {
                        vhost_poll_queue(&vq->poll);
                        break;
                }
        }
 
-       vhost_net_signal_used(nvq);
+       vhost_tx_batch(net, nvq, sock, &msg);
 }
 
 static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock)
        struct vhost_dev *dev;
        struct vhost_virtqueue **vqs;
        void **queue;
+       struct xdp_buff *xdp;
        int i;
 
        n = kvmalloc(sizeof *n, GFP_KERNEL | __GFP_RETRY_MAYFAIL);
        }
        n->vqs[VHOST_NET_VQ_RX].rxq.queue = queue;
 
+       xdp = kmalloc_array(VHOST_NET_BATCH, sizeof(*xdp), GFP_KERNEL);
+       if (!xdp) {
+               kfree(vqs);
+               kvfree(n);
+               kfree(queue);
+       }
+       n->vqs[VHOST_NET_VQ_TX].xdp = xdp;
+
        dev = &n->dev;
        vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq;
        vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq;
                n->vqs[i].ubuf_info = NULL;
                n->vqs[i].upend_idx = 0;
                n->vqs[i].done_idx = 0;
+               n->vqs[i].batched_xdp = 0;
                n->vqs[i].vhost_hlen = 0;
                n->vqs[i].sock_hlen = 0;
                n->vqs[i].rx_ring = NULL;
         * since jobs can re-queue themselves. */
        vhost_net_flush(n);
        kfree(n->vqs[VHOST_NET_VQ_RX].rxq.queue);
+       kfree(n->vqs[VHOST_NET_VQ_TX].xdp);
        kfree(n->dev.vqs);
        kvfree(n);
        return 0;