struct kmem_cache *
 __memcg_kmem_get_cache(struct kmem_cache *cachep, gfp_t gfp);
 
+int memcg_charge_kmem(struct mem_cgroup *memcg, gfp_t gfp, u64 size);
+void memcg_uncharge_kmem(struct mem_cgroup *memcg, u64 size);
+
 void mem_cgroup_destroy_cache(struct kmem_cache *cachep);
 int __kmem_cache_destroy_memcg_children(struct kmem_cache *s);
 
  * @cachep: the original global kmem cache
  * @gfp: allocation flags.
  *
- * This function assumes that the task allocating, which determines the memcg
- * in the page allocator, belongs to the same cgroup throughout the whole
- * process.  Misacounting can happen if the task calls memcg_kmem_get_cache()
- * while belonging to a cgroup, and later on changes. This is considered
- * acceptable, and should only happen upon task migration.
- *
- * Before the cache is created by the memcg core, there is also a possible
- * imbalance: the task belongs to a memcg, but the cache being allocated from
- * is the global cache, since the child cache is not yet guaranteed to be
- * ready. This case is also fine, since in this case the GFP_KMEMCG will not be
- * passed and the page allocator will not attempt any cgroup accounting.
+ * All memory allocated from a per-memcg cache is charged to the owner memcg.
  */
 static __always_inline struct kmem_cache *
 memcg_kmem_get_cache(struct kmem_cache *cachep, gfp_t gfp)
 
 }
 #endif
 
-static int memcg_charge_kmem(struct mem_cgroup *memcg, gfp_t gfp, u64 size)
+int memcg_charge_kmem(struct mem_cgroup *memcg, gfp_t gfp, u64 size)
 {
        struct res_counter *fail_res;
        int ret = 0;
        return ret;
 }
 
-static void memcg_uncharge_kmem(struct mem_cgroup *memcg, u64 size)
+void memcg_uncharge_kmem(struct mem_cgroup *memcg, u64 size)
 {
        res_counter_uncharge(&memcg->res, size);
        if (do_swap_account)
 
        if (cachep->flags & SLAB_RECLAIM_ACCOUNT)
                flags |= __GFP_RECLAIMABLE;
 
+       if (memcg_charge_slab(cachep, flags, cachep->gfporder))
+               return NULL;
+
        page = alloc_pages_exact_node(nodeid, flags | __GFP_NOTRACK, cachep->gfporder);
        if (!page) {
+               memcg_uncharge_slab(cachep, cachep->gfporder);
                slab_out_of_memory(cachep, flags, nodeid);
                return NULL;
        }
        memcg_release_pages(cachep, cachep->gfporder);
        if (current->reclaim_state)
                current->reclaim_state->reclaimed_slab += nr_freed;
-       __free_memcg_kmem_pages(page, cachep->gfporder);
+       __free_pages(page, cachep->gfporder);
+       memcg_uncharge_slab(cachep, cachep->gfporder);
 }
 
 static void kmem_rcu_free(struct rcu_head *head)
 
                return s;
        return s->memcg_params->root_cache;
 }
+
+static __always_inline int memcg_charge_slab(struct kmem_cache *s,
+                                            gfp_t gfp, int order)
+{
+       if (!memcg_kmem_enabled())
+               return 0;
+       if (is_root_cache(s))
+               return 0;
+       return memcg_charge_kmem(s->memcg_params->memcg, gfp,
+                                PAGE_SIZE << order);
+}
+
+static __always_inline void memcg_uncharge_slab(struct kmem_cache *s, int order)
+{
+       if (!memcg_kmem_enabled())
+               return;
+       if (is_root_cache(s))
+               return;
+       memcg_uncharge_kmem(s->memcg_params->memcg, PAGE_SIZE << order);
+}
 #else
 static inline bool is_root_cache(struct kmem_cache *s)
 {
 {
        return s;
 }
+
+static inline int memcg_charge_slab(struct kmem_cache *s, gfp_t gfp, int order)
+{
+       return 0;
+}
+
+static inline void memcg_uncharge_slab(struct kmem_cache *s, int order)
+{
+}
 #endif
 
 static inline struct kmem_cache *cache_from_obj(struct kmem_cache *s, void *x)
 
                                 root_cache->size, root_cache->align,
                                 root_cache->flags, root_cache->ctor,
                                 memcg, root_cache);
-       if (IS_ERR(s)) {
+       if (IS_ERR(s))
                kfree(cache_name);
-               goto out_unlock;
-       }
-
-       s->allocflags |= __GFP_KMEMCG;
 
 out_unlock:
        mutex_unlock(&slab_mutex);
 
 /*
  * Slab allocation and freeing
  */
-static inline struct page *alloc_slab_page(gfp_t flags, int node,
-                                       struct kmem_cache_order_objects oo)
+static inline struct page *alloc_slab_page(struct kmem_cache *s,
+               gfp_t flags, int node, struct kmem_cache_order_objects oo)
 {
+       struct page *page;
        int order = oo_order(oo);
 
        flags |= __GFP_NOTRACK;
 
+       if (memcg_charge_slab(s, flags, order))
+               return NULL;
+
        if (node == NUMA_NO_NODE)
-               return alloc_pages(flags, order);
+               page = alloc_pages(flags, order);
        else
-               return alloc_pages_exact_node(node, flags, order);
+               page = alloc_pages_exact_node(node, flags, order);
+
+       if (!page)
+               memcg_uncharge_slab(s, order);
+
+       return page;
 }
 
 static struct page *allocate_slab(struct kmem_cache *s, gfp_t flags, int node)
         */
        alloc_gfp = (flags | __GFP_NOWARN | __GFP_NORETRY) & ~__GFP_NOFAIL;
 
-       page = alloc_slab_page(alloc_gfp, node, oo);
+       page = alloc_slab_page(s, alloc_gfp, node, oo);
        if (unlikely(!page)) {
                oo = s->min;
                alloc_gfp = flags;
                 * Allocation may have failed due to fragmentation.
                 * Try a lower order alloc if possible
                 */
-               page = alloc_slab_page(alloc_gfp, node, oo);
+               page = alloc_slab_page(s, alloc_gfp, node, oo);
 
                if (page)
                        stat(s, ORDER_FALLBACK);
        page_mapcount_reset(page);
        if (current->reclaim_state)
                current->reclaim_state->reclaimed_slab += pages;
-       __free_memcg_kmem_pages(page, order);
+       __free_pages(page, order);
+       memcg_uncharge_slab(s, order);
 }
 
 #define need_reserve_slab_rcu                                          \