for_each_node(nid) {
                struct lruvec *lruvec = get_lruvec(memcg, nid);
 
-               /* where the last iteration ended (exclusive) */
+               /* where the current iteration continues after */
+               if (lruvec->mm_state.head == &mm->lru_gen.list)
+                       lruvec->mm_state.head = lruvec->mm_state.head->prev;
+
+               /* where the last iteration ended before */
                if (lruvec->mm_state.tail == &mm->lru_gen.list)
                        lruvec->mm_state.tail = lruvec->mm_state.tail->next;
-
-               /* where the current iteration continues (inclusive) */
-               if (lruvec->mm_state.head != &mm->lru_gen.list)
-                       continue;
-
-               lruvec->mm_state.head = lruvec->mm_state.head->next;
-               /* the deletion ends the current iteration */
-               if (lruvec->mm_state.head == &mm_list->fifo)
-                       WRITE_ONCE(lruvec->mm_state.seq, lruvec->mm_state.seq + 1);
        }
 
        list_del_init(&mm->lru_gen.list);
                            struct mm_struct **iter)
 {
        bool first = false;
-       bool last = true;
+       bool last = false;
        struct mm_struct *mm = NULL;
        struct mem_cgroup *memcg = lruvec_memcg(lruvec);
        struct lru_gen_mm_list *mm_list = get_mm_list(memcg);
        struct lru_gen_mm_state *mm_state = &lruvec->mm_state;
 
        /*
-        * There are four interesting cases for this page table walker:
-        * 1. It tries to start a new iteration of mm_list with a stale max_seq;
-        *    there is nothing left to do.
-        * 2. It's the first of the current generation, and it needs to reset
-        *    the Bloom filter for the next generation.
-        * 3. It reaches the end of mm_list, and it needs to increment
-        *    mm_state->seq; the iteration is done.
-        * 4. It's the last of the current generation, and it needs to reset the
-        *    mm stats counters for the next generation.
+        * mm_state->seq is incremented after each iteration of mm_list. There
+        * are three interesting cases for this page table walker:
+        * 1. It tries to start a new iteration with a stale max_seq: there is
+        *    nothing left to do.
+        * 2. It started the next iteration: it needs to reset the Bloom filter
+        *    so that a fresh set of PTE tables can be recorded.
+        * 3. It ended the current iteration: it needs to reset the mm stats
+        *    counters and tell its caller to increment max_seq.
         */
        spin_lock(&mm_list->lock);
 
        VM_WARN_ON_ONCE(mm_state->seq + 1 < walk->max_seq);
-       VM_WARN_ON_ONCE(*iter && mm_state->seq > walk->max_seq);
-       VM_WARN_ON_ONCE(*iter && !mm_state->nr_walkers);
 
-       if (walk->max_seq <= mm_state->seq) {
-               if (!*iter)
-                       last = false;
+       if (walk->max_seq <= mm_state->seq)
                goto done;
-       }
 
-       if (!mm_state->nr_walkers) {
-               VM_WARN_ON_ONCE(mm_state->head && mm_state->head != &mm_list->fifo);
+       if (!mm_state->head)
+               mm_state->head = &mm_list->fifo;
 
-               mm_state->head = mm_list->fifo.next;
+       if (mm_state->head == &mm_list->fifo)
                first = true;
-       }
-
-       while (!mm && mm_state->head != &mm_list->fifo) {
-               mm = list_entry(mm_state->head, struct mm_struct, lru_gen.list);
 
+       do {
                mm_state->head = mm_state->head->next;
+               if (mm_state->head == &mm_list->fifo) {
+                       WRITE_ONCE(mm_state->seq, mm_state->seq + 1);
+                       last = true;
+                       break;
+               }
 
                /* force scan for those added after the last iteration */
-               if (!mm_state->tail || mm_state->tail == &mm->lru_gen.list) {
-                       mm_state->tail = mm_state->head;
+               if (!mm_state->tail || mm_state->tail == mm_state->head) {
+                       mm_state->tail = mm_state->head->next;
                        walk->force_scan = true;
                }
 
+               mm = list_entry(mm_state->head, struct mm_struct, lru_gen.list);
                if (should_skip_mm(mm, walk))
                        mm = NULL;
-       }
-
-       if (mm_state->head == &mm_list->fifo)
-               WRITE_ONCE(mm_state->seq, mm_state->seq + 1);
+       } while (!mm);
 done:
-       if (*iter && !mm)
-               mm_state->nr_walkers--;
-       if (!*iter && mm)
-               mm_state->nr_walkers++;
-
-       if (mm_state->nr_walkers)
-               last = false;
-
        if (*iter || last)
                reset_mm_stats(lruvec, walk, last);
 
 
        VM_WARN_ON_ONCE(mm_state->seq + 1 < max_seq);
 
-       if (max_seq > mm_state->seq && !mm_state->nr_walkers) {
-               VM_WARN_ON_ONCE(mm_state->head && mm_state->head != &mm_list->fifo);
-
+       if (max_seq > mm_state->seq) {
+               mm_state->head = NULL;
+               mm_state->tail = NULL;
                WRITE_ONCE(mm_state->seq, mm_state->seq + 1);
                reset_mm_stats(lruvec, NULL, true);
                success = true;
 
                walk_pmd_range(&val, addr, next, args);
 
-               /* a racy check to curtail the waiting time */
-               if (wq_has_sleeper(&walk->lruvec->mm_state.wait))
-                       return 1;
-
                if (need_resched() || walk->batched >= MAX_LRU_BATCH) {
                        end = (addr | ~PUD_MASK) + 1;
                        goto done;
        walk->next_addr = FIRST_USER_ADDRESS;
 
        do {
+               DEFINE_MAX_SEQ(lruvec);
+
                err = -EBUSY;
 
+               /* another thread might have called inc_max_seq() */
+               if (walk->max_seq != max_seq)
+                       break;
+
                /* folio_update_gen() requires stable folio_memcg() */
                if (!mem_cgroup_trylock_pages(memcg))
                        break;
                success = iterate_mm_list(lruvec, walk, &mm);
                if (mm)
                        walk_mm(lruvec, mm, walk);
-
-               cond_resched();
        } while (mm);
 done:
-       if (!success) {
-               if (sc->priority <= DEF_PRIORITY - 2)
-                       wait_event_killable(lruvec->mm_state.wait,
-                                           max_seq < READ_ONCE(lrugen->max_seq));
-               return false;
-       }
+       if (success)
+               inc_max_seq(lruvec, can_swap, force_scan);
 
-       VM_WARN_ON_ONCE(max_seq != READ_ONCE(lrugen->max_seq));
-
-       inc_max_seq(lruvec, can_swap, force_scan);
-       /* either this sees any waiters or they will see updated max_seq */
-       if (wq_has_sleeper(&lruvec->mm_state.wait))
-               wake_up_all(&lruvec->mm_state.wait);
-
-       return true;
+       return success;
 }
 
 /******************************************************************************
                INIT_LIST_HEAD(&lrugen->folios[gen][type][zone]);
 
        lruvec->mm_state.seq = MIN_NR_GENS;
-       init_waitqueue_head(&lruvec->mm_state.wait);
 }
 
 #ifdef CONFIG_MEMCG
        for_each_node(nid) {
                struct lruvec *lruvec = get_lruvec(memcg, nid);
 
-               VM_WARN_ON_ONCE(lruvec->mm_state.nr_walkers);
                VM_WARN_ON_ONCE(memchr_inv(lruvec->lrugen.nr_pages, 0,
                                           sizeof(lruvec->lrugen.nr_pages)));