/**
  * wait_iff_congested - Conditionally wait for a backing_dev to become uncongested or a pgdat to complete writes
- * @pgdat: A pgdat to check if it is heavily congested
  * @sync: SYNC or ASYNC IO
  * @timeout: timeout in jiffies
  *
- * In the event of a congested backing_dev (any backing_dev) and the given
- * @pgdat has experienced recent congestion, this waits for up to @timeout
- * jiffies for either a BDI to exit congestion of the given @sync queue
- * or a write to complete.
- *
- * In the absence of pgdat congestion, cond_resched() is called to yield
- * the processor if necessary but otherwise does not sleep.
+ * In the event of a congested backing_dev (any backing_dev) this waits
+ * for up to @timeout jiffies for either a BDI to exit congestion of the
+ * given @sync queue or a write to complete.
  *
  * The return value is 0 if the sleep is for the full timeout. Otherwise,
  * it is the number of jiffies that were still remaining when the function
  * returned. return_value == timeout implies the function did not sleep.
  */
-long wait_iff_congested(struct pglist_data *pgdat, int sync, long timeout)
+long wait_iff_congested(int sync, long timeout)
 {
        long ret;
        unsigned long start = jiffies;
        wait_queue_head_t *wqh = &congestion_wqh[sync];
 
        /*
-        * If there is no congestion, or heavy congestion is not being
-        * encountered in the current pgdat, yield if necessary instead
+        * If there is no congestion, yield if necessary instead
         * of sleeping on the congestion queue
         */
-       if (atomic_read(&nr_wb_congested[sync]) == 0 ||
-           !test_bit(PGDAT_CONGESTED, &pgdat->flags)) {
+       if (atomic_read(&nr_wb_congested[sync]) == 0) {
                cond_resched();
 
                /* In case we scheduled, work out time remaining */
 
 #endif
        return false;
 }
+
+static void set_memcg_congestion(pg_data_t *pgdat,
+                               struct mem_cgroup *memcg,
+                               bool congested)
+{
+       struct mem_cgroup_per_node *mn;
+
+       if (!memcg)
+               return;
+
+       mn = mem_cgroup_nodeinfo(memcg, pgdat->node_id);
+       WRITE_ONCE(mn->congested, congested);
+}
+
+static bool memcg_congested(pg_data_t *pgdat,
+                       struct mem_cgroup *memcg)
+{
+       struct mem_cgroup_per_node *mn;
+
+       mn = mem_cgroup_nodeinfo(memcg, pgdat->node_id);
+       return READ_ONCE(mn->congested);
+
+}
 #else
 static bool global_reclaim(struct scan_control *sc)
 {
 {
        return true;
 }
+
+static inline void set_memcg_congestion(struct pglist_data *pgdat,
+                               struct mem_cgroup *memcg, bool congested)
+{
+}
+
+static inline bool memcg_congested(struct pglist_data *pgdat,
+                       struct mem_cgroup *memcg)
+{
+       return false;
+
+}
 #endif
 
 /*
        return true;
 }
 
+static bool pgdat_memcg_congested(pg_data_t *pgdat, struct mem_cgroup *memcg)
+{
+       return test_bit(PGDAT_CONGESTED, &pgdat->flags) ||
+               (memcg && memcg_congested(pgdat, memcg));
+}
+
 static bool shrink_node(pg_data_t *pgdat, struct scan_control *sc)
 {
        struct reclaim_state *reclaim_state = current->reclaim_state;
                if (sc->nr_reclaimed - nr_reclaimed)
                        reclaimable = true;
 
-               /*
-                * If reclaim is isolating dirty pages under writeback, it
-                * implies that the long-lived page allocation rate is exceeding
-                * the page laundering rate. Either the global limits are not
-                * being effective at throttling processes due to the page
-                * distribution throughout zones or there is heavy usage of a
-                * slow backing device. The only option is to throttle from
-                * reclaim context which is not ideal as there is no guarantee
-                * the dirtying process is throttled in the same way
-                * balance_dirty_pages() manages.
-                *
-                * Once a node is flagged PGDAT_WRITEBACK, kswapd will count the
-                * number of pages under pages flagged for immediate reclaim and
-                * stall if any are encountered in the nr_immediate check below.
-                */
-               if (sc->nr.writeback && sc->nr.writeback == sc->nr.taken)
-                       set_bit(PGDAT_WRITEBACK, &pgdat->flags);
+               if (current_is_kswapd()) {
+                       /*
+                        * If reclaim is isolating dirty pages under writeback,
+                        * it implies that the long-lived page allocation rate
+                        * is exceeding the page laundering rate. Either the
+                        * global limits are not being effective at throttling
+                        * processes due to the page distribution throughout
+                        * zones or there is heavy usage of a slow backing
+                        * device. The only option is to throttle from reclaim
+                        * context which is not ideal as there is no guarantee
+                        * the dirtying process is throttled in the same way
+                        * balance_dirty_pages() manages.
+                        *
+                        * Once a node is flagged PGDAT_WRITEBACK, kswapd will
+                        * count the number of pages under pages flagged for
+                        * immediate reclaim and stall if any are encountered
+                        * in the nr_immediate check below.
+                        */
+                       if (sc->nr.writeback && sc->nr.writeback == sc->nr.taken)
+                               set_bit(PGDAT_WRITEBACK, &pgdat->flags);
 
-               /*
-                * Legacy memcg will stall in page writeback so avoid forcibly
-                * stalling here.
-                */
-               if (sane_reclaim(sc)) {
                        /*
                         * Tag a node as congested if all the dirty pages
                         * scanned were backed by a congested BDI and
                                congestion_wait(BLK_RW_ASYNC, HZ/10);
                }
 
+               /*
+                * Legacy memcg will stall in page writeback so avoid forcibly
+                * stalling in wait_iff_congested().
+                */
+               if (!global_reclaim(sc) && sane_reclaim(sc) &&
+                   sc->nr.dirty && sc->nr.dirty == sc->nr.congested)
+                       set_memcg_congestion(pgdat, root, true);
+
                /*
                 * Stall direct reclaim for IO completions if underlying BDIs
                 * and node is congested. Allow kswapd to continue until it
                 * the LRU too quickly.
                 */
                if (!sc->hibernation_mode && !current_is_kswapd() &&
-                   current_may_throttle())
-                       wait_iff_congested(pgdat, BLK_RW_ASYNC, HZ/10);
+                  current_may_throttle() && pgdat_memcg_congested(pgdat, root))
+                       wait_iff_congested(BLK_RW_ASYNC, HZ/10);
 
        } while (should_continue_reclaim(pgdat, sc->nr_reclaimed - nr_reclaimed,
                                         sc->nr_scanned - nr_scanned, sc));
                        continue;
                last_pgdat = zone->zone_pgdat;
                snapshot_refaults(sc->target_mem_cgroup, zone->zone_pgdat);
+               set_memcg_congestion(last_pgdat, sc->target_mem_cgroup, false);
        }
 
        delayacct_freepages_end();