static void io_req_task_submit(struct callback_head *cb)
 {
        struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
+       struct io_ring_ctx *ctx = req->ctx;
 
        __io_req_task_submit(req);
+       percpu_ref_put(&ctx->refs);
 }
 
 static void io_req_task_queue(struct io_kiocb *req)
        int ret;
 
        init_task_work(&req->task_work, io_req_task_submit);
+       percpu_ref_get(&req->ctx->refs);
 
        ret = io_req_task_work_add(req, &req->task_work);
        if (unlikely(ret)) {
                refcount_inc(&req->refs);
                io_queue_async_work(req);
        }
+
+       percpu_ref_put(&ctx->refs);
 }
 #endif
 
                return false;
 
        init_task_work(&req->task_work, io_rw_resubmit);
+       percpu_ref_get(&req->ctx->refs);
+
        ret = io_req_task_work_add(req, &req->task_work);
        if (!ret)
                return true;
        list_del_init(&wait->entry);
 
        init_task_work(&req->task_work, io_req_task_submit);
+       percpu_ref_get(&req->ctx->refs);
+
        /* submit ref gets dropped, acquire a new one */
        refcount_inc(&req->refs);
        ret = io_req_task_work_add(req, &req->task_work);
 
        req->result = mask;
        init_task_work(&req->task_work, func);
+       percpu_ref_get(&req->ctx->refs);
+
        /*
         * If this fails, then the task is exiting. When a task exits, the
         * work gets canceled, so just cancel this request as well instead
 static void io_poll_task_func(struct callback_head *cb)
 {
        struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
+       struct io_ring_ctx *ctx = req->ctx;
        struct io_kiocb *nxt = NULL;
 
        io_poll_task_handler(req, &nxt);
        if (nxt)
                __io_req_task_submit(nxt);
+       percpu_ref_put(&ctx->refs);
 }
 
 static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode,
 
        if (io_poll_rewait(req, &apoll->poll)) {
                spin_unlock_irq(&ctx->completion_lock);
+               percpu_ref_put(&ctx->refs);
                return;
        }
 
        else
                __io_req_task_cancel(req, -ECANCELED);
 
+       percpu_ref_put(&ctx->refs);
        kfree(apoll->double_poll);
        kfree(apoll);
 }