{
        struct file *file = vma->vm_file;
        struct address_space *mapping = file->f_mapping;
-       void *entry;
+       void *entry, **slot;
        pgoff_t index = vmf->pgoff;
 
        spin_lock_irq(&mapping->tree_lock);
-       entry = get_unlocked_mapping_entry(mapping, index, NULL);
-       if (!entry || !radix_tree_exceptional_entry(entry))
-               goto out;
+       entry = get_unlocked_mapping_entry(mapping, index, &slot);
+       if (!entry || !radix_tree_exceptional_entry(entry)) {
+               if (entry)
+                       put_unlocked_mapping_entry(mapping, index, entry);
+               spin_unlock_irq(&mapping->tree_lock);
+               return VM_FAULT_NOPAGE;
+       }
        radix_tree_tag_set(&mapping->page_tree, index, PAGECACHE_TAG_DIRTY);
-       put_unlocked_mapping_entry(mapping, index, entry);
-out:
+       entry = lock_slot(mapping, slot);
        spin_unlock_irq(&mapping->tree_lock);
+       /*
+        * If we race with somebody updating the PTE and finish_mkwrite_fault()
+        * fails, we don't care. We need to return VM_FAULT_NOPAGE and retry
+        * the fault in either case.
+        */
+       finish_mkwrite_fault(vmf);
+       put_locked_mapping_entry(mapping, index, entry);
        return VM_FAULT_NOPAGE;
 }
 EXPORT_SYMBOL_GPL(dax_pfn_mkwrite);
 
                pte_unmap_unlock(vmf->pte, vmf->ptl);
                vmf->flags |= FAULT_FLAG_MKWRITE;
                ret = vma->vm_ops->pfn_mkwrite(vma, vmf);
-               if (ret & VM_FAULT_ERROR)
+               if (ret & (VM_FAULT_ERROR | VM_FAULT_NOPAGE))
                        return ret;
                return finish_mkwrite_fault(vmf);
        }