static unsigned long hyp_idmap_end;
 static phys_addr_t hyp_idmap_vector;
 
+static unsigned long io_map_base;
+
 #define S2_PGD_SIZE    (PTRS_PER_S2_PGD * sizeof(pgd_t))
 #define hyp_pgd_order get_order(PTRS_PER_PGD * sizeof(pgd_t))
 
  *
  * Assumes hyp_pgd is a page table used strictly in Hyp-mode and
  * therefore contains either mappings in the kernel memory area (above
- * PAGE_OFFSET), or device mappings in the vmalloc range (from
- * VMALLOC_START to VMALLOC_END).
+ * PAGE_OFFSET), or device mappings in the idmap range.
  *
- * boot_hyp_pgd should only map two pages for the init code.
+ * boot_hyp_pgd should only map the idmap range, and is only used in
+ * the extended idmap case.
  */
 void free_hyp_pgds(void)
 {
+       pgd_t *id_pgd;
+
        mutex_lock(&kvm_hyp_pgd_mutex);
 
+       id_pgd = boot_hyp_pgd ? boot_hyp_pgd : hyp_pgd;
+
+       if (id_pgd) {
+               /* In case we never called hyp_mmu_init() */
+               if (!io_map_base)
+                       io_map_base = hyp_idmap_start;
+               unmap_hyp_idmap_range(id_pgd, io_map_base,
+                                     hyp_idmap_start + PAGE_SIZE - io_map_base);
+       }
+
        if (boot_hyp_pgd) {
-               unmap_hyp_idmap_range(boot_hyp_pgd, hyp_idmap_start, PAGE_SIZE);
                free_pages((unsigned long)boot_hyp_pgd, hyp_pgd_order);
                boot_hyp_pgd = NULL;
        }
 
        if (hyp_pgd) {
-               unmap_hyp_idmap_range(hyp_pgd, hyp_idmap_start, PAGE_SIZE);
                unmap_hyp_range(hyp_pgd, kern_hyp_va(PAGE_OFFSET),
                                (uintptr_t)high_memory - PAGE_OFFSET);
-               unmap_hyp_range(hyp_pgd, kern_hyp_va(VMALLOC_START),
-                               VMALLOC_END - VMALLOC_START);
 
                free_pages((unsigned long)hyp_pgd, hyp_pgd_order);
                hyp_pgd = NULL;
                           void __iomem **kaddr,
                           void __iomem **haddr)
 {
-       unsigned long start, end;
-       int ret;
+       pgd_t *pgd = hyp_pgd;
+       unsigned long base;
+       int ret = 0;
 
        *kaddr = ioremap(phys_addr, size);
        if (!*kaddr)
                return 0;
        }
 
+       mutex_lock(&kvm_hyp_pgd_mutex);
 
-       start = kern_hyp_va((unsigned long)*kaddr);
-       end = kern_hyp_va((unsigned long)*kaddr + size);
-       ret = __create_hyp_mappings(hyp_pgd, PTRS_PER_PGD, start, end,
-                                    __phys_to_pfn(phys_addr), PAGE_HYP_DEVICE);
+       /*
+        * This assumes that we we have enough space below the idmap
+        * page to allocate our VAs. If not, the check below will
+        * kick. A potential alternative would be to detect that
+        * overflow and switch to an allocation above the idmap.
+        *
+        * The allocated size is always a multiple of PAGE_SIZE.
+        */
+       size = PAGE_ALIGN(size + offset_in_page(phys_addr));
+       base = io_map_base - size;
 
+       /*
+        * Verify that BIT(VA_BITS - 1) hasn't been flipped by
+        * allocating the new area, as it would indicate we've
+        * overflowed the idmap/IO address range.
+        */
+       if ((base ^ io_map_base) & BIT(VA_BITS - 1))
+               ret = -ENOMEM;
+       else
+               io_map_base = base;
+
+       mutex_unlock(&kvm_hyp_pgd_mutex);
+
+       if (ret)
+               goto out;
+
+       if (__kvm_cpu_uses_extended_idmap())
+               pgd = boot_hyp_pgd;
+
+       ret = __create_hyp_mappings(pgd, __kvm_idmap_ptrs_per_pgd(),
+                                   base, base + size,
+                                   __phys_to_pfn(phys_addr), PAGE_HYP_DEVICE);
+       if (ret)
+               goto out;
+
+       *haddr = (void __iomem *)base + offset_in_page(phys_addr);
+
+out:
        if (ret) {
                iounmap(*kaddr);
                *kaddr = NULL;
                return ret;
        }
 
-       *haddr = (void __iomem *)start;
        return 0;
 }
 
                        goto out;
        }
 
+       io_map_base = hyp_idmap_start;
        return 0;
 out:
        free_hyp_pgds();