#include <linux/pgtable.h>
 #include <asm/tlbflush.h>
 #include <asm/fixmap.h>
+#include <asm/pgalloc.h>
+
+static __init void *early_alloc(size_t size, int node)
+{
+       void *ptr = memblock_alloc_try_nid(size, size,
+               __pa(MAX_DMA_ADDRESS), MEMBLOCK_ALLOC_ACCESSIBLE, node);
+
+       if (!ptr)
+               panic("%pS: Failed to allocate %zu bytes align=%zx nid=%d from=%llx\n",
+                       __func__, size, size, node, (u64)__pa(MAX_DMA_ADDRESS));
+
+       return ptr;
+}
 
 extern pgd_t early_pg_dir[PTRS_PER_PGD];
 asmlinkage void __init kasan_early_init(void)
        memset(start, 0, end - start);
 }
 
+void __init kasan_shallow_populate(void *start, void *end)
+{
+       unsigned long vaddr = (unsigned long)start & PAGE_MASK;
+       unsigned long vend = PAGE_ALIGN((unsigned long)end);
+       unsigned long pfn;
+       int index;
+       void *p;
+       pud_t *pud_dir, *pud_k;
+       pgd_t *pgd_dir, *pgd_k;
+       p4d_t *p4d_dir, *p4d_k;
+
+       while (vaddr < vend) {
+               index = pgd_index(vaddr);
+               pfn = csr_read(CSR_SATP) & SATP_PPN;
+               pgd_dir = (pgd_t *)pfn_to_virt(pfn) + index;
+               pgd_k = init_mm.pgd + index;
+               pgd_dir = pgd_offset_k(vaddr);
+               set_pgd(pgd_dir, *pgd_k);
+
+               p4d_dir = p4d_offset(pgd_dir, vaddr);
+               p4d_k  = p4d_offset(pgd_k, vaddr);
+
+               vaddr = (vaddr + PUD_SIZE) & PUD_MASK;
+               pud_dir = pud_offset(p4d_dir, vaddr);
+               pud_k = pud_offset(p4d_k, vaddr);
+
+               if (pud_present(*pud_dir)) {
+                       p = early_alloc(PAGE_SIZE, NUMA_NO_NODE);
+                       pud_populate(&init_mm, pud_dir, p);
+               }
+               vaddr += PAGE_SIZE;
+       }
+}
+
 void __init kasan_init(void)
 {
        phys_addr_t _start, _end;
 
        kasan_populate_early_shadow((void *)KASAN_SHADOW_START,
                                    (void *)kasan_mem_to_shadow((void *)
-                                                               VMALLOC_END));
+                                                               VMEMMAP_END));
+       if (IS_ENABLED(CONFIG_KASAN_VMALLOC))
+               kasan_shallow_populate(
+                       (void *)kasan_mem_to_shadow((void *)VMALLOC_START),
+                       (void *)kasan_mem_to_shadow((void *)VMALLOC_END));
+       else
+               kasan_populate_early_shadow(
+                       (void *)kasan_mem_to_shadow((void *)VMALLOC_START),
+                       (void *)kasan_mem_to_shadow((void *)VMALLOC_END));
 
        for_each_mem_range(i, &_start, &_end) {
                void *start = (void *)_start;