#include <linux/splice.h>
 #include <linux/task_work.h>
 #include <linux/pagemap.h>
+#include <linux/io_uring.h>
 
 #define CREATE_TRACE_POINTS
 #include <trace/events/io_uring.h>
         */
        struct fixed_file_data  *file_data;
        unsigned                nr_user_files;
-       int                     ring_fd;
-       struct file             *ring_file;
 
        /* if used, fixed mapped user buffers */
        unsigned                nr_user_bufs;
                WRITE_ONCE(cqe->user_data, req->user_data);
                WRITE_ONCE(cqe->res, res);
                WRITE_ONCE(cqe->flags, cflags);
-       } else if (ctx->cq_overflow_flushed) {
+       } else if (ctx->cq_overflow_flushed || req->task->io_uring->in_idle) {
+               /*
+                * If we're in ring overflow flush mode, or in task cancel mode,
+                * then we cannot store the request for later flushing, we need
+                * to drop it on the floor.
+                */
                WRITE_ONCE(ctx->rings->cq_overflow,
                                atomic_inc_return(&ctx->cached_cq_overflow));
        } else {
 
 static void __io_free_req_finish(struct io_kiocb *req)
 {
+       struct io_uring_task *tctx = req->task->io_uring;
        struct io_ring_ctx *ctx = req->ctx;
 
+       atomic_long_inc(&tctx->req_complete);
+       if (tctx->in_idle)
+               wake_up(&tctx->wait);
        put_task_struct(req->task);
 
        if (likely(!io_is_fallback_req(req)))
        if (rb->to_free)
                __io_req_free_batch_flush(ctx, rb);
        if (rb->task) {
+               atomic_long_add(rb->task_refs, &rb->task->io_uring->req_complete);
                put_task_struct_many(rb->task, rb->task_refs);
                rb->task = NULL;
        }
                io_queue_next(req);
 
        if (req->task != rb->task) {
-               if (rb->task)
+               if (rb->task) {
+                       atomic_long_add(rb->task_refs, &rb->task->io_uring->req_complete);
                        put_task_struct_many(rb->task, rb->task_refs);
+               }
                rb->task = req->task;
                rb->task_refs = 0;
        }
                return -EBADF;
 
        req->close.fd = READ_ONCE(sqe->fd);
-       if ((req->file && req->file->f_op == &io_uring_fops) ||
-           req->close.fd == req->ctx->ring_fd)
+       if ((req->file && req->file->f_op == &io_uring_fops))
                return -EBADF;
 
        req->close.put_file = NULL;
                wake_up(&ctx->inflight_wait);
        spin_unlock_irqrestore(&ctx->inflight_lock, flags);
        req->flags &= ~REQ_F_INFLIGHT;
+       put_files_struct(req->work.files);
        req->work.files = NULL;
 }
 
 
 static int io_grab_files(struct io_kiocb *req)
 {
-       int ret = -EBADF;
        struct io_ring_ctx *ctx = req->ctx;
 
        io_req_init_async(req);
 
        if (req->work.files || (req->flags & REQ_F_NO_FILE_TABLE))
                return 0;
-       if (!ctx->ring_file)
-               return -EBADF;
 
-       rcu_read_lock();
+       req->work.files = get_files_struct(current);
+       req->flags |= REQ_F_INFLIGHT;
+
        spin_lock_irq(&ctx->inflight_lock);
-       /*
-        * We use the f_ops->flush() handler to ensure that we can flush
-        * out work accessing these files if the fd is closed. Check if
-        * the fd has changed since we started down this path, and disallow
-        * this operation if it has.
-        */
-       if (fcheck(ctx->ring_fd) == ctx->ring_file) {
-               list_add(&req->inflight_entry, &ctx->inflight_list);
-               req->flags |= REQ_F_INFLIGHT;
-               req->work.files = current->files;
-               ret = 0;
-       }
+       list_add(&req->inflight_entry, &ctx->inflight_list);
        spin_unlock_irq(&ctx->inflight_lock);
-       rcu_read_unlock();
-
-       return ret;
+       return 0;
 }
 
 static inline int io_prep_work_files(struct io_kiocb *req)
        refcount_set(&req->refs, 2);
        req->task = current;
        get_task_struct(req->task);
+       atomic_long_inc(&req->task->io_uring->req_issue);
        req->result = 0;
 
        if (unlikely(req->opcode >= IORING_OP_LAST))
        return io_req_set_file(state, req, READ_ONCE(sqe->fd));
 }
 
-static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
-                         struct file *ring_file, int ring_fd)
+static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr)
 {
        struct io_submit_state state;
        struct io_kiocb *link = NULL;
 
        io_submit_state_start(&state, ctx, nr);
 
-       ctx->ring_fd = ring_fd;
-       ctx->ring_file = ring_file;
-
        for (i = 0; i < nr; i++) {
                const struct io_uring_sqe *sqe;
                struct io_kiocb *req;
 
                mutex_lock(&ctx->uring_lock);
                if (likely(!percpu_ref_is_dying(&ctx->refs)))
-                       ret = io_submit_sqes(ctx, to_submit, NULL, -1);
+                       ret = io_submit_sqes(ctx, to_submit);
                mutex_unlock(&ctx->uring_lock);
                timeout = jiffies + ctx->sq_thread_idle;
        }
        return ret;
 }
 
+static int io_uring_alloc_task_context(struct task_struct *task)
+{
+       struct io_uring_task *tctx;
+
+       tctx = kmalloc(sizeof(*tctx), GFP_KERNEL);
+       if (unlikely(!tctx))
+               return -ENOMEM;
+
+       xa_init(&tctx->xa);
+       init_waitqueue_head(&tctx->wait);
+       tctx->last = NULL;
+       tctx->in_idle = 0;
+       atomic_long_set(&tctx->req_issue, 0);
+       atomic_long_set(&tctx->req_complete, 0);
+       task->io_uring = tctx;
+       return 0;
+}
+
+void __io_uring_free(struct task_struct *tsk)
+{
+       struct io_uring_task *tctx = tsk->io_uring;
+
+       WARN_ON_ONCE(!xa_empty(&tctx->xa));
+       xa_destroy(&tctx->xa);
+       kfree(tctx);
+       tsk->io_uring = NULL;
+}
+
 static int io_sq_offload_start(struct io_ring_ctx *ctx,
                               struct io_uring_params *p)
 {
                        ctx->sqo_thread = NULL;
                        goto err;
                }
+               ret = io_uring_alloc_task_context(ctx->sqo_thread);
+               if (ret)
+                       goto err;
                wake_up_process(ctx->sqo_thread);
        } else if (p->flags & IORING_SETUP_SQ_AFF) {
                /* Can't have SQ_AFF without SQPOLL */
 {
        struct files_struct *files = data;
 
-       return work->files == files;
+       return !files || work->files == files;
 }
 
 /*
 
                spin_lock_irq(&ctx->inflight_lock);
                list_for_each_entry(req, &ctx->inflight_list, inflight_entry) {
-                       if (req->work.files != files)
+                       if (files && req->work.files != files)
                                continue;
                        /* req is being completed, ignore */
                        if (!refcount_inc_not_zero(&req->refs))
        return io_task_match(req, task);
 }
 
+static bool __io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
+                                           struct task_struct *task,
+                                           struct files_struct *files)
+{
+       bool ret;
+
+       ret = io_uring_cancel_files(ctx, files);
+       if (!files) {
+               enum io_wq_cancel cret;
+
+               cret = io_wq_cancel_cb(ctx->io_wq, io_cancel_task_cb, task, true);
+               if (cret != IO_WQ_CANCEL_NOTFOUND)
+                       ret = true;
+
+               /* SQPOLL thread does its own polling */
+               if (!(ctx->flags & IORING_SETUP_SQPOLL)) {
+                       while (!list_empty_careful(&ctx->iopoll_list)) {
+                               io_iopoll_try_reap_events(ctx);
+                               ret = true;
+                       }
+               }
+
+               ret |= io_poll_remove_all(ctx, task);
+               ret |= io_kill_timeouts(ctx, task);
+       }
+
+       return ret;
+}
+
+/*
+ * We need to iteratively cancel requests, in case a request has dependent
+ * hard links. These persist even for failure of cancelations, hence keep
+ * looping until none are found.
+ */
+static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
+                                         struct files_struct *files)
+{
+       struct task_struct *task = current;
+
+       if (ctx->flags & IORING_SETUP_SQPOLL)
+               task = ctx->sqo_thread;
+
+       io_cqring_overflow_flush(ctx, true, task, files);
+
+       while (__io_uring_cancel_task_requests(ctx, task, files)) {
+               io_run_task_work();
+               cond_resched();
+       }
+}
+
+/*
+ * Note that this task has used io_uring. We use it for cancelation purposes.
+ */
+static int io_uring_add_task_file(struct file *file)
+{
+       if (unlikely(!current->io_uring)) {
+               int ret;
+
+               ret = io_uring_alloc_task_context(current);
+               if (unlikely(ret))
+                       return ret;
+       }
+       if (current->io_uring->last != file) {
+               XA_STATE(xas, ¤t->io_uring->xa, (unsigned long) file);
+               void *old;
+
+               rcu_read_lock();
+               old = xas_load(&xas);
+               if (old != file) {
+                       get_file(file);
+                       xas_lock(&xas);
+                       xas_store(&xas, file);
+                       xas_unlock(&xas);
+               }
+               rcu_read_unlock();
+               current->io_uring->last = file;
+       }
+
+       return 0;
+}
+
+/*
+ * Remove this io_uring_file -> task mapping.
+ */
+static void io_uring_del_task_file(struct file *file)
+{
+       struct io_uring_task *tctx = current->io_uring;
+       XA_STATE(xas, &tctx->xa, (unsigned long) file);
+
+       if (tctx->last == file)
+               tctx->last = NULL;
+
+       xas_lock(&xas);
+       file = xas_store(&xas, NULL);
+       xas_unlock(&xas);
+
+       if (file)
+               fput(file);
+}
+
+static void __io_uring_attempt_task_drop(struct file *file)
+{
+       XA_STATE(xas, ¤t->io_uring->xa, (unsigned long) file);
+       struct file *old;
+
+       rcu_read_lock();
+       old = xas_load(&xas);
+       rcu_read_unlock();
+
+       if (old == file)
+               io_uring_del_task_file(file);
+}
+
+/*
+ * Drop task note for this file if we're the only ones that hold it after
+ * pending fput()
+ */
+static void io_uring_attempt_task_drop(struct file *file, bool exiting)
+{
+       if (!current->io_uring)
+               return;
+       /*
+        * fput() is pending, will be 2 if the only other ref is our potential
+        * task file note. If the task is exiting, drop regardless of count.
+        */
+       if (!exiting && atomic_long_read(&file->f_count) != 2)
+               return;
+
+       __io_uring_attempt_task_drop(file);
+}
+
+void __io_uring_files_cancel(struct files_struct *files)
+{
+       struct io_uring_task *tctx = current->io_uring;
+       XA_STATE(xas, &tctx->xa, 0);
+
+       /* make sure overflow events are dropped */
+       tctx->in_idle = true;
+
+       do {
+               struct io_ring_ctx *ctx;
+               struct file *file;
+
+               xas_lock(&xas);
+               file = xas_next_entry(&xas, ULONG_MAX);
+               xas_unlock(&xas);
+
+               if (!file)
+                       break;
+
+               ctx = file->private_data;
+
+               io_uring_cancel_task_requests(ctx, files);
+               if (files)
+                       io_uring_del_task_file(file);
+       } while (1);
+}
+
+static inline bool io_uring_task_idle(struct io_uring_task *tctx)
+{
+       return atomic_long_read(&tctx->req_issue) ==
+               atomic_long_read(&tctx->req_complete);
+}
+
+/*
+ * Find any io_uring fd that this task has registered or done IO on, and cancel
+ * requests.
+ */
+void __io_uring_task_cancel(void)
+{
+       struct io_uring_task *tctx = current->io_uring;
+       DEFINE_WAIT(wait);
+       long completions;
+
+       /* make sure overflow events are dropped */
+       tctx->in_idle = true;
+
+       while (!io_uring_task_idle(tctx)) {
+               /* read completions before cancelations */
+               completions = atomic_long_read(&tctx->req_complete);
+               __io_uring_files_cancel(NULL);
+
+               prepare_to_wait(&tctx->wait, &wait, TASK_UNINTERRUPTIBLE);
+
+               /*
+                * If we've seen completions, retry. This avoids a race where
+                * a completion comes in before we did prepare_to_wait().
+                */
+               if (completions != atomic_long_read(&tctx->req_complete))
+                       continue;
+               if (io_uring_task_idle(tctx))
+                       break;
+               schedule();
+       }
+
+       finish_wait(&tctx->wait, &wait);
+       tctx->in_idle = false;
+}
+
 static int io_uring_flush(struct file *file, void *data)
 {
        struct io_ring_ctx *ctx = file->private_data;
 
-       io_uring_cancel_files(ctx, data);
-
        /*
         * If the task is going away, cancel work it may have pending
         */
        if (fatal_signal_pending(current) || (current->flags & PF_EXITING))
-               io_wq_cancel_cb(ctx->io_wq, io_cancel_task_cb, current, true);
+               data = NULL;
 
+       io_uring_cancel_task_requests(ctx, data);
+       io_uring_attempt_task_drop(file, !data);
        return 0;
 }
 
                        wake_up(&ctx->sqo_wait);
                submitted = to_submit;
        } else if (to_submit) {
+               ret = io_uring_add_task_file(f.file);
+               if (unlikely(ret))
+                       goto out;
                mutex_lock(&ctx->uring_lock);
-               submitted = io_submit_sqes(ctx, to_submit, f.file, fd);
+               submitted = io_submit_sqes(ctx, to_submit);
                mutex_unlock(&ctx->uring_lock);
 
                if (submitted != to_submit)
        file = anon_inode_getfile("[io_uring]", &io_uring_fops, ctx,
                                        O_RDWR | O_CLOEXEC);
        if (IS_ERR(file)) {
+err_fd:
                put_unused_fd(ret);
                ret = PTR_ERR(file);
                goto err;
 #if defined(CONFIG_UNIX)
        ctx->ring_sock->file = file;
 #endif
+       if (unlikely(io_uring_add_task_file(file))) {
+               file = ERR_PTR(-ENOMEM);
+               goto err_fd;
+       }
        fd_install(ret, file);
        return ret;
 err: