#ifdef CONFIG_MEMCG
 
 static int memcg_shrinker_map_size;
-static DEFINE_MUTEX(memcg_shrinker_map_mutex);
 
 static void free_shrinker_map_rcu(struct rcu_head *head)
 {
        struct mem_cgroup_per_node *pn;
        int nid;
 
-       lockdep_assert_held(&memcg_shrinker_map_mutex);
-
        for_each_node(nid) {
                pn = memcg->nodeinfo[nid];
                old = rcu_dereference_protected(pn->shrinker_map, true);
        if (mem_cgroup_is_root(memcg))
                return 0;
 
-       mutex_lock(&memcg_shrinker_map_mutex);
+       down_write(&shrinker_rwsem);
        size = memcg_shrinker_map_size;
        for_each_node(nid) {
                map = kvzalloc_node(sizeof(*map) + size, GFP_KERNEL, nid);
                }
                rcu_assign_pointer(memcg->nodeinfo[nid]->shrinker_map, map);
        }
-       mutex_unlock(&memcg_shrinker_map_mutex);
+       up_write(&shrinker_rwsem);
 
        return ret;
 }
        if (size <= old_size)
                return 0;
 
-       mutex_lock(&memcg_shrinker_map_mutex);
        if (!root_mem_cgroup)
-               goto unlock;
+               goto out;
+
+       lockdep_assert_held(&shrinker_rwsem);
 
        memcg = mem_cgroup_iter(NULL, NULL, NULL);
        do {
                ret = expand_one_shrinker_map(memcg, size, old_size);
                if (ret) {
                        mem_cgroup_iter_break(NULL, memcg);
-                       goto unlock;
+                       goto out;
                }
        } while ((memcg = mem_cgroup_iter(NULL, memcg, NULL)) != NULL);
-unlock:
+out:
        if (!ret)
                memcg_shrinker_map_size = size;
-       mutex_unlock(&memcg_shrinker_map_mutex);
+
        return ret;
 }