size_t size)
 {
        if (addr) {
-               unsigned long alloc_end = addr + (PAGE_SIZE << order);
-               unsigned long used = addr + PAGE_ALIGN(size);
-
-               split_page(virt_to_page((void *)addr), order);
-               while (used < alloc_end) {
-                       free_page(used);
-                       used += PAGE_SIZE;
-               }
+               unsigned long nr = DIV_ROUND_UP(size, PAGE_SIZE);
+               struct page *page = virt_to_page((void *)addr);
+               struct page *last = page + nr;
+
+               split_page_owner(page, 1 << order);
+               split_page_memcg(page, 1 << order);
+               while (page < --last)
+                       set_page_refcounted(last);
+
+               last = page + (1UL << order);
+               for (page += nr; page < last; page++)
+                       __free_pages_ok(page, 0, FPI_TO_TAIL);
        }
        return (void *)addr;
 }