unsigned int                    stats_updates;
 
        /* Cached pointers for fast iteration in memcg_rstat_updated() */
-       struct memcg_vmstats_percpu     *parent;
-       struct memcg_vmstats            *vmstats;
+       struct memcg_vmstats_percpu __percpu    *parent_pcpu;
+       struct memcg_vmstats                    *vmstats;
 
        /* The above should fit a single cacheline for memcg_rstat_updated() */
 
 
 static inline void memcg_rstat_updated(struct mem_cgroup *memcg, int val)
 {
+       struct memcg_vmstats_percpu __percpu *statc_pcpu;
        struct memcg_vmstats_percpu *statc;
-       int cpu = smp_processor_id();
+       int cpu;
        unsigned int stats_updates;
 
        if (!val)
                return;
 
+       /* Don't assume callers have preemption disabled. */
+       cpu = get_cpu();
+
        cgroup_rstat_updated(memcg->css.cgroup, cpu);
-       statc = this_cpu_ptr(memcg->vmstats_percpu);
-       for (; statc; statc = statc->parent) {
+       statc_pcpu = memcg->vmstats_percpu;
+       for (; statc_pcpu; statc_pcpu = statc->parent_pcpu) {
+               statc = this_cpu_ptr(statc_pcpu);
                /*
                 * If @memcg is already flushable then all its ancestors are
                 * flushable as well and also there is no need to increase
                if (memcg_vmstats_needs_flush(statc->vmstats))
                        break;
 
-               stats_updates = READ_ONCE(statc->stats_updates) + abs(val);
-               WRITE_ONCE(statc->stats_updates, stats_updates);
+               stats_updates = this_cpu_add_return(statc_pcpu->stats_updates,
+                                                   abs(val));
                if (stats_updates < MEMCG_CHARGE_BATCH)
                        continue;
 
+               stats_updates = this_cpu_xchg(statc_pcpu->stats_updates, 0);
                atomic64_add(stats_updates, &statc->vmstats->stats_updates);
-               WRITE_ONCE(statc->stats_updates, 0);
        }
+       put_cpu();
 }
 
 static void __mem_cgroup_flush_stats(struct mem_cgroup *memcg, bool force)
 
 static struct mem_cgroup *mem_cgroup_alloc(struct mem_cgroup *parent)
 {
-       struct memcg_vmstats_percpu *statc, *pstatc;
+       struct memcg_vmstats_percpu *statc;
+       struct memcg_vmstats_percpu __percpu *pstatc_pcpu;
        struct mem_cgroup *memcg;
        int node, cpu;
        int __maybe_unused i;
 
        for_each_possible_cpu(cpu) {
                if (parent)
-                       pstatc = per_cpu_ptr(parent->vmstats_percpu, cpu);
+                       pstatc_pcpu = parent->vmstats_percpu;
                statc = per_cpu_ptr(memcg->vmstats_percpu, cpu);
-               statc->parent = parent ? pstatc : NULL;
+               statc->parent_pcpu = parent ? pstatc_pcpu : NULL;
                statc->vmstats = memcg->vmstats;
        }