#include <linux/if_vlan.h>
 #include <linux/slab.h>
 #include <linux/cpu.h>
+#include <linux/average.h>
 
 static int napi_weight = NAPI_POLL_WEIGHT;
 module_param(napi_weight, int, 0444);
 
 /* FIXME: MTU in config. */
 #define GOOD_PACKET_LEN (ETH_HLEN + VLAN_HLEN + ETH_DATA_LEN)
-#define MERGE_BUFFER_LEN (ALIGN(GOOD_PACKET_LEN + \
-                                sizeof(struct virtio_net_hdr_mrg_rxbuf), \
-                                L1_CACHE_BYTES))
 #define GOOD_COPY_LEN  128
 
+/* Weight used for the RX packet size EWMA. The average packet size is used to
+ * determine the packet buffer size when refilling RX rings. As the entire RX
+ * ring may be refilled at once, the weight is chosen so that the EWMA will be
+ * insensitive to short-term, transient changes in packet size.
+ */
+#define RECEIVE_AVG_WEIGHT 64
+
+/* Minimum alignment for mergeable packet buffers. */
+#define MERGEABLE_BUFFER_ALIGN max(L1_CACHE_BYTES, 256)
+
 #define VIRTNET_DRIVER_VERSION "1.0.0"
 
 struct virtnet_stats {
        /* Chain pages by the private ptr. */
        struct page *pages;
 
+       /* Average packet length for mergeable receive buffers. */
+       struct ewma mrg_avg_pkt_len;
+
        /* Page frag for packet buffer allocation. */
        struct page_frag alloc_frag;
 
        netif_wake_subqueue(vi->dev, vq2txq(vq));
 }
 
+static unsigned int mergeable_ctx_to_buf_truesize(unsigned long mrg_ctx)
+{
+       unsigned int truesize = mrg_ctx & (MERGEABLE_BUFFER_ALIGN - 1);
+       return (truesize + 1) * MERGEABLE_BUFFER_ALIGN;
+}
+
+static void *mergeable_ctx_to_buf_address(unsigned long mrg_ctx)
+{
+       return (void *)(mrg_ctx & -MERGEABLE_BUFFER_ALIGN);
+
+}
+
+static unsigned long mergeable_buf_to_ctx(void *buf, unsigned int truesize)
+{
+       unsigned int size = truesize / MERGEABLE_BUFFER_ALIGN;
+       return (unsigned long)buf | (size - 1);
+}
+
 /* Called from bottom half context */
 static struct sk_buff *page_to_skb(struct receive_queue *rq,
                                   struct page *page, unsigned int offset,
 
 static struct sk_buff *receive_mergeable(struct net_device *dev,
                                         struct receive_queue *rq,
-                                        void *buf,
+                                        unsigned long ctx,
                                         unsigned int len)
 {
+       void *buf = mergeable_ctx_to_buf_address(ctx);
        struct skb_vnet_hdr *hdr = buf;
        int num_buf = hdr->mhdr.num_buffers;
        struct page *page = virt_to_head_page(buf);
        int offset = buf - page_address(page);
-       unsigned int truesize = max_t(unsigned int, len, MERGE_BUFFER_LEN);
+       unsigned int truesize = max(len, mergeable_ctx_to_buf_truesize(ctx));
+
        struct sk_buff *head_skb = page_to_skb(rq, page, offset, len, truesize);
        struct sk_buff *curr_skb = head_skb;
 
        if (unlikely(!curr_skb))
                goto err_skb;
-
        while (--num_buf) {
                int num_skb_frags;
 
-               buf = virtqueue_get_buf(rq->vq, &len);
-               if (unlikely(!buf)) {
+               ctx = (unsigned long)virtqueue_get_buf(rq->vq, &len);
+               if (unlikely(!ctx)) {
                        pr_debug("%s: rx error: %d buffers out of %d missing\n",
                                 dev->name, num_buf, hdr->mhdr.num_buffers);
                        dev->stats.rx_length_errors++;
                        goto err_buf;
                }
 
+               buf = mergeable_ctx_to_buf_address(ctx);
                page = virt_to_head_page(buf);
 
                num_skb_frags = skb_shinfo(curr_skb)->nr_frags;
                        head_skb->truesize += nskb->truesize;
                        num_skb_frags = 0;
                }
-               truesize = max_t(unsigned int, len, MERGE_BUFFER_LEN);
+               truesize = max(len, mergeable_ctx_to_buf_truesize(ctx));
                if (curr_skb != head_skb) {
                        head_skb->data_len += len;
                        head_skb->len += len;
                }
        }
 
+       ewma_add(&rq->mrg_avg_pkt_len, head_skb->len);
        return head_skb;
 
 err_skb:
        put_page(page);
        while (--num_buf) {
-               buf = virtqueue_get_buf(rq->vq, &len);
-               if (unlikely(!buf)) {
+               ctx = (unsigned long)virtqueue_get_buf(rq->vq, &len);
+               if (unlikely(!ctx)) {
                        pr_debug("%s: rx error: %d buffers missing\n",
                                 dev->name, num_buf);
                        dev->stats.rx_length_errors++;
                        break;
                }
-               page = virt_to_head_page(buf);
+               page = virt_to_head_page(mergeable_ctx_to_buf_address(ctx));
                put_page(page);
        }
 err_buf:
        if (unlikely(len < sizeof(struct virtio_net_hdr) + ETH_HLEN)) {
                pr_debug("%s: short packet %i\n", dev->name, len);
                dev->stats.rx_length_errors++;
-               if (vi->mergeable_rx_bufs)
-                       put_page(virt_to_head_page(buf));
-               else if (vi->big_packets)
+               if (vi->mergeable_rx_bufs) {
+                       unsigned long ctx = (unsigned long)buf;
+                       void *base = mergeable_ctx_to_buf_address(ctx);
+                       put_page(virt_to_head_page(base));
+               } else if (vi->big_packets) {
                        give_pages(rq, buf);
-               else
+               } else {
                        dev_kfree_skb(buf);
+               }
                return;
        }
 
        if (vi->mergeable_rx_bufs)
-               skb = receive_mergeable(dev, rq, buf, len);
+               skb = receive_mergeable(dev, rq, (unsigned long)buf, len);
        else if (vi->big_packets)
                skb = receive_big(dev, rq, buf, len);
        else
 
 static int add_recvbuf_mergeable(struct receive_queue *rq, gfp_t gfp)
 {
+       const size_t hdr_len = sizeof(struct virtio_net_hdr_mrg_rxbuf);
        struct page_frag *alloc_frag = &rq->alloc_frag;
        char *buf;
+       unsigned long ctx;
        int err;
        unsigned int len, hole;
 
-       if (unlikely(!skb_page_frag_refill(MERGE_BUFFER_LEN, alloc_frag, gfp)))
+       len = hdr_len + clamp_t(unsigned int, ewma_read(&rq->mrg_avg_pkt_len),
+                               GOOD_PACKET_LEN, PAGE_SIZE - hdr_len);
+       len = ALIGN(len, MERGEABLE_BUFFER_ALIGN);
+       if (unlikely(!skb_page_frag_refill(len, alloc_frag, gfp)))
                return -ENOMEM;
+
        buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset;
+       ctx = mergeable_buf_to_ctx(buf, len);
        get_page(alloc_frag->page);
-       len = MERGE_BUFFER_LEN;
        alloc_frag->offset += len;
        hole = alloc_frag->size - alloc_frag->offset;
-       if (hole < MERGE_BUFFER_LEN) {
+       if (hole < len) {
+               /* To avoid internal fragmentation, if there is very likely not
+                * enough space for another buffer, add the remaining space to
+                * the current buffer. This extra space is not included in
+                * the truesize stored in ctx.
+                */
                len += hole;
                alloc_frag->offset += hole;
        }
 
        sg_init_one(rq->sg, buf, len);
-       err = virtqueue_add_inbuf(rq->vq, rq->sg, 1, buf, gfp);
+       err = virtqueue_add_inbuf(rq->vq, rq->sg, 1, (void *)ctx, gfp);
        if (err < 0)
                put_page(virt_to_head_page(buf));
 
                struct virtqueue *vq = vi->rq[i].vq;
 
                while ((buf = virtqueue_detach_unused_buf(vq)) != NULL) {
-                       if (vi->mergeable_rx_bufs)
-                               put_page(virt_to_head_page(buf));
-                       else if (vi->big_packets)
+                       if (vi->mergeable_rx_bufs) {
+                               unsigned long ctx = (unsigned long)buf;
+                               void *base = mergeable_ctx_to_buf_address(ctx);
+                               put_page(virt_to_head_page(base));
+                       } else if (vi->big_packets) {
                                give_pages(&vi->rq[i], buf);
-                       else
+                       } else {
                                dev_kfree_skb(buf);
+                       }
                }
        }
 }
                               napi_weight);
 
                sg_init_table(vi->rq[i].sg, ARRAY_SIZE(vi->rq[i].sg));
+               ewma_init(&vi->rq[i].mrg_avg_pkt_len, 1, RECEIVE_AVG_WEIGHT);
                sg_init_table(vi->sq[i].sg, ARRAY_SIZE(vi->sq[i].sg));
        }