}
 }
 
-static struct page *free_pt_page(unsigned long pt, struct page *freelist)
+static struct page *free_pt_page(u64 *pt, struct page *freelist)
 {
-       struct page *p = virt_to_page((void *)pt);
+       struct page *p = virt_to_page(pt);
 
        p->freelist = freelist;
 
        return p;
 }
 
-#define DEFINE_FREE_PT_FN(LVL, FN)                                             \
-static struct page *free_pt_##LVL (unsigned long __pt, struct page *freelist)  \
-{                                                                              \
-       unsigned long p;                                                        \
-       u64 *pt;                                                                \
-       int i;                                                                  \
-                                                                               \
-       pt = (u64 *)__pt;                                                       \
-                                                                               \
-       for (i = 0; i < 512; ++i) {                                             \
-               /* PTE present? */                                              \
-               if (!IOMMU_PTE_PRESENT(pt[i]))                                  \
-                       continue;                                               \
-                                                                               \
-               /* Large PTE? */                                                \
-               if (PM_PTE_LEVEL(pt[i]) == 0 ||                                 \
-                   PM_PTE_LEVEL(pt[i]) == 7)                                   \
-                       continue;                                               \
-                                                                               \
-               p = (unsigned long)IOMMU_PTE_PAGE(pt[i]);                       \
-               freelist = FN(p, freelist);                                     \
-       }                                                                       \
-                                                                               \
-       return free_pt_page((unsigned long)pt, freelist);                       \
-}
+static struct page *free_pt_lvl(u64 *pt, struct page *freelist, int lvl)
+{
+       u64 *p;
+       int i;
+
+       for (i = 0; i < 512; ++i) {
+               /* PTE present? */
+               if (!IOMMU_PTE_PRESENT(pt[i]))
+                       continue;
 
-DEFINE_FREE_PT_FN(l2, free_pt_page)
-DEFINE_FREE_PT_FN(l3, free_pt_l2)
-DEFINE_FREE_PT_FN(l4, free_pt_l3)
-DEFINE_FREE_PT_FN(l5, free_pt_l4)
-DEFINE_FREE_PT_FN(l6, free_pt_l5)
+               /* Large PTE? */
+               if (PM_PTE_LEVEL(pt[i]) == 0 ||
+                   PM_PTE_LEVEL(pt[i]) == 7)
+                       continue;
 
-static struct page *free_sub_pt(unsigned long root, int mode,
-                               struct page *freelist)
+               /*
+                * Free the next level. No need to look at l1 tables here since
+                * they can only contain leaf PTEs; just free them directly.
+                */
+               p = IOMMU_PTE_PAGE(pt[i]);
+               if (lvl > 2)
+                       freelist = free_pt_lvl(p, freelist, lvl - 1);
+               else
+                       freelist = free_pt_page(p, freelist);
+       }
+
+       return free_pt_page(pt, freelist);
+}
+
+static struct page *free_sub_pt(u64 *root, int mode, struct page *freelist)
 {
        switch (mode) {
        case PAGE_MODE_NONE:
                freelist = free_pt_page(root, freelist);
                break;
        case PAGE_MODE_2_LEVEL:
-               freelist = free_pt_l2(root, freelist);
-               break;
        case PAGE_MODE_3_LEVEL:
-               freelist = free_pt_l3(root, freelist);
-               break;
        case PAGE_MODE_4_LEVEL:
-               freelist = free_pt_l4(root, freelist);
-               break;
        case PAGE_MODE_5_LEVEL:
-               freelist = free_pt_l5(root, freelist);
-               break;
        case PAGE_MODE_6_LEVEL:
-               freelist = free_pt_l6(root, freelist);
+               free_pt_lvl(root, freelist, mode);
                break;
        default:
                BUG();
 
 static struct page *free_clear_pte(u64 *pte, u64 pteval, struct page *freelist)
 {
-       unsigned long pt;
+       u64 *pt;
        int mode;
 
        while (cmpxchg64(pte, pteval, 0) != pteval) {
        if (!IOMMU_PTE_PRESENT(pteval))
                return freelist;
 
-       pt   = (unsigned long)IOMMU_PTE_PAGE(pteval);
+       pt   = IOMMU_PTE_PAGE(pteval);
        mode = IOMMU_PTE_MODE(pteval);
 
        return free_sub_pt(pt, mode, freelist);
        struct amd_io_pgtable *pgtable = container_of(iop, struct amd_io_pgtable, iop);
        struct protection_domain *dom;
        struct page *freelist = NULL;
-       unsigned long root;
 
        if (pgtable->mode == PAGE_MODE_NONE)
                return;
        BUG_ON(pgtable->mode < PAGE_MODE_NONE ||
               pgtable->mode > PAGE_MODE_6_LEVEL);
 
-       root = (unsigned long)pgtable->root;
-       freelist = free_sub_pt(root, pgtable->mode, freelist);
+       freelist = free_sub_pt(pgtable->root, pgtable->mode, freelist);
 
        free_page_list(freelist);
 }