return !cgroup_subsys_enabled(memory_cgrp_subsys);
 }
 
-static inline void mem_cgroup_protection(struct mem_cgroup *memcg,
-                                        unsigned long *min, unsigned long *low)
+static inline unsigned long mem_cgroup_protection(struct mem_cgroup *memcg,
+                                                 bool in_low_reclaim)
 {
-       if (mem_cgroup_disabled()) {
-               *min = 0;
-               *low = 0;
-               return;
-       }
+       if (mem_cgroup_disabled())
+               return 0;
+
+       if (in_low_reclaim)
+               return READ_ONCE(memcg->memory.emin);
 
-       *min = READ_ONCE(memcg->memory.emin);
-       *low = READ_ONCE(memcg->memory.elow);
+       return max(READ_ONCE(memcg->memory.emin),
+                  READ_ONCE(memcg->memory.elow));
 }
 
 enum mem_cgroup_protection mem_cgroup_protected(struct mem_cgroup *root,
 {
 }
 
-static inline void mem_cgroup_protection(struct mem_cgroup *memcg,
-                                        unsigned long *min, unsigned long *low)
+static inline unsigned long mem_cgroup_protection(struct mem_cgroup *memcg,
+                                                 bool in_low_reclaim)
 {
-       *min = 0;
-       *low = 0;
+       return 0;
 }
 
 static inline enum mem_cgroup_protection mem_cgroup_protected(
 
                int file = is_file_lru(lru);
                unsigned long lruvec_size;
                unsigned long scan;
-               unsigned long min, low;
+               unsigned long protection;
 
                lruvec_size = lruvec_lru_size(lruvec, lru, sc->reclaim_idx);
-               mem_cgroup_protection(memcg, &min, &low);
+               protection = mem_cgroup_protection(memcg,
+                                                  sc->memcg_low_reclaim);
 
-               if (min || low) {
+               if (protection) {
                        /*
                         * Scale a cgroup's reclaim pressure by proportioning
                         * its current usage to its memory.low or memory.min
                         * setting extremely liberal protection thresholds. It
                         * also means we simply get no protection at all if we
                         * set it too low, which is not ideal.
-                        */
-                       unsigned long cgroup_size = mem_cgroup_size(memcg);
-
-                       /*
-                        * If there is any protection in place, we adjust scan
-                        * pressure in proportion to how much a group's current
-                        * usage exceeds that, in percent.
+                        *
+                        * If there is any protection in place, we reduce scan
+                        * pressure by how much of the total memory used is
+                        * within protection thresholds.
                         *
                         * There is one special case: in the first reclaim pass,
                         * we skip over all groups that are within their low
                         * ideally want to honor how well-behaved groups are in
                         * that case instead of simply punishing them all
                         * equally. As such, we reclaim them based on how much
-                        * of their best-effort protection they are using. Usage
-                        * below memory.min is excluded from consideration when
-                        * calculating utilisation, as it isn't ever
-                        * reclaimable, so it might as well not exist for our
-                        * purposes.
+                        * memory they are using, reducing the scan pressure
+                        * again by how much of the total memory used is under
+                        * hard protection.
                         */
-                       if (sc->memcg_low_reclaim && low > min) {
-                               /*
-                                * Reclaim according to utilisation between min
-                                * and low
-                                */
-                               scan = lruvec_size * (cgroup_size - min) /
-                                       (low - min);
-                       } else {
-                               /* Reclaim according to protection overage */
-                               scan = lruvec_size * cgroup_size /
-                                       max(min, low) - lruvec_size;
-                       }
+                       unsigned long cgroup_size = mem_cgroup_size(memcg);
+
+                       /* Avoid TOCTOU with earlier protection check */
+                       cgroup_size = max(cgroup_size, protection);
+
+                       scan = lruvec_size - lruvec_size * protection /
+                               cgroup_size;
 
                        /*
-                        * Don't allow the scan target to exceed the lruvec
-                        * size, which otherwise could happen if we have >200%
-                        * overage in the normal case, or >100% overage when
-                        * sc->memcg_low_reclaim is set.
-                        *
-                        * This is important because other cgroups without
-                        * memory.low have their scan target initially set to
-                        * their lruvec size, so allowing values >100% of the
-                        * lruvec size here could result in penalising cgroups
-                        * with memory.low set even *more* than their peers in
-                        * some cases in the case of large overages.
-                        *
-                        * Also, minimally target SWAP_CLUSTER_MAX pages to keep
+                        * Minimally target SWAP_CLUSTER_MAX pages to keep
                         * reclaim moving forwards, avoiding decremeting
                         * sc->priority further than desirable.
                         */
-                       scan = clamp(scan, SWAP_CLUSTER_MAX, lruvec_size);
+                       scan = max(scan, SWAP_CLUSTER_MAX);
                } else {
                        scan = lruvec_size;
                }