struct list_lru_one {
        struct list_head        list;
-       /* kept as signed so we can catch imbalance bugs */
+       /* may become negative during memcg reparenting */
        long                    nr_items;
 };
 
 #define list_lru_init_memcg(lru)       __list_lru_init((lru), true, NULL)
 
 int memcg_update_all_list_lrus(int num_memcgs);
+void memcg_drain_all_list_lrus(int src_idx, int dst_idx);
 
 /**
  * list_lru_add: add an element to the lru list's tail
 
 
        spin_lock(&nlru->lock);
        l = list_lru_from_kmem(nlru, item);
-       WARN_ON_ONCE(l->nr_items < 0);
        if (list_empty(item)) {
                list_add_tail(item, &l->list);
                l->nr_items++;
        if (!list_empty(item)) {
                list_del_init(item);
                l->nr_items--;
-               WARN_ON_ONCE(l->nr_items < 0);
                spin_unlock(&nlru->lock);
                return true;
        }
 
        spin_lock(&nlru->lock);
        l = list_lru_from_memcg_idx(nlru, memcg_idx);
-       WARN_ON_ONCE(l->nr_items < 0);
        count = l->nr_items;
        spin_unlock(&nlru->lock);
 
                memcg_cancel_update_list_lru(lru, old_size, new_size);
        goto out;
 }
+
+static void memcg_drain_list_lru_node(struct list_lru_node *nlru,
+                                     int src_idx, int dst_idx)
+{
+       struct list_lru_one *src, *dst;
+
+       /*
+        * Since list_lru_{add,del} may be called under an IRQ-safe lock,
+        * we have to use IRQ-safe primitives here to avoid deadlock.
+        */
+       spin_lock_irq(&nlru->lock);
+
+       src = list_lru_from_memcg_idx(nlru, src_idx);
+       dst = list_lru_from_memcg_idx(nlru, dst_idx);
+
+       list_splice_init(&src->list, &dst->list);
+       dst->nr_items += src->nr_items;
+       src->nr_items = 0;
+
+       spin_unlock_irq(&nlru->lock);
+}
+
+static void memcg_drain_list_lru(struct list_lru *lru,
+                                int src_idx, int dst_idx)
+{
+       int i;
+
+       if (!list_lru_memcg_aware(lru))
+               return;
+
+       for (i = 0; i < nr_node_ids; i++)
+               memcg_drain_list_lru_node(&lru->node[i], src_idx, dst_idx);
+}
+
+void memcg_drain_all_list_lrus(int src_idx, int dst_idx)
+{
+       struct list_lru *lru;
+
+       mutex_lock(&list_lrus_mutex);
+       list_for_each_entry(lru, &list_lrus, list)
+               memcg_drain_list_lru(lru, src_idx, dst_idx);
+       mutex_unlock(&list_lrus_mutex);
+}
 #else
 static int memcg_init_list_lru(struct list_lru *lru, bool memcg_aware)
 {
 
 #if defined(CONFIG_MEMCG_KMEM)
         /* Index in the kmem_cache->memcg_params.memcg_caches array */
        int kmemcg_id;
+       bool kmem_acct_activated;
        bool kmem_acct_active;
 #endif
 
 struct static_key memcg_kmem_enabled_key;
 EXPORT_SYMBOL(memcg_kmem_enabled_key);
 
-static void memcg_free_cache_id(int id);
-
 static void disarm_kmem_keys(struct mem_cgroup *memcg)
 {
-       if (memcg->kmemcg_id >= 0) {
+       if (memcg->kmem_acct_activated)
                static_key_slow_dec(&memcg_kmem_enabled_key);
-               memcg_free_cache_id(memcg->kmemcg_id);
-       }
        /*
         * This check can't live in kmem destruction function,
         * since the charges will outlive the cgroup
        int memcg_id;
 
        BUG_ON(memcg->kmemcg_id >= 0);
+       BUG_ON(memcg->kmem_acct_activated);
        BUG_ON(memcg->kmem_acct_active);
 
        /*
         * patched.
         */
        memcg->kmemcg_id = memcg_id;
+       memcg->kmem_acct_activated = true;
        memcg->kmem_acct_active = true;
 out:
        return err;
 
 static void memcg_deactivate_kmem(struct mem_cgroup *memcg)
 {
+       struct cgroup_subsys_state *css;
+       struct mem_cgroup *parent, *child;
+       int kmemcg_id;
+
        if (!memcg->kmem_acct_active)
                return;
 
        memcg->kmem_acct_active = false;
 
        memcg_deactivate_kmem_caches(memcg);
+
+       kmemcg_id = memcg->kmemcg_id;
+       BUG_ON(kmemcg_id < 0);
+
+       parent = parent_mem_cgroup(memcg);
+       if (!parent)
+               parent = root_mem_cgroup;
+
+       /*
+        * Change kmemcg_id of this cgroup and all its descendants to the
+        * parent's id, and then move all entries from this cgroup's list_lrus
+        * to ones of the parent. After we have finished, all list_lrus
+        * corresponding to this cgroup are guaranteed to remain empty. The
+        * ordering is imposed by list_lru_node->lock taken by
+        * memcg_drain_all_list_lrus().
+        */
+       css_for_each_descendant_pre(css, &memcg->css) {
+               child = mem_cgroup_from_css(css);
+               BUG_ON(child->kmemcg_id != kmemcg_id);
+               child->kmemcg_id = parent->kmemcg_id;
+               if (!memcg->use_hierarchy)
+                       break;
+       }
+       memcg_drain_all_list_lrus(kmemcg_id, parent->kmemcg_id);
+
+       memcg_free_cache_id(kmemcg_id);
 }
 
 static void memcg_destroy_kmem(struct mem_cgroup *memcg)