struct io_wqe_acct *acct,
                                        struct io_cb_cancel_data *match);
 static void create_worker_cb(struct callback_head *cb);
+static void io_wq_cancel_tw_create(struct io_wq *wq);
 
 static bool io_worker_get(struct io_worker *worker)
 {
            test_and_set_bit_lock(0, &worker->create_state))
                goto fail_release;
 
+       atomic_inc(&wq->worker_refs);
        init_task_work(&worker->create_work, func);
        worker->create_index = acct->index;
-       if (!task_work_add(wq->task, &worker->create_work, TWA_SIGNAL))
+       if (!task_work_add(wq->task, &worker->create_work, TWA_SIGNAL)) {
+               /*
+                * EXIT may have been set after checking it above, check after
+                * adding the task_work and remove any creation item if it is
+                * now set. wq exit does that too, but we can have added this
+                * work item after we canceled in io_wq_exit_workers().
+                */
+               if (test_bit(IO_WQ_BIT_EXIT, &wq->state))
+                       io_wq_cancel_tw_create(wq);
+               io_worker_ref_put(wq);
                return true;
+       }
+       io_worker_ref_put(wq);
        clear_bit_unlock(0, &worker->create_state);
 fail_release:
        io_worker_release(worker);
        set_bit(IO_WQ_BIT_EXIT, &wq->state);
 }
 
-static void io_wq_exit_workers(struct io_wq *wq)
+static void io_wq_cancel_tw_create(struct io_wq *wq)
 {
        struct callback_head *cb;
-       int node;
-
-       if (!wq->task)
-               return;
 
        while ((cb = task_work_cancel_match(wq->task, io_task_work_match, wq)) != NULL) {
                struct io_worker *worker;
                worker = container_of(cb, struct io_worker, create_work);
                io_worker_cancel_cb(worker);
        }
+}
+
+static void io_wq_exit_workers(struct io_wq *wq)
+{
+       int node;
+
+       if (!wq->task)
+               return;
+
+       io_wq_cancel_tw_create(wq);
 
        rcu_read_lock();
        for_each_node(node) {