return rc;
 }
 
+/**
+ * dax_iomap_cow_copy - Copy the data from source to destination before write
+ * @pos:       address to do copy from.
+ * @length:    size of copy operation.
+ * @align_size:        aligned w.r.t align_size (either PMD_SIZE or PAGE_SIZE)
+ * @srcmap:    iomap srcmap
+ * @daddr:     destination address to copy to.
+ *
+ * This can be called from two places. Either during DAX write fault (page
+ * aligned), to copy the length size data to daddr. Or, while doing normal DAX
+ * write operation, dax_iomap_actor() might call this to do the copy of either
+ * start or end unaligned address. In the latter case the rest of the copy of
+ * aligned ranges is taken care by dax_iomap_actor() itself.
+ */
+static int dax_iomap_cow_copy(loff_t pos, uint64_t length, size_t align_size,
+               const struct iomap *srcmap, void *daddr)
+{
+       loff_t head_off = pos & (align_size - 1);
+       size_t size = ALIGN(head_off + length, align_size);
+       loff_t end = pos + length;
+       loff_t pg_end = round_up(end, align_size);
+       bool copy_all = head_off == 0 && end == pg_end;
+       void *saddr = 0;
+       int ret = 0;
+
+       ret = dax_iomap_direct_access(srcmap, pos, size, &saddr, NULL);
+       if (ret)
+               return ret;
+
+       if (copy_all) {
+               ret = copy_mc_to_kernel(daddr, saddr, length);
+               return ret ? -EIO : 0;
+       }
+
+       /* Copy the head part of the range */
+       if (head_off) {
+               ret = copy_mc_to_kernel(daddr, saddr, head_off);
+               if (ret)
+                       return -EIO;
+       }
+
+       /* Copy the tail part of the range */
+       if (end < pg_end) {
+               loff_t tail_off = head_off + length;
+               loff_t tail_len = pg_end - end;
+
+               ret = copy_mc_to_kernel(daddr + tail_off, saddr + tail_off,
+                                       tail_len);
+               if (ret)
+                       return -EIO;
+       }
+       return 0;
+}
+
 /*
  * The user has performed a load from a hole in the file.  Allocating a new
  * page in the file would cause excessive storage usage for workloads with
                struct iov_iter *iter)
 {
        const struct iomap *iomap = &iomi->iomap;
+       const struct iomap *srcmap = &iomi->srcmap;
        loff_t length = iomap_length(iomi);
        loff_t pos = iomi->pos;
        struct dax_device *dax_dev = iomap->dax_dev;
        loff_t end = pos + length, done = 0;
+       bool write = iov_iter_rw(iter) == WRITE;
        ssize_t ret = 0;
        size_t xfer;
        int id;
 
-       if (iov_iter_rw(iter) == READ) {
+       if (!write) {
                end = min(end, i_size_read(iomi->inode));
                if (pos >= end)
                        return 0;
                        return iov_iter_zero(min(length, end - pos), iter);
        }
 
-       if (WARN_ON_ONCE(iomap->type != IOMAP_MAPPED))
+       /*
+        * In DAX mode, enforce either pure overwrites of written extents, or
+        * writes to unwritten extents as part of a copy-on-write operation.
+        */
+       if (WARN_ON_ONCE(iomap->type != IOMAP_MAPPED &&
+                       !(iomap->flags & IOMAP_F_SHARED)))
                return -EIO;
 
        /*
                        break;
                }
 
+               if (write &&
+                   srcmap->type != IOMAP_HOLE && srcmap->addr != iomap->addr) {
+                       ret = dax_iomap_cow_copy(pos, length, PAGE_SIZE, srcmap,
+                                                kaddr);
+                       if (ret)
+                               break;
+               }
+
                map_len = PFN_PHYS(map_len);
                kaddr += offset;
                map_len -= offset;
                if (recovery)
                        xfer = dax_recovery_write(dax_dev, pgoff, kaddr,
                                        map_len, iter);
-               else if (iov_iter_rw(iter) == WRITE)
+               else if (write)
                        xfer = dax_copy_from_iter(dax_dev, pgoff, kaddr,
                                        map_len, iter);
                else
 {
        struct address_space *mapping = vmf->vma->vm_file->f_mapping;
        const struct iomap *iomap = &iter->iomap;
+       const struct iomap *srcmap = &iter->srcmap;
        size_t size = pmd ? PMD_SIZE : PAGE_SIZE;
        loff_t pos = (loff_t)xas->xa_index << PAGE_SHIFT;
        bool write = vmf->flags & FAULT_FLAG_WRITE;
        unsigned long entry_flags = pmd ? DAX_PMD : 0;
        int err = 0;
        pfn_t pfn;
+       void *kaddr;
 
        if (!pmd && vmf->cow_page)
                return dax_fault_cow_page(vmf, iter);
                return dax_pmd_load_hole(xas, vmf, iomap, entry);
        }
 
-       if (iomap->type != IOMAP_MAPPED) {
+       if (iomap->type != IOMAP_MAPPED && !(iomap->flags & IOMAP_F_SHARED)) {
                WARN_ON_ONCE(1);
                return pmd ? VM_FAULT_FALLBACK : VM_FAULT_SIGBUS;
        }
 
-       err = dax_iomap_direct_access(&iter->iomap, pos, size, NULL, &pfn);
+       err = dax_iomap_direct_access(iomap, pos, size, &kaddr, &pfn);
        if (err)
                return pmd ? VM_FAULT_FALLBACK : dax_fault_return(err);
 
        *entry = dax_insert_entry(xas, mapping, vmf, *entry, pfn, entry_flags,
                                  write && !sync);
 
+       if (write &&
+           srcmap->type != IOMAP_HOLE && srcmap->addr != iomap->addr) {
+               err = dax_iomap_cow_copy(pos, size, size, srcmap, kaddr);
+               if (err)
+                       return dax_fault_return(err);
+       }
+
        if (sync)
                return dax_fault_synchronous_pfnp(pfnp, pfn);