#include <linux/slab.h>
 #include <linux/workqueue.h>
 #include <linux/blkdev.h>
+#include <linux/bvec.h>
 #include <linux/net.h>
 #include <net/sock.h>
 #include <net/af_unix.h>
 #include <linux/sched/mm.h>
 #include <linux/uaccess.h>
 #include <linux/nospec.h>
+#include <linux/sizes.h>
+#include <linux/hugetlb.h>
 
 #include <uapi/linux/io_uring.h>
 
        struct io_uring_cqe     cqes[];
 };
 
+struct io_mapped_ubuf {
+       u64             ubuf;
+       size_t          len;
+       struct          bio_vec *bvec;
+       unsigned int    nr_bvecs;
+};
+
 struct io_ring_ctx {
        struct {
                struct percpu_ref       refs;
                struct fasync_struct    *cq_fasync;
        } ____cacheline_aligned_in_smp;
 
+       /* if used, fixed mapped user buffers */
+       unsigned                nr_user_bufs;
+       struct io_mapped_ubuf   *user_bufs;
+
        struct user_struct      *user;
 
        struct completion       ctx_done;
        }
 }
 
+static int io_import_fixed(struct io_ring_ctx *ctx, int rw,
+                          const struct io_uring_sqe *sqe,
+                          struct iov_iter *iter)
+{
+       size_t len = READ_ONCE(sqe->len);
+       struct io_mapped_ubuf *imu;
+       unsigned index, buf_index;
+       size_t offset;
+       u64 buf_addr;
+
+       /* attempt to use fixed buffers without having provided iovecs */
+       if (unlikely(!ctx->user_bufs))
+               return -EFAULT;
+
+       buf_index = READ_ONCE(sqe->buf_index);
+       if (unlikely(buf_index >= ctx->nr_user_bufs))
+               return -EFAULT;
+
+       index = array_index_nospec(buf_index, ctx->nr_user_bufs);
+       imu = &ctx->user_bufs[index];
+       buf_addr = READ_ONCE(sqe->addr);
+
+       /* overflow */
+       if (buf_addr + len < buf_addr)
+               return -EFAULT;
+       /* not inside the mapped region */
+       if (buf_addr < imu->ubuf || buf_addr + len > imu->ubuf + imu->len)
+               return -EFAULT;
+
+       /*
+        * May not be a start of buffer, set size appropriately
+        * and advance us to the beginning.
+        */
+       offset = buf_addr - imu->ubuf;
+       iov_iter_bvec(iter, rw, imu->bvec, imu->nr_bvecs, offset + len);
+       if (offset)
+               iov_iter_advance(iter, offset);
+       return 0;
+}
+
 static int io_import_iovec(struct io_ring_ctx *ctx, int rw,
                           const struct sqe_submit *s, struct iovec **iovec,
                           struct iov_iter *iter)
        const struct io_uring_sqe *sqe = s->sqe;
        void __user *buf = u64_to_user_ptr(READ_ONCE(sqe->addr));
        size_t sqe_len = READ_ONCE(sqe->len);
+       u8 opcode;
+
+       /*
+        * We're reading ->opcode for the second time, but the first read
+        * doesn't care whether it's _FIXED or not, so it doesn't matter
+        * whether ->opcode changes concurrently. The first read does care
+        * about whether it is a READ or a WRITE, so we don't trust this read
+        * for that purpose and instead let the caller pass in the read/write
+        * flag.
+        */
+       opcode = READ_ONCE(sqe->opcode);
+       if (opcode == IORING_OP_READ_FIXED ||
+           opcode == IORING_OP_WRITE_FIXED) {
+               ssize_t ret = io_import_fixed(ctx, rw, sqe, iter);
+               *iovec = NULL;
+               return ret;
+       }
 
        if (!s->has_user)
                return -EFAULT;
 
        if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL))
                return -EINVAL;
-       if (unlikely(sqe->addr || sqe->ioprio))
+       if (unlikely(sqe->addr || sqe->ioprio || sqe->buf_index))
                return -EINVAL;
 
        fd = READ_ONCE(sqe->fd);
                ret = io_nop(req, req->user_data);
                break;
        case IORING_OP_READV:
+               if (unlikely(s->sqe->buf_index))
+                       return -EINVAL;
                ret = io_read(req, s, force_nonblock, state);
                break;
        case IORING_OP_WRITEV:
+               if (unlikely(s->sqe->buf_index))
+                       return -EINVAL;
+               ret = io_write(req, s, force_nonblock, state);
+               break;
+       case IORING_OP_READ_FIXED:
+               ret = io_read(req, s, force_nonblock, state);
+               break;
+       case IORING_OP_WRITE_FIXED:
                ret = io_write(req, s, force_nonblock, state);
                break;
        case IORING_OP_FSYNC:
        return 0;
 }
 
+static inline bool io_sqe_needs_user(const struct io_uring_sqe *sqe)
+{
+       u8 opcode = READ_ONCE(sqe->opcode);
+
+       return !(opcode == IORING_OP_READ_FIXED ||
+                opcode == IORING_OP_WRITE_FIXED);
+}
+
 static void io_sq_wq_submit_work(struct work_struct *work)
 {
        struct io_kiocb *req = container_of(work, struct io_kiocb, work);
        struct sqe_submit *s = &req->submit;
        const struct io_uring_sqe *sqe = s->sqe;
        struct io_ring_ctx *ctx = req->ctx;
-       mm_segment_t old_fs = get_fs();
+       mm_segment_t old_fs;
+       bool needs_user;
        int ret;
 
         /* Ensure we clear previously set forced non-block flag */
        req->flags &= ~REQ_F_FORCE_NONBLOCK;
        req->rw.ki_flags &= ~IOCB_NOWAIT;
 
-       if (!mmget_not_zero(ctx->sqo_mm)) {
-               ret = -EFAULT;
-               goto err;
-       }
-
-       use_mm(ctx->sqo_mm);
-       set_fs(USER_DS);
-       s->has_user = true;
        s->needs_lock = true;
+       s->has_user = false;
+
+       /*
+        * If we're doing IO to fixed buffers, we don't need to get/set
+        * user context
+        */
+       needs_user = io_sqe_needs_user(s->sqe);
+       if (needs_user) {
+               if (!mmget_not_zero(ctx->sqo_mm)) {
+                       ret = -EFAULT;
+                       goto err;
+               }
+               use_mm(ctx->sqo_mm);
+               old_fs = get_fs();
+               set_fs(USER_DS);
+               s->has_user = true;
+       }
 
        do {
                ret = __io_submit_sqe(ctx, req, s, false, NULL);
                cond_resched();
        } while (1);
 
-       set_fs(old_fs);
-       unuse_mm(ctx->sqo_mm);
-       mmput(ctx->sqo_mm);
+       if (needs_user) {
+               set_fs(old_fs);
+               unuse_mm(ctx->sqo_mm);
+               mmput(ctx->sqo_mm);
+       }
 err:
        if (ret) {
                io_cqring_add_event(ctx, sqe->user_data, ret, 0);
        return (bytes + PAGE_SIZE - 1) / PAGE_SIZE;
 }
 
+static int io_sqe_buffer_unregister(struct io_ring_ctx *ctx)
+{
+       int i, j;
+
+       if (!ctx->user_bufs)
+               return -ENXIO;
+
+       for (i = 0; i < ctx->nr_user_bufs; i++) {
+               struct io_mapped_ubuf *imu = &ctx->user_bufs[i];
+
+               for (j = 0; j < imu->nr_bvecs; j++)
+                       put_page(imu->bvec[j].bv_page);
+
+               if (ctx->account_mem)
+                       io_unaccount_mem(ctx->user, imu->nr_bvecs);
+               kfree(imu->bvec);
+               imu->nr_bvecs = 0;
+       }
+
+       kfree(ctx->user_bufs);
+       ctx->user_bufs = NULL;
+       ctx->nr_user_bufs = 0;
+       return 0;
+}
+
+static int io_copy_iov(struct io_ring_ctx *ctx, struct iovec *dst,
+                      void __user *arg, unsigned index)
+{
+       struct iovec __user *src;
+
+#ifdef CONFIG_COMPAT
+       if (ctx->compat) {
+               struct compat_iovec __user *ciovs;
+               struct compat_iovec ciov;
+
+               ciovs = (struct compat_iovec __user *) arg;
+               if (copy_from_user(&ciov, &ciovs[index], sizeof(ciov)))
+                       return -EFAULT;
+
+               dst->iov_base = (void __user *) (unsigned long) ciov.iov_base;
+               dst->iov_len = ciov.iov_len;
+               return 0;
+       }
+#endif
+       src = (struct iovec __user *) arg;
+       if (copy_from_user(dst, &src[index], sizeof(*dst)))
+               return -EFAULT;
+       return 0;
+}
+
+static int io_sqe_buffer_register(struct io_ring_ctx *ctx, void __user *arg,
+                                 unsigned nr_args)
+{
+       struct vm_area_struct **vmas = NULL;
+       struct page **pages = NULL;
+       int i, j, got_pages = 0;
+       int ret = -EINVAL;
+
+       if (ctx->user_bufs)
+               return -EBUSY;
+       if (!nr_args || nr_args > UIO_MAXIOV)
+               return -EINVAL;
+
+       ctx->user_bufs = kcalloc(nr_args, sizeof(struct io_mapped_ubuf),
+                                       GFP_KERNEL);
+       if (!ctx->user_bufs)
+               return -ENOMEM;
+
+       for (i = 0; i < nr_args; i++) {
+               struct io_mapped_ubuf *imu = &ctx->user_bufs[i];
+               unsigned long off, start, end, ubuf;
+               int pret, nr_pages;
+               struct iovec iov;
+               size_t size;
+
+               ret = io_copy_iov(ctx, &iov, arg, i);
+               if (ret)
+                       break;
+
+               /*
+                * Don't impose further limits on the size and buffer
+                * constraints here, we'll -EINVAL later when IO is
+                * submitted if they are wrong.
+                */
+               ret = -EFAULT;
+               if (!iov.iov_base || !iov.iov_len)
+                       goto err;
+
+               /* arbitrary limit, but we need something */
+               if (iov.iov_len > SZ_1G)
+                       goto err;
+
+               ubuf = (unsigned long) iov.iov_base;
+               end = (ubuf + iov.iov_len + PAGE_SIZE - 1) >> PAGE_SHIFT;
+               start = ubuf >> PAGE_SHIFT;
+               nr_pages = end - start;
+
+               if (ctx->account_mem) {
+                       ret = io_account_mem(ctx->user, nr_pages);
+                       if (ret)
+                               goto err;
+               }
+
+               ret = 0;
+               if (!pages || nr_pages > got_pages) {
+                       kfree(vmas);
+                       kfree(pages);
+                       pages = kmalloc_array(nr_pages, sizeof(struct page *),
+                                               GFP_KERNEL);
+                       vmas = kmalloc_array(nr_pages,
+                                       sizeof(struct vm_area_struct *),
+                                       GFP_KERNEL);
+                       if (!pages || !vmas) {
+                               ret = -ENOMEM;
+                               if (ctx->account_mem)
+                                       io_unaccount_mem(ctx->user, nr_pages);
+                               goto err;
+                       }
+                       got_pages = nr_pages;
+               }
+
+               imu->bvec = kmalloc_array(nr_pages, sizeof(struct bio_vec),
+                                               GFP_KERNEL);
+               ret = -ENOMEM;
+               if (!imu->bvec) {
+                       if (ctx->account_mem)
+                               io_unaccount_mem(ctx->user, nr_pages);
+                       goto err;
+               }
+
+               ret = 0;
+               down_read(¤t->mm->mmap_sem);
+               pret = get_user_pages_longterm(ubuf, nr_pages, FOLL_WRITE,
+                                               pages, vmas);
+               if (pret == nr_pages) {
+                       /* don't support file backed memory */
+                       for (j = 0; j < nr_pages; j++) {
+                               struct vm_area_struct *vma = vmas[j];
+
+                               if (vma->vm_file &&
+                                   !is_file_hugepages(vma->vm_file)) {
+                                       ret = -EOPNOTSUPP;
+                                       break;
+                               }
+                       }
+               } else {
+                       ret = pret < 0 ? pret : -EFAULT;
+               }
+               up_read(¤t->mm->mmap_sem);
+               if (ret) {
+                       /*
+                        * if we did partial map, or found file backed vmas,
+                        * release any pages we did get
+                        */
+                       if (pret > 0) {
+                               for (j = 0; j < pret; j++)
+                                       put_page(pages[j]);
+                       }
+                       if (ctx->account_mem)
+                               io_unaccount_mem(ctx->user, nr_pages);
+                       goto err;
+               }
+
+               off = ubuf & ~PAGE_MASK;
+               size = iov.iov_len;
+               for (j = 0; j < nr_pages; j++) {
+                       size_t vec_len;
+
+                       vec_len = min_t(size_t, size, PAGE_SIZE - off);
+                       imu->bvec[j].bv_page = pages[j];
+                       imu->bvec[j].bv_len = vec_len;
+                       imu->bvec[j].bv_offset = off;
+                       off = 0;
+                       size -= vec_len;
+               }
+               /* store original address for later verification */
+               imu->ubuf = ubuf;
+               imu->len = iov.iov_len;
+               imu->nr_bvecs = nr_pages;
+
+               ctx->nr_user_bufs++;
+       }
+       kfree(pages);
+       kfree(vmas);
+       return 0;
+err:
+       kfree(pages);
+       kfree(vmas);
+       io_sqe_buffer_unregister(ctx);
+       return ret;
+}
+
 static void io_ring_ctx_free(struct io_ring_ctx *ctx)
 {
        if (ctx->sqo_wq)
                mmdrop(ctx->sqo_mm);
 
        io_iopoll_reap_events(ctx);
+       io_sqe_buffer_unregister(ctx);
 
 #if defined(CONFIG_UNIX)
        if (ctx->ring_sock)
        return io_uring_setup(entries, params);
 }
 
+static int __io_uring_register(struct io_ring_ctx *ctx, unsigned opcode,
+                              void __user *arg, unsigned nr_args)
+{
+       int ret;
+
+       percpu_ref_kill(&ctx->refs);
+       wait_for_completion(&ctx->ctx_done);
+
+       switch (opcode) {
+       case IORING_REGISTER_BUFFERS:
+               ret = io_sqe_buffer_register(ctx, arg, nr_args);
+               break;
+       case IORING_UNREGISTER_BUFFERS:
+               ret = -EINVAL;
+               if (arg || nr_args)
+                       break;
+               ret = io_sqe_buffer_unregister(ctx);
+               break;
+       default:
+               ret = -EINVAL;
+               break;
+       }
+
+       /* bring the ctx back to life */
+       reinit_completion(&ctx->ctx_done);
+       percpu_ref_reinit(&ctx->refs);
+       return ret;
+}
+
+SYSCALL_DEFINE4(io_uring_register, unsigned int, fd, unsigned int, opcode,
+               void __user *, arg, unsigned int, nr_args)
+{
+       struct io_ring_ctx *ctx;
+       long ret = -EBADF;
+       struct fd f;
+
+       f = fdget(fd);
+       if (!f.file)
+               return -EBADF;
+
+       ret = -EOPNOTSUPP;
+       if (f.file->f_op != &io_uring_fops)
+               goto out_fput;
+
+       ctx = f.file->private_data;
+
+       mutex_lock(&ctx->uring_lock);
+       ret = __io_uring_register(ctx, opcode, arg, nr_args);
+       mutex_unlock(&ctx->uring_lock);
+out_fput:
+       fdput(f);
+       return ret;
+}
+
 static int __init io_uring_init(void)
 {
        req_cachep = KMEM_CACHE(io_kiocb, SLAB_HWCACHE_ALIGN | SLAB_PANIC);