static s64 tctx_inflight(struct io_uring_task *tctx)
 {
-       unsigned long index;
-       struct file *file;
-       s64 inflight;
-
-       inflight = percpu_counter_sum(&tctx->inflight);
-       if (!tctx->sqpoll)
-               return inflight;
+       return percpu_counter_sum(&tctx->inflight);
+}
 
-       /*
-        * If we have SQPOLL rings, then we need to iterate and find them, and
-        * add the pending count for those.
-        */
-       xa_for_each(&tctx->xa, index, file) {
-               struct io_ring_ctx *ctx = file->private_data;
+static void io_uring_cancel_sqpoll(struct io_ring_ctx *ctx)
+{
+       struct io_uring_task *tctx;
+       s64 inflight;
+       DEFINE_WAIT(wait);
 
-               if (ctx->flags & IORING_SETUP_SQPOLL) {
-                       struct io_uring_task *__tctx = ctx->sqo_task->io_uring;
+       if (!ctx->sq_data)
+               return;
+       tctx = ctx->sq_data->thread->io_uring;
+       io_disable_sqo_submit(ctx);
 
-                       inflight += percpu_counter_sum(&__tctx->inflight);
-               }
-       }
+       atomic_inc(&tctx->in_idle);
+       do {
+               /* read completions before cancelations */
+               inflight = tctx_inflight(tctx);
+               if (!inflight)
+                       break;
+               io_uring_cancel_task_requests(ctx, NULL);
 
-       return inflight;
+               prepare_to_wait(&tctx->wait, &wait, TASK_UNINTERRUPTIBLE);
+               /*
+                * If we've seen completions, retry without waiting. This
+                * avoids a race where a completion comes in before we did
+                * prepare_to_wait().
+                */
+               if (inflight == tctx_inflight(tctx))
+                       schedule();
+               finish_wait(&tctx->wait, &wait);
+       } while (1);
+       atomic_dec(&tctx->in_idle);
 }
 
 /*
        atomic_inc(&tctx->in_idle);
 
        /* trigger io_disable_sqo_submit() */
-       if (tctx->sqpoll)
-               __io_uring_files_cancel(NULL);
+       if (tctx->sqpoll) {
+               struct file *file;
+               unsigned long index;
+
+               xa_for_each(&tctx->xa, index, file)
+                       io_uring_cancel_sqpoll(file->private_data);
+       }
 
        do {
                /* read completions before cancelations */