#define NO_CONT_MAPPINGS       BIT(1)
 #define NO_EXEC_MAPPINGS       BIT(2)  /* assumes FEAT_HPDS is not used */
 
+enum pgtable_type {
+       TABLE_PTE,
+       TABLE_PMD,
+       TABLE_PUD,
+       TABLE_P4D,
+};
+
 u64 kimage_voffset __ro_after_init;
 EXPORT_SYMBOL(kimage_voffset);
 
 }
 EXPORT_SYMBOL(phys_mem_access_prot);
 
-static phys_addr_t __init early_pgtable_alloc(int shift)
+static phys_addr_t __init early_pgtable_alloc(enum pgtable_type pgtable_type)
 {
        phys_addr_t phys;
 
 static void alloc_init_cont_pte(pmd_t *pmdp, unsigned long addr,
                                unsigned long end, phys_addr_t phys,
                                pgprot_t prot,
-                               phys_addr_t (*pgtable_alloc)(int),
+                               phys_addr_t (*pgtable_alloc)(enum pgtable_type),
                                int flags)
 {
        unsigned long next;
                if (flags & NO_EXEC_MAPPINGS)
                        pmdval |= PMD_TABLE_PXN;
                BUG_ON(!pgtable_alloc);
-               pte_phys = pgtable_alloc(PAGE_SHIFT);
+               pte_phys = pgtable_alloc(TABLE_PTE);
                ptep = pte_set_fixmap(pte_phys);
                init_clear_pgtable(ptep);
                ptep += pte_index(addr);
 
 static void init_pmd(pmd_t *pmdp, unsigned long addr, unsigned long end,
                     phys_addr_t phys, pgprot_t prot,
-                    phys_addr_t (*pgtable_alloc)(int), int flags)
+                    phys_addr_t (*pgtable_alloc)(enum pgtable_type), int flags)
 {
        unsigned long next;
 
 static void alloc_init_cont_pmd(pud_t *pudp, unsigned long addr,
                                unsigned long end, phys_addr_t phys,
                                pgprot_t prot,
-                               phys_addr_t (*pgtable_alloc)(int), int flags)
+                               phys_addr_t (*pgtable_alloc)(enum pgtable_type),
+                               int flags)
 {
        unsigned long next;
        pud_t pud = READ_ONCE(*pudp);
                if (flags & NO_EXEC_MAPPINGS)
                        pudval |= PUD_TABLE_PXN;
                BUG_ON(!pgtable_alloc);
-               pmd_phys = pgtable_alloc(PMD_SHIFT);
+               pmd_phys = pgtable_alloc(TABLE_PMD);
                pmdp = pmd_set_fixmap(pmd_phys);
                init_clear_pgtable(pmdp);
                pmdp += pmd_index(addr);
 
 static void alloc_init_pud(p4d_t *p4dp, unsigned long addr, unsigned long end,
                           phys_addr_t phys, pgprot_t prot,
-                          phys_addr_t (*pgtable_alloc)(int),
+                          phys_addr_t (*pgtable_alloc)(enum pgtable_type),
                           int flags)
 {
        unsigned long next;
                if (flags & NO_EXEC_MAPPINGS)
                        p4dval |= P4D_TABLE_PXN;
                BUG_ON(!pgtable_alloc);
-               pud_phys = pgtable_alloc(PUD_SHIFT);
+               pud_phys = pgtable_alloc(TABLE_PUD);
                pudp = pud_set_fixmap(pud_phys);
                init_clear_pgtable(pudp);
                pudp += pud_index(addr);
 
 static void alloc_init_p4d(pgd_t *pgdp, unsigned long addr, unsigned long end,
                           phys_addr_t phys, pgprot_t prot,
-                          phys_addr_t (*pgtable_alloc)(int),
+                          phys_addr_t (*pgtable_alloc)(enum pgtable_type),
                           int flags)
 {
        unsigned long next;
                if (flags & NO_EXEC_MAPPINGS)
                        pgdval |= PGD_TABLE_PXN;
                BUG_ON(!pgtable_alloc);
-               p4d_phys = pgtable_alloc(P4D_SHIFT);
+               p4d_phys = pgtable_alloc(TABLE_P4D);
                p4dp = p4d_set_fixmap(p4d_phys);
                init_clear_pgtable(p4dp);
                p4dp += p4d_index(addr);
 static void __create_pgd_mapping_locked(pgd_t *pgdir, phys_addr_t phys,
                                        unsigned long virt, phys_addr_t size,
                                        pgprot_t prot,
-                                       phys_addr_t (*pgtable_alloc)(int),
+                                       phys_addr_t (*pgtable_alloc)(enum pgtable_type),
                                        int flags)
 {
        unsigned long addr, end, next;
 static void __create_pgd_mapping(pgd_t *pgdir, phys_addr_t phys,
                                 unsigned long virt, phys_addr_t size,
                                 pgprot_t prot,
-                                phys_addr_t (*pgtable_alloc)(int),
+                                phys_addr_t (*pgtable_alloc)(enum pgtable_type),
                                 int flags)
 {
        mutex_lock(&fixmap_lock);
 extern __alias(__create_pgd_mapping_locked)
 void create_kpti_ng_temp_pgd(pgd_t *pgdir, phys_addr_t phys, unsigned long virt,
                             phys_addr_t size, pgprot_t prot,
-                            phys_addr_t (*pgtable_alloc)(int), int flags);
+                            phys_addr_t (*pgtable_alloc)(enum pgtable_type),
+                            int flags);
 #endif
 
-static phys_addr_t __pgd_pgtable_alloc(int shift)
+static phys_addr_t __pgd_pgtable_alloc(enum pgtable_type pgtable_type)
 {
        /* Page is zeroed by init_clear_pgtable() so don't duplicate effort. */
        void *ptr = (void *)__get_free_page(GFP_PGTABLE_KERNEL & ~__GFP_ZERO);
        return __pa(ptr);
 }
 
-static phys_addr_t pgd_pgtable_alloc(int shift)
+static phys_addr_t pgd_pgtable_alloc(enum pgtable_type pgtable_type)
 {
-       phys_addr_t pa = __pgd_pgtable_alloc(shift);
+       phys_addr_t pa = __pgd_pgtable_alloc(pgtable_type);
        struct ptdesc *ptdesc = page_ptdesc(phys_to_page(pa));
 
        /*
         * Call proper page table ctor in case later we need to
         * call core mm functions like apply_to_page_range() on
         * this pre-allocated page table.
-        *
-        * We don't select ARCH_ENABLE_SPLIT_PMD_PTLOCK if pmd is
-        * folded, and if so pagetable_pte_ctor() becomes nop.
         */
-       if (shift == PAGE_SHIFT)
+       switch (pgtable_type) {
+       case TABLE_PTE:
                BUG_ON(!pagetable_pte_ctor(NULL, ptdesc));
-       else if (shift == PMD_SHIFT)
+               break;
+       case TABLE_PMD:
                BUG_ON(!pagetable_pmd_ctor(NULL, ptdesc));
+               break;
+       default:
+               break;
+       }
 
        return pa;
 }