* struct vmemmap_remap_walk - walk vmemmap page table
  *
  * @remap_pte:         called for each lowest-level entry (PTE).
+ * @nr_walked:         the number of walked pte.
  * @reuse_page:                the page which is reused for the tail vmemmap pages.
  * @reuse_addr:                the virtual address of the @reuse_page page.
  * @vmemmap_pages:     the list head of the vmemmap pages that can be freed
 struct vmemmap_remap_walk {
        void (*remap_pte)(pte_t *pte, unsigned long addr,
                          struct vmemmap_remap_walk *walk);
+       unsigned long nr_walked;
        struct page *reuse_page;
        unsigned long reuse_addr;
        struct list_head *vmemmap_pages;
 };
 
+static int split_vmemmap_huge_pmd(pmd_t *pmd, unsigned long start,
+                                 struct vmemmap_remap_walk *walk)
+{
+       pmd_t __pmd;
+       int i;
+       unsigned long addr = start;
+       struct page *page = pmd_page(*pmd);
+       pte_t *pgtable = pte_alloc_one_kernel(&init_mm);
+
+       if (!pgtable)
+               return -ENOMEM;
+
+       pmd_populate_kernel(&init_mm, &__pmd, pgtable);
+
+       for (i = 0; i < PMD_SIZE / PAGE_SIZE; i++, addr += PAGE_SIZE) {
+               pte_t entry, *pte;
+               pgprot_t pgprot = PAGE_KERNEL;
+
+               entry = mk_pte(page + i, pgprot);
+               pte = pte_offset_kernel(&__pmd, addr);
+               set_pte_at(&init_mm, addr, pte, entry);
+       }
+
+       /* Make pte visible before pmd. See comment in __pte_alloc(). */
+       smp_wmb();
+       pmd_populate_kernel(&init_mm, pmd, pgtable);
+
+       flush_tlb_kernel_range(start, start + PMD_SIZE);
+
+       return 0;
+}
+
 static void vmemmap_pte_range(pmd_t *pmd, unsigned long addr,
                              unsigned long end,
                              struct vmemmap_remap_walk *walk)
                 */
                addr += PAGE_SIZE;
                pte++;
+               walk->nr_walked++;
        }
 
-       for (; addr != end; addr += PAGE_SIZE, pte++)
+       for (; addr != end; addr += PAGE_SIZE, pte++) {
                walk->remap_pte(pte, addr, walk);
+               walk->nr_walked++;
+       }
 }
 
-static void vmemmap_pmd_range(pud_t *pud, unsigned long addr,
-                             unsigned long end,
-                             struct vmemmap_remap_walk *walk)
+static int vmemmap_pmd_range(pud_t *pud, unsigned long addr,
+                            unsigned long end,
+                            struct vmemmap_remap_walk *walk)
 {
        pmd_t *pmd;
        unsigned long next;
 
        pmd = pmd_offset(pud, addr);
        do {
-               BUG_ON(pmd_leaf(*pmd));
+               if (pmd_leaf(*pmd)) {
+                       int ret;
 
+                       ret = split_vmemmap_huge_pmd(pmd, addr & PMD_MASK, walk);
+                       if (ret)
+                               return ret;
+               }
                next = pmd_addr_end(addr, end);
                vmemmap_pte_range(pmd, addr, next, walk);
        } while (pmd++, addr = next, addr != end);
+
+       return 0;
 }
 
-static void vmemmap_pud_range(p4d_t *p4d, unsigned long addr,
-                             unsigned long end,
-                             struct vmemmap_remap_walk *walk)
+static int vmemmap_pud_range(p4d_t *p4d, unsigned long addr,
+                            unsigned long end,
+                            struct vmemmap_remap_walk *walk)
 {
        pud_t *pud;
        unsigned long next;
 
        pud = pud_offset(p4d, addr);
        do {
+               int ret;
+
                next = pud_addr_end(addr, end);
-               vmemmap_pmd_range(pud, addr, next, walk);
+               ret = vmemmap_pmd_range(pud, addr, next, walk);
+               if (ret)
+                       return ret;
        } while (pud++, addr = next, addr != end);
+
+       return 0;
 }
 
-static void vmemmap_p4d_range(pgd_t *pgd, unsigned long addr,
-                             unsigned long end,
-                             struct vmemmap_remap_walk *walk)
+static int vmemmap_p4d_range(pgd_t *pgd, unsigned long addr,
+                            unsigned long end,
+                            struct vmemmap_remap_walk *walk)
 {
        p4d_t *p4d;
        unsigned long next;
 
        p4d = p4d_offset(pgd, addr);
        do {
+               int ret;
+
                next = p4d_addr_end(addr, end);
-               vmemmap_pud_range(p4d, addr, next, walk);
+               ret = vmemmap_pud_range(p4d, addr, next, walk);
+               if (ret)
+                       return ret;
        } while (p4d++, addr = next, addr != end);
+
+       return 0;
 }
 
-static void vmemmap_remap_range(unsigned long start, unsigned long end,
-                               struct vmemmap_remap_walk *walk)
+static int vmemmap_remap_range(unsigned long start, unsigned long end,
+                              struct vmemmap_remap_walk *walk)
 {
        unsigned long addr = start;
        unsigned long next;
 
        pgd = pgd_offset_k(addr);
        do {
+               int ret;
+
                next = pgd_addr_end(addr, end);
-               vmemmap_p4d_range(pgd, addr, next, walk);
+               ret = vmemmap_p4d_range(pgd, addr, next, walk);
+               if (ret)
+                       return ret;
        } while (pgd++, addr = next, addr != end);
 
        /*
         * belongs to the range.
         */
        flush_tlb_kernel_range(start + PAGE_SIZE, end);
+
+       return 0;
 }
 
 /*
        pte_t entry = mk_pte(walk->reuse_page, pgprot);
        struct page *page = pte_page(*pte);
 
-       list_add(&page->lru, walk->vmemmap_pages);
+       list_add_tail(&page->lru, walk->vmemmap_pages);
        set_pte_at(&init_mm, addr, pte, entry);
 }
 
+static void vmemmap_restore_pte(pte_t *pte, unsigned long addr,
+                               struct vmemmap_remap_walk *walk)
+{
+       pgprot_t pgprot = PAGE_KERNEL;
+       struct page *page;
+       void *to;
+
+       BUG_ON(pte_page(*pte) != walk->reuse_page);
+
+       page = list_first_entry(walk->vmemmap_pages, struct page, lru);
+       list_del(&page->lru);
+       to = page_to_virt(page);
+       copy_page(to, (void *)walk->reuse_addr);
+
+       set_pte_at(&init_mm, addr, pte, mk_pte(page, pgprot));
+}
+
 /**
  * vmemmap_remap_free - remap the vmemmap virtual address range [@start, @end)
  *                     to the page which @reuse is mapped to, then free vmemmap
  *             remap.
  * @reuse:     reuse address.
  *
- * Note: This function depends on vmemmap being base page mapped. Please make
- * sure that we disable PMD mapping of vmemmap pages when calling this function.
+ * Return: %0 on success, negative error code otherwise.
  */
-void vmemmap_remap_free(unsigned long start, unsigned long end,
-                       unsigned long reuse)
+int vmemmap_remap_free(unsigned long start, unsigned long end,
+                      unsigned long reuse)
 {
+       int ret;
        LIST_HEAD(vmemmap_pages);
        struct vmemmap_remap_walk walk = {
                .remap_pte      = vmemmap_remap_pte,
         */
        BUG_ON(start - reuse != PAGE_SIZE);
 
-       vmemmap_remap_range(reuse, end, &walk);
-       free_vmemmap_page_list(&vmemmap_pages);
-}
+       mmap_write_lock(&init_mm);
+       ret = vmemmap_remap_range(reuse, end, &walk);
+       mmap_write_downgrade(&init_mm);
 
-static void vmemmap_restore_pte(pte_t *pte, unsigned long addr,
-                               struct vmemmap_remap_walk *walk)
-{
-       pgprot_t pgprot = PAGE_KERNEL;
-       struct page *page;
-       void *to;
+       if (ret && walk.nr_walked) {
+               end = reuse + walk.nr_walked * PAGE_SIZE;
+               /*
+                * vmemmap_pages contains pages from the previous
+                * vmemmap_remap_range call which failed.  These
+                * are pages which were removed from the vmemmap.
+                * They will be restored in the following call.
+                */
+               walk = (struct vmemmap_remap_walk) {
+                       .remap_pte      = vmemmap_restore_pte,
+                       .reuse_addr     = reuse,
+                       .vmemmap_pages  = &vmemmap_pages,
+               };
 
-       BUG_ON(pte_page(*pte) != walk->reuse_page);
+               vmemmap_remap_range(reuse, end, &walk);
+       }
+       mmap_read_unlock(&init_mm);
 
-       page = list_first_entry(walk->vmemmap_pages, struct page, lru);
-       list_del(&page->lru);
-       to = page_to_virt(page);
-       copy_page(to, (void *)walk->reuse_addr);
+       free_vmemmap_page_list(&vmemmap_pages);
 
-       set_pte_at(&init_mm, addr, pte, mk_pte(page, pgprot));
+       return ret;
 }
 
 static int alloc_vmemmap_page_list(unsigned long start, unsigned long end,
  *             remap.
  * @reuse:     reuse address.
  * @gfp_mask:  GFP flag for allocating vmemmap pages.
+ *
+ * Return: %0 on success, negative error code otherwise.
  */
 int vmemmap_remap_alloc(unsigned long start, unsigned long end,
                        unsigned long reuse, gfp_t gfp_mask)
        /* See the comment in the vmemmap_remap_free(). */
        BUG_ON(start - reuse != PAGE_SIZE);
 
-       might_sleep_if(gfpflags_allow_blocking(gfp_mask));
-
        if (alloc_vmemmap_page_list(start, end, gfp_mask, &vmemmap_pages))
                return -ENOMEM;
 
+       mmap_read_lock(&init_mm);
        vmemmap_remap_range(reuse, end, &walk);
+       mmap_read_unlock(&init_mm);
 
        return 0;
 }