#include <linux/seq_file.h>
 #include <linux/highmem.h>
 #include <linux/pci.h>
+#include <linux/ptdump.h>
 
 #include <asm/e820/types.h>
 #include <asm/pgtable.h>
  * when a "break" in the continuity is found.
  */
 struct pg_state {
+       struct ptdump_state ptdump;
        int level;
-       pgprot_t current_prot;
+       pgprotval_t current_prot;
        pgprotval_t effective_prot;
+       pgprotval_t prot_levels[5];
        unsigned long start_address;
-       unsigned long current_address;
        const struct addr_marker *marker;
        unsigned long lines;
        bool to_dmesg;
 /*
  * Print a readable form of a pgprot_t to the seq_file
  */
-static void printk_prot(struct seq_file *m, pgprot_t prot, int level, bool dmsg)
+static void printk_prot(struct seq_file *m, pgprotval_t pr, int level, bool dmsg)
 {
-       pgprotval_t pr = pgprot_val(prot);
        static const char * const level_name[] =
                { "cr3", "pgd", "p4d", "pud", "pmd", "pte" };
 
        pt_dump_cont_printf(m, dmsg, "%s\n", level_name[level]);
 }
 
-/*
- * On 64 bits, sign-extend the 48 bit address to 64 bit
- */
-static unsigned long normalize_addr(unsigned long u)
-{
-       int shift;
-       if (!IS_ENABLED(CONFIG_X86_64))
-               return u;
-
-       shift = 64 - (__VIRTUAL_MASK_SHIFT + 1);
-       return (signed long)(u << shift) >> shift;
-}
-
-static void note_wx(struct pg_state *st)
+static void note_wx(struct pg_state *st, unsigned long addr)
 {
        unsigned long npages;
 
-       npages = (st->current_address - st->start_address) / PAGE_SIZE;
+       npages = (addr - st->start_address) / PAGE_SIZE;
 
 #ifdef CONFIG_PCI_BIOS
        /*
         * Inform about it, but avoid the warning.
         */
        if (pcibios_enabled && st->start_address >= PAGE_OFFSET + BIOS_BEGIN &&
-           st->current_address <= PAGE_OFFSET + BIOS_END) {
+           addr <= PAGE_OFFSET + BIOS_END) {
                pr_warn_once("x86/mm: PCI BIOS W+X mapping %lu pages\n", npages);
                return;
        }
                  (void *)st->start_address);
 }
 
+static inline pgprotval_t effective_prot(pgprotval_t prot1, pgprotval_t prot2)
+{
+       return (prot1 & prot2 & (_PAGE_USER | _PAGE_RW)) |
+              ((prot1 | prot2) & _PAGE_NX);
+}
+
 /*
  * This function gets called on a break in a continuous series
  * of PTE entries; the next one is different so we need to
  * print what we collected so far.
  */
-static void note_page(struct pg_state *st, pgprot_t new_prot,
-                     pgprotval_t new_eff, int level)
+static void note_page(struct ptdump_state *pt_st, unsigned long addr, int level,
+                     unsigned long val)
 {
-       pgprotval_t prot, cur, eff;
+       struct pg_state *st = container_of(pt_st, struct pg_state, ptdump);
+       pgprotval_t new_prot, new_eff;
+       pgprotval_t cur, eff;
        static const char units[] = "BKMGTPE";
        struct seq_file *m = st->seq;
 
+       new_prot = val & PTE_FLAGS_MASK;
+
+       if (level > 1) {
+               new_eff = effective_prot(st->prot_levels[level - 2],
+                                        new_prot);
+       } else {
+               new_eff = new_prot;
+       }
+
+       if (level > 0)
+               st->prot_levels[level - 1] = new_eff;
+
        /*
         * If we have a "break" in the series, we need to flush the state that
         * we have now. "break" is either changing perms, levels or
         * address space marker.
         */
-       prot = pgprot_val(new_prot);
-       cur = pgprot_val(st->current_prot);
+       cur = st->current_prot;
        eff = st->effective_prot;
 
        if (!st->level) {
                st->lines = 0;
                pt_dump_seq_printf(m, st->to_dmesg, "---[ %s ]---\n",
                                   st->marker->name);
-       } else if (prot != cur || new_eff != eff || level != st->level ||
-                  st->current_address >= st->marker[1].start_address) {
+       } else if (new_prot != cur || new_eff != eff || level != st->level ||
+                  addr >= st->marker[1].start_address) {
                const char *unit = units;
                unsigned long delta;
                int width = sizeof(unsigned long) * 2;
 
                if (st->check_wx && (eff & _PAGE_RW) && !(eff & _PAGE_NX))
-                       note_wx(st);
+                       note_wx(st, addr);
 
                /*
                 * Now print the actual finished series
                        pt_dump_seq_printf(m, st->to_dmesg,
                                           "0x%0*lx-0x%0*lx   ",
                                           width, st->start_address,
-                                          width, st->current_address);
+                                          width, addr);
 
-                       delta = st->current_address - st->start_address;
+                       delta = addr - st->start_address;
                        while (!(delta & 1023) && unit[1]) {
                                delta >>= 10;
                                unit++;
                 * such as the start of vmalloc space etc.
                 * This helps in the interpretation.
                 */
-               if (st->current_address >= st->marker[1].start_address) {
+               if (addr >= st->marker[1].start_address) {
                        if (st->marker->max_lines &&
                            st->lines > st->marker->max_lines) {
                                unsigned long nskip =
                                           st->marker->name);
                }
 
-               st->start_address = st->current_address;
+               st->start_address = addr;
                st->current_prot = new_prot;
                st->effective_prot = new_eff;
                st->level = level;
        }
 }
 
-static inline pgprotval_t effective_prot(pgprotval_t prot1, pgprotval_t prot2)
-{
-       return (prot1 & prot2 & (_PAGE_USER | _PAGE_RW)) |
-              ((prot1 | prot2) & _PAGE_NX);
-}
-
-static void walk_pte_level(struct pg_state *st, pmd_t addr, pgprotval_t eff_in,
-                          unsigned long P)
-{
-       int i;
-       pte_t *pte;
-       pgprotval_t prot, eff;
-
-       for (i = 0; i < PTRS_PER_PTE; i++) {
-               st->current_address = normalize_addr(P + i * PTE_LEVEL_MULT);
-               pte = pte_offset_map(&addr, st->current_address);
-               prot = pte_flags(*pte);
-               eff = effective_prot(eff_in, prot);
-               note_page(st, __pgprot(prot), eff, 5);
-               pte_unmap(pte);
-       }
-}
-#ifdef CONFIG_KASAN
-
-/*
- * This is an optimization for KASAN=y case. Since all kasan page tables
- * eventually point to the kasan_early_shadow_page we could call note_page()
- * right away without walking through lower level page tables. This saves
- * us dozens of seconds (minutes for 5-level config) while checking for
- * W+X mapping or reading kernel_page_tables debugfs file.
- */
-static inline bool kasan_page_table(struct pg_state *st, void *pt)
-{
-       if (__pa(pt) == __pa(kasan_early_shadow_pmd) ||
-           (pgtable_l5_enabled() &&
-                       __pa(pt) == __pa(kasan_early_shadow_p4d)) ||
-           __pa(pt) == __pa(kasan_early_shadow_pud)) {
-               pgprotval_t prot = pte_flags(kasan_early_shadow_pte[0]);
-               note_page(st, __pgprot(prot), 0, 5);
-               return true;
-       }
-       return false;
-}
-#else
-static inline bool kasan_page_table(struct pg_state *st, void *pt)
-{
-       return false;
-}
-#endif
-
-#if PTRS_PER_PMD > 1
-
-static void walk_pmd_level(struct pg_state *st, pud_t addr,
-                          pgprotval_t eff_in, unsigned long P)
-{
-       int i;
-       pmd_t *start, *pmd_start;
-       pgprotval_t prot, eff;
-
-       pmd_start = start = (pmd_t *)pud_page_vaddr(addr);
-       for (i = 0; i < PTRS_PER_PMD; i++) {
-               st->current_address = normalize_addr(P + i * PMD_LEVEL_MULT);
-               if (!pmd_none(*start)) {
-                       prot = pmd_flags(*start);
-                       eff = effective_prot(eff_in, prot);
-                       if (pmd_large(*start) || !pmd_present(*start)) {
-                               note_page(st, __pgprot(prot), eff, 4);
-                       } else if (!kasan_page_table(st, pmd_start)) {
-                               walk_pte_level(st, *start, eff,
-                                              P + i * PMD_LEVEL_MULT);
-                       }
-               } else
-                       note_page(st, __pgprot(0), 0, 4);
-               start++;
-       }
-}
-
-#else
-#define walk_pmd_level(s,a,e,p) walk_pte_level(s,__pmd(pud_val(a)),e,p)
-#define pud_large(a) pmd_large(__pmd(pud_val(a)))
-#define pud_none(a)  pmd_none(__pmd(pud_val(a)))
-#endif
-
-#if PTRS_PER_PUD > 1
-
-static void walk_pud_level(struct pg_state *st, p4d_t addr, pgprotval_t eff_in,
-                          unsigned long P)
+static void ptdump_walk_pgd_level_core(struct seq_file *m, pgd_t *pgd,
+                                      bool checkwx, bool dmesg)
 {
-       int i;
-       pud_t *start, *pud_start;
-       pgprotval_t prot, eff;
-
-       pud_start = start = (pud_t *)p4d_page_vaddr(addr);
-
-       for (i = 0; i < PTRS_PER_PUD; i++) {
-               st->current_address = normalize_addr(P + i * PUD_LEVEL_MULT);
-               if (!pud_none(*start)) {
-                       prot = pud_flags(*start);
-                       eff = effective_prot(eff_in, prot);
-                       if (pud_large(*start) || !pud_present(*start)) {
-                               note_page(st, __pgprot(prot), eff, 3);
-                       } else if (!kasan_page_table(st, pud_start)) {
-                               walk_pmd_level(st, *start, eff,
-                                              P + i * PUD_LEVEL_MULT);
-                       }
-               } else
-                       note_page(st, __pgprot(0), 0, 3);
+       const struct ptdump_range ptdump_ranges[] = {
+#ifdef CONFIG_X86_64
 
-               start++;
-       }
-}
+#define normalize_addr_shift (64 - (__VIRTUAL_MASK_SHIFT + 1))
+#define normalize_addr(u) ((signed long)((u) << normalize_addr_shift) >> \
+                          normalize_addr_shift)
 
+       {0, PTRS_PER_PGD * PGD_LEVEL_MULT / 2},
+       {normalize_addr(PTRS_PER_PGD * PGD_LEVEL_MULT / 2), ~0UL},
 #else
-#define walk_pud_level(s,a,e,p) walk_pmd_level(s,__pud(p4d_val(a)),e,p)
-#define p4d_large(a) pud_large(__pud(p4d_val(a)))
-#define p4d_none(a)  pud_none(__pud(p4d_val(a)))
+       {0, ~0UL},
 #endif
+       {0, 0}
+};
 
-static void walk_p4d_level(struct pg_state *st, pgd_t addr, pgprotval_t eff_in,
-                          unsigned long P)
-{
-       int i;
-       p4d_t *start, *p4d_start;
-       pgprotval_t prot, eff;
-
-       if (PTRS_PER_P4D == 1)
-               return walk_pud_level(st, __p4d(pgd_val(addr)), eff_in, P);
-
-       p4d_start = start = (p4d_t *)pgd_page_vaddr(addr);
-
-       for (i = 0; i < PTRS_PER_P4D; i++) {
-               st->current_address = normalize_addr(P + i * P4D_LEVEL_MULT);
-               if (!p4d_none(*start)) {
-                       prot = p4d_flags(*start);
-                       eff = effective_prot(eff_in, prot);
-                       if (p4d_large(*start) || !p4d_present(*start)) {
-                               note_page(st, __pgprot(prot), eff, 2);
-                       } else if (!kasan_page_table(st, p4d_start)) {
-                               walk_pud_level(st, *start, eff,
-                                              P + i * P4D_LEVEL_MULT);
-                       }
-               } else
-                       note_page(st, __pgprot(0), 0, 2);
-
-               start++;
-       }
-}
+       struct pg_state st = {
+               .ptdump = {
+                       .note_page      = note_page,
+                       .range          = ptdump_ranges
+               },
+               .to_dmesg       = dmesg,
+               .check_wx       = checkwx,
+               .seq            = m
+       };
 
-#define pgd_large(a) (pgtable_l5_enabled() ? pgd_large(a) : p4d_large(__p4d(pgd_val(a))))
-#define pgd_none(a)  (pgtable_l5_enabled() ? pgd_none(a) : p4d_none(__p4d(pgd_val(a))))
+       struct mm_struct fake_mm = {
+               .pgd = pgd
+       };
+       init_rwsem(&fake_mm.mmap_sem);
 
-static inline bool is_hypervisor_range(int idx)
-{
-#ifdef CONFIG_X86_64
-       /*
-        * A hole in the beginning of kernel address space reserved
-        * for a hypervisor.
-        */
-       return  (idx >= pgd_index(GUARD_HOLE_BASE_ADDR)) &&
-               (idx <  pgd_index(GUARD_HOLE_END_ADDR));
-#else
-       return false;
-#endif
-}
-
-static void ptdump_walk_pgd_level_core(struct seq_file *m, pgd_t *pgd,
-                                      bool checkwx, bool dmesg)
-{
-       pgd_t *start = pgd;
-       pgprotval_t prot, eff;
-       int i;
-       struct pg_state st = {};
-
-       st.to_dmesg = dmesg;
-       st.check_wx = checkwx;
-       st.seq = m;
-       if (checkwx)
-               st.wx_pages = 0;
-
-       for (i = 0; i < PTRS_PER_PGD; i++) {
-               st.current_address = normalize_addr(i * PGD_LEVEL_MULT);
-               if (!pgd_none(*start) && !is_hypervisor_range(i)) {
-                       prot = pgd_flags(*start);
-#ifdef CONFIG_X86_PAE
-                       eff = _PAGE_USER | _PAGE_RW;
-#else
-                       eff = prot;
-#endif
-                       if (pgd_large(*start) || !pgd_present(*start)) {
-                               note_page(&st, __pgprot(prot), eff, 1);
-                       } else {
-                               walk_p4d_level(&st, *start, eff,
-                                              i * PGD_LEVEL_MULT);
-                       }
-               } else
-                       note_page(&st, __pgprot(0), 0, 1);
-
-               cond_resched();
-               start++;
-       }
+       ptdump_walk_pgd(&st.ptdump, &fake_mm);
 
-       /* Flush out the last page */
-       st.current_address = normalize_addr(PTRS_PER_PGD*PGD_LEVEL_MULT);
-       note_page(&st, __pgprot(0), 0, 0);
        if (!checkwx)
                return;
        if (st.wx_pages)