]> www.infradead.org Git - users/hch/misc.git/commitdiff
io_uring: move cancelations to be io_uring_task based
authorJens Axboe <axboe@kernel.dk>
Sun, 3 Nov 2024 17:22:43 +0000 (10:22 -0700)
committerJens Axboe <axboe@kernel.dk>
Wed, 6 Nov 2024 20:55:38 +0000 (13:55 -0700)
Right now the task_struct pointer is used as the key to match a task,
but in preparation for some io_kiocb changes, move it to using struct
io_uring_task instead. No functional changes intended in this patch.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
12 files changed:
io_uring/futex.c
io_uring/futex.h
io_uring/io_uring.c
io_uring/io_uring.h
io_uring/poll.c
io_uring/poll.h
io_uring/timeout.c
io_uring/timeout.h
io_uring/uring_cmd.c
io_uring/uring_cmd.h
io_uring/waitid.c
io_uring/waitid.h

index 914848f46beb21345f0913c1399a896bb5884dca..e29662f039e1a13ad22e1e581ccc54aa7a390622 100644 (file)
@@ -141,7 +141,7 @@ int io_futex_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
        return -ENOENT;
 }
 
-bool io_futex_remove_all(struct io_ring_ctx *ctx, struct task_struct *task,
+bool io_futex_remove_all(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
                         bool cancel_all)
 {
        struct hlist_node *tmp;
@@ -151,7 +151,7 @@ bool io_futex_remove_all(struct io_ring_ctx *ctx, struct task_struct *task,
        lockdep_assert_held(&ctx->uring_lock);
 
        hlist_for_each_entry_safe(req, tmp, &ctx->futex_list, hash_node) {
-               if (!io_match_task_safe(req, task, cancel_all))
+               if (!io_match_task_safe(req, tctx, cancel_all))
                        continue;
                hlist_del_init(&req->hash_node);
                __io_futex_cancel(ctx, req);
index b8bb09873d5795c8ddba677f7b81823e22996a4a..d789fcf715e3869418685727d3d67f8a98148956 100644 (file)
@@ -11,7 +11,7 @@ int io_futex_wake(struct io_kiocb *req, unsigned int issue_flags);
 #if defined(CONFIG_FUTEX)
 int io_futex_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
                    unsigned int issue_flags);
-bool io_futex_remove_all(struct io_ring_ctx *ctx, struct task_struct *task,
+bool io_futex_remove_all(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
                         bool cancel_all);
 bool io_futex_cache_init(struct io_ring_ctx *ctx);
 void io_futex_cache_free(struct io_ring_ctx *ctx);
@@ -23,7 +23,7 @@ static inline int io_futex_cancel(struct io_ring_ctx *ctx,
        return 0;
 }
 static inline bool io_futex_remove_all(struct io_ring_ctx *ctx,
-                                      struct task_struct *task, bool cancel_all)
+                                      struct io_uring_task *tctx, bool cancel_all)
 {
        return false;
 }
index 5bab8a3b04566a79dad94984345abefe4e412e49..4a2282c8546420c219ef86252232086539fdc48a 100644 (file)
@@ -142,7 +142,7 @@ struct io_defer_entry {
 #define IO_CQ_WAKE_FORCE       (IO_CQ_WAKE_INIT >> 1)
 
 static bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
-                                        struct task_struct *task,
+                                        struct io_uring_task *tctx,
                                         bool cancel_all);
 
 static void io_queue_sqe(struct io_kiocb *req);
@@ -201,12 +201,12 @@ static bool io_match_linked(struct io_kiocb *head)
  * As io_match_task() but protected against racing with linked timeouts.
  * User must not hold timeout_lock.
  */
-bool io_match_task_safe(struct io_kiocb *head, struct task_struct *task,
+bool io_match_task_safe(struct io_kiocb *head, struct io_uring_task *tctx,
                        bool cancel_all)
 {
        bool matched;
 
-       if (task && head->task != task)
+       if (tctx && head->task->io_uring != tctx)
                return false;
        if (cancel_all)
                return true;
@@ -2987,7 +2987,7 @@ static int io_uring_release(struct inode *inode, struct file *file)
 }
 
 struct io_task_cancel {
-       struct task_struct *task;
+       struct io_uring_task *tctx;
        bool all;
 };
 
@@ -2996,11 +2996,11 @@ static bool io_cancel_task_cb(struct io_wq_work *work, void *data)
        struct io_kiocb *req = container_of(work, struct io_kiocb, work);
        struct io_task_cancel *cancel = data;
 
-       return io_match_task_safe(req, cancel->task, cancel->all);
+       return io_match_task_safe(req, cancel->tctx, cancel->all);
 }
 
 static __cold bool io_cancel_defer_files(struct io_ring_ctx *ctx,
-                                        struct task_struct *task,
+                                        struct io_uring_task *tctx,
                                         bool cancel_all)
 {
        struct io_defer_entry *de;
@@ -3008,7 +3008,7 @@ static __cold bool io_cancel_defer_files(struct io_ring_ctx *ctx,
 
        spin_lock(&ctx->completion_lock);
        list_for_each_entry_reverse(de, &ctx->defer_list, list) {
-               if (io_match_task_safe(de->req, task, cancel_all)) {
+               if (io_match_task_safe(de->req, tctx, cancel_all)) {
                        list_cut_position(&list, &ctx->defer_list, &de->list);
                        break;
                }
@@ -3051,11 +3051,10 @@ static __cold bool io_uring_try_cancel_iowq(struct io_ring_ctx *ctx)
 }
 
 static __cold bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
-                                               struct task_struct *task,
+                                               struct io_uring_task *tctx,
                                                bool cancel_all)
 {
-       struct io_task_cancel cancel = { .task = task, .all = cancel_all, };
-       struct io_uring_task *tctx = task ? task->io_uring : NULL;
+       struct io_task_cancel cancel = { .tctx = tctx, .all = cancel_all, };
        enum io_wq_cancel cret;
        bool ret = false;
 
@@ -3069,9 +3068,9 @@ static __cold bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
        if (!ctx->rings)
                return false;
 
-       if (!task) {
+       if (!tctx) {
                ret |= io_uring_try_cancel_iowq(ctx);
-       } else if (tctx && tctx->io_wq) {
+       } else if (tctx->io_wq) {
                /*
                 * Cancels requests of all rings, not only @ctx, but
                 * it's fine as the task is in exit/exec.
@@ -3094,15 +3093,15 @@ static __cold bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
        if ((ctx->flags & IORING_SETUP_DEFER_TASKRUN) &&
            io_allowed_defer_tw_run(ctx))
                ret |= io_run_local_work(ctx, INT_MAX) > 0;
-       ret |= io_cancel_defer_files(ctx, task, cancel_all);
+       ret |= io_cancel_defer_files(ctx, tctx, cancel_all);
        mutex_lock(&ctx->uring_lock);
-       ret |= io_poll_remove_all(ctx, task, cancel_all);
-       ret |= io_waitid_remove_all(ctx, task, cancel_all);
-       ret |= io_futex_remove_all(ctx, task, cancel_all);
-       ret |= io_uring_try_cancel_uring_cmd(ctx, task, cancel_all);
+       ret |= io_poll_remove_all(ctx, tctx, cancel_all);
+       ret |= io_waitid_remove_all(ctx, tctx, cancel_all);
+       ret |= io_futex_remove_all(ctx, tctx, cancel_all);
+       ret |= io_uring_try_cancel_uring_cmd(ctx, tctx, cancel_all);
        mutex_unlock(&ctx->uring_lock);
-       ret |= io_kill_timeouts(ctx, task, cancel_all);
-       if (task)
+       ret |= io_kill_timeouts(ctx, tctx, cancel_all);
+       if (tctx)
                ret |= io_run_task_work() > 0;
        else
                ret |= flush_delayed_work(&ctx->fallback_work);
@@ -3155,12 +3154,13 @@ __cold void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd)
                                if (node->ctx->sq_data)
                                        continue;
                                loop |= io_uring_try_cancel_requests(node->ctx,
-                                                       current, cancel_all);
+                                                       current->io_uring,
+                                                       cancel_all);
                        }
                } else {
                        list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
                                loop |= io_uring_try_cancel_requests(ctx,
-                                                                    current,
+                                                                    current->io_uring,
                                                                     cancel_all);
                }
 
index e3e6cb14de5dc7301b3cbe5e393780201f28c64f..17ffdb1e41c5bf79b90c48425e450d302088762b 100644 (file)
@@ -115,7 +115,7 @@ void io_queue_next(struct io_kiocb *req);
 void io_task_refs_refill(struct io_uring_task *tctx);
 bool __io_alloc_req_refill(struct io_ring_ctx *ctx);
 
-bool io_match_task_safe(struct io_kiocb *head, struct task_struct *task,
+bool io_match_task_safe(struct io_kiocb *head, struct io_uring_task *tctx,
                        bool cancel_all);
 
 void io_activate_pollwq(struct io_ring_ctx *ctx);
index 2d6698fb740008e443a303c55f2584a345b59eaf..7db3010b5733ed5c57b59eb2a124f9214d0fa2c1 100644 (file)
@@ -714,7 +714,7 @@ int io_arm_poll_handler(struct io_kiocb *req, unsigned issue_flags)
 /*
  * Returns true if we found and killed one or more poll requests
  */
-__cold bool io_poll_remove_all(struct io_ring_ctx *ctx, struct task_struct *tsk,
+__cold bool io_poll_remove_all(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
                               bool cancel_all)
 {
        unsigned nr_buckets = 1U << ctx->cancel_table.hash_bits;
@@ -729,7 +729,7 @@ __cold bool io_poll_remove_all(struct io_ring_ctx *ctx, struct task_struct *tsk,
                struct io_hash_bucket *hb = &ctx->cancel_table.hbs[i];
 
                hlist_for_each_entry_safe(req, tmp, &hb->list, hash_node) {
-                       if (io_match_task_safe(req, tsk, cancel_all)) {
+                       if (io_match_task_safe(req, tctx, cancel_all)) {
                                hlist_del_init(&req->hash_node);
                                io_poll_cancel_req(req);
                                found = true;
index b0e3745f5a29b9ab0f8e1901af9abe3906c4e956..04ede93113dc72b888a7baec1966bf64dc6a8148 100644 (file)
@@ -40,7 +40,7 @@ struct io_cancel_data;
 int io_poll_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
                   unsigned issue_flags);
 int io_arm_poll_handler(struct io_kiocb *req, unsigned issue_flags);
-bool io_poll_remove_all(struct io_ring_ctx *ctx, struct task_struct *tsk,
+bool io_poll_remove_all(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
                        bool cancel_all);
 
 void io_poll_task_func(struct io_kiocb *req, struct io_tw_state *ts);
index 9973876d91b0ef32010691e60b249988a76bdbe9..18286cb53a69f72eff176d3795ead1df47e49b67 100644 (file)
@@ -637,13 +637,13 @@ void io_queue_linked_timeout(struct io_kiocb *req)
        io_put_req(req);
 }
 
-static bool io_match_task(struct io_kiocb *head, struct task_struct *task,
+static bool io_match_task(struct io_kiocb *head, struct io_uring_task *tctx,
                          bool cancel_all)
        __must_hold(&head->ctx->timeout_lock)
 {
        struct io_kiocb *req;
 
-       if (task && head->task != task)
+       if (tctx && head->task->io_uring != tctx)
                return false;
        if (cancel_all)
                return true;
@@ -656,7 +656,7 @@ static bool io_match_task(struct io_kiocb *head, struct task_struct *task,
 }
 
 /* Returns true if we found and killed one or more timeouts */
-__cold bool io_kill_timeouts(struct io_ring_ctx *ctx, struct task_struct *tsk,
+__cold bool io_kill_timeouts(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
                             bool cancel_all)
 {
        struct io_timeout *timeout, *tmp;
@@ -671,7 +671,7 @@ __cold bool io_kill_timeouts(struct io_ring_ctx *ctx, struct task_struct *tsk,
        list_for_each_entry_safe(timeout, tmp, &ctx->timeout_list, list) {
                struct io_kiocb *req = cmd_to_io_kiocb(timeout);
 
-               if (io_match_task(req, tsk, cancel_all) &&
+               if (io_match_task(req, tctx, cancel_all) &&
                    io_kill_timeout(req, -ECANCELED))
                        canceled++;
        }
index a6939f18313e86e66574232c3b2d041a161add60..e91b32448dcf9077d17baaaa72dd7528cb4378fd 100644 (file)
@@ -24,7 +24,7 @@ static inline struct io_kiocb *io_disarm_linked_timeout(struct io_kiocb *req)
 __cold void io_flush_timeouts(struct io_ring_ctx *ctx);
 struct io_cancel_data;
 int io_timeout_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd);
-__cold bool io_kill_timeouts(struct io_ring_ctx *ctx, struct task_struct *tsk,
+__cold bool io_kill_timeouts(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
                             bool cancel_all);
 void io_queue_linked_timeout(struct io_kiocb *req);
 void io_disarm_next(struct io_kiocb *req);
index 88a73d21fc0bf09f752a8ef742df42acac9fb395..f88fbc9869d0da35dba7c6683dc3bb9e6fe52810 100644 (file)
@@ -47,7 +47,7 @@ static void io_req_uring_cleanup(struct io_kiocb *req, unsigned int issue_flags)
 }
 
 bool io_uring_try_cancel_uring_cmd(struct io_ring_ctx *ctx,
-                                  struct task_struct *task, bool cancel_all)
+                                  struct io_uring_task *tctx, bool cancel_all)
 {
        struct hlist_node *tmp;
        struct io_kiocb *req;
@@ -61,7 +61,7 @@ bool io_uring_try_cancel_uring_cmd(struct io_ring_ctx *ctx,
                                struct io_uring_cmd);
                struct file *file = req->file;
 
-               if (!cancel_all && req->task != task)
+               if (!cancel_all && req->task->io_uring != tctx)
                        continue;
 
                if (cmd->flags & IORING_URING_CMD_CANCELABLE) {
index a361f98664d2de6b6b7a260261fedba884f155a7..7dba0f1efc58268d3736d00baa6ed7982b394700 100644 (file)
@@ -8,4 +8,4 @@ int io_uring_cmd(struct io_kiocb *req, unsigned int issue_flags);
 int io_uring_cmd_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe);
 
 bool io_uring_try_cancel_uring_cmd(struct io_ring_ctx *ctx,
-                                  struct task_struct *task, bool cancel_all);
+                                  struct io_uring_task *tctx, bool cancel_all);
index 6362ec20abc0cf022a12ac7a83a8dd8bb073804f..9b7c23f96c4714d50d9d8a7ee00a28c1a650af7a 100644 (file)
@@ -184,7 +184,7 @@ int io_waitid_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
        return -ENOENT;
 }
 
-bool io_waitid_remove_all(struct io_ring_ctx *ctx, struct task_struct *task,
+bool io_waitid_remove_all(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
                          bool cancel_all)
 {
        struct hlist_node *tmp;
@@ -194,7 +194,7 @@ bool io_waitid_remove_all(struct io_ring_ctx *ctx, struct task_struct *task,
        lockdep_assert_held(&ctx->uring_lock);
 
        hlist_for_each_entry_safe(req, tmp, &ctx->waitid_list, hash_node) {
-               if (!io_match_task_safe(req, task, cancel_all))
+               if (!io_match_task_safe(req, tctx, cancel_all))
                        continue;
                hlist_del_init(&req->hash_node);
                __io_waitid_cancel(ctx, req);
index 956a8adafe8c033d9c0373bef17ab321cc72f457..d5544aaf302ad6a4f07f2e9704d4ad85d499bc13 100644 (file)
@@ -11,5 +11,5 @@ int io_waitid_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe);
 int io_waitid(struct io_kiocb *req, unsigned int issue_flags);
 int io_waitid_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
                     unsigned int issue_flags);
-bool io_waitid_remove_all(struct io_ring_ctx *ctx, struct task_struct *task,
+bool io_waitid_remove_all(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
                          bool cancel_all);