}
 
 static int bpf_cgroup_storages_alloc(struct bpf_cgroup_storage *storages[],
-                                    struct bpf_prog *prog)
+                                    struct bpf_cgroup_storage *new_storages[],
+                                    enum bpf_attach_type type,
+                                    struct bpf_prog *prog,
+                                    struct cgroup *cgrp)
 {
        enum bpf_cgroup_storage_type stype;
+       struct bpf_cgroup_storage_key key;
+       struct bpf_map *map;
+
+       key.cgroup_inode_id = cgroup_id(cgrp);
+       key.attach_type = type;
 
        for_each_cgroup_storage_type(stype) {
+               map = prog->aux->cgroup_storage[stype];
+               if (!map)
+                       continue;
+
+               storages[stype] = cgroup_storage_lookup((void *)map, &key, false);
+               if (storages[stype])
+                       continue;
+
                storages[stype] = bpf_cgroup_storage_alloc(prog, stype);
                if (IS_ERR(storages[stype])) {
-                       storages[stype] = NULL;
-                       bpf_cgroup_storages_free(storages);
+                       bpf_cgroup_storages_free(new_storages);
                        return -ENOMEM;
                }
+
+               new_storages[stype] = storages[stype];
        }
 
        return 0;
 }
 
 static void bpf_cgroup_storages_link(struct bpf_cgroup_storage *storages[],
-                                    struct cgroup* cgrp,
+                                    struct cgroup *cgrp,
                                     enum bpf_attach_type attach_type)
 {
        enum bpf_cgroup_storage_type stype;
                bpf_cgroup_storage_link(storages[stype], cgrp, attach_type);
 }
 
-static void bpf_cgroup_storages_unlink(struct bpf_cgroup_storage *storages[])
-{
-       enum bpf_cgroup_storage_type stype;
-
-       for_each_cgroup_storage_type(stype)
-               bpf_cgroup_storage_unlink(storages[stype]);
-}
-
 /* Called when bpf_cgroup_link is auto-detached from dying cgroup.
  * It drops cgroup and bpf_prog refcounts, and marks bpf_link as defunct. It
  * doesn't free link memory, which will eventually be done by bpf_link's
        struct cgroup *p, *cgrp = container_of(work, struct cgroup,
                                               bpf.release_work);
        struct bpf_prog_array *old_array;
+       struct list_head *storages = &cgrp->bpf.storages;
+       struct bpf_cgroup_storage *storage, *stmp;
+
        unsigned int type;
 
        mutex_lock(&cgroup_mutex);
 
        for (type = 0; type < ARRAY_SIZE(cgrp->bpf.progs); type++) {
                struct list_head *progs = &cgrp->bpf.progs[type];
-               struct bpf_prog_list *pl, *tmp;
+               struct bpf_prog_list *pl, *pltmp;
 
-               list_for_each_entry_safe(pl, tmp, progs, node) {
+               list_for_each_entry_safe(pl, pltmp, progs, node) {
                        list_del(&pl->node);
                        if (pl->prog)
                                bpf_prog_put(pl->prog);
                        if (pl->link)
                                bpf_cgroup_link_auto_detach(pl->link);
-                       bpf_cgroup_storages_unlink(pl->storage);
-                       bpf_cgroup_storages_free(pl->storage);
                        kfree(pl);
                        static_branch_dec(&cgroup_bpf_enabled_key);
                }
                bpf_prog_array_free(old_array);
        }
 
+       list_for_each_entry_safe(storage, stmp, storages, list_cg) {
+               bpf_cgroup_storage_unlink(storage);
+               bpf_cgroup_storage_free(storage);
+       }
+
        mutex_unlock(&cgroup_mutex);
 
        for (p = cgroup_parent(cgrp); p; p = cgroup_parent(p))
        for (i = 0; i < NR; i++)
                INIT_LIST_HEAD(&cgrp->bpf.progs[i]);
 
+       INIT_LIST_HEAD(&cgrp->bpf.storages);
+
        for (i = 0; i < NR; i++)
                if (compute_effective_progs(cgrp, i, &arrays[i]))
                        goto cleanup;
        struct list_head *progs = &cgrp->bpf.progs[type];
        struct bpf_prog *old_prog = NULL;
        struct bpf_cgroup_storage *storage[MAX_BPF_CGROUP_STORAGE_TYPE] = {};
-       struct bpf_cgroup_storage *old_storage[MAX_BPF_CGROUP_STORAGE_TYPE] = {};
+       struct bpf_cgroup_storage *new_storage[MAX_BPF_CGROUP_STORAGE_TYPE] = {};
        struct bpf_prog_list *pl;
        int err;
 
        if (IS_ERR(pl))
                return PTR_ERR(pl);
 
-       if (bpf_cgroup_storages_alloc(storage, prog ? : link->link.prog))
+       if (bpf_cgroup_storages_alloc(storage, new_storage, type,
+                                     prog ? : link->link.prog, cgrp))
                return -ENOMEM;
 
        if (pl) {
                old_prog = pl->prog;
-               bpf_cgroup_storages_unlink(pl->storage);
-               bpf_cgroup_storages_assign(old_storage, pl->storage);
        } else {
                pl = kmalloc(sizeof(*pl), GFP_KERNEL);
                if (!pl) {
-                       bpf_cgroup_storages_free(storage);
+                       bpf_cgroup_storages_free(new_storage);
                        return -ENOMEM;
                }
                list_add_tail(&pl->node, progs);
        if (err)
                goto cleanup;
 
-       bpf_cgroup_storages_free(old_storage);
        if (old_prog)
                bpf_prog_put(old_prog);
        else
                static_branch_inc(&cgroup_bpf_enabled_key);
-       bpf_cgroup_storages_link(pl->storage, cgrp, type);
+       bpf_cgroup_storages_link(new_storage, cgrp, type);
        return 0;
 
 cleanup:
                pl->prog = old_prog;
                pl->link = NULL;
        }
-       bpf_cgroup_storages_free(pl->storage);
-       bpf_cgroup_storages_assign(pl->storage, old_storage);
-       bpf_cgroup_storages_link(pl->storage, cgrp, type);
+       bpf_cgroup_storages_free(new_storage);
        if (!old_prog) {
                list_del(&pl->node);
                kfree(pl);
 
        /* now can actually delete it from this cgroup list */
        list_del(&pl->node);
-       bpf_cgroup_storages_unlink(pl->storage);
-       bpf_cgroup_storages_free(pl->storage);
        kfree(pl);
        if (list_empty(progs))
                /* last program was detached, reset flags to zero */
 
 #include <linux/slab.h>
 #include <uapi/linux/btf.h>
 
+#include "../cgroup/cgroup-internal.h"
+
 DEFINE_PER_CPU(struct bpf_cgroup_storage*, bpf_cgroup_storage[MAX_BPF_CGROUP_STORAGE_TYPE]);
 
 #ifdef CONFIG_CGROUP_BPF
        struct bpf_map map;
 
        spinlock_t lock;
-       struct bpf_prog_aux *aux;
        struct rb_root root;
        struct list_head list;
 };
        return container_of(map, struct bpf_cgroup_storage_map, map);
 }
 
-static int bpf_cgroup_storage_key_cmp(
-       const struct bpf_cgroup_storage_key *key1,
-       const struct bpf_cgroup_storage_key *key2)
+static bool attach_type_isolated(const struct bpf_map *map)
 {
-       if (key1->cgroup_inode_id < key2->cgroup_inode_id)
-               return -1;
-       else if (key1->cgroup_inode_id > key2->cgroup_inode_id)
-               return 1;
-       else if (key1->attach_type < key2->attach_type)
-               return -1;
-       else if (key1->attach_type > key2->attach_type)
-               return 1;
+       return map->key_size == sizeof(struct bpf_cgroup_storage_key);
+}
+
+static int bpf_cgroup_storage_key_cmp(const struct bpf_cgroup_storage_map *map,
+                                     const void *_key1, const void *_key2)
+{
+       if (attach_type_isolated(&map->map)) {
+               const struct bpf_cgroup_storage_key *key1 = _key1;
+               const struct bpf_cgroup_storage_key *key2 = _key2;
+
+               if (key1->cgroup_inode_id < key2->cgroup_inode_id)
+                       return -1;
+               else if (key1->cgroup_inode_id > key2->cgroup_inode_id)
+                       return 1;
+               else if (key1->attach_type < key2->attach_type)
+                       return -1;
+               else if (key1->attach_type > key2->attach_type)
+                       return 1;
+       } else {
+               const __u64 *cgroup_inode_id1 = _key1;
+               const __u64 *cgroup_inode_id2 = _key2;
+
+               if (*cgroup_inode_id1 < *cgroup_inode_id2)
+                       return -1;
+               else if (*cgroup_inode_id1 > *cgroup_inode_id2)
+                       return 1;
+       }
        return 0;
 }
 
-static struct bpf_cgroup_storage *cgroup_storage_lookup(
-       struct bpf_cgroup_storage_map *map, struct bpf_cgroup_storage_key *key,
-       bool locked)
+struct bpf_cgroup_storage *
+cgroup_storage_lookup(struct bpf_cgroup_storage_map *map,
+                     void *key, bool locked)
 {
        struct rb_root *root = &map->root;
        struct rb_node *node;
 
                storage = container_of(node, struct bpf_cgroup_storage, node);
 
-               switch (bpf_cgroup_storage_key_cmp(key, &storage->key)) {
+               switch (bpf_cgroup_storage_key_cmp(map, key, &storage->key)) {
                case -1:
                        node = node->rb_left;
                        break;
                this = container_of(*new, struct bpf_cgroup_storage, node);
 
                parent = *new;
-               switch (bpf_cgroup_storage_key_cmp(&storage->key, &this->key)) {
+               switch (bpf_cgroup_storage_key_cmp(map, &storage->key, &this->key)) {
                case -1:
                        new = &((*new)->rb_left);
                        break;
        return 0;
 }
 
-static void *cgroup_storage_lookup_elem(struct bpf_map *_map, void *_key)
+static void *cgroup_storage_lookup_elem(struct bpf_map *_map, void *key)
 {
        struct bpf_cgroup_storage_map *map = map_to_storage(_map);
-       struct bpf_cgroup_storage_key *key = _key;
        struct bpf_cgroup_storage *storage;
 
        storage = cgroup_storage_lookup(map, key, false);
        return &READ_ONCE(storage->buf)->data[0];
 }
 
-static int cgroup_storage_update_elem(struct bpf_map *map, void *_key,
+static int cgroup_storage_update_elem(struct bpf_map *map, void *key,
                                      void *value, u64 flags)
 {
-       struct bpf_cgroup_storage_key *key = _key;
        struct bpf_cgroup_storage *storage;
        struct bpf_storage_buffer *new;
 
-       if (unlikely(flags & ~(BPF_F_LOCK | BPF_EXIST | BPF_NOEXIST)))
-               return -EINVAL;
-
-       if (unlikely(flags & BPF_NOEXIST))
+       if (unlikely(flags & ~(BPF_F_LOCK | BPF_EXIST)))
                return -EINVAL;
 
        if (unlikely((flags & BPF_F_LOCK) &&
        return 0;
 }
 
-int bpf_percpu_cgroup_storage_copy(struct bpf_map *_map, void *_key,
+int bpf_percpu_cgroup_storage_copy(struct bpf_map *_map, void *key,
                                   void *value)
 {
        struct bpf_cgroup_storage_map *map = map_to_storage(_map);
-       struct bpf_cgroup_storage_key *key = _key;
        struct bpf_cgroup_storage *storage;
        int cpu, off = 0;
        u32 size;
        return 0;
 }
 
-int bpf_percpu_cgroup_storage_update(struct bpf_map *_map, void *_key,
+int bpf_percpu_cgroup_storage_update(struct bpf_map *_map, void *key,
                                     void *value, u64 map_flags)
 {
        struct bpf_cgroup_storage_map *map = map_to_storage(_map);
-       struct bpf_cgroup_storage_key *key = _key;
        struct bpf_cgroup_storage *storage;
        int cpu, off = 0;
        u32 size;
        return 0;
 }
 
-static int cgroup_storage_get_next_key(struct bpf_map *_map, void *_key,
+static int cgroup_storage_get_next_key(struct bpf_map *_map, void *key,
                                       void *_next_key)
 {
        struct bpf_cgroup_storage_map *map = map_to_storage(_map);
-       struct bpf_cgroup_storage_key *key = _key;
-       struct bpf_cgroup_storage_key *next = _next_key;
        struct bpf_cgroup_storage *storage;
 
        spin_lock_bh(&map->lock);
                if (!storage)
                        goto enoent;
 
-               storage = list_next_entry(storage, list);
+               storage = list_next_entry(storage, list_map);
                if (!storage)
                        goto enoent;
        } else {
                storage = list_first_entry(&map->list,
-                                        struct bpf_cgroup_storage, list);
+                                        struct bpf_cgroup_storage, list_map);
        }
 
        spin_unlock_bh(&map->lock);
-       next->attach_type = storage->key.attach_type;
-       next->cgroup_inode_id = storage->key.cgroup_inode_id;
+
+       if (attach_type_isolated(&map->map)) {
+               struct bpf_cgroup_storage_key *next = _next_key;
+               *next = storage->key;
+       } else {
+               __u64 *next = _next_key;
+               *next = storage->key.cgroup_inode_id;
+       }
        return 0;
 
 enoent:
        struct bpf_map_memory mem;
        int ret;
 
-       if (attr->key_size != sizeof(struct bpf_cgroup_storage_key))
+       if (attr->key_size != sizeof(struct bpf_cgroup_storage_key) &&
+           attr->key_size != sizeof(__u64))
                return ERR_PTR(-EINVAL);
 
        if (attr->value_size == 0)
 static void cgroup_storage_map_free(struct bpf_map *_map)
 {
        struct bpf_cgroup_storage_map *map = map_to_storage(_map);
+       struct list_head *storages = &map->list;
+       struct bpf_cgroup_storage *storage, *stmp;
+
+       mutex_lock(&cgroup_mutex);
+
+       list_for_each_entry_safe(storage, stmp, storages, list_map) {
+               bpf_cgroup_storage_unlink(storage);
+               bpf_cgroup_storage_free(storage);
+       }
+
+       mutex_unlock(&cgroup_mutex);
 
        WARN_ON(!RB_EMPTY_ROOT(&map->root));
        WARN_ON(!list_empty(&map->list));
                                    const struct btf_type *key_type,
                                    const struct btf_type *value_type)
 {
-       struct btf_member *m;
-       u32 offset, size;
-
-       /* Key is expected to be of struct bpf_cgroup_storage_key type,
-        * which is:
-        * struct bpf_cgroup_storage_key {
-        *      __u64   cgroup_inode_id;
-        *      __u32   attach_type;
-        * };
-        */
+       if (attach_type_isolated(map)) {
+               struct btf_member *m;
+               u32 offset, size;
+
+               /* Key is expected to be of struct bpf_cgroup_storage_key type,
+                * which is:
+                * struct bpf_cgroup_storage_key {
+                *      __u64   cgroup_inode_id;
+                *      __u32   attach_type;
+                * };
+                */
+
+               /*
+                * Key_type must be a structure with two fields.
+                */
+               if (BTF_INFO_KIND(key_type->info) != BTF_KIND_STRUCT ||
+                   BTF_INFO_VLEN(key_type->info) != 2)
+                       return -EINVAL;
+
+               /*
+                * The first field must be a 64 bit integer at 0 offset.
+                */
+               m = (struct btf_member *)(key_type + 1);
+               size = sizeof_field(struct bpf_cgroup_storage_key, cgroup_inode_id);
+               if (!btf_member_is_reg_int(btf, key_type, m, 0, size))
+                       return -EINVAL;
+
+               /*
+                * The second field must be a 32 bit integer at 64 bit offset.
+                */
+               m++;
+               offset = offsetof(struct bpf_cgroup_storage_key, attach_type);
+               size = sizeof_field(struct bpf_cgroup_storage_key, attach_type);
+               if (!btf_member_is_reg_int(btf, key_type, m, offset, size))
+                       return -EINVAL;
+       } else {
+               u32 int_data;
 
-       /*
-        * Key_type must be a structure with two fields.
-        */
-       if (BTF_INFO_KIND(key_type->info) != BTF_KIND_STRUCT ||
-           BTF_INFO_VLEN(key_type->info) != 2)
-               return -EINVAL;
+               /*
+                * Key is expected to be u64, which stores the cgroup_inode_id
+                */
 
-       /*
-        * The first field must be a 64 bit integer at 0 offset.
-        */
-       m = (struct btf_member *)(key_type + 1);
-       size = sizeof_field(struct bpf_cgroup_storage_key, cgroup_inode_id);
-       if (!btf_member_is_reg_int(btf, key_type, m, 0, size))
-               return -EINVAL;
+               if (BTF_INFO_KIND(key_type->info) != BTF_KIND_INT)
+                       return -EINVAL;
 
-       /*
-        * The second field must be a 32 bit integer at 64 bit offset.
-        */
-       m++;
-       offset = offsetof(struct bpf_cgroup_storage_key, attach_type);
-       size = sizeof_field(struct bpf_cgroup_storage_key, attach_type);
-       if (!btf_member_is_reg_int(btf, key_type, m, offset, size))
-               return -EINVAL;
+               int_data = *(u32 *)(key_type + 1);
+               if (BTF_INT_BITS(int_data) != 64 || BTF_INT_OFFSET(int_data))
+                       return -EINVAL;
+       }
 
        return 0;
 }
 
-static void cgroup_storage_seq_show_elem(struct bpf_map *map, void *_key,
+static void cgroup_storage_seq_show_elem(struct bpf_map *map, void *key,
                                         struct seq_file *m)
 {
        enum bpf_cgroup_storage_type stype = cgroup_storage_type(map);
-       struct bpf_cgroup_storage_key *key = _key;
        struct bpf_cgroup_storage *storage;
        int cpu;
 
 int bpf_cgroup_storage_assign(struct bpf_prog_aux *aux, struct bpf_map *_map)
 {
        enum bpf_cgroup_storage_type stype = cgroup_storage_type(_map);
-       struct bpf_cgroup_storage_map *map = map_to_storage(_map);
-       int ret = -EBUSY;
-
-       spin_lock_bh(&map->lock);
 
-       if (map->aux && map->aux != aux)
-               goto unlock;
        if (aux->cgroup_storage[stype] &&
            aux->cgroup_storage[stype] != _map)
-               goto unlock;
+               return -EBUSY;
 
-       map->aux = aux;
        aux->cgroup_storage[stype] = _map;
-       ret = 0;
-unlock:
-       spin_unlock_bh(&map->lock);
-
-       return ret;
-}
-
-void bpf_cgroup_storage_release(struct bpf_prog_aux *aux, struct bpf_map *_map)
-{
-       enum bpf_cgroup_storage_type stype = cgroup_storage_type(_map);
-       struct bpf_cgroup_storage_map *map = map_to_storage(_map);
-
-       spin_lock_bh(&map->lock);
-       if (map->aux == aux) {
-               WARN_ON(aux->cgroup_storage[stype] != _map);
-               map->aux = NULL;
-               aux->cgroup_storage[stype] = NULL;
-       }
-       spin_unlock_bh(&map->lock);
+       return 0;
 }
 
 static size_t bpf_cgroup_storage_calculate_size(struct bpf_map *map, u32 *pages)
 
        spin_lock_bh(&map->lock);
        WARN_ON(cgroup_storage_insert(map, storage));
-       list_add(&storage->list, &map->list);
+       list_add(&storage->list_map, &map->list);
+       list_add(&storage->list_cg, &cgroup->bpf.storages);
        spin_unlock_bh(&map->lock);
 }
 
        root = &map->root;
        rb_erase(&storage->node, root);
 
-       list_del(&storage->list);
+       list_del(&storage->list_map);
+       list_del(&storage->list_cg);
        spin_unlock_bh(&map->lock);
 }