static void freeze_page_vma(struct vm_area_struct *vma, struct page *page,
                unsigned long address)
 {
+       unsigned long haddr = address & HPAGE_PMD_MASK;
        spinlock_t *ptl;
        pgd_t *pgd;
        pud_t *pud;
        }
        if (pmd_trans_huge(*pmd)) {
                if (page == pmd_page(*pmd))
-                       __split_huge_pmd_locked(vma, pmd, address, true);
+                       __split_huge_pmd_locked(vma, pmd, haddr, true);
                spin_unlock(ptl);
                return;
        }
        spin_unlock(ptl);
 
        pte = pte_offset_map_lock(vma->vm_mm, pmd, address, &ptl);
-       for (i = 0; i < nr; i++, address += PAGE_SIZE, page++) {
+       for (i = 0; i < nr; i++, address += PAGE_SIZE, page++, pte++) {
                pte_t entry, swp_pte;
                swp_entry_t swp_entry;
 
-               if (!pte_present(pte[i]))
+               /*
+                * We've just crossed page table boundary: need to map next one.
+                * It can happen if THP was mremaped to non PMD-aligned address.
+                */
+               if (unlikely(address == haddr + HPAGE_PMD_SIZE)) {
+                       pte_unmap_unlock(pte - 1, ptl);
+                       pmd = mm_find_pmd(vma->vm_mm, address);
+                       if (!pmd)
+                               return;
+                       pte = pte_offset_map_lock(vma->vm_mm, pmd,
+                                       address, &ptl);
+               }
+
+               if (!pte_present(*pte))
                        continue;
-               if (page_to_pfn(page) != pte_pfn(pte[i]))
+               if (page_to_pfn(page) != pte_pfn(*pte))
                        continue;
                flush_cache_page(vma, address, page_to_pfn(page));
-               entry = ptep_clear_flush(vma, address, pte + i);
+               entry = ptep_clear_flush(vma, address, pte);
                if (pte_dirty(entry))
                        SetPageDirty(page);
                swp_entry = make_migration_entry(page, pte_write(entry));
                swp_pte = swp_entry_to_pte(swp_entry);
                if (pte_soft_dirty(entry))
                        swp_pte = pte_swp_mksoft_dirty(swp_pte);
-               set_pte_at(vma->vm_mm, address, pte + i, swp_pte);
+               set_pte_at(vma->vm_mm, address, pte, swp_pte);
                page_remove_rmap(page, false);
                put_page(page);
        }
-       pte_unmap_unlock(pte, ptl);
+       pte_unmap_unlock(pte - 1, ptl);
 }
 
 static void freeze_page(struct anon_vma *anon_vma, struct page *page)
 
        anon_vma_interval_tree_foreach(avc, &anon_vma->rb_root, pgoff,
                        pgoff + HPAGE_PMD_NR - 1) {
-               unsigned long haddr;
+               unsigned long address = __vma_address(page, avc->vma);
 
-               haddr = __vma_address(page, avc->vma) & HPAGE_PMD_MASK;
                mmu_notifier_invalidate_range_start(avc->vma->vm_mm,
-                               haddr, haddr + HPAGE_PMD_SIZE);
-               freeze_page_vma(avc->vma, page, haddr);
+                               address, address + HPAGE_PMD_SIZE);
+               freeze_page_vma(avc->vma, page, address);
                mmu_notifier_invalidate_range_end(avc->vma->vm_mm,
-                               haddr, haddr + HPAGE_PMD_SIZE);
+                               address, address + HPAGE_PMD_SIZE);
        }
 }
 
        pmd_t *pmd;
        pte_t *pte, entry;
        swp_entry_t swp_entry;
+       unsigned long haddr = address & HPAGE_PMD_MASK;
        int i, nr = HPAGE_PMD_NR;
 
        /* Skip pages which doesn't belong to the VMA */
        pmd = mm_find_pmd(vma->vm_mm, address);
        if (!pmd)
                return;
+
        pte = pte_offset_map_lock(vma->vm_mm, pmd, address, &ptl);
-       for (i = 0; i < nr; i++, address += PAGE_SIZE, page++) {
-               if (!is_swap_pte(pte[i]))
+       for (i = 0; i < nr; i++, address += PAGE_SIZE, page++, pte++) {
+               /*
+                * We've just crossed page table boundary: need to map next one.
+                * It can happen if THP was mremaped to non-PMD aligned address.
+                */
+               if (unlikely(address == haddr + HPAGE_PMD_SIZE)) {
+                       pte_unmap_unlock(pte - 1, ptl);
+                       pmd = mm_find_pmd(vma->vm_mm, address);
+                       if (!pmd)
+                               return;
+                       pte = pte_offset_map_lock(vma->vm_mm, pmd,
+                                       address, &ptl);
+               }
+
+               if (!is_swap_pte(*pte))
                        continue;
 
-               swp_entry = pte_to_swp_entry(pte[i]);
+               swp_entry = pte_to_swp_entry(*pte);
                if (!is_migration_entry(swp_entry))
                        continue;
                if (migration_entry_to_page(swp_entry) != page)
                        entry = maybe_mkwrite(entry, vma);
 
                flush_dcache_page(page);
-               set_pte_at(vma->vm_mm, address, pte + i, entry);
+               set_pte_at(vma->vm_mm, address, pte, entry);
 
                /* No need to invalidate - it was non-present before */
-               update_mmu_cache(vma, address, pte + i);
+               update_mmu_cache(vma, address, pte);
        }
-       pte_unmap_unlock(pte, ptl);
+       pte_unmap_unlock(pte - 1, ptl);
 }
 
 static void unfreeze_page(struct anon_vma *anon_vma, struct page *page)
        spin_lock(&split_queue_lock);
        count = page_count(head);
        mapcount = total_mapcount(head);
-       if (mapcount == count - 1) {
+       if (!mapcount && count == 1) {
                if (!list_empty(page_deferred_list(head))) {
                        split_queue_len--;
                        list_del(page_deferred_list(head));
                spin_unlock(&split_queue_lock);
                __split_huge_page(page, list);
                ret = 0;
-       } else if (IS_ENABLED(CONFIG_DEBUG_VM) && mapcount > count - 1) {
+       } else if (IS_ENABLED(CONFIG_DEBUG_VM) && mapcount) {
                spin_unlock(&split_queue_lock);
                pr_alert("total_mapcount: %u, page_count(): %u\n",
                                mapcount, count);
                if (PageTail(page))
                        dump_page(head, NULL);
-               dump_page(page, "total_mapcount(head) > page_count(head) - 1");
+               dump_page(page, "total_mapcount(head) > 0");
                BUG();
        } else {
                spin_unlock(&split_queue_lock);