dax_wake_mapping_entry_waiter(mapping, index, entry, false);
 }
 
+static unsigned long dax_entry_size(void *entry)
+{
+       if (dax_is_zero_entry(entry))
+               return 0;
+       else if (dax_is_empty_entry(entry))
+               return 0;
+       else if (dax_is_pmd_entry(entry))
+               return PMD_SIZE;
+       else
+               return PAGE_SIZE;
+}
+
+static unsigned long dax_radix_end_pfn(void *entry)
+{
+       return dax_radix_pfn(entry) + dax_entry_size(entry) / PAGE_SIZE;
+}
+
+/*
+ * Iterate through all mapped pfns represented by an entry, i.e. skip
+ * 'empty' and 'zero' entries.
+ */
+#define for_each_mapped_pfn(entry, pfn) \
+       for (pfn = dax_radix_pfn(entry); \
+                       pfn < dax_radix_end_pfn(entry); pfn++)
+
+static void dax_associate_entry(void *entry, struct address_space *mapping)
+{
+       unsigned long pfn;
+
+       if (IS_ENABLED(CONFIG_FS_DAX_LIMITED))
+               return;
+
+       for_each_mapped_pfn(entry, pfn) {
+               struct page *page = pfn_to_page(pfn);
+
+               WARN_ON_ONCE(page->mapping);
+               page->mapping = mapping;
+       }
+}
+
+static void dax_disassociate_entry(void *entry, struct address_space *mapping,
+               bool trunc)
+{
+       unsigned long pfn;
+
+       if (IS_ENABLED(CONFIG_FS_DAX_LIMITED))
+               return;
+
+       for_each_mapped_pfn(entry, pfn) {
+               struct page *page = pfn_to_page(pfn);
+
+               WARN_ON_ONCE(trunc && page_ref_count(page) > 1);
+               WARN_ON_ONCE(page->mapping && page->mapping != mapping);
+               page->mapping = NULL;
+       }
+}
+
 /*
  * Find radix tree entry at given index. If it points to an exceptional entry,
  * return it with the radix tree entry locked. If the radix tree doesn't
                }
 
                if (pmd_downgrade) {
+                       dax_disassociate_entry(entry, mapping, false);
                        radix_tree_delete(&mapping->page_tree, index);
                        mapping->nrexceptional--;
                        dax_wake_mapping_entry_waiter(mapping, index, entry,
            (radix_tree_tag_get(page_tree, index, PAGECACHE_TAG_DIRTY) ||
             radix_tree_tag_get(page_tree, index, PAGECACHE_TAG_TOWRITE)))
                goto out;
+       dax_disassociate_entry(entry, mapping, trunc);
        radix_tree_delete(page_tree, index);
        mapping->nrexceptional--;
        ret = 1;
 
        spin_lock_irq(&mapping->tree_lock);
        new_entry = dax_radix_locked_entry(pfn, flags);
+       if (dax_entry_size(entry) != dax_entry_size(new_entry)) {
+               dax_disassociate_entry(entry, mapping, false);
+               dax_associate_entry(new_entry, mapping);
+       }
 
        if (dax_is_zero_entry(entry) || dax_is_empty_entry(entry)) {
                /*