int flags);
 #endif
 
-static phys_addr_t __pgd_pgtable_alloc(enum pgtable_type pgtable_type)
+static phys_addr_t __pgd_pgtable_alloc(struct mm_struct *mm,
+                                      enum pgtable_type pgtable_type)
 {
        /* Page is zeroed by init_clear_pgtable() so don't duplicate effort. */
-       void *ptr = (void *)__get_free_page(GFP_PGTABLE_KERNEL & ~__GFP_ZERO);
+       struct ptdesc *ptdesc = pagetable_alloc(GFP_PGTABLE_KERNEL & ~__GFP_ZERO, 0);
+       phys_addr_t pa;
 
-       BUG_ON(!ptr);
-       return __pa(ptr);
-}
-
-static phys_addr_t pgd_pgtable_alloc(enum pgtable_type pgtable_type)
-{
-       phys_addr_t pa = __pgd_pgtable_alloc(pgtable_type);
-       struct ptdesc *ptdesc = page_ptdesc(phys_to_page(pa));
+       BUG_ON(!ptdesc);
+       pa = page_to_phys(ptdesc_page(ptdesc));
 
-       /*
-        * Call proper page table ctor in case later we need to
-        * call core mm functions like apply_to_page_range() on
-        * this pre-allocated page table.
-        */
        switch (pgtable_type) {
        case TABLE_PTE:
-               BUG_ON(!pagetable_pte_ctor(NULL, ptdesc));
+               BUG_ON(!pagetable_pte_ctor(mm, ptdesc));
                break;
        case TABLE_PMD:
-               BUG_ON(!pagetable_pmd_ctor(NULL, ptdesc));
+               BUG_ON(!pagetable_pmd_ctor(mm, ptdesc));
                break;
        default:
                break;
        return pa;
 }
 
+static phys_addr_t __maybe_unused
+pgd_pgtable_alloc_init_mm(enum pgtable_type pgtable_type)
+{
+       return __pgd_pgtable_alloc(&init_mm, pgtable_type);
+}
+
+static phys_addr_t
+pgd_pgtable_alloc_special_mm(enum pgtable_type pgtable_type)
+{
+       return __pgd_pgtable_alloc(NULL, pgtable_type);
+}
+
 /*
  * This function can only be used to modify existing table entries,
  * without allocating new levels of table. Note that this permits the
                flags = NO_BLOCK_MAPPINGS | NO_CONT_MAPPINGS;
 
        __create_pgd_mapping(mm->pgd, phys, virt, size, prot,
-                            pgd_pgtable_alloc, flags);
+                            pgd_pgtable_alloc_special_mm, flags);
 }
 
 static void update_mapping_prot(phys_addr_t phys, unsigned long virt,
        memset(tramp_pg_dir, 0, PGD_SIZE);
        __create_pgd_mapping(tramp_pg_dir, pa_start, TRAMP_VALIAS,
                             entry_tramp_text_size(), prot,
-                            __pgd_pgtable_alloc, NO_BLOCK_MAPPINGS);
+                            pgd_pgtable_alloc_init_mm, NO_BLOCK_MAPPINGS);
 
        /* Map both the text and data into the kernel page table */
        for (i = 0; i < DIV_ROUND_UP(entry_tramp_text_size(), PAGE_SIZE); i++)
                flags |= NO_BLOCK_MAPPINGS | NO_CONT_MAPPINGS;
 
        __create_pgd_mapping(swapper_pg_dir, start, __phys_to_virt(start),
-                            size, params->pgprot, __pgd_pgtable_alloc,
+                            size, params->pgprot, pgd_pgtable_alloc_init_mm,
                             flags);
 
        memblock_clear_nomap(start, size);