#define VIRTIO_XDP_TX          BIT(0)
 #define VIRTIO_XDP_REDIR       BIT(1)
 
-#define VIRTIO_XDP_FLAG                BIT(0)
-#define VIRTIO_ORPHAN_FLAG     BIT(1)
-
 /* RX packet size EWMA. The average packet size is used to determine the packet
  * buffer size when refilling RX rings. As the entire RX ring may be refilled
  * at once, the weight is chosen so that the EWMA will be insensitive to short-
                                               struct page *page, void *buf,
                                               int len, int truesize);
 
+enum virtnet_xmit_type {
+       VIRTNET_XMIT_TYPE_SKB,
+       VIRTNET_XMIT_TYPE_SKB_ORPHAN,
+       VIRTNET_XMIT_TYPE_XDP,
+};
+
 static int rss_indirection_table_alloc(struct virtio_net_ctrl_rss *rss, u16 indir_table_size)
 {
        if (!indir_table_size) {
        kfree(rss->indirection_table);
 }
 
-static bool is_xdp_frame(void *ptr)
-{
-       return (unsigned long)ptr & VIRTIO_XDP_FLAG;
-}
+/* We use the last two bits of the pointer to distinguish the xmit type. */
+#define VIRTNET_XMIT_TYPE_MASK (BIT(0) | BIT(1))
 
-static void *xdp_to_ptr(struct xdp_frame *ptr)
+static enum virtnet_xmit_type virtnet_xmit_ptr_unpack(void **ptr)
 {
-       return (void *)((unsigned long)ptr | VIRTIO_XDP_FLAG);
-}
+       unsigned long p = (unsigned long)*ptr;
 
-static struct xdp_frame *ptr_to_xdp(void *ptr)
-{
-       return (struct xdp_frame *)((unsigned long)ptr & ~VIRTIO_XDP_FLAG);
-}
+       *ptr = (void *)(p & ~VIRTNET_XMIT_TYPE_MASK);
 
-static bool is_orphan_skb(void *ptr)
-{
-       return (unsigned long)ptr & VIRTIO_ORPHAN_FLAG;
+       return p & VIRTNET_XMIT_TYPE_MASK;
 }
 
-static void *skb_to_ptr(struct sk_buff *skb, bool orphan)
+static void *virtnet_xmit_ptr_pack(void *ptr, enum virtnet_xmit_type type)
 {
-       return (void *)((unsigned long)skb | (orphan ? VIRTIO_ORPHAN_FLAG : 0));
+       return (void *)((unsigned long)ptr | type);
 }
 
-static struct sk_buff *ptr_to_skb(void *ptr)
+static int virtnet_add_outbuf(struct send_queue *sq, int num, void *data,
+                             enum virtnet_xmit_type type)
 {
-       return (struct sk_buff *)((unsigned long)ptr & ~VIRTIO_ORPHAN_FLAG);
+       return virtqueue_add_outbuf(sq->vq, sq->sg, num,
+                                   virtnet_xmit_ptr_pack(data, type),
+                                   GFP_ATOMIC);
 }
 
 static void sg_fill_dma(struct scatterlist *sg, dma_addr_t addr, u32 len)
 static void __free_old_xmit(struct send_queue *sq, struct netdev_queue *txq,
                            bool in_napi, struct virtnet_sq_free_stats *stats)
 {
+       struct xdp_frame *frame;
+       struct sk_buff *skb;
        unsigned int len;
        void *ptr;
 
        while ((ptr = virtqueue_get_buf(sq->vq, &len)) != NULL) {
-               if (!is_xdp_frame(ptr)) {
-                       struct sk_buff *skb = ptr_to_skb(ptr);
+               switch (virtnet_xmit_ptr_unpack(&ptr)) {
+               case VIRTNET_XMIT_TYPE_SKB:
+                       skb = ptr;
 
                        pr_debug("Sent skb %p\n", skb);
+                       stats->napi_packets++;
+                       stats->napi_bytes += skb->len;
+                       napi_consume_skb(skb, in_napi);
+                       break;
 
-                       if (is_orphan_skb(ptr)) {
-                               stats->packets++;
-                               stats->bytes += skb->len;
-                       } else {
-                               stats->napi_packets++;
-                               stats->napi_bytes += skb->len;
-                       }
+               case VIRTNET_XMIT_TYPE_SKB_ORPHAN:
+                       skb = ptr;
+
+                       stats->packets++;
+                       stats->bytes += skb->len;
                        napi_consume_skb(skb, in_napi);
-               } else {
-                       struct xdp_frame *frame = ptr_to_xdp(ptr);
+                       break;
+
+               case VIRTNET_XMIT_TYPE_XDP:
+                       frame = ptr;
 
                        stats->packets++;
                        stats->bytes += xdp_get_frame_len(frame);
                        xdp_return_frame(frame);
+                       break;
                }
        }
        netdev_tx_completed_queue(txq, stats->napi_packets, stats->napi_bytes);
                            skb_frag_size(frag), skb_frag_off(frag));
        }
 
-       err = virtqueue_add_outbuf(sq->vq, sq->sg, nr_frags + 1,
-                                  xdp_to_ptr(xdpf), GFP_ATOMIC);
+       err = virtnet_add_outbuf(sq, nr_frags + 1, xdpf, VIRTNET_XMIT_TYPE_XDP);
        if (unlikely(err))
                return -ENOSPC; /* Caller handle free/refcnt */
 
                        return num_sg;
                num_sg++;
        }
-       return virtqueue_add_outbuf(sq->vq, sq->sg, num_sg,
-                                   skb_to_ptr(skb, orphan), GFP_ATOMIC);
+
+       return virtnet_add_outbuf(sq, num_sg, skb,
+                                 orphan ? VIRTNET_XMIT_TYPE_SKB_ORPHAN : VIRTNET_XMIT_TYPE_SKB);
 }
 
 static netdev_tx_t start_xmit(struct sk_buff *skb, struct net_device *dev)
 
 static void virtnet_sq_free_unused_buf(struct virtqueue *vq, void *buf)
 {
-       if (!is_xdp_frame(buf))
+       switch (virtnet_xmit_ptr_unpack(&buf)) {
+       case VIRTNET_XMIT_TYPE_SKB:
+       case VIRTNET_XMIT_TYPE_SKB_ORPHAN:
                dev_kfree_skb(buf);
-       else
-               xdp_return_frame(ptr_to_xdp(buf));
+               break;
+
+       case VIRTNET_XMIT_TYPE_XDP:
+               xdp_return_frame(buf);
+               break;
+       }
 }
 
 static void free_unused_bufs(struct virtnet_info *vi)