#include <linux/if_xdp.h>
 #include <linux/types.h>
 #include <linux/dma-mapping.h>
+#include <linux/bpf.h>
 #include <net/xdp.h>
 
 struct xsk_buff_pool;
        u32 free_heads_cnt;
        u32 headroom;
        u32 chunk_size;
+       u32 chunk_shift;
        u32 frame_len;
        u8 cached_need_wakeup;
        bool uses_need_wakeup;
        struct xdp_buff_xsk *free_heads[];
 };
 
+/* Masks for xdp_umem_page flags.
+ * The low 12-bits of the addr will be 0 since this is the page address, so we
+ * can use them for flags.
+ */
+#define XSK_NEXT_PG_CONTIG_SHIFT 0
+#define XSK_NEXT_PG_CONTIG_MASK BIT_ULL(XSK_NEXT_PG_CONTIG_SHIFT)
+
 /* AF_XDP core. */
 struct xsk_buff_pool *xp_create_and_assign_umem(struct xdp_sock *xs,
                                                struct xdp_umem *umem);
 int xp_assign_dev_shared(struct xsk_buff_pool *pool, struct xdp_umem *umem,
                         struct net_device *dev, u16 queue_id);
 void xp_destroy(struct xsk_buff_pool *pool);
-void xp_release(struct xdp_buff_xsk *xskb);
 void xp_get_pool(struct xsk_buff_pool *pool);
 bool xp_put_pool(struct xsk_buff_pool *pool);
 void xp_clear_dev(struct xsk_buff_pool *pool);
 /* AF_XDP, and XDP core. */
 void xp_free(struct xdp_buff_xsk *xskb);
 
+static inline void xp_init_xskb_addr(struct xdp_buff_xsk *xskb, struct xsk_buff_pool *pool,
+                                    u64 addr)
+{
+       xskb->orig_addr = addr;
+       xskb->xdp.data_hard_start = pool->addrs + addr + pool->headroom;
+}
+
+static inline void xp_init_xskb_dma(struct xdp_buff_xsk *xskb, struct xsk_buff_pool *pool,
+                                   dma_addr_t *dma_pages, u64 addr)
+{
+       xskb->frame_dma = (dma_pages[addr >> PAGE_SHIFT] & ~XSK_NEXT_PG_CONTIG_MASK) +
+               (addr & ~PAGE_MASK);
+       xskb->dma = xskb->frame_dma + pool->headroom + XDP_PACKET_HEADROOM;
+}
+
 /* AF_XDP ZC drivers, via xdp_sock_buff.h */
 void xp_set_rxq_info(struct xsk_buff_pool *pool, struct xdp_rxq_info *rxq);
 int xp_dma_map(struct xsk_buff_pool *pool, struct device *dev,
                xp_unaligned_extract_offset(addr);
 }
 
+static inline u32 xp_aligned_extract_idx(struct xsk_buff_pool *pool, u64 addr)
+{
+       return xp_aligned_extract_addr(pool, addr) >> pool->chunk_shift;
+}
+
+static inline void xp_release(struct xdp_buff_xsk *xskb)
+{
+       if (xskb->pool->unaligned)
+               xskb->pool->free_heads[xskb->pool->free_heads_cnt++] = xskb;
+}
+
+static inline u64 xp_get_handle(struct xdp_buff_xsk *xskb)
+{
+       u64 offset = xskb->xdp.data - xskb->xdp.data_hard_start;
+
+       offset += xskb->pool->headroom;
+       if (!xskb->pool->unaligned)
+               return xskb->orig_addr + offset;
+       return xskb->orig_addr + (offset << XSK_UNALIGNED_BUF_OFFSET_SHIFT);
+}
+
 #endif /* XSK_BUFF_POOL_H_ */
 
        return 0;
 }
 
-void xp_release(struct xdp_buff_xsk *xskb)
-{
-       xskb->pool->free_heads[xskb->pool->free_heads_cnt++] = xskb;
-}
-
-static u64 xp_get_handle(struct xdp_buff_xsk *xskb)
-{
-       u64 offset = xskb->xdp.data - xskb->xdp.data_hard_start;
-
-       offset += xskb->pool->headroom;
-       if (!xskb->pool->unaligned)
-               return xskb->orig_addr + offset;
-       return xskb->orig_addr + (offset << XSK_UNALIGNED_BUF_OFFSET_SHIFT);
-}
-
 static int __xsk_rcv_zc(struct xdp_sock *xs, struct xdp_buff *xdp, u32 len)
 {
        struct xdp_buff_xsk *xskb = container_of(xdp, struct xdp_buff_xsk, xdp);
 
 struct xsk_buff_pool *xp_create_and_assign_umem(struct xdp_sock *xs,
                                                struct xdp_umem *umem)
 {
+       bool unaligned = umem->flags & XDP_UMEM_UNALIGNED_CHUNK_FLAG;
        struct xsk_buff_pool *pool;
        struct xdp_buff_xsk *xskb;
-       u32 i;
+       u32 i, entries;
 
-       pool = kvzalloc(struct_size(pool, free_heads, umem->chunks),
-                       GFP_KERNEL);
+       entries = unaligned ? umem->chunks : 0;
+       pool = kvzalloc(struct_size(pool, free_heads, entries), GFP_KERNEL);
        if (!pool)
                goto out;
 
        pool->free_heads_cnt = umem->chunks;
        pool->headroom = umem->headroom;
        pool->chunk_size = umem->chunk_size;
-       pool->unaligned = umem->flags & XDP_UMEM_UNALIGNED_CHUNK_FLAG;
+       pool->chunk_shift = ffs(umem->chunk_size) - 1;
+       pool->unaligned = unaligned;
        pool->frame_len = umem->chunk_size - umem->headroom -
                XDP_PACKET_HEADROOM;
        pool->umem = umem;
                xskb = &pool->heads[i];
                xskb->pool = pool;
                xskb->xdp.frame_sz = umem->chunk_size - umem->headroom;
-               pool->free_heads[i] = xskb;
+               if (pool->unaligned)
+                       pool->free_heads[i] = xskb;
+               else
+                       xp_init_xskb_addr(xskb, pool, i * pool->chunk_size);
        }
 
        return pool;
 
        if (pool->unaligned)
                xp_check_dma_contiguity(dma_map);
+       else
+               for (i = 0; i < pool->heads_cnt; i++) {
+                       struct xdp_buff_xsk *xskb = &pool->heads[i];
+
+                       xp_init_xskb_dma(xskb, pool, dma_map->dma_pages, xskb->orig_addr);
+               }
 
        err = xp_init_dma_info(pool, dma_map);
        if (err) {
        if (pool->free_heads_cnt == 0)
                return NULL;
 
-       xskb = pool->free_heads[--pool->free_heads_cnt];
-
        for (;;) {
                if (!xskq_cons_peek_addr_unchecked(pool->fq, &addr)) {
                        pool->fq->queue_empty_descs++;
                }
                break;
        }
-       xskq_cons_release(pool->fq);
 
-       xskb->orig_addr = addr;
-       xskb->xdp.data_hard_start = pool->addrs + addr + pool->headroom;
-       if (pool->dma_pages_cnt) {
-               xskb->frame_dma = (pool->dma_pages[addr >> PAGE_SHIFT] &
-                                  ~XSK_NEXT_PG_CONTIG_MASK) +
-                                 (addr & ~PAGE_MASK);
-               xskb->dma = xskb->frame_dma + pool->headroom +
-                           XDP_PACKET_HEADROOM;
+       if (pool->unaligned) {
+               xskb = pool->free_heads[--pool->free_heads_cnt];
+               xp_init_xskb_addr(xskb, pool, addr);
+               if (pool->dma_pages_cnt)
+                       xp_init_xskb_dma(xskb, pool, pool->dma_pages, addr);
+       } else {
+               xskb = &pool->heads[xp_aligned_extract_idx(pool, addr)];
        }
+
+       xskq_cons_release(pool->fq);
        return xskb;
 }
 
                        continue;
                }
 
-               xskb = pool->free_heads[--pool->free_heads_cnt];
+               if (pool->unaligned) {
+                       xskb = pool->free_heads[--pool->free_heads_cnt];
+                       xp_init_xskb_addr(xskb, pool, addr);
+                       if (pool->dma_pages_cnt)
+                               xp_init_xskb_dma(xskb, pool, pool->dma_pages, addr);
+               } else {
+                       xskb = &pool->heads[xp_aligned_extract_idx(pool, addr)];
+               }
+
                *xdp = &xskb->xdp;
-               xskb->orig_addr = addr;
-               xskb->xdp.data_hard_start = pool->addrs + addr + pool->headroom;
-               xskb->frame_dma = (pool->dma_pages[addr >> PAGE_SHIFT] &
-                                  ~XSK_NEXT_PG_CONTIG_MASK) + (addr & ~PAGE_MASK);
-               xskb->dma = xskb->frame_dma + pool->headroom + XDP_PACKET_HEADROOM;
                xdp++;
        }