while (!list_empty_careful(&ctx->inflight_list)) {
                enum io_wq_cancel ret = IO_WQ_CANCEL_NOTFOUND;
+               struct io_kiocb *cancel_req = NULL;
 
                spin_lock_irq(&ctx->inflight_lock);
                list_for_each_entry(req, &ctx->inflight_list, inflight_entry) {
-                       if (req->work.files == files) {
-                               ret = io_wq_cancel_work(ctx->io_wq, &req->work);
-                               break;
-                       }
+                       if (req->work.files != files)
+                               continue;
+                       /* req is being completed, ignore */
+                       if (!refcount_inc_not_zero(&req->refs))
+                               continue;
+                       cancel_req = req;
+                       break;
                }
-               if (ret == IO_WQ_CANCEL_RUNNING)
+               if (cancel_req)
                        prepare_to_wait(&ctx->inflight_wait, &wait,
-                                       TASK_UNINTERRUPTIBLE);
-
+                                               TASK_UNINTERRUPTIBLE);
                spin_unlock_irq(&ctx->inflight_lock);
 
-               /*
-                * We need to keep going until we get NOTFOUND. We only cancel
-                * one work at the time.
-                *
-                * If we get CANCEL_RUNNING, then wait for a work to complete
-                * before continuing.
-                */
-               if (ret == IO_WQ_CANCEL_OK)
-                       continue;
-               else if (ret != IO_WQ_CANCEL_RUNNING)
+               if (cancel_req) {
+                       ret = io_wq_cancel_work(ctx->io_wq, &cancel_req->work);
+                       io_put_req(cancel_req);
+               }
+
+               /* We need to keep going until we don't find a matching req */
+               if (!cancel_req)
                        break;
                schedule();
        }
+       finish_wait(&ctx->inflight_wait, &wait);
 }
 
 static int io_uring_flush(struct file *file, void *data)