/* Number of zpools in zswap_pool (empirically determined for scalability) */
 #define ZSWAP_NR_ZPOOLS 32
 
+/* Enable/disable memory pressure-based shrinker. */
+static bool zswap_shrinker_enabled = IS_ENABLED(
+               CONFIG_ZSWAP_SHRINKER_DEFAULT_ON);
+module_param_named(shrinker_enabled, zswap_shrinker_enabled, bool, 0644);
+
 /*********************************
 * data structures
 **********************************/
        char tfm_name[CRYPTO_MAX_ALG_NAME];
        struct list_lru list_lru;
        struct mem_cgroup *next_shrink;
+       struct shrinker *shrinker;
+       atomic_t nr_stored;
 };
 
 /*
                        DIV_ROUND_UP(zswap_pool_total_size, PAGE_SIZE);
 }
 
+static u64 get_zswap_pool_size(struct zswap_pool *pool)
+{
+       u64 pool_size = 0;
+       int i;
+
+       for (i = 0; i < ZSWAP_NR_ZPOOLS; i++)
+               pool_size += zpool_get_total_size(pool->zpools[i]);
+
+       return pool_size;
+}
+
 static void zswap_update_total_size(void)
 {
        struct zswap_pool *pool;
        u64 total = 0;
-       int i;
 
        rcu_read_lock();
 
        list_for_each_entry_rcu(pool, &zswap_pools, list)
-               for (i = 0; i < ZSWAP_NR_ZPOOLS; i++)
-                       total += zpool_get_total_size(pool->zpools[i]);
+               total += get_zswap_pool_size(pool);
 
        rcu_read_unlock();
 
        kmem_cache_free(zswap_entry_cache, entry);
 }
 
+/*********************************
+* zswap lruvec functions
+**********************************/
+void zswap_lruvec_state_init(struct lruvec *lruvec)
+{
+       atomic_long_set(&lruvec->zswap_lruvec_state.nr_zswap_protected, 0);
+}
+
+void zswap_page_swapin(struct page *page)
+{
+       struct lruvec *lruvec;
+
+       if (page) {
+               lruvec = folio_lruvec(page_folio(page));
+               atomic_long_inc(&lruvec->zswap_lruvec_state.nr_zswap_protected);
+       }
+}
+
 /*********************************
 * lru functions
 **********************************/
 static void zswap_lru_add(struct list_lru *list_lru, struct zswap_entry *entry)
 {
+       atomic_long_t *nr_zswap_protected;
+       unsigned long lru_size, old, new;
        int nid = entry_to_nid(entry);
        struct mem_cgroup *memcg;
+       struct lruvec *lruvec;
 
        /*
         * Note that it is safe to use rcu_read_lock() here, even in the face of
        memcg = mem_cgroup_from_entry(entry);
        /* will always succeed */
        list_lru_add(list_lru, &entry->lru, nid, memcg);
+
+       /* Update the protection area */
+       lru_size = list_lru_count_one(list_lru, nid, memcg);
+       lruvec = mem_cgroup_lruvec(memcg, NODE_DATA(nid));
+       nr_zswap_protected = &lruvec->zswap_lruvec_state.nr_zswap_protected;
+       old = atomic_long_inc_return(nr_zswap_protected);
+       /*
+        * Decay to avoid overflow and adapt to changing workloads.
+        * This is based on LRU reclaim cost decaying heuristics.
+        */
+       do {
+               new = old > lru_size / 4 ? old / 2 : old;
+       } while (!atomic_long_try_cmpxchg(nr_zswap_protected, &old, new));
        rcu_read_unlock();
 }
 
        int nid = entry_to_nid(entry);
        spinlock_t *lock = &list_lru->node[nid].lock;
        struct mem_cgroup *memcg;
+       struct lruvec *lruvec;
 
        rcu_read_lock();
        memcg = mem_cgroup_from_entry(entry);
        /* we cannot use list_lru_add here, because it increments node's lru count */
        list_lru_putback(list_lru, &entry->lru, nid, memcg);
        spin_unlock(lock);
+
+       lruvec = mem_cgroup_lruvec(memcg, NODE_DATA(entry_to_nid(entry)));
+       /* increment the protection area to account for the LRU rotation. */
+       atomic_long_inc(&lruvec->zswap_lruvec_state.nr_zswap_protected);
        rcu_read_unlock();
 }
 
        else {
                zswap_lru_del(&entry->pool->list_lru, entry);
                zpool_free(zswap_find_zpool(entry), entry->handle);
+               atomic_dec(&entry->pool->nr_stored);
                zswap_pool_put(entry->pool);
        }
        zswap_entry_cache_free(entry);
        return entry;
 }
 
+/*********************************
+* shrinker functions
+**********************************/
+static enum lru_status shrink_memcg_cb(struct list_head *item, struct list_lru_one *l,
+                                      spinlock_t *lock, void *arg);
+
+static unsigned long zswap_shrinker_scan(struct shrinker *shrinker,
+               struct shrink_control *sc)
+{
+       struct lruvec *lruvec = mem_cgroup_lruvec(sc->memcg, NODE_DATA(sc->nid));
+       unsigned long shrink_ret, nr_protected, lru_size;
+       struct zswap_pool *pool = shrinker->private_data;
+       bool encountered_page_in_swapcache = false;
+
+       if (!zswap_shrinker_enabled) {
+               sc->nr_scanned = 0;
+               return SHRINK_STOP;
+       }
+
+       nr_protected =
+               atomic_long_read(&lruvec->zswap_lruvec_state.nr_zswap_protected);
+       lru_size = list_lru_shrink_count(&pool->list_lru, sc);
+
+       /*
+        * Abort if we are shrinking into the protected region.
+        *
+        * This short-circuiting is necessary because if we have too many multiple
+        * concurrent reclaimers getting the freeable zswap object counts at the
+        * same time (before any of them made reasonable progress), the total
+        * number of reclaimed objects might be more than the number of unprotected
+        * objects (i.e the reclaimers will reclaim into the protected area of the
+        * zswap LRU).
+        */
+       if (nr_protected >= lru_size - sc->nr_to_scan) {
+               sc->nr_scanned = 0;
+               return SHRINK_STOP;
+       }
+
+       shrink_ret = list_lru_shrink_walk(&pool->list_lru, sc, &shrink_memcg_cb,
+               &encountered_page_in_swapcache);
+
+       if (encountered_page_in_swapcache)
+               return SHRINK_STOP;
+
+       return shrink_ret ? shrink_ret : SHRINK_STOP;
+}
+
+static unsigned long zswap_shrinker_count(struct shrinker *shrinker,
+               struct shrink_control *sc)
+{
+       struct zswap_pool *pool = shrinker->private_data;
+       struct mem_cgroup *memcg = sc->memcg;
+       struct lruvec *lruvec = mem_cgroup_lruvec(memcg, NODE_DATA(sc->nid));
+       unsigned long nr_backing, nr_stored, nr_freeable, nr_protected;
+
+       if (!zswap_shrinker_enabled)
+               return 0;
+
+#ifdef CONFIG_MEMCG_KMEM
+       mem_cgroup_flush_stats();
+       nr_backing = memcg_page_state(memcg, MEMCG_ZSWAP_B) >> PAGE_SHIFT;
+       nr_stored = memcg_page_state(memcg, MEMCG_ZSWAPPED);
+#else
+       /* use pool stats instead of memcg stats */
+       nr_backing = get_zswap_pool_size(pool) >> PAGE_SHIFT;
+       nr_stored = atomic_read(&pool->nr_stored);
+#endif
+
+       if (!nr_stored)
+               return 0;
+
+       nr_protected =
+               atomic_long_read(&lruvec->zswap_lruvec_state.nr_zswap_protected);
+       nr_freeable = list_lru_shrink_count(&pool->list_lru, sc);
+       /*
+        * Subtract the lru size by an estimate of the number of pages
+        * that should be protected.
+        */
+       nr_freeable = nr_freeable > nr_protected ? nr_freeable - nr_protected : 0;
+
+       /*
+        * Scale the number of freeable pages by the memory saving factor.
+        * This ensures that the better zswap compresses memory, the fewer
+        * pages we will evict to swap (as it will otherwise incur IO for
+        * relatively small memory saving).
+        */
+       return mult_frac(nr_freeable, nr_backing, nr_stored);
+}
+
+static void zswap_alloc_shrinker(struct zswap_pool *pool)
+{
+       pool->shrinker =
+               shrinker_alloc(SHRINKER_NUMA_AWARE | SHRINKER_MEMCG_AWARE, "mm-zswap");
+       if (!pool->shrinker)
+               return;
+
+       pool->shrinker->private_data = pool;
+       pool->shrinker->scan_objects = zswap_shrinker_scan;
+       pool->shrinker->count_objects = zswap_shrinker_count;
+       pool->shrinker->batch = 0;
+       pool->shrinker->seeks = DEFAULT_SEEKS;
+}
+
 /*********************************
 * per-cpu code
 **********************************/
                                       spinlock_t *lock, void *arg)
 {
        struct zswap_entry *entry = container_of(item, struct zswap_entry, lru);
+       bool *encountered_page_in_swapcache = (bool *)arg;
        struct zswap_tree *tree;
        pgoff_t swpoffset;
        enum lru_status ret = LRU_REMOVED_RETRY;
                zswap_reject_reclaim_fail++;
                zswap_lru_putback(&entry->pool->list_lru, entry);
                ret = LRU_RETRY;
+
+               /*
+                * Encountering a page already in swap cache is a sign that we are shrinking
+                * into the warmer region. We should terminate shrinking (if we're in the dynamic
+                * shrinker context).
+                */
+               if (writeback_result == -EEXIST && encountered_page_in_swapcache) {
+                       ret = LRU_SKIP;
+                       *encountered_page_in_swapcache = true;
+               }
+
                goto put_unlock;
        }
        zswap_written_back_pages++;
                                       &pool->node);
        if (ret)
                goto error;
+
+       zswap_alloc_shrinker(pool);
+       if (!pool->shrinker)
+               goto error;
+
        pr_debug("using %s compressor\n", pool->tfm_name);
 
        /* being the current pool takes 1 ref; this func expects the
         */
        kref_init(&pool->kref);
        INIT_LIST_HEAD(&pool->list);
-       list_lru_init_memcg(&pool->list_lru, NULL);
+       if (list_lru_init_memcg(&pool->list_lru, pool->shrinker))
+               goto lru_fail;
+       shrinker_register(pool->shrinker);
        INIT_WORK(&pool->shrink_work, shrink_worker);
+       atomic_set(&pool->nr_stored, 0);
 
        zswap_pool_debug("created", pool);
 
        return pool;
 
+lru_fail:
+       list_lru_destroy(&pool->list_lru);
+       shrinker_free(pool->shrinker);
 error:
        if (pool->acomp_ctx)
                free_percpu(pool->acomp_ctx);
 
        zswap_pool_debug("destroying", pool);
 
+       shrinker_free(pool->shrinker);
        cpuhp_state_remove_instance(CPUHP_MM_ZSWP_POOL_PREPARE, &pool->node);
        free_percpu(pool->acomp_ctx);
        list_lru_destroy(&pool->list_lru);
        if (entry->length) {
                INIT_LIST_HEAD(&entry->lru);
                zswap_lru_add(&entry->pool->list_lru, entry);
+               atomic_inc(&entry->pool->nr_stored);
        }
        spin_unlock(&tree->lock);