]> www.infradead.org Git - users/hch/misc.git/commitdiff
io_uring: move struct io_kiocb from task_struct to io_uring_task
authorJens Axboe <axboe@kernel.dk>
Sun, 3 Nov 2024 17:23:38 +0000 (10:23 -0700)
committerJens Axboe <axboe@kernel.dk>
Wed, 6 Nov 2024 20:55:38 +0000 (13:55 -0700)
Rather than store the task_struct itself in struct io_kiocb, store
the io_uring specific task_struct. The life times are the same in terms
of io_uring, and this avoids doing some dereferences through the
task_struct. For the hot path of putting local task references, we can
deref req->tctx instead, which we'll need anyway in that function
regardless of whether it's local or remote references.

This is mostly straight forward, except the original task PF_EXITING
check needs a bit of tweaking. task_work is _always_ run from the
originating task, except in the fallback case, where it's run from a
kernel thread. Replace the potentially racy (in case of fallback work)
checks for req->task->flags with current->flags. It's either the still
the original task, in which case PF_EXITING will be sane, or it has
PF_KTHREAD set, in which case it's fallback work. Both cases should
prevent moving forward with the given request.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
14 files changed:
include/linux/io_uring/cmd.h
include/linux/io_uring_types.h
io_uring/cancel.c
io_uring/fdinfo.c
io_uring/io_uring.c
io_uring/io_uring.h
io_uring/msg_ring.c
io_uring/notif.c
io_uring/poll.c
io_uring/rw.c
io_uring/tctx.c
io_uring/timeout.c
io_uring/uring_cmd.c
io_uring/waitid.c

index c189d36ad55ea6c28056aacd199d40deb1900865..578a3fdf5c719cf45fd4b6f9c894204d6b4f946c 100644 (file)
@@ -110,7 +110,7 @@ static inline void io_uring_cmd_complete_in_task(struct io_uring_cmd *ioucmd,
 
 static inline struct task_struct *io_uring_cmd_get_task(struct io_uring_cmd *cmd)
 {
-       return cmd_to_io_kiocb(cmd)->task;
+       return cmd_to_io_kiocb(cmd)->tctx->task;
 }
 
 #endif /* _LINUX_IO_URING_CMD_H */
index 01e7fb9fcfe2a22db6c93959270bf07d3574c963..fba2988accc3a9715e763e62be783a9206addc4c 100644 (file)
@@ -84,6 +84,7 @@ struct io_uring_task {
        /* submission side */
        int                             cached_refs;
        const struct io_ring_ctx        *last;
+       struct task_struct              *task;
        struct io_wq                    *io_wq;
        struct file                     *registered_rings[IO_RINGFD_REG_MAX];
 
@@ -625,7 +626,7 @@ struct io_kiocb {
        struct io_cqe                   cqe;
 
        struct io_ring_ctx              *ctx;
-       struct task_struct              *task;
+       struct io_uring_task            *tctx;
 
        union {
                /* stores selected buf, valid IFF REQ_F_BUFFER_SELECTED is set */
index bbca5cb69cb50b33c736356a2c89440d28cb324f..48419356783944a2efb8cf6b8eb8f4b22a6e7414 100644 (file)
@@ -205,7 +205,7 @@ int io_async_cancel(struct io_kiocb *req, unsigned int issue_flags)
                .opcode = cancel->opcode,
                .seq    = atomic_inc_return(&req->ctx->cancel_seq),
        };
-       struct io_uring_task *tctx = req->task->io_uring;
+       struct io_uring_task *tctx = req->tctx;
        int ret;
 
        if (cd.flags & IORING_ASYNC_CANCEL_FD) {
index 8da0d9e4533a190d7919cfdae22b5dfd2a27f65c..efbec34ccb18d7543c8d87e49d233c85bb186ef2 100644 (file)
@@ -203,7 +203,7 @@ __cold void io_uring_show_fdinfo(struct seq_file *m, struct file *file)
 
                hlist_for_each_entry(req, &hb->list, hash_node)
                        seq_printf(m, "  op=%d, task_works=%d\n", req->opcode,
-                                       task_work_pending(req->task));
+                                       task_work_pending(req->tctx->task));
        }
 
        if (has_lock)
index 43afd9da7d0742c3e0bf2d4c30aa871c5986e036..7c1ca36b117b38ed00ceb9943761c708cd48ddc5 100644 (file)
@@ -206,7 +206,7 @@ bool io_match_task_safe(struct io_kiocb *head, struct io_uring_task *tctx,
 {
        bool matched;
 
-       if (tctx && head->task->io_uring != tctx)
+       if (tctx && head->tctx != tctx)
                return false;
        if (cancel_all)
                return true;
@@ -407,11 +407,8 @@ static void io_clean_op(struct io_kiocb *req)
                kfree(req->apoll);
                req->apoll = NULL;
        }
-       if (req->flags & REQ_F_INFLIGHT) {
-               struct io_uring_task *tctx = req->task->io_uring;
-
-               atomic_dec(&tctx->inflight_tracked);
-       }
+       if (req->flags & REQ_F_INFLIGHT)
+               atomic_dec(&req->tctx->inflight_tracked);
        if (req->flags & REQ_F_CREDS)
                put_cred(req->creds);
        if (req->flags & REQ_F_ASYNC_DATA) {
@@ -425,7 +422,7 @@ static inline void io_req_track_inflight(struct io_kiocb *req)
 {
        if (!(req->flags & REQ_F_INFLIGHT)) {
                req->flags |= REQ_F_INFLIGHT;
-               atomic_inc(&req->task->io_uring->inflight_tracked);
+               atomic_inc(&req->tctx->inflight_tracked);
        }
 }
 
@@ -514,7 +511,7 @@ static void io_prep_async_link(struct io_kiocb *req)
 static void io_queue_iowq(struct io_kiocb *req)
 {
        struct io_kiocb *link = io_prep_linked_timeout(req);
-       struct io_uring_task *tctx = req->task->io_uring;
+       struct io_uring_task *tctx = req->tctx;
 
        BUG_ON(!tctx);
        BUG_ON(!tctx->io_wq);
@@ -529,7 +526,7 @@ static void io_queue_iowq(struct io_kiocb *req)
         * procedure rather than attempt to run this request (or create a new
         * worker for it).
         */
-       if (WARN_ON_ONCE(!same_thread_group(req->task, current)))
+       if (WARN_ON_ONCE(!same_thread_group(tctx->task, current)))
                atomic_or(IO_WQ_WORK_CANCEL, &req->work.flags);
 
        trace_io_uring_queue_async_work(req, io_wq_is_hashed(&req->work));
@@ -678,17 +675,17 @@ static void io_cqring_do_overflow_flush(struct io_ring_ctx *ctx)
 }
 
 /* must to be called somewhat shortly after putting a request */
-static inline void io_put_task(struct task_struct *task)
+static inline void io_put_task(struct io_kiocb *req)
 {
-       struct io_uring_task *tctx = task->io_uring;
+       struct io_uring_task *tctx = req->tctx;
 
-       if (likely(task == current)) {
+       if (likely(tctx->task == current)) {
                tctx->cached_refs++;
        } else {
                percpu_counter_sub(&tctx->inflight, 1);
                if (unlikely(atomic_read(&tctx->in_cancel)))
                        wake_up(&tctx->wait);
-               put_task_struct(task);
+               put_task_struct(tctx->task);
        }
 }
 
@@ -1207,7 +1204,7 @@ static inline void io_req_local_work_add(struct io_kiocb *req,
 
 static void io_req_normal_work_add(struct io_kiocb *req)
 {
-       struct io_uring_task *tctx = req->task->io_uring;
+       struct io_uring_task *tctx = req->tctx;
        struct io_ring_ctx *ctx = req->ctx;
 
        /* task_work already pending, we're done */
@@ -1226,7 +1223,7 @@ static void io_req_normal_work_add(struct io_kiocb *req)
                return;
        }
 
-       if (likely(!task_work_add(req->task, &tctx->task_work, ctx->notify_method)))
+       if (likely(!task_work_add(tctx->task, &tctx->task_work, ctx->notify_method)))
                return;
 
        io_fallback_tw(tctx, false);
@@ -1343,8 +1340,7 @@ static void io_req_task_cancel(struct io_kiocb *req, struct io_tw_state *ts)
 void io_req_task_submit(struct io_kiocb *req, struct io_tw_state *ts)
 {
        io_tw_lock(req->ctx, ts);
-       /* req->task == current here, checking PF_EXITING is safe */
-       if (unlikely(req->task->flags & PF_EXITING))
+       if (unlikely(io_should_terminate_tw()))
                io_req_defer_failed(req, -EFAULT);
        else if (req->flags & REQ_F_FORCE_ASYNC)
                io_queue_iowq(req);
@@ -1403,7 +1399,7 @@ static void io_free_batch_list(struct io_ring_ctx *ctx,
                }
                io_put_file(req);
                io_req_put_rsrc_nodes(req);
-               io_put_task(req->task);
+               io_put_task(req);
 
                node = req->comp_list.next;
                io_req_add_to_cache(req, ctx);
@@ -2019,7 +2015,7 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
        req->flags = (__force io_req_flags_t) sqe_flags;
        req->cqe.user_data = READ_ONCE(sqe->user_data);
        req->file = NULL;
-       req->task = current;
+       req->tctx = current->io_uring;
        req->cancel_seq_set = false;
 
        if (unlikely(opcode >= IORING_OP_LAST)) {
index 17ffdb1e41c5bf79b90c48425e450d302088762b..702c8e987430de1728023720caa7af1f68866300 100644 (file)
@@ -426,6 +426,19 @@ static inline bool io_allowed_run_tw(struct io_ring_ctx *ctx)
                      ctx->submitter_task == current);
 }
 
+/*
+ * Terminate the request if either of these conditions are true:
+ *
+ * 1) It's being executed by the original task, but that task is marked
+ *    with PF_EXITING as it's exiting.
+ * 2) PF_KTHREAD is set, in which case the invoker of the task_work is
+ *    our fallback task_work.
+ */
+static inline bool io_should_terminate_tw(void)
+{
+       return current->flags & (PF_KTHREAD | PF_EXITING);
+}
+
 static inline void io_req_queue_tw_complete(struct io_kiocb *req, s32 res)
 {
        io_req_set_res(req, res, 0);
index 99af39e1d0fb4e46833b1e2c2dd11f85d1342ddb..e63af34004b70eb59838b03a89e4fab4221e229d 100644 (file)
@@ -89,8 +89,8 @@ static void io_msg_tw_complete(struct io_kiocb *req, struct io_tw_state *ts)
 static int io_msg_remote_post(struct io_ring_ctx *ctx, struct io_kiocb *req,
                              int res, u32 cflags, u64 user_data)
 {
-       req->task = READ_ONCE(ctx->submitter_task);
-       if (!req->task) {
+       req->tctx = READ_ONCE(ctx->submitter_task->io_uring);
+       if (!req->tctx) {
                kmem_cache_free(req_cachep, req);
                return -EOWNERDEAD;
        }
index 8dfbb0bd8e4d575f14898a5c7c45318c53dcd98f..ee3a33510b3c2ad84b60962377e75ed1e1910e86 100644 (file)
@@ -89,7 +89,7 @@ static int io_link_skb(struct sk_buff *skb, struct ubuf_info *uarg)
 
        /* make sure all noifications can be finished in the same task_work */
        if (unlikely(notif->ctx != prev_notif->ctx ||
-                    notif->task != prev_notif->task))
+                    notif->tctx != prev_notif->tctx))
                return -EEXIST;
 
        nd->head = prev_nd->head;
@@ -115,7 +115,7 @@ struct io_kiocb *io_alloc_notif(struct io_ring_ctx *ctx)
        notif->opcode = IORING_OP_NOP;
        notif->flags = 0;
        notif->file = NULL;
-       notif->task = current;
+       notif->tctx = current->io_uring;
        io_get_task_refs(1);
        notif->file_node = NULL;
        notif->buf_node = NULL;
index 7db3010b5733ed5c57b59eb2a124f9214d0fa2c1..bced9edd52335ac42a3390524d42e486e6784f70 100644 (file)
@@ -224,8 +224,7 @@ static int io_poll_check_events(struct io_kiocb *req, struct io_tw_state *ts)
 {
        int v;
 
-       /* req->task == current here, checking PF_EXITING is safe */
-       if (unlikely(req->task->flags & PF_EXITING))
+       if (unlikely(io_should_terminate_tw()))
                return -ECANCELED;
 
        do {
index 144730344c0f47a149e90454a91168b2efb86744..e368b9afde03864a4ddd1459266254a27659841d 100644 (file)
@@ -435,7 +435,7 @@ static bool io_rw_should_reissue(struct io_kiocb *req)
         * Play it safe and assume not safe to re-import and reissue if we're
         * not in the original thread group (or in task context).
         */
-       if (!same_thread_group(req->task, current) || !in_task())
+       if (!same_thread_group(req->tctx->task, current) || !in_task())
                return false;
        return true;
 }
index c043fe93a3f2327bfc4f90bb1a68f113bc485dc9..503f3ff8bc4f90e891e3e06b2cd0bed536606be8 100644 (file)
@@ -81,6 +81,7 @@ __cold int io_uring_alloc_task_context(struct task_struct *task,
                return ret;
        }
 
+       tctx->task = task;
        xa_init(&tctx->xa);
        init_waitqueue_head(&tctx->wait);
        atomic_set(&tctx->in_cancel, 0);
index 18286cb53a69f72eff176d3795ead1df47e49b67..5b12bd6a804c80d596b1fd46ed15f5b8b480b114 100644 (file)
@@ -300,16 +300,18 @@ static void io_req_task_link_timeout(struct io_kiocb *req, struct io_tw_state *t
 {
        struct io_timeout *timeout = io_kiocb_to_cmd(req, struct io_timeout);
        struct io_kiocb *prev = timeout->prev;
-       int ret = -ENOENT;
+       int ret;
 
        if (prev) {
-               if (!(req->task->flags & PF_EXITING)) {
+               if (!io_should_terminate_tw()) {
                        struct io_cancel_data cd = {
                                .ctx            = req->ctx,
                                .data           = prev->cqe.user_data,
                        };
 
-                       ret = io_try_cancel(req->task->io_uring, &cd, 0);
+                       ret = io_try_cancel(req->tctx, &cd, 0);
+               } else {
+                       ret = -ECANCELED;
                }
                io_req_set_res(req, ret ?: -ETIME, 0);
                io_req_task_complete(req, ts);
@@ -643,7 +645,7 @@ static bool io_match_task(struct io_kiocb *head, struct io_uring_task *tctx,
 {
        struct io_kiocb *req;
 
-       if (tctx && head->task->io_uring != tctx)
+       if (tctx && head->tctx != tctx)
                return false;
        if (cancel_all)
                return true;
index f88fbc9869d0da35dba7c6683dc3bb9e6fe52810..40b8b777ba120357ea1b87bc0a3a10a572ea7016 100644 (file)
@@ -61,7 +61,7 @@ bool io_uring_try_cancel_uring_cmd(struct io_ring_ctx *ctx,
                                struct io_uring_cmd);
                struct file *file = req->file;
 
-               if (!cancel_all && req->task->io_uring != tctx)
+               if (!cancel_all && req->tctx != tctx)
                        continue;
 
                if (cmd->flags & IORING_URING_CMD_CANCELABLE) {
index 9b7c23f96c4714d50d9d8a7ee00a28c1a650af7a..daef5dd644f0497ceb9f2a72d25b70003fb9b846 100644 (file)
@@ -331,7 +331,7 @@ int io_waitid(struct io_kiocb *req, unsigned int issue_flags)
        hlist_add_head(&req->hash_node, &ctx->waitid_list);
 
        init_waitqueue_func_entry(&iwa->wo.child_wait, io_waitid_wait);
-       iwa->wo.child_wait.private = req->task;
+       iwa->wo.child_wait.private = req->tctx->task;
        iw->head = &current->signal->wait_chldexit;
        add_wait_queue(iw->head, &iwa->wo.child_wait);