#include "xfs_iomap.h"
 #include "xfs_reflink.h"
 
+#include <linux/dax.h>
 #include <linux/falloc.h>
 #include <linux/backing-dev.h>
 #include <linux/mman.h>
        pos = iocb->ki_pos;
 
        trace_xfs_file_dax_write(iocb, from);
-       ret = dax_iomap_rw(iocb, from, &xfs_direct_write_iomap_ops);
+       ret = dax_iomap_rw(iocb, from, &xfs_dax_write_iomap_ops);
        if (ret > 0 && iocb->ki_pos > i_size_read(inode)) {
                i_size_write(inode, iocb->ki_pos);
                error = xfs_setfilesize(ip, pos, ret);
        return vfs_setpos(file, offset, inode->i_sb->s_maxbytes);
 }
 
+#ifdef CONFIG_FS_DAX
+static int
+xfs_dax_fault(
+       struct vm_fault         *vmf,
+       enum page_entry_size    pe_size,
+       bool                    write_fault,
+       pfn_t                   *pfn)
+{
+       return dax_iomap_fault(vmf, pe_size, pfn, NULL,
+                       (write_fault && !vmf->cow_page) ?
+                               &xfs_dax_write_iomap_ops :
+                               &xfs_read_iomap_ops);
+}
+#else
+static int
+xfs_dax_fault(
+       struct vm_fault         *vmf,
+       enum page_entry_size    pe_size,
+       bool                    write_fault,
+       pfn_t                   *pfn)
+{
+       return 0;
+}
+#endif
+
 /*
  * Locking for serialisation of IO during page faults. This results in a lock
  * ordering of:
                pfn_t pfn;
 
                xfs_ilock(XFS_I(inode), XFS_MMAPLOCK_SHARED);
-               ret = dax_iomap_fault(vmf, pe_size, &pfn, NULL,
-                               (write_fault && !vmf->cow_page) ?
-                                &xfs_direct_write_iomap_ops :
-                                &xfs_read_iomap_ops);
+               ret = xfs_dax_fault(vmf, pe_size, write_fault, &pfn);
                if (ret & VM_FAULT_NEEDDSYNC)
                        ret = dax_finish_sync_fault(vmf, pe_size, pfn);
                xfs_iunlock(XFS_I(inode), XFS_MMAPLOCK_SHARED);
 
 
                /* may drop and re-acquire the ilock */
                error = xfs_reflink_allocate_cow(ip, &imap, &cmap, &shared,
-                               &lockmode, flags & IOMAP_DIRECT);
+                               &lockmode,
+                               (flags & IOMAP_DIRECT) || IS_DAX(inode));
                if (error)
                        goto out_unlock;
                if (shared)
        .iomap_begin            = xfs_direct_write_iomap_begin,
 };
 
+static int
+xfs_dax_write_iomap_end(
+       struct inode            *inode,
+       loff_t                  pos,
+       loff_t                  length,
+       ssize_t                 written,
+       unsigned                flags,
+       struct iomap            *iomap)
+{
+       struct xfs_inode        *ip = XFS_I(inode);
+
+       if (!xfs_is_cow_inode(ip))
+               return 0;
+
+       if (!written) {
+               xfs_reflink_cancel_cow_range(ip, pos, length, true);
+               return 0;
+       }
+
+       return xfs_reflink_end_cow(ip, pos, written);
+}
+
+const struct iomap_ops xfs_dax_write_iomap_ops = {
+       .iomap_begin    = xfs_direct_write_iomap_begin,
+       .iomap_end      = xfs_dax_write_iomap_end,
+};
+
 static int
 xfs_buffered_write_iomap_begin(
        struct inode            *inode,