u64 bytes;
        u64 napi_packets;
        u64 napi_bytes;
+       u64 xsk;
 };
 
 struct virtnet_sq_stats {
                                               struct sk_buff *curr_skb,
                                               struct page *page, void *buf,
                                               int len, int truesize);
+static void virtnet_xsk_completed(struct send_queue *sq, int num);
 
 enum virtnet_xmit_type {
        VIRTNET_XMIT_TYPE_SKB,
        VIRTNET_XMIT_TYPE_SKB_ORPHAN,
        VIRTNET_XMIT_TYPE_XDP,
+       VIRTNET_XMIT_TYPE_XSK,
 };
 
 static int rss_indirection_table_alloc(struct virtio_net_ctrl_rss *rss, u16 indir_table_size)
 /* We use the last two bits of the pointer to distinguish the xmit type. */
 #define VIRTNET_XMIT_TYPE_MASK (BIT(0) | BIT(1))
 
+#define VIRTIO_XSK_FLAG_OFFSET 2
+
 static enum virtnet_xmit_type virtnet_xmit_ptr_unpack(void **ptr)
 {
        unsigned long p = (unsigned long)*ptr;
                                    GFP_ATOMIC);
 }
 
+static u32 virtnet_ptr_to_xsk_buff_len(void *ptr)
+{
+       return ((unsigned long)ptr) >> VIRTIO_XSK_FLAG_OFFSET;
+}
+
 static void sg_fill_dma(struct scatterlist *sg, dma_addr_t addr, u32 len)
 {
        sg_dma_address(sg) = addr;
                        stats->bytes += xdp_get_frame_len(frame);
                        xdp_return_frame(frame);
                        break;
+
+               case VIRTNET_XMIT_TYPE_XSK:
+                       stats->bytes += virtnet_ptr_to_xsk_buff_len(ptr);
+                       stats->xsk++;
+                       break;
                }
        }
        netdev_tx_completed_queue(txq, stats->napi_packets, stats->napi_bytes);
 }
 
+static void virtnet_free_old_xmit(struct send_queue *sq,
+                                 struct netdev_queue *txq,
+                                 bool in_napi,
+                                 struct virtnet_sq_free_stats *stats)
+{
+       __free_old_xmit(sq, txq, in_napi, stats);
+
+       if (stats->xsk)
+               virtnet_xsk_completed(sq, stats->xsk);
+}
+
 /* Converting between virtqueue no. and kernel tx/rx queue no.
  * 0:rx0 1:tx0 2:rx1 3:tx1 ... 2N:rxN 2N+1:txN 2N+2:cvq
  */
 {
        struct virtnet_sq_free_stats stats = {0};
 
-       __free_old_xmit(sq, txq, in_napi, &stats);
+       virtnet_free_old_xmit(sq, txq, in_napi, &stats);
 
        /* Avoid overhead when no packets have been processed
         * happens when called speculatively from start_xmit.
        return err;
 }
 
+static void *virtnet_xsk_to_ptr(u32 len)
+{
+       unsigned long p;
+
+       p = len << VIRTIO_XSK_FLAG_OFFSET;
+
+       return virtnet_xmit_ptr_pack((void *)p, VIRTNET_XMIT_TYPE_XSK);
+}
+
+static int virtnet_xsk_xmit_one(struct send_queue *sq,
+                               struct xsk_buff_pool *pool,
+                               struct xdp_desc *desc)
+{
+       struct virtnet_info *vi;
+       dma_addr_t addr;
+
+       vi = sq->vq->vdev->priv;
+
+       addr = xsk_buff_raw_get_dma(pool, desc->addr);
+       xsk_buff_raw_dma_sync_for_device(pool, addr, desc->len);
+
+       sg_init_table(sq->sg, 2);
+       sg_fill_dma(sq->sg, sq->xsk_hdr_dma_addr, vi->hdr_len);
+       sg_fill_dma(sq->sg + 1, addr, desc->len);
+
+       return virtqueue_add_outbuf_premapped(sq->vq, sq->sg, 2,
+                                             virtnet_xsk_to_ptr(desc->len),
+                                             GFP_ATOMIC);
+}
+
+static int virtnet_xsk_xmit_batch(struct send_queue *sq,
+                                 struct xsk_buff_pool *pool,
+                                 unsigned int budget,
+                                 u64 *kicks)
+{
+       struct xdp_desc *descs = pool->tx_descs;
+       bool kick = false;
+       u32 nb_pkts, i;
+       int err;
+
+       budget = min_t(u32, budget, sq->vq->num_free);
+
+       nb_pkts = xsk_tx_peek_release_desc_batch(pool, budget);
+       if (!nb_pkts)
+               return 0;
+
+       for (i = 0; i < nb_pkts; i++) {
+               err = virtnet_xsk_xmit_one(sq, pool, &descs[i]);
+               if (unlikely(err)) {
+                       xsk_tx_completed(sq->xsk_pool, nb_pkts - i);
+                       break;
+               }
+
+               kick = true;
+       }
+
+       if (kick && virtqueue_kick_prepare(sq->vq) && virtqueue_notify(sq->vq))
+               (*kicks)++;
+
+       return i;
+}
+
+static bool virtnet_xsk_xmit(struct send_queue *sq, struct xsk_buff_pool *pool,
+                            int budget)
+{
+       struct virtnet_info *vi = sq->vq->vdev->priv;
+       struct virtnet_sq_free_stats stats = {};
+       struct net_device *dev = vi->dev;
+       u64 kicks = 0;
+       int sent;
+
+       /* Avoid to wakeup napi meanless, so call __free_old_xmit instead of
+        * free_old_xmit().
+        */
+       __free_old_xmit(sq, netdev_get_tx_queue(dev, sq - vi->sq), true, &stats);
+
+       if (stats.xsk)
+               xsk_tx_completed(sq->xsk_pool, stats.xsk);
+
+       sent = virtnet_xsk_xmit_batch(sq, pool, budget, &kicks);
+
+       if (!is_xdp_raw_buffer_queue(vi, sq - vi->sq))
+               check_sq_full_and_disable(vi, vi->dev, sq);
+
+       u64_stats_update_begin(&sq->stats.syncp);
+       u64_stats_add(&sq->stats.packets, stats.packets);
+       u64_stats_add(&sq->stats.bytes,   stats.bytes);
+       u64_stats_add(&sq->stats.kicks,   kicks);
+       u64_stats_add(&sq->stats.xdp_tx,  sent);
+       u64_stats_update_end(&sq->stats.syncp);
+
+       if (xsk_uses_need_wakeup(pool))
+               xsk_set_tx_need_wakeup(pool);
+
+       return sent;
+}
+
+static void xsk_wakeup(struct send_queue *sq)
+{
+       if (napi_if_scheduled_mark_missed(&sq->napi))
+               return;
+
+       local_bh_disable();
+       virtqueue_napi_schedule(&sq->napi, sq->vq);
+       local_bh_enable();
+}
+
 static int virtnet_xsk_wakeup(struct net_device *dev, u32 qid, u32 flag)
 {
        struct virtnet_info *vi = netdev_priv(dev);
 
        sq = &vi->sq[qid];
 
-       if (napi_if_scheduled_mark_missed(&sq->napi))
-               return 0;
+       xsk_wakeup(sq);
+       return 0;
+}
 
-       local_bh_disable();
-       virtqueue_napi_schedule(&sq->napi, sq->vq);
-       local_bh_enable();
+static void virtnet_xsk_completed(struct send_queue *sq, int num)
+{
+       xsk_tx_completed(sq->xsk_pool, num);
 
-       return 0;
+       /* If this is called by rx poll, start_xmit and xdp xmit we should
+        * wakeup the tx napi to consume the xsk tx queue, because the tx
+        * interrupt may not be triggered.
+        */
+       xsk_wakeup(sq);
 }
 
 static int __virtnet_xdp_xmit_one(struct virtnet_info *vi,
        }
 
        /* Free up any pending old buffers before queueing new ones. */
-       __free_old_xmit(sq, netdev_get_tx_queue(dev, sq - vi->sq),
-                       false, &stats);
+       virtnet_free_old_xmit(sq, netdev_get_tx_queue(dev, sq - vi->sq),
+                             false, &stats);
 
        for (i = 0; i < n; i++) {
                struct xdp_frame *xdpf = frames[i];
        struct virtnet_info *vi = sq->vq->vdev->priv;
        unsigned int index = vq2txq(sq->vq);
        struct netdev_queue *txq;
-       int opaque;
+       int opaque, xsk_done = 0;
        bool done;
 
        if (unlikely(is_xdp_raw_buffer_queue(vi, index))) {
        txq = netdev_get_tx_queue(vi->dev, index);
        __netif_tx_lock(txq, raw_smp_processor_id());
        virtqueue_disable_cb(sq->vq);
-       free_old_xmit(sq, txq, !!budget);
+
+       if (sq->xsk_pool)
+               xsk_done = virtnet_xsk_xmit(sq, sq->xsk_pool, budget);
+       else
+               free_old_xmit(sq, txq, !!budget);
 
        if (sq->vq->num_free >= 2 + MAX_SKB_FRAGS) {
                if (netif_tx_queue_stopped(txq)) {
                netif_tx_wake_queue(txq);
        }
 
+       if (xsk_done >= budget) {
+               __netif_tx_unlock(txq);
+               return budget;
+       }
+
        opaque = virtqueue_enable_cb_prepare(sq->vq);
 
        done = napi_complete_done(napi, 0);
 
 static void virtnet_sq_free_unused_buf(struct virtqueue *vq, void *buf)
 {
+       struct virtnet_info *vi = vq->vdev->priv;
+       struct send_queue *sq;
+       int i = vq2rxq(vq);
+
+       sq = &vi->sq[i];
+
        switch (virtnet_xmit_ptr_unpack(&buf)) {
        case VIRTNET_XMIT_TYPE_SKB:
        case VIRTNET_XMIT_TYPE_SKB_ORPHAN:
        case VIRTNET_XMIT_TYPE_XDP:
                xdp_return_frame(buf);
                break;
+
+       case VIRTNET_XMIT_TYPE_XSK:
+               xsk_tx_completed(sq->xsk_pool, 1);
+               break;
        }
 }