}
 }
 
-static void __mem_cgroup_flush_stats(void)
+static void do_flush_stats(bool atomic)
 {
        /*
         * We always flush the entire tree, so concurrent flushers can just
                return;
 
        WRITE_ONCE(flush_next_time, jiffies_64 + 2*FLUSH_TIME);
-       cgroup_rstat_flush_atomic(root_mem_cgroup->css.cgroup);
+
+       if (atomic)
+               cgroup_rstat_flush_atomic(root_mem_cgroup->css.cgroup);
+       else
+               cgroup_rstat_flush(root_mem_cgroup->css.cgroup);
+
        atomic_set(&stats_flush_threshold, 0);
        atomic_set(&stats_flush_ongoing, 0);
 }
 
+static bool should_flush_stats(void)
+{
+       return atomic_read(&stats_flush_threshold) > num_online_cpus();
+}
+
 void mem_cgroup_flush_stats(void)
 {
-       if (atomic_read(&stats_flush_threshold) > num_online_cpus())
-               __mem_cgroup_flush_stats();
+       if (should_flush_stats())
+               do_flush_stats(false);
 }
 
-void mem_cgroup_flush_stats_ratelimited(void)
+void mem_cgroup_flush_stats_atomic(void)
+{
+       if (should_flush_stats())
+               do_flush_stats(true);
+}
+
+void mem_cgroup_flush_stats_atomic_ratelimited(void)
 {
        if (time_after64(jiffies_64, READ_ONCE(flush_next_time)))
-               mem_cgroup_flush_stats();
+               mem_cgroup_flush_stats_atomic();
 }
 
 static void flush_memcg_stats_dwork(struct work_struct *w)
 {
-       __mem_cgroup_flush_stats();
+       /*
+        * Always flush here so that flushing in latency-sensitive paths is
+        * as cheap as possible.
+        */
+       do_flush_stats(false);
        queue_delayed_work(system_unbound_wq, &stats_flush_dwork, FLUSH_TIME);
 }
 
                 * done from irq context; use stale stats in this case.
                 * Arguably, usage threshold events are not reliable on the root
                 * memcg anyway since its usage is ill-defined.
+                *
+                * Additionally, other call paths through memcg_check_events()
+                * disable irqs, so make sure we are flushing stats atomically.
                 */
                if (in_task())
-                       mem_cgroup_flush_stats();
+                       mem_cgroup_flush_stats_atomic();
                val = memcg_page_state(memcg, NR_FILE_PAGES) +
                        memcg_page_state(memcg, NR_ANON_MAPPED);
                if (swap)
        struct mem_cgroup *memcg = mem_cgroup_from_css(wb->memcg_css);
        struct mem_cgroup *parent;
 
-       mem_cgroup_flush_stats();
+       /*
+        * wb_writeback() takes a spinlock and calls
+        * wb_over_bg_thresh()->mem_cgroup_wb_stats(). Do not sleep.
+        */
+       mem_cgroup_flush_stats_atomic();
 
        *pdirty = memcg_page_state(memcg, NR_FILE_DIRTY);
        *pwriteback = memcg_page_state(memcg, NR_WRITEBACK);