if (cgroup_memory_nokmem)
                return 0;
 
-       BUG_ON(memcg->kmemcg_id >= 0);
+       if (unlikely(mem_cgroup_is_root(memcg)))
+               return 0;
 
        memcg_id = memcg_alloc_cache_id();
        if (memcg_id < 0)
        struct mem_cgroup *parent;
        int kmemcg_id;
 
-       if (memcg->kmemcg_id == -1)
+       if (cgroup_memory_nokmem)
+               return;
+
+       if (unlikely(mem_cgroup_is_root(memcg)))
                return;
 
        parent = parent_mem_cgroup(memcg);
        memcg_reparent_objcgs(memcg, parent);
 
        kmemcg_id = memcg->kmemcg_id;
-       BUG_ON(kmemcg_id < 0);
 
        /*
         * After we have finished memcg_reparent_objcgs(), all list_lrus
        memcg_drain_all_list_lrus(kmemcg_id, parent);
 
        memcg_free_cache_id(kmemcg_id);
-       memcg->kmemcg_id = -1;
 }
 #else
 static int memcg_online_kmem(struct mem_cgroup *memcg)
 {
        struct mem_cgroup *parent = mem_cgroup_from_css(parent_css);
        struct mem_cgroup *memcg, *old_memcg;
-       long error = -ENOMEM;
 
        old_memcg = set_active_memcg(parent);
        memcg = mem_cgroup_alloc();
                return &memcg->css;
        }
 
-       /* The following stuff does not apply to the root */
-       error = memcg_online_kmem(memcg);
-       if (error)
-               goto fail;
-
        if (cgroup_subsys_on_dfl(memory_cgrp_subsys) && !cgroup_memory_nosocket)
                static_branch_inc(&memcg_sockets_enabled_key);
 
        return &memcg->css;
-fail:
-       mem_cgroup_id_remove(memcg);
-       mem_cgroup_free(memcg);
-       return ERR_PTR(error);
 }
 
 static int mem_cgroup_css_online(struct cgroup_subsys_state *css)
 {
        struct mem_cgroup *memcg = mem_cgroup_from_css(css);
 
+       if (memcg_online_kmem(memcg))
+               goto remove_id;
+
        /*
         * A memcg must be visible for expand_shrinker_info()
         * by the time the maps are allocated. So, we allocate maps
         * here, when for_each_mem_cgroup() can't skip it.
         */
-       if (alloc_shrinker_info(memcg)) {
-               mem_cgroup_id_remove(memcg);
-               return -ENOMEM;
-       }
+       if (alloc_shrinker_info(memcg))
+               goto offline_kmem;
 
        /* Online state pins memcg ID, memcg ID pins CSS */
        refcount_set(&memcg->id.ref, 1);
                queue_delayed_work(system_unbound_wq, &stats_flush_dwork,
                                   2UL*HZ);
        return 0;
+offline_kmem:
+       memcg_offline_kmem(memcg);
+remove_id:
+       mem_cgroup_id_remove(memcg);
+       return -ENOMEM;
 }
 
 static void mem_cgroup_css_offline(struct cgroup_subsys_state *css)
        cancel_work_sync(&memcg->high_work);
        mem_cgroup_remove_from_trees(memcg);
        free_shrinker_info(memcg);
-
-       /* Need to offline kmem if online_css() fails */
-       memcg_offline_kmem(memcg);
        mem_cgroup_free(memcg);
 }