#include <linux/module.h>
 #include <linux/virtio.h>
 #include <linux/virtio_net.h>
+#include <linux/bpf.h>
 #include <linux/scatterlist.h>
 #include <linux/if_vlan.h>
 #include <linux/slab.h>
 
        struct napi_struct napi;
 
+       struct bpf_prog __rcu *xdp_prog;
+
        /* Chain pages by the private ptr. */
        struct page *pages;
 
        return skb;
 }
 
+static u32 do_xdp_prog(struct virtnet_info *vi,
+                      struct bpf_prog *xdp_prog,
+                      struct page *page, int offset, int len)
+{
+       int hdr_padded_len;
+       struct xdp_buff xdp;
+       u32 act;
+       u8 *buf;
+
+       buf = page_address(page) + offset;
+
+       if (vi->mergeable_rx_bufs)
+               hdr_padded_len = sizeof(struct virtio_net_hdr_mrg_rxbuf);
+       else
+               hdr_padded_len = sizeof(struct padded_vnet_hdr);
+
+       xdp.data = buf + hdr_padded_len;
+       xdp.data_end = xdp.data + (len - vi->hdr_len);
+
+       act = bpf_prog_run_xdp(xdp_prog, &xdp);
+       switch (act) {
+       case XDP_PASS:
+               return XDP_PASS;
+       default:
+               bpf_warn_invalid_xdp_action(act);
+       case XDP_TX:
+       case XDP_ABORTED:
+       case XDP_DROP:
+               return XDP_DROP;
+       }
+}
+
 static struct sk_buff *receive_small(struct virtnet_info *vi, void *buf, unsigned int len)
 {
        struct sk_buff * skb = buf;
                                   void *buf,
                                   unsigned int len)
 {
+       struct bpf_prog *xdp_prog;
        struct page *page = buf;
-       struct sk_buff *skb = page_to_skb(vi, rq, page, 0, len, PAGE_SIZE);
+       struct sk_buff *skb;
 
+       rcu_read_lock();
+       xdp_prog = rcu_dereference(rq->xdp_prog);
+       if (xdp_prog) {
+               struct virtio_net_hdr_mrg_rxbuf *hdr = buf;
+               u32 act;
+
+               if (unlikely(hdr->hdr.gso_type || hdr->hdr.flags))
+                       goto err_xdp;
+               act = do_xdp_prog(vi, xdp_prog, page, 0, len);
+               if (act == XDP_DROP)
+                       goto err_xdp;
+       }
+       rcu_read_unlock();
+
+       skb = page_to_skb(vi, rq, page, 0, len, PAGE_SIZE);
        if (unlikely(!skb))
                goto err;
 
        return skb;
 
+err_xdp:
+       rcu_read_unlock();
 err:
        dev->stats.rx_dropped++;
        give_pages(rq, page);
        u16 num_buf = virtio16_to_cpu(vi->vdev, hdr->num_buffers);
        struct page *page = virt_to_head_page(buf);
        int offset = buf - page_address(page);
-       unsigned int truesize = max(len, mergeable_ctx_to_buf_truesize(ctx));
+       struct sk_buff *head_skb, *curr_skb;
+       struct bpf_prog *xdp_prog;
+       unsigned int truesize;
+
+       rcu_read_lock();
+       xdp_prog = rcu_dereference(rq->xdp_prog);
+       if (xdp_prog) {
+               u32 act;
+
+               /* No known backend devices should send packets with
+                * more than a single buffer when XDP conditions are
+                * met. However it is not strictly illegal so the case
+                * is handled as an exception and a warning is thrown.
+                */
+               if (unlikely(num_buf > 1)) {
+                       bpf_warn_invalid_xdp_buffer();
+                       goto err_xdp;
+               }
+
+               /* Transient failure which in theory could occur if
+                * in-flight packets from before XDP was enabled reach
+                * the receive path after XDP is loaded. In practice I
+                * was not able to create this condition.
+                */
+               if (unlikely(hdr->hdr.gso_type || hdr->hdr.flags))
+                       goto err_xdp;
+
+               act = do_xdp_prog(vi, xdp_prog, page, offset, len);
+               if (act == XDP_DROP)
+                       goto err_xdp;
+       }
+       rcu_read_unlock();
 
-       struct sk_buff *head_skb = page_to_skb(vi, rq, page, offset, len,
-                                              truesize);
-       struct sk_buff *curr_skb = head_skb;
+       truesize = max(len, mergeable_ctx_to_buf_truesize(ctx));
+       head_skb = page_to_skb(vi, rq, page, offset, len, truesize);
+       curr_skb = head_skb;
 
        if (unlikely(!curr_skb))
                goto err_skb;
        ewma_pkt_len_add(&rq->mrg_avg_pkt_len, head_skb->len);
        return head_skb;
 
+err_xdp:
+       rcu_read_unlock();
 err_skb:
        put_page(page);
        while (--num_buf) {
        if (queue_pairs > vi->max_queue_pairs || queue_pairs == 0)
                return -EINVAL;
 
+       /* For now we don't support modifying channels while XDP is loaded
+        * also when XDP is loaded all RX queues have XDP programs so we only
+        * need to check a single RX queue.
+        */
+       if (vi->rq[0].xdp_prog)
+               return -EINVAL;
+
        get_online_cpus();
        err = virtnet_set_queues(vi, queue_pairs);
        if (!err) {
        .set_settings = virtnet_set_settings,
 };
 
+static int virtnet_xdp_set(struct net_device *dev, struct bpf_prog *prog)
+{
+       unsigned long int max_sz = PAGE_SIZE - sizeof(struct padded_vnet_hdr);
+       struct virtnet_info *vi = netdev_priv(dev);
+       struct bpf_prog *old_prog;
+       int i;
+
+       if (virtio_has_feature(vi->vdev, VIRTIO_NET_F_GUEST_TSO4) ||
+           virtio_has_feature(vi->vdev, VIRTIO_NET_F_GUEST_TSO6)) {
+               netdev_warn(dev, "can't set XDP while host is implementing LRO, disable LRO first\n");
+               return -EOPNOTSUPP;
+       }
+
+       if (vi->mergeable_rx_bufs && !vi->any_header_sg) {
+               netdev_warn(dev, "XDP expects header/data in single page, any_header_sg required\n");
+               return -EINVAL;
+       }
+
+       if (dev->mtu > max_sz) {
+               netdev_warn(dev, "XDP requires MTU less than %lu\n", max_sz);
+               return -EINVAL;
+       }
+
+       if (prog) {
+               prog = bpf_prog_add(prog, vi->max_queue_pairs - 1);
+               if (IS_ERR(prog))
+                       return PTR_ERR(prog);
+       }
+
+       for (i = 0; i < vi->max_queue_pairs; i++) {
+               old_prog = rtnl_dereference(vi->rq[i].xdp_prog);
+               rcu_assign_pointer(vi->rq[i].xdp_prog, prog);
+               if (old_prog)
+                       bpf_prog_put(old_prog);
+       }
+
+       return 0;
+}
+
+static bool virtnet_xdp_query(struct net_device *dev)
+{
+       struct virtnet_info *vi = netdev_priv(dev);
+       int i;
+
+       for (i = 0; i < vi->max_queue_pairs; i++) {
+               if (vi->rq[i].xdp_prog)
+                       return true;
+       }
+       return false;
+}
+
+static int virtnet_xdp(struct net_device *dev, struct netdev_xdp *xdp)
+{
+       switch (xdp->command) {
+       case XDP_SETUP_PROG:
+               return virtnet_xdp_set(dev, xdp->prog);
+       case XDP_QUERY_PROG:
+               xdp->prog_attached = virtnet_xdp_query(dev);
+               return 0;
+       default:
+               return -EINVAL;
+       }
+}
+
 static const struct net_device_ops virtnet_netdev = {
        .ndo_open            = virtnet_open,
        .ndo_stop            = virtnet_close,
 #ifdef CONFIG_NET_RX_BUSY_POLL
        .ndo_busy_poll          = virtnet_busy_poll,
 #endif
+       .ndo_xdp                = virtnet_xdp,
 };
 
 static void virtnet_config_changed_work(struct work_struct *work)
 
 static void free_receive_bufs(struct virtnet_info *vi)
 {
+       struct bpf_prog *old_prog;
        int i;
 
+       rtnl_lock();
        for (i = 0; i < vi->max_queue_pairs; i++) {
                while (vi->rq[i].pages)
                        __free_pages(get_a_page(&vi->rq[i], GFP_KERNEL), 0);
+
+               old_prog = rtnl_dereference(vi->rq[i].xdp_prog);
+               RCU_INIT_POINTER(vi->rq[i].xdp_prog, NULL);
+               if (old_prog)
+                       bpf_prog_put(old_prog);
        }
+       rtnl_unlock();
 }
 
 static void free_receive_page_frags(struct virtnet_info *vi)