}
        return &lru->node[nid].lru;
 }
+
+static inline struct list_lru_one *
+list_lru_from_memcg(struct list_lru *lru, int nid, struct mem_cgroup *memcg)
+{
+       struct list_lru_one *l;
+again:
+       l = list_lru_from_memcg_idx(lru, nid, memcg_kmem_id(memcg));
+       if (likely(l))
+               return l;
+
+       memcg = parent_mem_cgroup(memcg);
+       VM_WARN_ON(!css_is_dying(&memcg->css));
+       goto again;
+}
 #else
 static void list_lru_register(struct list_lru *lru)
 {
 {
        return &lru->node[nid].lru;
 }
+
+static inline struct list_lru_one *
+list_lru_from_memcg(struct list_lru *lru, int nid, int idx)
+{
+       return &lru->node[nid].lru;
+}
 #endif /* CONFIG_MEMCG */
 
 /* The caller must ensure the memcg lifetime. */
 
        spin_lock(&nlru->lock);
        if (list_empty(item)) {
-               l = list_lru_from_memcg_idx(lru, nid, memcg_kmem_id(memcg));
+               l = list_lru_from_memcg(lru, nid, memcg);
                list_add_tail(item, &l->list);
                /* Set shrinker bit if the first element was added */
                if (!l->nr_items++)
 
        spin_lock(&nlru->lock);
        if (!list_empty(item)) {
-               l = list_lru_from_memcg_idx(lru, nid, memcg_kmem_id(memcg));
+               l = list_lru_from_memcg(lru, nid, memcg);
                list_del_init(item);
                l->nr_items--;
                nlru->nr_items--;
        return mlru;
 }
 
-static void memcg_list_lru_free(struct list_lru *lru, int src_idx)
-{
-       struct list_lru_memcg *mlru = xa_erase_irq(&lru->xa, src_idx);
-
-       /*
-        * The __list_lru_walk_one() can walk the list of this node.
-        * We need kvfree_rcu() here. And the walking of the list
-        * is under lru->node[nid]->lock, which can serve as a RCU
-        * read-side critical section.
-        */
-       if (mlru)
-               kvfree_rcu(mlru, rcu);
-}
-
 static inline void memcg_init_list_lru(struct list_lru *lru, bool memcg_aware)
 {
        if (memcg_aware)
 }
 
 static void memcg_reparent_list_lru_node(struct list_lru *lru, int nid,
-                                        int src_idx, struct mem_cgroup *dst_memcg)
+                                        struct list_lru_one *src,
+                                        struct mem_cgroup *dst_memcg)
 {
        struct list_lru_node *nlru = &lru->node[nid];
-       int dst_idx = dst_memcg->kmemcg_id;
-       struct list_lru_one *src, *dst;
+       struct list_lru_one *dst;
 
        /*
         * Since list_lru_{add,del} may be called under an IRQ-safe lock,
         * we have to use IRQ-safe primitives here to avoid deadlock.
         */
        spin_lock_irq(&nlru->lock);
-
-       src = list_lru_from_memcg_idx(lru, nid, src_idx);
-       if (!src)
-               goto out;
-       dst = list_lru_from_memcg_idx(lru, nid, dst_idx);
+       dst = list_lru_from_memcg_idx(lru, nid, memcg_kmem_id(dst_memcg));
 
        list_splice_init(&src->list, &dst->list);
 
                set_shrinker_bit(dst_memcg, nid, lru_shrinker_id(lru));
                src->nr_items = 0;
        }
-out:
        spin_unlock_irq(&nlru->lock);
 }
 
 void memcg_reparent_list_lrus(struct mem_cgroup *memcg, struct mem_cgroup *parent)
 {
-       struct cgroup_subsys_state *css;
        struct list_lru *lru;
-       int src_idx = memcg->kmemcg_id, i;
-
-       /*
-        * Change kmemcg_id of this cgroup and all its descendants to the
-        * parent's id, and then move all entries from this cgroup's list_lrus
-        * to ones of the parent.
-        */
-       rcu_read_lock();
-       css_for_each_descendant_pre(css, &memcg->css) {
-               struct mem_cgroup *child;
-
-               child = mem_cgroup_from_css(css);
-               WRITE_ONCE(child->kmemcg_id, parent->kmemcg_id);
-       }
-       rcu_read_unlock();
+       int i;
 
-       /*
-        * With kmemcg_id set to parent, holding the lock of each list_lru_node
-        * below can prevent list_lru_{add,del,isolate} from touching the lru,
-        * safe to reparent.
-        */
        mutex_lock(&list_lrus_mutex);
        list_for_each_entry(lru, &memcg_list_lrus, list) {
+               struct list_lru_memcg *mlru;
+               XA_STATE(xas, &lru->xa, memcg->kmemcg_id);
+
+               /*
+                * Lock the Xarray to ensure no on going list_lru_memcg
+                * allocation and further allocation will see css_is_dying().
+                */
+               xas_lock_irq(&xas);
+               mlru = xas_store(&xas, NULL);
+               xas_unlock_irq(&xas);
+               if (!mlru)
+                       continue;
+
+               /*
+                * With Xarray value set to NULL, holding the lru lock below
+                * prevents list_lru_{add,del,isolate} from touching the lru,
+                * safe to reparent.
+                */
                for_each_node(i)
-                       memcg_reparent_list_lru_node(lru, i, src_idx, parent);
+                       memcg_reparent_list_lru_node(lru, i, &mlru->node[i], parent);
 
                /*
                 * Here all list_lrus corresponding to the cgroup are guaranteed
                 * to remain empty, we can safely free this lru, any further
                 * memcg_list_lru_alloc() call will simply bail out.
                 */
-               memcg_list_lru_free(lru, src_idx);
+               kvfree_rcu(mlru, rcu);
        }
        mutex_unlock(&list_lrus_mutex);
 }
 int memcg_list_lru_alloc(struct mem_cgroup *memcg, struct list_lru *lru,
                         gfp_t gfp)
 {
-       int i;
        unsigned long flags;
-       struct list_lru_memcg_table {
-               struct list_lru_memcg *mlru;
-               struct mem_cgroup *memcg;
-       } *table;
+       struct list_lru_memcg *mlru;
+       struct mem_cgroup *pos, *parent;
        XA_STATE(xas, &lru->xa, 0);
 
        if (!list_lru_memcg_aware(lru) || memcg_list_lru_allocated(memcg, lru))
                return 0;
 
        gfp &= GFP_RECLAIM_MASK;
-       table = kmalloc_array(memcg->css.cgroup->level, sizeof(*table), gfp);
-       if (!table)
-               return -ENOMEM;
-
        /*
         * Because the list_lru can be reparented to the parent cgroup's
         * list_lru, we should make sure that this cgroup and all its
         * ancestors have allocated list_lru_memcg.
         */
-       for (i = 0; memcg; memcg = parent_mem_cgroup(memcg), i++) {
-               if (memcg_list_lru_allocated(memcg, lru))
-                       break;
-
-               table[i].memcg = memcg;
-               table[i].mlru = memcg_init_list_lru_one(gfp);
-               if (!table[i].mlru) {
-                       while (i--)
-                               kfree(table[i].mlru);
-                       kfree(table);
-                       return -ENOMEM;
+       do {
+               /*
+                * Keep finding the farest parent that wasn't populated
+                * until found memcg itself.
+                */
+               pos = memcg;
+               parent = parent_mem_cgroup(pos);
+               while (!memcg_list_lru_allocated(parent, lru)) {
+                       pos = parent;
+                       parent = parent_mem_cgroup(pos);
                }
-       }
-
-       xas_lock_irqsave(&xas, flags);
-       while (i--) {
-               int index = READ_ONCE(table[i].memcg->kmemcg_id);
-               struct list_lru_memcg *mlru = table[i].mlru;
 
-               xas_set(&xas, index);
-retry:
-               if (unlikely(index < 0 || xas_error(&xas) || xas_load(&xas))) {
-                       kfree(mlru);
-               } else {
-                       xas_store(&xas, mlru);
-                       if (xas_error(&xas) == -ENOMEM) {
-                               xas_unlock_irqrestore(&xas, flags);
-                               if (xas_nomem(&xas, gfp))
-                                       xas_set_err(&xas, 0);
-                               xas_lock_irqsave(&xas, flags);
-                               /*
-                                * The xas lock has been released, this memcg
-                                * can be reparented before us. So reload
-                                * memcg id. More details see the comments
-                                * in memcg_reparent_list_lrus().
-                                */
-                               index = READ_ONCE(table[i].memcg->kmemcg_id);
-                               if (index < 0)
-                                       xas_set_err(&xas, 0);
-                               else if (!xas_error(&xas) && index != xas.xa_index)
-                                       xas_set(&xas, index);
-                               goto retry;
+               mlru = memcg_init_list_lru_one(gfp);
+               if (!mlru)
+                       return -ENOMEM;
+               xas_set(&xas, pos->kmemcg_id);
+               do {
+                       xas_lock_irqsave(&xas, flags);
+                       if (!xas_load(&xas) && !css_is_dying(&pos->css)) {
+                               xas_store(&xas, mlru);
+                               if (!xas_error(&xas))
+                                       mlru = NULL;
                        }
-               }
-       }
-       /* xas_nomem() is used to free memory instead of memory allocation. */
-       if (xas.xa_alloc)
-               xas_nomem(&xas, gfp);
-       xas_unlock_irqrestore(&xas, flags);
-       kfree(table);
+                       xas_unlock_irqrestore(&xas, flags);
+               } while (xas_nomem(&xas, gfp));
+               if (mlru)
+                       kfree(mlru);
+       } while (pos != memcg && !css_is_dying(&pos->css));
 
        return xas_error(&xas);
 }