#include "acl.h"
 
 #ifdef CONFIG_FS_DAX
+/*
+ * The lock ordering for ext2 DAX fault paths is:
+ *
+ * mmap_sem (MM)
+ *   sb_start_pagefault (vfs, freeze)
+ *     ext2_inode_info->dax_sem
+ *       address_space->i_mmap_rwsem or page_lock (mutually exclusive in DAX)
+ *         ext2_inode_info->truncate_mutex
+ *
+ * The default page_lock and i_size verification done by non-DAX fault paths
+ * is sufficient because ext2 doesn't support hole punching.
+ */
 static int ext2_dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf)
 {
-       return dax_fault(vma, vmf, ext2_get_block, NULL);
+       struct inode *inode = file_inode(vma->vm_file);
+       struct ext2_inode_info *ei = EXT2_I(inode);
+       int ret;
+
+       if (vmf->flags & FAULT_FLAG_WRITE) {
+               sb_start_pagefault(inode->i_sb);
+               file_update_time(vma->vm_file);
+       }
+       down_read(&ei->dax_sem);
+
+       ret = __dax_fault(vma, vmf, ext2_get_block, NULL);
+
+       up_read(&ei->dax_sem);
+       if (vmf->flags & FAULT_FLAG_WRITE)
+               sb_end_pagefault(inode->i_sb);
+       return ret;
 }
 
 static int ext2_dax_pmd_fault(struct vm_area_struct *vma, unsigned long addr,
                                                pmd_t *pmd, unsigned int flags)
 {
-       return dax_pmd_fault(vma, addr, pmd, flags, ext2_get_block, NULL);
+       struct inode *inode = file_inode(vma->vm_file);
+       struct ext2_inode_info *ei = EXT2_I(inode);
+       int ret;
+
+       if (flags & FAULT_FLAG_WRITE) {
+               sb_start_pagefault(inode->i_sb);
+               file_update_time(vma->vm_file);
+       }
+       down_read(&ei->dax_sem);
+
+       ret = __dax_pmd_fault(vma, addr, pmd, flags, ext2_get_block, NULL);
+
+       up_read(&ei->dax_sem);
+       if (flags & FAULT_FLAG_WRITE)
+               sb_end_pagefault(inode->i_sb);
+       return ret;
 }
 
 static int ext2_dax_mkwrite(struct vm_area_struct *vma, struct vm_fault *vmf)
 {
-       return dax_mkwrite(vma, vmf, ext2_get_block, NULL);
+       struct inode *inode = file_inode(vma->vm_file);
+       struct ext2_inode_info *ei = EXT2_I(inode);
+       int ret;
+
+       sb_start_pagefault(inode->i_sb);
+       file_update_time(vma->vm_file);
+       down_read(&ei->dax_sem);
+
+       ret = __dax_mkwrite(vma, vmf, ext2_get_block, NULL);
+
+       up_read(&ei->dax_sem);
+       sb_end_pagefault(inode->i_sb);
+       return ret;
+}
+
+static int ext2_dax_pfn_mkwrite(struct vm_area_struct *vma,
+               struct vm_fault *vmf)
+{
+       struct inode *inode = file_inode(vma->vm_file);
+       struct ext2_inode_info *ei = EXT2_I(inode);
+       int ret = VM_FAULT_NOPAGE;
+       loff_t size;
+
+       sb_start_pagefault(inode->i_sb);
+       file_update_time(vma->vm_file);
+       down_read(&ei->dax_sem);
+
+       /* check that the faulting page hasn't raced with truncate */
+       size = (i_size_read(inode) + PAGE_SIZE - 1) >> PAGE_SHIFT;
+       if (vmf->pgoff >= size)
+               ret = VM_FAULT_SIGBUS;
+
+       up_read(&ei->dax_sem);
+       sb_end_pagefault(inode->i_sb);
+       return ret;
 }
 
 static const struct vm_operations_struct ext2_dax_vm_ops = {
        .fault          = ext2_dax_fault,
        .pmd_fault      = ext2_dax_pmd_fault,
        .page_mkwrite   = ext2_dax_mkwrite,
-       .pfn_mkwrite    = dax_pfn_mkwrite,
+       .pfn_mkwrite    = ext2_dax_pfn_mkwrite,
 };
 
 static int ext2_file_mmap(struct file *file, struct vm_area_struct *vma)
 
                ext2_free_data(inode, p, q);
 }
 
+/* dax_sem must be held when calling this function */
 static void __ext2_truncate_blocks(struct inode *inode, loff_t offset)
 {
        __le32 *i_data = EXT2_I(inode)->i_data;
        blocksize = inode->i_sb->s_blocksize;
        iblock = (offset + blocksize-1) >> EXT2_BLOCK_SIZE_BITS(inode->i_sb);
 
+#ifdef CONFIG_FS_DAX
+       WARN_ON(!rwsem_is_locked(&ei->dax_sem));
+#endif
+
        n = ext2_block_to_path(inode, iblock, offsets, NULL);
        if (n == 0)
                return;
                return;
        if (IS_APPEND(inode) || IS_IMMUTABLE(inode))
                return;
+
+       dax_sem_down_write(EXT2_I(inode));
        __ext2_truncate_blocks(inode, offset);
+       dax_sem_up_write(EXT2_I(inode));
 }
 
 static int ext2_setsize(struct inode *inode, loff_t newsize)
        if (error)
                return error;
 
+       dax_sem_down_write(EXT2_I(inode));
        truncate_setsize(inode, newsize);
        __ext2_truncate_blocks(inode, newsize);
+       dax_sem_up_write(EXT2_I(inode));
 
        inode->i_mtime = inode->i_ctime = CURRENT_TIME_SEC;
        if (inode_needs_sync(inode)) {