struct io_poll_iocb     *double_poll;
 };
 
-typedef void (*io_req_tw_func_t)(struct io_kiocb *req);
+typedef void (*io_req_tw_func_t)(struct io_kiocb *req, bool *locked);
 
 struct io_task_work {
        union {
 }
 EXPORT_SYMBOL(io_uring_get_socket);
 
+static inline void io_tw_lock(struct io_ring_ctx *ctx, bool *locked)
+{
+       if (!*locked) {
+               mutex_lock(&ctx->uring_lock);
+               *locked = true;
+       }
+}
+
 #define io_for_each_link(pos, head) \
        for (pos = (head); pos; pos = pos->link)
 
                                                fallback_work.work);
        struct llist_node *node = llist_del_all(&ctx->fallback_llist);
        struct io_kiocb *req, *tmp;
+       bool locked = false;
 
        percpu_ref_get(&ctx->refs);
        llist_for_each_entry_safe(req, tmp, node, io_task_work.fallback_node)
-               req->io_task_work.func(req);
+               req->io_task_work.func(req, &locked);
 
-       mutex_lock(&ctx->uring_lock);
-       if (ctx->submit_state.compl_nr)
-               io_submit_flush_completions(ctx);
-       mutex_unlock(&ctx->uring_lock);
+       if (locked) {
+               if (ctx->submit_state.compl_nr)
+                       io_submit_flush_completions(ctx);
+               mutex_unlock(&ctx->uring_lock);
+       }
        percpu_ref_put(&ctx->refs);
+
 }
 
 static struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
        }
 }
 
-static void io_queue_async_work(struct io_kiocb *req)
+static void io_queue_async_work(struct io_kiocb *req, bool *locked)
 {
        struct io_ring_ctx *ctx = req->ctx;
        struct io_kiocb *link = io_prep_linked_timeout(req);
        struct io_uring_task *tctx = req->task->io_uring;
 
+       /* must not take the lock, NULL it as a precaution */
+       locked = NULL;
+
        BUG_ON(!tctx);
        BUG_ON(!tctx->io_wq);
 
        return __io_req_find_next(req);
 }
 
-static void ctx_flush_and_put(struct io_ring_ctx *ctx)
+static void ctx_flush_and_put(struct io_ring_ctx *ctx, bool *locked)
 {
        if (!ctx)
                return;
-       if (ctx->submit_state.compl_nr) {
-               mutex_lock(&ctx->uring_lock);
+       if (*locked) {
                if (ctx->submit_state.compl_nr)
                        io_submit_flush_completions(ctx);
                mutex_unlock(&ctx->uring_lock);
+               *locked = false;
        }
        percpu_ref_put(&ctx->refs);
 }
 
 static void tctx_task_work(struct callback_head *cb)
 {
+       bool locked = false;
        struct io_ring_ctx *ctx = NULL;
        struct io_uring_task *tctx = container_of(cb, struct io_uring_task,
                                                  task_work);
                                                            io_task_work.node);
 
                        if (req->ctx != ctx) {
-                               ctx_flush_and_put(ctx);
+                               ctx_flush_and_put(ctx, &locked);
                                ctx = req->ctx;
                                percpu_ref_get(&ctx->refs);
                        }
-                       req->io_task_work.func(req);
+                       req->io_task_work.func(req, &locked);
                        node = next;
                } while (node);
 
                cond_resched();
        }
 
-       ctx_flush_and_put(ctx);
+       ctx_flush_and_put(ctx, &locked);
 }
 
 static void io_req_task_work_add(struct io_kiocb *req)
        }
 }
 
-static void io_req_task_cancel(struct io_kiocb *req)
+static void io_req_task_cancel(struct io_kiocb *req, bool *locked)
 {
        struct io_ring_ctx *ctx = req->ctx;
 
        /* ctx is guaranteed to stay alive while we hold uring_lock */
-       mutex_lock(&ctx->uring_lock);
+       io_tw_lock(ctx, locked);
        io_req_complete_failed(req, req->result);
-       mutex_unlock(&ctx->uring_lock);
 }
 
-static void io_req_task_submit(struct io_kiocb *req)
+static void io_req_task_submit(struct io_kiocb *req, bool *locked)
 {
        struct io_ring_ctx *ctx = req->ctx;
 
        /* ctx stays valid until unlock, even if we drop all ours ctx->refs */
-       mutex_lock(&ctx->uring_lock);
+       io_tw_lock(ctx, locked);
        /* req->task == current here, checking PF_EXITING is safe */
        if (likely(!(req->task->flags & PF_EXITING)))
                __io_queue_sqe(req);
        else
                io_req_complete_failed(req, -EFAULT);
-       mutex_unlock(&ctx->uring_lock);
 }
 
 static void io_req_task_queue_fail(struct io_kiocb *req, int ret)
        __io_free_req(req);
 }
 
+static void io_free_req_work(struct io_kiocb *req, bool *locked)
+{
+       io_free_req(req);
+}
+
 struct req_batch {
        struct task_struct      *task;
        int                     task_refs;
 static inline void io_put_req_deferred(struct io_kiocb *req)
 {
        if (req_ref_put_and_test(req)) {
-               req->io_task_work.func = io_free_req;
+               req->io_task_work.func = io_free_req_work;
                io_req_task_work_add(req);
        }
 }
        return false;
 }
 
-static void io_req_task_complete(struct io_kiocb *req)
+static void io_req_task_complete(struct io_kiocb *req, bool *locked)
 {
        __io_req_complete(req, 0, req->result, io_put_rw_kbuf(req));
 }
 {
        if (__io_complete_rw_common(req, res))
                return;
-       io_req_task_complete(req);
+       __io_req_complete(req, 0, req->result, io_put_rw_kbuf(req));
 }
 
 static void io_complete_rw(struct kiocb *kiocb, long res, long res2)
        return !(flags & IORING_CQE_F_MORE);
 }
 
-static void io_poll_task_func(struct io_kiocb *req)
+static void io_poll_task_func(struct io_kiocb *req, bool *locked)
 {
        struct io_ring_ctx *ctx = req->ctx;
        struct io_kiocb *nxt;
                if (done) {
                        nxt = io_put_req_find_next(req);
                        if (nxt)
-                               io_req_task_submit(nxt);
+                               io_req_task_submit(nxt, locked);
                }
        }
 }
        __io_queue_proc(&apoll->poll, pt, head, &apoll->double_poll);
 }
 
-static void io_async_task_func(struct io_kiocb *req)
+static void io_async_task_func(struct io_kiocb *req, bool *locked)
 {
        struct async_poll *apoll = req->apoll;
        struct io_ring_ctx *ctx = req->ctx;
        spin_unlock(&ctx->completion_lock);
 
        if (!READ_ONCE(apoll->poll.canceled))
-               io_req_task_submit(req);
+               io_req_task_submit(req, locked);
        else
                io_req_complete_failed(req, -ECANCELED);
 }
        return 0;
 }
 
-static void io_req_task_timeout(struct io_kiocb *req)
+static void io_req_task_timeout(struct io_kiocb *req, bool *locked)
 {
        req_set_fail(req);
        io_req_complete_post(req, -ETIME, 0);
        if (!req_need_defer(req, seq) && list_empty(&ctx->defer_list)) {
                spin_unlock(&ctx->completion_lock);
                kfree(de);
-               io_queue_async_work(req);
+               io_queue_async_work(req, NULL);
                return true;
        }
 
                return io_file_get_normal(ctx, req, fd);
 }
 
-static void io_req_task_link_timeout(struct io_kiocb *req)
+static void io_req_task_link_timeout(struct io_kiocb *req, bool *locked)
 {
        struct io_kiocb *prev = req->timeout.prev;
        int ret;
                         * Queued up for async execution, worker will release
                         * submit reference when the iocb is actually submitted.
                         */
-                       io_queue_async_work(req);
+                       io_queue_async_work(req, NULL);
                        break;
                }
 
                if (unlikely(ret))
                        io_req_complete_failed(req, ret);
                else
-                       io_queue_async_work(req);
+                       io_queue_async_work(req, NULL);
        }
 }