#include "zcrx.h"
 #include "rsrc.h"
 
+#define IO_DMA_ATTR (DMA_ATTR_SKIP_CPU_SYNC | DMA_ATTR_WEAK_ORDERING)
+
 static inline struct io_zcrx_ifq *io_pp_to_ifq(struct page_pool *pp)
 {
        return pp->mp_priv;
 {
        struct io_zcrx_area *area = io_zcrx_iov_to_area(niov);
 
-       return area->pages[net_iov_idx(niov)];
+       return area->mem.pages[net_iov_idx(niov)];
 }
 
-#define IO_DMA_ATTR (DMA_ATTR_SKIP_CPU_SYNC | DMA_ATTR_WEAK_ORDERING)
+static void io_release_area_mem(struct io_zcrx_mem *mem)
+{
+       if (mem->pages) {
+               unpin_user_pages(mem->pages, mem->nr_folios);
+               kvfree(mem->pages);
+       }
+}
+
+static int io_import_area(struct io_zcrx_ifq *ifq,
+                         struct io_zcrx_mem *mem,
+                         struct io_uring_zcrx_area_reg *area_reg)
+{
+       struct page **pages;
+       int nr_pages;
+       int ret;
+
+       ret = io_validate_user_buf_range(area_reg->addr, area_reg->len);
+       if (ret)
+               return ret;
+       if (!area_reg->addr)
+               return -EFAULT;
+       if (area_reg->addr & ~PAGE_MASK || area_reg->len & ~PAGE_MASK)
+               return -EINVAL;
+
+       pages = io_pin_pages((unsigned long)area_reg->addr, area_reg->len,
+                                  &nr_pages);
+       if (IS_ERR(pages))
+               return PTR_ERR(pages);
+
+       mem->pages = pages;
+       mem->nr_folios = nr_pages;
+       mem->size = area_reg->len;
+       return 0;
+}
 
 static void __io_zcrx_unmap_area(struct io_zcrx_ifq *ifq,
                                 struct io_zcrx_area *area, int nr_mapped)
                struct net_iov *niov = &area->nia.niovs[i];
                dma_addr_t dma;
 
-               dma = dma_map_page_attrs(ifq->dev, area->pages[i], 0, PAGE_SIZE,
-                                        DMA_FROM_DEVICE, IO_DMA_ATTR);
+               dma = dma_map_page_attrs(ifq->dev, area->mem.pages[i], 0,
+                                        PAGE_SIZE, DMA_FROM_DEVICE, IO_DMA_ATTR);
                if (dma_mapping_error(ifq->dev, dma))
                        break;
                if (net_mp_niov_set_dma_addr(niov, dma)) {
 static void io_zcrx_free_area(struct io_zcrx_area *area)
 {
        io_zcrx_unmap_area(area->ifq, area);
+       io_release_area_mem(&area->mem);
 
        kvfree(area->freelist);
        kvfree(area->nia.niovs);
        kvfree(area->user_refs);
-       if (area->pages) {
-               unpin_user_pages(area->pages, area->nr_folios);
-               kvfree(area->pages);
-       }
        kfree(area);
 }
 
                               struct io_uring_zcrx_area_reg *area_reg)
 {
        struct io_zcrx_area *area;
-       int i, ret, nr_pages, nr_iovs;
+       unsigned nr_iovs;
+       int i, ret;
 
        if (area_reg->flags || area_reg->rq_area_token)
                return -EINVAL;
        if (area_reg->__resv1 || area_reg->__resv2[0] || area_reg->__resv2[1])
                return -EINVAL;
-       if (area_reg->addr & ~PAGE_MASK || area_reg->len & ~PAGE_MASK)
-               return -EINVAL;
-
-       ret = io_validate_user_buf_range(area_reg->addr, area_reg->len);
-       if (ret)
-               return ret;
-       if (!area_reg->addr)
-               return -EFAULT;
 
        ret = -ENOMEM;
        area = kzalloc(sizeof(*area), GFP_KERNEL);
        if (!area)
                goto err;
 
-       area->pages = io_pin_pages((unsigned long)area_reg->addr, area_reg->len,
-                                  &nr_pages);
-       if (IS_ERR(area->pages)) {
-               ret = PTR_ERR(area->pages);
-               area->pages = NULL;
+       ret = io_import_area(ifq, &area->mem, area_reg);
+       if (ret)
                goto err;
-       }
-       area->nr_folios = nr_iovs = nr_pages;
+
+       nr_iovs = area->mem.size >> PAGE_SHIFT;
        area->nia.num_niovs = nr_iovs;
 
+       ret = -ENOMEM;
        area->nia.niovs = kvmalloc_array(nr_iovs, sizeof(area->nia.niovs[0]),
                                         GFP_KERNEL | __GFP_ZERO);
        if (!area->nia.niovs)