return __io_req_find_next(req);
 }
 
+static void ctx_flush_and_put(struct io_ring_ctx *ctx)
+{
+       if (!ctx)
+               return;
+       if (ctx->submit_state.comp.nr) {
+               mutex_lock(&ctx->uring_lock);
+               io_submit_flush_completions(&ctx->submit_state.comp, ctx);
+               mutex_unlock(&ctx->uring_lock);
+       }
+       percpu_ref_put(&ctx->refs);
+}
+
 static bool __tctx_task_work(struct io_uring_task *tctx)
 {
        struct io_ring_ctx *ctx = NULL;
        node = list.first;
        while (node) {
                struct io_wq_work_node *next = node->next;
-               struct io_ring_ctx *this_ctx;
                struct io_kiocb *req;
 
                req = container_of(node, struct io_kiocb, io_task_work.node);
-               this_ctx = req->ctx;
-               req->task_work.func(&req->task_work);
-               node = next;
-
-               if (!ctx) {
-                       ctx = this_ctx;
-               } else if (ctx != this_ctx) {
-                       mutex_lock(&ctx->uring_lock);
-                       io_submit_flush_completions(&ctx->submit_state.comp, ctx);
-                       mutex_unlock(&ctx->uring_lock);
-                       ctx = this_ctx;
+               if (req->ctx != ctx) {
+                       ctx_flush_and_put(ctx);
+                       ctx = req->ctx;
+                       percpu_ref_get(&ctx->refs);
                }
-       }
 
-       if (ctx && ctx->submit_state.comp.nr) {
-               mutex_lock(&ctx->uring_lock);
-               io_submit_flush_completions(&ctx->submit_state.comp, ctx);
-               mutex_unlock(&ctx->uring_lock);
+               req->task_work.func(&req->task_work);
+               node = next;
        }
 
+       ctx_flush_and_put(ctx);
        return list.first != NULL;
 }