static void io_req_free_batch_finish(struct io_ring_ctx *ctx,
                                     struct req_batch *rb)
 {
-       if (rb->task)
-               io_put_task(rb->task, rb->task_refs);
        if (rb->ctx_refs)
                percpu_ref_put_many(&ctx->refs, rb->ctx_refs);
+       if (rb->task == current)
+               current->io_uring->cached_refs += rb->task_refs;
+       else if (rb->task)
+               io_put_task(rb->task, rb->task_refs);
 }
 
 static void io_req_free_batch(struct req_batch *rb, struct io_kiocb *req,
        struct io_uring_task *tctx = task->io_uring;
        unsigned int refs = tctx->cached_refs;
 
-       tctx->cached_refs = 0;
-       percpu_counter_sub(&tctx->inflight, refs);
-       put_task_struct_many(task, refs);
+       if (refs) {
+               tctx->cached_refs = 0;
+               percpu_counter_sub(&tctx->inflight, refs);
+               put_task_struct_many(task, refs);
+       }
 }
 
 /*
        if (tctx->io_wq)
                io_wq_exit_start(tctx->io_wq);
 
-       io_uring_drop_tctx_refs(current);
        atomic_inc(&tctx->in_idle);
        do {
+               io_uring_drop_tctx_refs(current);
                /* read completions before cancelations */
                inflight = tctx_inflight(tctx, !cancel_all);
                if (!inflight)
                }
 
                prepare_to_wait(&tctx->wait, &wait, TASK_UNINTERRUPTIBLE);
+               io_uring_drop_tctx_refs(current);
                /*
                 * If we've seen completions, retry without waiting. This
                 * avoids a race where a completion comes in before we did