up_read(&sqd->rw_lock);
                        cond_resched();
                        down_read(&sqd->rw_lock);
+                       io_run_task_work();
                        timeout = jiffies + sqd->sq_thread_idle;
                        continue;
                }
                finish_wait(&sqd->wait, &wait);
                timeout = jiffies + sqd->sq_thread_idle;
        }
-
-       list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
-               io_uring_cancel_sqpoll(ctx);
        up_read(&sqd->rw_lock);
-
+       down_write(&sqd->rw_lock);
+       /*
+        * someone may have parked and added a cancellation task_work, run
+        * it first because we don't want it in io_uring_cancel_sqpoll()
+        */
        io_run_task_work();
 
-       down_write(&sqd->rw_lock);
+       list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
+               io_uring_cancel_sqpoll(ctx);
        sqd->thread = NULL;
        list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
                io_ring_set_wakeup_flag(ctx);
        up_write(&sqd->rw_lock);
+
+       io_run_task_work();
        complete(&sqd->exited);
        do_exit(0);
 }
 static void io_sq_thread_unpark(struct io_sq_data *sqd)
        __releases(&sqd->rw_lock)
 {
-       if (sqd->thread == current)
-               return;
+       WARN_ON_ONCE(sqd->thread == current);
+
        clear_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
        up_write(&sqd->rw_lock);
 }
 static void io_sq_thread_park(struct io_sq_data *sqd)
        __acquires(&sqd->rw_lock)
 {
-       if (sqd->thread == current)
-               return;
+       WARN_ON_ONCE(sqd->thread == current);
+
        set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
        down_write(&sqd->rw_lock);
        /* set again for consistency, in case concurrent parks are happening */
 
 static void io_sq_thread_stop(struct io_sq_data *sqd)
 {
-       if (test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state))
-               return;
+       WARN_ON_ONCE(sqd->thread == current);
+
        down_write(&sqd->rw_lock);
        set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
        if (sqd->thread)
 
        if (sqd) {
                io_sq_thread_park(sqd);
-               list_del(&ctx->sqd_list);
+               list_del_init(&ctx->sqd_list);
                io_sqd_update_thread_idle(sqd);
                io_sq_thread_unpark(sqd);
 
        init_waitqueue_head(&tctx->wait);
        tctx->last = NULL;
        atomic_set(&tctx->in_idle, 0);
-       tctx->sqpoll = false;
        task->io_uring = tctx;
        spin_lock_init(&tctx->task_lock);
        INIT_WQ_LIST(&tctx->task_list);
 
                io_uring_try_cancel_requests(ctx, task, files);
 
-               if (ctx->sq_data)
-                       io_sq_thread_unpark(ctx->sq_data);
                prepare_to_wait(&task->io_uring->wait, &wait,
                                TASK_UNINTERRUPTIBLE);
                if (inflight == io_uring_count_inflight(ctx, task, files))
                        schedule();
                finish_wait(&task->io_uring->wait, &wait);
-               if (ctx->sq_data)
-                       io_sq_thread_park(ctx->sq_data);
-       }
-}
-
-/*
- * We need to iteratively cancel requests, in case a request has dependent
- * hard links. These persist even for failure of cancelations, hence keep
- * looping until none are found.
- */
-static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
-                                         struct files_struct *files)
-{
-       struct task_struct *task = current;
-
-       if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data) {
-               io_sq_thread_park(ctx->sq_data);
-               task = ctx->sq_data->thread;
-               if (task)
-                       atomic_inc(&task->io_uring->in_idle);
        }
-
-       io_uring_cancel_files(ctx, task, files);
-       if (!files)
-               io_uring_try_cancel_requests(ctx, task, NULL);
-
-       if (task)
-               atomic_dec(&task->io_uring->in_idle);
-       if (ctx->sq_data)
-               io_sq_thread_unpark(ctx->sq_data);
 }
 
 /*
                }
                tctx->last = ctx;
        }
-
-       /*
-        * This is race safe in that the task itself is doing this, hence it
-        * cannot be going through the exit/cancel paths at the same time.
-        * This cannot be modified while exit/cancel is running.
-        */
-       if (!tctx->sqpoll && (ctx->flags & IORING_SETUP_SQPOLL))
-               tctx->sqpoll = true;
-
        return 0;
 }
 
        }
 }
 
+static s64 tctx_inflight(struct io_uring_task *tctx)
+{
+       return percpu_counter_sum(&tctx->inflight);
+}
+
+static void io_sqpoll_cancel_cb(struct callback_head *cb)
+{
+       struct io_tctx_exit *work = container_of(cb, struct io_tctx_exit, task_work);
+       struct io_ring_ctx *ctx = work->ctx;
+       struct io_sq_data *sqd = ctx->sq_data;
+
+       if (sqd->thread)
+               io_uring_cancel_sqpoll(ctx);
+       complete(&work->completion);
+}
+
+static void io_sqpoll_cancel_sync(struct io_ring_ctx *ctx)
+{
+       struct io_sq_data *sqd = ctx->sq_data;
+       struct io_tctx_exit work = { .ctx = ctx, };
+       struct task_struct *task;
+
+       io_sq_thread_park(sqd);
+       list_del_init(&ctx->sqd_list);
+       io_sqd_update_thread_idle(sqd);
+       task = sqd->thread;
+       if (task) {
+               init_completion(&work.completion);
+               init_task_work(&work.task_work, io_sqpoll_cancel_cb);
+               WARN_ON_ONCE(task_work_add(task, &work.task_work, TWA_SIGNAL));
+               wake_up_process(task);
+       }
+       io_sq_thread_unpark(sqd);
+
+       if (task)
+               wait_for_completion(&work.completion);
+}
+
 void __io_uring_files_cancel(struct files_struct *files)
 {
        struct io_uring_task *tctx = current->io_uring;
 
        /* make sure overflow events are dropped */
        atomic_inc(&tctx->in_idle);
-       xa_for_each(&tctx->xa, index, node)
-               io_uring_cancel_task_requests(node->ctx, files);
+       xa_for_each(&tctx->xa, index, node) {
+               struct io_ring_ctx *ctx = node->ctx;
+
+               if (ctx->sq_data) {
+                       io_sqpoll_cancel_sync(ctx);
+                       continue;
+               }
+               io_uring_cancel_files(ctx, current, files);
+               if (!files)
+                       io_uring_try_cancel_requests(ctx, current, NULL);
+       }
        atomic_dec(&tctx->in_idle);
 
        if (files)
                io_uring_clean_tctx(tctx);
 }
 
-static s64 tctx_inflight(struct io_uring_task *tctx)
-{
-       return percpu_counter_sum(&tctx->inflight);
-}
-
+/* should only be called by SQPOLL task */
 static void io_uring_cancel_sqpoll(struct io_ring_ctx *ctx)
 {
        struct io_sq_data *sqd = ctx->sq_data;
-       struct io_uring_task *tctx;
+       struct io_uring_task *tctx = current->io_uring;
        s64 inflight;
        DEFINE_WAIT(wait);
 
-       if (!sqd)
-               return;
-       io_sq_thread_park(sqd);
-       if (!sqd->thread || !sqd->thread->io_uring) {
-               io_sq_thread_unpark(sqd);
-               return;
-       }
-       tctx = ctx->sq_data->thread->io_uring;
+       WARN_ON_ONCE(!sqd || ctx->sq_data->thread != current);
+
        atomic_inc(&tctx->in_idle);
        do {
                /* read completions before cancelations */
                inflight = tctx_inflight(tctx);
                if (!inflight)
                        break;
-               io_uring_cancel_task_requests(ctx, NULL);
+               io_uring_try_cancel_requests(ctx, current, NULL);
 
                prepare_to_wait(&tctx->wait, &wait, TASK_UNINTERRUPTIBLE);
                /*
                finish_wait(&tctx->wait, &wait);
        } while (1);
        atomic_dec(&tctx->in_idle);
-       io_sq_thread_unpark(sqd);
 }
 
 /*
 
        /* make sure overflow events are dropped */
        atomic_inc(&tctx->in_idle);
-
-       if (tctx->sqpoll) {
-               struct io_tctx_node *node;
-               unsigned long index;
-
-               xa_for_each(&tctx->xa, index, node)
-                       io_uring_cancel_sqpoll(node->ctx);
-       }
-
        do {
                /* read completions before cancelations */
                inflight = tctx_inflight(tctx);