return 0;
 }
 
+static ssize_t
+blkdev_direct_write(struct kiocb *iocb, struct iov_iter *from)
+{
+       size_t count = iov_iter_count(from);
+       ssize_t written;
+
+       written = kiocb_invalidate_pages(iocb, count);
+       if (written) {
+               if (written == -EBUSY)
+                       return 0;
+               return written;
+       }
+
+       written = blkdev_direct_IO(iocb, from);
+       if (written > 0) {
+               kiocb_invalidate_post_direct_write(iocb, count);
+               iocb->ki_pos += written;
+               count -= written;
+       }
+       if (written != -EIOCBQUEUED)
+               iov_iter_revert(from, count - iov_iter_count(from));
+       return written;
+}
+
 /*
  * Write data to the block device.  Only intended for the block device itself
  * and the raw driver which basically is a fake block device.
  */
 static ssize_t blkdev_write_iter(struct kiocb *iocb, struct iov_iter *from)
 {
-       struct block_device *bdev = I_BDEV(iocb->ki_filp->f_mapping->host);
+       struct file *file = iocb->ki_filp;
+       struct block_device *bdev = I_BDEV(file->f_mapping->host);
        struct inode *bd_inode = bdev->bd_inode;
        loff_t size = bdev_nr_bytes(bdev);
        size_t shorted = 0;
                iov_iter_truncate(from, size);
        }
 
-       ret = __generic_file_write_iter(iocb, from);
+       ret = file_remove_privs(file);
+       if (ret)
+               return ret;
+
+       ret = file_update_time(file);
+       if (ret)
+               return ret;
+
+       if (iocb->ki_flags & IOCB_DIRECT) {
+               ret = blkdev_direct_write(iocb, from);
+               if (ret >= 0 && iov_iter_count(from))
+                       ret = direct_write_fallback(iocb, from, ret,
+                                       generic_perform_write(iocb, from));
+       } else {
+               ret = generic_perform_write(iocb, from);
+       }
+
        if (ret > 0)
                ret = generic_write_sync(iocb, ret);
        iov_iter_reexpand(from, iov_iter_count(from) + shorted);