}
 }
 
+/*
+ * This function should be called when a memcg is being offlined.
+ *
+ * Since the global shrinker shrink_worker() may hold a reference
+ * of the memcg, we must check and release the reference in
+ * zswap_next_shrink.
+ *
+ * shrink_worker() must handle the case where this function releases
+ * the reference of memcg being shrunk.
+ */
 void zswap_memcg_offline_cleanup(struct mem_cgroup *memcg)
 {
        /* lock out zswap shrinker walking memcg tree */
        spin_lock(&zswap_shrink_lock);
-       if (zswap_next_shrink == memcg)
-               zswap_next_shrink = mem_cgroup_iter(NULL, zswap_next_shrink, NULL);
+       if (zswap_next_shrink == memcg) {
+               do {
+                       zswap_next_shrink = mem_cgroup_iter(NULL, zswap_next_shrink, NULL);
+               } while (zswap_next_shrink && !mem_cgroup_online(zswap_next_shrink));
+       }
        spin_unlock(&zswap_shrink_lock);
 }
 
        /* Reclaim down to the accept threshold */
        thr = zswap_accept_thr_pages();
 
-       /* global reclaim will select cgroup in a round-robin fashion. */
+       /*
+        * Global reclaim will select cgroup in a round-robin fashion.
+        *
+        * We save iteration cursor memcg into zswap_next_shrink,
+        * which can be modified by the offline memcg cleaner
+        * zswap_memcg_offline_cleanup().
+        *
+        * Since the offline cleaner is called only once, we cannot leave an
+        * offline memcg reference in zswap_next_shrink.
+        * We can rely on the cleaner only if we get online memcg under lock.
+        *
+        * If we get an offline memcg, we cannot determine if the cleaner has
+        * already been called or will be called later. We must put back the
+        * reference before returning from this function. Otherwise, the
+        * offline memcg left in zswap_next_shrink will hold the reference
+        * until the next run of shrink_worker().
+        */
        do {
-               spin_lock(&zswap_shrink_lock);
-               zswap_next_shrink = mem_cgroup_iter(NULL, zswap_next_shrink, NULL);
-               memcg = zswap_next_shrink;
-
                /*
-                * We need to retry if we have gone through a full round trip, or if we
-                * got an offline memcg (or else we risk undoing the effect of the
-                * zswap memcg offlining cleanup callback). This is not catastrophic
-                * per se, but it will keep the now offlined memcg hostage for a while.
+                * Start shrinking from the next memcg after zswap_next_shrink.
+                * When the offline cleaner has already advanced the cursor,
+                * advancing the cursor here overlooks one memcg, but this
+                * should be negligibly rare.
                 *
-                * Note that if we got an online memcg, we will keep the extra
-                * reference in case the original reference obtained by mem_cgroup_iter
-                * is dropped by the zswap memcg offlining callback, ensuring that the
-                * memcg is not killed when we are reclaiming.
+                * If we get an online memcg, keep the extra reference in case
+                * the original one obtained by mem_cgroup_iter() is dropped by
+                * zswap_memcg_offline_cleanup() while we are shrinking the
+                * memcg.
                 */
-               if (!memcg) {
-                       spin_unlock(&zswap_shrink_lock);
-                       if (++failures == MAX_RECLAIM_RETRIES)
-                               break;
-
-                       goto resched;
-               }
-
-               if (!mem_cgroup_tryget_online(memcg)) {
-                       /* drop the reference from mem_cgroup_iter() */
-                       mem_cgroup_iter_break(NULL, memcg);
-                       zswap_next_shrink = NULL;
-                       spin_unlock(&zswap_shrink_lock);
+               spin_lock(&zswap_shrink_lock);
+               do {
+                       memcg = mem_cgroup_iter(NULL, zswap_next_shrink, NULL);
+                       zswap_next_shrink = memcg;
+               } while (memcg && !mem_cgroup_tryget_online(memcg));
+               spin_unlock(&zswap_shrink_lock);
 
+               if (!memcg) {
                        if (++failures == MAX_RECLAIM_RETRIES)
                                break;
 
                        goto resched;
                }
-               spin_unlock(&zswap_shrink_lock);
 
                ret = shrink_memcg(memcg);
                /* drop the extra reference */