static inline struct task_struct *io_uring_cmd_get_task(struct io_uring_cmd *cmd)
 {
-       return cmd_to_io_kiocb(cmd)->task;
+       return cmd_to_io_kiocb(cmd)->tctx->task;
 }
 
 #endif /* _LINUX_IO_URING_CMD_H */
 
        /* submission side */
        int                             cached_refs;
        const struct io_ring_ctx        *last;
+       struct task_struct              *task;
        struct io_wq                    *io_wq;
        struct file                     *registered_rings[IO_RINGFD_REG_MAX];
 
        struct io_cqe                   cqe;
 
        struct io_ring_ctx              *ctx;
-       struct task_struct              *task;
+       struct io_uring_task            *tctx;
 
        union {
                /* stores selected buf, valid IFF REQ_F_BUFFER_SELECTED is set */
 
                .opcode = cancel->opcode,
                .seq    = atomic_inc_return(&req->ctx->cancel_seq),
        };
-       struct io_uring_task *tctx = req->task->io_uring;
+       struct io_uring_task *tctx = req->tctx;
        int ret;
 
        if (cd.flags & IORING_ASYNC_CANCEL_FD) {
 
 
                hlist_for_each_entry(req, &hb->list, hash_node)
                        seq_printf(m, "  op=%d, task_works=%d\n", req->opcode,
-                                       task_work_pending(req->task));
+                                       task_work_pending(req->tctx->task));
        }
 
        if (has_lock)
 
 {
        bool matched;
 
-       if (tctx && head->task->io_uring != tctx)
+       if (tctx && head->tctx != tctx)
                return false;
        if (cancel_all)
                return true;
                kfree(req->apoll);
                req->apoll = NULL;
        }
-       if (req->flags & REQ_F_INFLIGHT) {
-               struct io_uring_task *tctx = req->task->io_uring;
-
-               atomic_dec(&tctx->inflight_tracked);
-       }
+       if (req->flags & REQ_F_INFLIGHT)
+               atomic_dec(&req->tctx->inflight_tracked);
        if (req->flags & REQ_F_CREDS)
                put_cred(req->creds);
        if (req->flags & REQ_F_ASYNC_DATA) {
 {
        if (!(req->flags & REQ_F_INFLIGHT)) {
                req->flags |= REQ_F_INFLIGHT;
-               atomic_inc(&req->task->io_uring->inflight_tracked);
+               atomic_inc(&req->tctx->inflight_tracked);
        }
 }
 
 static void io_queue_iowq(struct io_kiocb *req)
 {
        struct io_kiocb *link = io_prep_linked_timeout(req);
-       struct io_uring_task *tctx = req->task->io_uring;
+       struct io_uring_task *tctx = req->tctx;
 
        BUG_ON(!tctx);
        BUG_ON(!tctx->io_wq);
         * procedure rather than attempt to run this request (or create a new
         * worker for it).
         */
-       if (WARN_ON_ONCE(!same_thread_group(req->task, current)))
+       if (WARN_ON_ONCE(!same_thread_group(tctx->task, current)))
                atomic_or(IO_WQ_WORK_CANCEL, &req->work.flags);
 
        trace_io_uring_queue_async_work(req, io_wq_is_hashed(&req->work));
 }
 
 /* must to be called somewhat shortly after putting a request */
-static inline void io_put_task(struct task_struct *task)
+static inline void io_put_task(struct io_kiocb *req)
 {
-       struct io_uring_task *tctx = task->io_uring;
+       struct io_uring_task *tctx = req->tctx;
 
-       if (likely(task == current)) {
+       if (likely(tctx->task == current)) {
                tctx->cached_refs++;
        } else {
                percpu_counter_sub(&tctx->inflight, 1);
                if (unlikely(atomic_read(&tctx->in_cancel)))
                        wake_up(&tctx->wait);
-               put_task_struct(task);
+               put_task_struct(tctx->task);
        }
 }
 
 
 static void io_req_normal_work_add(struct io_kiocb *req)
 {
-       struct io_uring_task *tctx = req->task->io_uring;
+       struct io_uring_task *tctx = req->tctx;
        struct io_ring_ctx *ctx = req->ctx;
 
        /* task_work already pending, we're done */
                return;
        }
 
-       if (likely(!task_work_add(req->task, &tctx->task_work, ctx->notify_method)))
+       if (likely(!task_work_add(tctx->task, &tctx->task_work, ctx->notify_method)))
                return;
 
        io_fallback_tw(tctx, false);
 void io_req_task_submit(struct io_kiocb *req, struct io_tw_state *ts)
 {
        io_tw_lock(req->ctx, ts);
-       /* req->task == current here, checking PF_EXITING is safe */
-       if (unlikely(req->task->flags & PF_EXITING))
+       if (unlikely(io_should_terminate_tw()))
                io_req_defer_failed(req, -EFAULT);
        else if (req->flags & REQ_F_FORCE_ASYNC)
                io_queue_iowq(req);
                }
                io_put_file(req);
                io_req_put_rsrc_nodes(req);
-               io_put_task(req->task);
+               io_put_task(req);
 
                node = req->comp_list.next;
                io_req_add_to_cache(req, ctx);
        req->flags = (__force io_req_flags_t) sqe_flags;
        req->cqe.user_data = READ_ONCE(sqe->user_data);
        req->file = NULL;
-       req->task = current;
+       req->tctx = current->io_uring;
        req->cancel_seq_set = false;
 
        if (unlikely(opcode >= IORING_OP_LAST)) {
 
                      ctx->submitter_task == current);
 }
 
+/*
+ * Terminate the request if either of these conditions are true:
+ *
+ * 1) It's being executed by the original task, but that task is marked
+ *    with PF_EXITING as it's exiting.
+ * 2) PF_KTHREAD is set, in which case the invoker of the task_work is
+ *    our fallback task_work.
+ */
+static inline bool io_should_terminate_tw(void)
+{
+       return current->flags & (PF_KTHREAD | PF_EXITING);
+}
+
 static inline void io_req_queue_tw_complete(struct io_kiocb *req, s32 res)
 {
        io_req_set_res(req, res, 0);
 
 static int io_msg_remote_post(struct io_ring_ctx *ctx, struct io_kiocb *req,
                              int res, u32 cflags, u64 user_data)
 {
-       req->task = READ_ONCE(ctx->submitter_task);
-       if (!req->task) {
+       req->tctx = READ_ONCE(ctx->submitter_task->io_uring);
+       if (!req->tctx) {
                kmem_cache_free(req_cachep, req);
                return -EOWNERDEAD;
        }
 
 
        /* make sure all noifications can be finished in the same task_work */
        if (unlikely(notif->ctx != prev_notif->ctx ||
-                    notif->task != prev_notif->task))
+                    notif->tctx != prev_notif->tctx))
                return -EEXIST;
 
        nd->head = prev_nd->head;
        notif->opcode = IORING_OP_NOP;
        notif->flags = 0;
        notif->file = NULL;
-       notif->task = current;
+       notif->tctx = current->io_uring;
        io_get_task_refs(1);
        notif->file_node = NULL;
        notif->buf_node = NULL;
 
 {
        int v;
 
-       /* req->task == current here, checking PF_EXITING is safe */
-       if (unlikely(req->task->flags & PF_EXITING))
+       if (unlikely(io_should_terminate_tw()))
                return -ECANCELED;
 
        do {
 
         * Play it safe and assume not safe to re-import and reissue if we're
         * not in the original thread group (or in task context).
         */
-       if (!same_thread_group(req->task, current) || !in_task())
+       if (!same_thread_group(req->tctx->task, current) || !in_task())
                return false;
        return true;
 }
 
                return ret;
        }
 
+       tctx->task = task;
        xa_init(&tctx->xa);
        init_waitqueue_head(&tctx->wait);
        atomic_set(&tctx->in_cancel, 0);
 
 {
        struct io_timeout *timeout = io_kiocb_to_cmd(req, struct io_timeout);
        struct io_kiocb *prev = timeout->prev;
-       int ret = -ENOENT;
+       int ret;
 
        if (prev) {
-               if (!(req->task->flags & PF_EXITING)) {
+               if (!io_should_terminate_tw()) {
                        struct io_cancel_data cd = {
                                .ctx            = req->ctx,
                                .data           = prev->cqe.user_data,
                        };
 
-                       ret = io_try_cancel(req->task->io_uring, &cd, 0);
+                       ret = io_try_cancel(req->tctx, &cd, 0);
+               } else {
+                       ret = -ECANCELED;
                }
                io_req_set_res(req, ret ?: -ETIME, 0);
                io_req_task_complete(req, ts);
 {
        struct io_kiocb *req;
 
-       if (tctx && head->task->io_uring != tctx)
+       if (tctx && head->tctx != tctx)
                return false;
        if (cancel_all)
                return true;
 
                                struct io_uring_cmd);
                struct file *file = req->file;
 
-               if (!cancel_all && req->task->io_uring != tctx)
+               if (!cancel_all && req->tctx != tctx)
                        continue;
 
                if (cmd->flags & IORING_URING_CMD_CANCELABLE) {
 
        hlist_add_head(&req->hash_node, &ctx->waitid_list);
 
        init_waitqueue_func_entry(&iwa->wo.child_wait, io_waitid_wait);
-       iwa->wo.child_wait.private = req->task;
+       iwa->wo.child_wait.private = req->tctx->task;
        iw->head = ¤t->signal->wait_chldexit;
        add_wait_queue(iw->head, &iwa->wo.child_wait);