#include <asm/barrier.h>
 #include <asm/cputype.h>
 #include <asm/fixmap.h>
+#include <asm/kasan.h>
 #include <asm/kernel-pgtable.h>
 #include <asm/sections.h>
 #include <asm/setup.h>
                             late_pgtable_alloc);
 }
 
-#ifdef CONFIG_DEBUG_RODATA
-static void __init __map_memblock(phys_addr_t start, phys_addr_t end)
+static void __init __map_memblock(pgd_t *pgd, phys_addr_t start, phys_addr_t end)
 {
+
+       unsigned long kernel_start = __pa(_stext);
+       unsigned long kernel_end = __pa(_end);
+
        /*
-        * Set up the executable regions using the existing section mappings
-        * for now. This will get more fine grained later once all memory
-        * is mapped
+        * The kernel itself is mapped at page granularity. Map all other
+        * memory, making sure we don't overwrite the existing kernel mappings.
         */
-       unsigned long kernel_x_start = round_down(__pa(_stext), SWAPPER_BLOCK_SIZE);
-       unsigned long kernel_x_end = round_up(__pa(__init_end), SWAPPER_BLOCK_SIZE);
-
-       if (end < kernel_x_start) {
-               create_mapping(start, __phys_to_virt(start),
-                       end - start, PAGE_KERNEL);
-       } else if (start >= kernel_x_end) {
-               create_mapping(start, __phys_to_virt(start),
-                       end - start, PAGE_KERNEL);
-       } else {
-               if (start < kernel_x_start)
-                       create_mapping(start, __phys_to_virt(start),
-                               kernel_x_start - start,
-                               PAGE_KERNEL);
-               create_mapping(kernel_x_start,
-                               __phys_to_virt(kernel_x_start),
-                               kernel_x_end - kernel_x_start,
-                               PAGE_KERNEL_EXEC);
-               if (kernel_x_end < end)
-                       create_mapping(kernel_x_end,
-                               __phys_to_virt(kernel_x_end),
-                               end - kernel_x_end,
-                               PAGE_KERNEL);
+
+       /* No overlap with the kernel. */
+       if (end < kernel_start || start >= kernel_end) {
+               __create_pgd_mapping(pgd, start, __phys_to_virt(start),
+                                    end - start, PAGE_KERNEL,
+                                    early_pgtable_alloc);
+               return;
        }
 
+       /*
+        * This block overlaps the kernel mapping. Map the portion(s) which
+        * don't overlap.
+        */
+       if (start < kernel_start)
+               __create_pgd_mapping(pgd, start,
+                                    __phys_to_virt(start),
+                                    kernel_start - start, PAGE_KERNEL,
+                                    early_pgtable_alloc);
+       if (kernel_end < end)
+               __create_pgd_mapping(pgd, kernel_end,
+                                    __phys_to_virt(kernel_end),
+                                    end - kernel_end, PAGE_KERNEL,
+                                    early_pgtable_alloc);
 }
-#else
-static void __init __map_memblock(phys_addr_t start, phys_addr_t end)
-{
-       create_mapping(start, __phys_to_virt(start), end - start,
-                       PAGE_KERNEL_EXEC);
-}
-#endif
 
-static void __init map_mem(void)
+static void __init map_mem(pgd_t *pgd)
 {
        struct memblock_region *reg;
 
                if (memblock_is_nomap(reg))
                        continue;
 
-               __map_memblock(start, end);
+               __map_memblock(pgd, start, end);
        }
 }
 
-static void __init fixup_executable(void)
-{
-#ifdef CONFIG_DEBUG_RODATA
-       /* now that we are actually fully mapped, make the start/end more fine grained */
-       if (!IS_ALIGNED((unsigned long)_stext, SWAPPER_BLOCK_SIZE)) {
-               unsigned long aligned_start = round_down(__pa(_stext),
-                                                        SWAPPER_BLOCK_SIZE);
-
-               create_mapping(aligned_start, __phys_to_virt(aligned_start),
-                               __pa(_stext) - aligned_start,
-                               PAGE_KERNEL);
-       }
-
-       if (!IS_ALIGNED((unsigned long)__init_end, SWAPPER_BLOCK_SIZE)) {
-               unsigned long aligned_end = round_up(__pa(__init_end),
-                                                         SWAPPER_BLOCK_SIZE);
-               create_mapping(__pa(__init_end), (unsigned long)__init_end,
-                               aligned_end - __pa(__init_end),
-                               PAGE_KERNEL);
-       }
-#endif
-}
-
 #ifdef CONFIG_DEBUG_RODATA
 void mark_rodata_ro(void)
 {
                        PAGE_KERNEL);
 }
 
+static void __init map_kernel_chunk(pgd_t *pgd, void *va_start, void *va_end,
+                                   pgprot_t prot)
+{
+       phys_addr_t pa_start = __pa(va_start);
+       unsigned long size = va_end - va_start;
+
+       BUG_ON(!PAGE_ALIGNED(pa_start));
+       BUG_ON(!PAGE_ALIGNED(size));
+
+       __create_pgd_mapping(pgd, pa_start, (unsigned long)va_start, size, prot,
+                            early_pgtable_alloc);
+}
+
+/*
+ * Create fine-grained mappings for the kernel.
+ */
+static void __init map_kernel(pgd_t *pgd)
+{
+
+       map_kernel_chunk(pgd, _stext, _etext, PAGE_KERNEL_EXEC);
+       map_kernel_chunk(pgd, __init_begin, __init_end, PAGE_KERNEL_EXEC);
+       map_kernel_chunk(pgd, _data, _end, PAGE_KERNEL);
+
+       /*
+        * The fixmap falls in a separate pgd to the kernel, and doesn't live
+        * in the carveout for the swapper_pg_dir. We can simply re-use the
+        * existing dir for the fixmap.
+        */
+       set_pgd(pgd_offset_raw(pgd, FIXADDR_START), *pgd_offset_k(FIXADDR_START));
+
+       kasan_copy_shadow(pgd);
+}
+
 /*
  * paging_init() sets up the page tables, initialises the zone memory
  * maps and sets up the zero page.
  */
 void __init paging_init(void)
 {
-       map_mem();
-       fixup_executable();
+       phys_addr_t pgd_phys = early_pgtable_alloc();
+       pgd_t *pgd = pgd_set_fixmap(pgd_phys);
+
+       map_kernel(pgd);
+       map_mem(pgd);
+
+       /*
+        * We want to reuse the original swapper_pg_dir so we don't have to
+        * communicate the new address to non-coherent secondaries in
+        * secondary_entry, and so cpu_switch_mm can generate the address with
+        * adrp+add rather than a load from some global variable.
+        *
+        * To do this we need to go via a temporary pgd.
+        */
+       cpu_replace_ttbr1(__va(pgd_phys));
+       memcpy(swapper_pg_dir, pgd, PAGE_SIZE);
+       cpu_replace_ttbr1(swapper_pg_dir);
+
+       pgd_clear_fixmap();
+       memblock_free(pgd_phys, PAGE_SIZE);
+
+       /*
+        * We only reuse the PGD from the swapper_pg_dir, not the pud + pmd
+        * allocated with it.
+        */
+       memblock_free(__pa(swapper_pg_dir) + PAGE_SIZE,
+                     SWAPPER_DIR_SIZE - PAGE_SIZE);
 
        bootmem_init();
 }