WRITE_ONCE(stock->cached_objcg, objcg);
 }
 
-static void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
-                    enum node_stat_item idx, int nr)
+static void __account_obj_stock(struct obj_cgroup *objcg,
+                               struct memcg_stock_pcp *stock, int nr,
+                               struct pglist_data *pgdat, enum node_stat_item idx)
 {
-       struct memcg_stock_pcp *stock;
-       unsigned long flags;
        int *bytes;
 
-       local_lock_irqsave(&memcg_stock.stock_lock, flags);
-       stock = this_cpu_ptr(&memcg_stock);
-
        /*
         * Save vmstat data in stock and skip vmstat array update unless
-        * accumulating over a page of vmstat data or when pgdat or idx
-        * changes.
+        * accumulating over a page of vmstat data or when pgdat changes.
         */
-       if (READ_ONCE(stock->cached_objcg) != objcg) {
-               replace_stock_objcg(stock, objcg);
-               stock->cached_pgdat = pgdat;
-       } else if (stock->cached_pgdat != pgdat) {
+       if (stock->cached_pgdat != pgdat) {
                /* Flush the existing cached vmstat data */
                struct pglist_data *oldpg = stock->cached_pgdat;
 
        }
        if (nr)
                __mod_objcg_mlstate(objcg, pgdat, idx, nr);
-
-       local_unlock_irqrestore(&memcg_stock.stock_lock, flags);
 }
 
-static bool consume_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
+static bool consume_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes,
+                             struct pglist_data *pgdat, enum node_stat_item idx)
 {
        struct memcg_stock_pcp *stock;
        unsigned long flags;
        if (objcg == READ_ONCE(stock->cached_objcg) && stock->nr_bytes >= nr_bytes) {
                stock->nr_bytes -= nr_bytes;
                ret = true;
+
+               if (pgdat)
+                       __account_obj_stock(objcg, stock, nr_bytes, pgdat, idx);
        }
 
        local_unlock_irqrestore(&memcg_stock.stock_lock, flags);
 }
 
 static void refill_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes,
-                            bool allow_uncharge)
+               bool allow_uncharge, int nr_acct, struct pglist_data *pgdat,
+               enum node_stat_item idx)
 {
        struct memcg_stock_pcp *stock;
        unsigned long flags;
        }
        stock->nr_bytes += nr_bytes;
 
+       if (pgdat)
+               __account_obj_stock(objcg, stock, nr_acct, pgdat, idx);
+
        if (allow_uncharge && (stock->nr_bytes > PAGE_SIZE)) {
                nr_pages = stock->nr_bytes >> PAGE_SHIFT;
                stock->nr_bytes &= (PAGE_SIZE - 1);
                obj_cgroup_uncharge_pages(objcg, nr_pages);
 }
 
-int obj_cgroup_charge(struct obj_cgroup *objcg, gfp_t gfp, size_t size)
+static int obj_cgroup_charge_account(struct obj_cgroup *objcg, gfp_t gfp, size_t size,
+                                    struct pglist_data *pgdat, enum node_stat_item idx)
 {
        unsigned int nr_pages, nr_bytes;
        int ret;
 
-       if (consume_obj_stock(objcg, size))
+       if (likely(consume_obj_stock(objcg, size, pgdat, idx)))
                return 0;
 
        /*
                nr_pages += 1;
 
        ret = obj_cgroup_charge_pages(objcg, gfp, nr_pages);
-       if (!ret && nr_bytes)
-               refill_obj_stock(objcg, PAGE_SIZE - nr_bytes, false);
+       if (!ret && (nr_bytes || pgdat))
+               refill_obj_stock(objcg, nr_bytes ? PAGE_SIZE - nr_bytes : 0,
+                                        false, size, pgdat, idx);
 
        return ret;
 }
 
+int obj_cgroup_charge(struct obj_cgroup *objcg, gfp_t gfp, size_t size)
+{
+       return obj_cgroup_charge_account(objcg, gfp, size, NULL, 0);
+}
+
 void obj_cgroup_uncharge(struct obj_cgroup *objcg, size_t size)
 {
-       refill_obj_stock(objcg, size, true);
+       refill_obj_stock(objcg, size, true, 0, NULL, 0);
 }
 
 static inline size_t obj_full_size(struct kmem_cache *s)
                        return false;
        }
 
-       if (obj_cgroup_charge(objcg, flags, size * obj_full_size(s)))
-               return false;
-
        for (i = 0; i < size; i++) {
                slab = virt_to_slab(p[i]);
 
                if (!slab_obj_exts(slab) &&
                    alloc_slab_obj_exts(slab, s, flags, false)) {
-                       obj_cgroup_uncharge(objcg, obj_full_size(s));
                        continue;
                }
 
+               /*
+                * if we fail and size is 1, memcg_alloc_abort_single() will
+                * just free the object, which is ok as we have not assigned
+                * objcg to its obj_ext yet
+                *
+                * for larger sizes, kmem_cache_free_bulk() will uncharge
+                * any objects that were already charged and obj_ext assigned
+                *
+                * TODO: we could batch this until slab_pgdat(slab) changes
+                * between iterations, with a more complicated undo
+                */
+               if (obj_cgroup_charge_account(objcg, flags, obj_full_size(s),
+                                       slab_pgdat(slab), cache_vmstat_idx(s)))
+                       return false;
+
                off = obj_to_index(s, slab, p[i]);
                obj_cgroup_get(objcg);
                slab_obj_exts(slab)[off].objcg = objcg;
-               mod_objcg_state(objcg, slab_pgdat(slab),
-                               cache_vmstat_idx(s), obj_full_size(s));
        }
 
        return true;
 void __memcg_slab_free_hook(struct kmem_cache *s, struct slab *slab,
                            void **p, int objects, struct slabobj_ext *obj_exts)
 {
+       size_t obj_size = obj_full_size(s);
+
        for (int i = 0; i < objects; i++) {
                struct obj_cgroup *objcg;
                unsigned int off;
                        continue;
 
                obj_exts[off].objcg = NULL;
-               obj_cgroup_uncharge(objcg, obj_full_size(s));
-               mod_objcg_state(objcg, slab_pgdat(slab), cache_vmstat_idx(s),
-                               -obj_full_size(s));
+               refill_obj_stock(objcg, obj_size, true, -obj_size,
+                                slab_pgdat(slab), cache_vmstat_idx(s));
                obj_cgroup_put(objcg);
        }
 }