struct bpf_map * __must_check bpf_map_inc(struct bpf_map *map, bool uref);
 void bpf_map_put_with_uref(struct bpf_map *map);
 void bpf_map_put(struct bpf_map *map);
-int bpf_map_precharge_memlock(u32 pages);
 int bpf_map_charge_memlock(struct bpf_map *map, u32 pages);
 void bpf_map_uncharge_memlock(struct bpf_map *map, u32 pages);
+int bpf_map_charge_init(struct bpf_map_memory *mem, u32 pages);
+void bpf_map_charge_finish(struct bpf_map_memory *mem);
+void bpf_map_charge_move(struct bpf_map_memory *dst,
+                        struct bpf_map_memory *src);
 void *bpf_map_area_alloc(size_t size, int numa_node);
 void bpf_map_area_free(void *base);
 void bpf_map_init_from_attr(struct bpf_map *map, union bpf_attr *attr);
 
        u32 elem_size, index_mask, max_entries;
        bool unpriv = !capable(CAP_SYS_ADMIN);
        u64 cost, array_size, mask64;
+       struct bpf_map_memory mem;
        struct bpf_array *array;
 
        elem_size = round_up(attr->value_size, 8);
        }
        cost = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
 
-       ret = bpf_map_precharge_memlock(cost);
+       ret = bpf_map_charge_init(&mem, cost);
        if (ret < 0)
                return ERR_PTR(ret);
 
        /* allocate all map elements and zero-initialize them */
        array = bpf_map_area_alloc(array_size, numa_node);
-       if (!array)
+       if (!array) {
+               bpf_map_charge_finish(&mem);
                return ERR_PTR(-ENOMEM);
+       }
        array->index_mask = index_mask;
        array->map.unpriv_array = unpriv;
 
        /* copy mandatory map attributes */
        bpf_map_init_from_attr(&array->map, attr);
-       array->map.memory.pages = cost;
+       bpf_map_charge_move(&array->map.memory, &mem);
        array->elem_size = elem_size;
 
        if (percpu && bpf_array_alloc_percpu(array)) {
+               bpf_map_charge_finish(&array->map.memory);
                bpf_map_area_free(array);
                return ERR_PTR(-ENOMEM);
        }
 
        cost += cpu_map_bitmap_size(attr) * num_possible_cpus();
        if (cost >= U32_MAX - PAGE_SIZE)
                goto free_cmap;
-       cmap->map.memory.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
 
        /* Notice returns -EPERM on if map size is larger than memlock limit */
-       ret = bpf_map_precharge_memlock(cmap->map.memory.pages);
+       ret = bpf_map_charge_init(&cmap->map.memory,
+                                 round_up(cost, PAGE_SIZE) >> PAGE_SHIFT);
        if (ret) {
                err = ret;
                goto free_cmap;
        cmap->flush_needed = __alloc_percpu(cpu_map_bitmap_size(attr),
                                            __alignof__(unsigned long));
        if (!cmap->flush_needed)
-               goto free_cmap;
+               goto free_charge;
 
        /* Alloc array for possible remote "destination" CPUs */
        cmap->cpu_map = bpf_map_area_alloc(cmap->map.max_entries *
        return &cmap->map;
 free_percpu:
        free_percpu(cmap->flush_needed);
+free_charge:
+       bpf_map_charge_finish(&cmap->map.memory);
 free_cmap:
        kfree(cmap);
        return ERR_PTR(err);
 
        if (cost >= U32_MAX - PAGE_SIZE)
                goto free_dtab;
 
-       dtab->map.memory.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
-
-       /* if map size is larger than memlock limit, reject it early */
-       err = bpf_map_precharge_memlock(dtab->map.memory.pages);
+       /* if map size is larger than memlock limit, reject it */
+       err = bpf_map_charge_init(&dtab->map.memory,
+                                 round_up(cost, PAGE_SIZE) >> PAGE_SHIFT);
        if (err)
                goto free_dtab;
 
                                                __alignof__(unsigned long),
                                                GFP_KERNEL | __GFP_NOWARN);
        if (!dtab->flush_needed)
-               goto free_dtab;
+               goto free_charge;
 
        dtab->netdev_map = bpf_map_area_alloc(dtab->map.max_entries *
                                              sizeof(struct bpf_dtab_netdev *),
                                              dtab->map.numa_node);
        if (!dtab->netdev_map)
-               goto free_dtab;
+               goto free_charge;
 
        spin_lock(&dev_map_lock);
        list_add_tail_rcu(&dtab->list, &dev_map_list);
        spin_unlock(&dev_map_lock);
 
        return &dtab->map;
+free_charge:
+       bpf_map_charge_finish(&dtab->map.memory);
 free_dtab:
        free_percpu(dtab->flush_needed);
        kfree(dtab);
 
                /* make sure page count doesn't overflow */
                goto free_htab;
 
-       htab->map.memory.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
-
-       /* if map size is larger than memlock limit, reject it early */
-       err = bpf_map_precharge_memlock(htab->map.memory.pages);
+       /* if map size is larger than memlock limit, reject it */
+       err = bpf_map_charge_init(&htab->map.memory,
+                                 round_up(cost, PAGE_SIZE) >> PAGE_SHIFT);
        if (err)
                goto free_htab;
 
                                           sizeof(struct bucket),
                                           htab->map.numa_node);
        if (!htab->buckets)
-               goto free_htab;
+               goto free_charge;
 
        if (htab->map.map_flags & BPF_F_ZERO_SEED)
                htab->hashrnd = 0;
        prealloc_destroy(htab);
 free_buckets:
        bpf_map_area_free(htab->buckets);
+free_charge:
+       bpf_map_charge_finish(&htab->map.memory);
 free_htab:
        kfree(htab);
        return ERR_PTR(err);
 
 {
        int numa_node = bpf_map_attr_numa_node(attr);
        struct bpf_cgroup_storage_map *map;
+       struct bpf_map_memory mem;
        u32 pages;
        int ret;
 
 
        pages = round_up(sizeof(struct bpf_cgroup_storage_map), PAGE_SIZE) >>
                PAGE_SHIFT;
-       ret = bpf_map_precharge_memlock(pages);
+       ret = bpf_map_charge_init(&mem, pages);
        if (ret < 0)
                return ERR_PTR(ret);
 
        map = kmalloc_node(sizeof(struct bpf_cgroup_storage_map),
                           __GFP_ZERO | GFP_USER, numa_node);
-       if (!map)
+       if (!map) {
+               bpf_map_charge_finish(&mem);
                return ERR_PTR(-ENOMEM);
+       }
 
-       map->map.memory.pages = pages;
+       bpf_map_charge_move(&map->map.memory, &mem);
 
        /* copy mandatory map attributes */
        bpf_map_init_from_attr(&map->map, attr);
 
                goto out_err;
        }
 
-       trie->map.memory.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
-
-       ret = bpf_map_precharge_memlock(trie->map.memory.pages);
+       ret = bpf_map_charge_init(&trie->map.memory,
+                                 round_up(cost, PAGE_SIZE) >> PAGE_SHIFT);
        if (ret)
                goto out_err;
 
 
 static struct bpf_map *queue_stack_map_alloc(union bpf_attr *attr)
 {
        int ret, numa_node = bpf_map_attr_numa_node(attr);
+       struct bpf_map_memory mem = {0};
        struct bpf_queue_stack *qs;
        u64 size, queue_size, cost;
 
 
        cost = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
 
-       ret = bpf_map_precharge_memlock(cost);
+       ret = bpf_map_charge_init(&mem, cost);
        if (ret < 0)
                return ERR_PTR(ret);
 
        qs = bpf_map_area_alloc(queue_size, numa_node);
-       if (!qs)
+       if (!qs) {
+               bpf_map_charge_finish(&mem);
                return ERR_PTR(-ENOMEM);
+       }
 
        memset(qs, 0, sizeof(*qs));
 
        bpf_map_init_from_attr(&qs->map, attr);
 
-       qs->map.memory.pages = cost;
+       bpf_map_charge_move(&qs->map.memory, &mem);
        qs->size = size;
 
        raw_spin_lock_init(&qs->lock);
 
 {
        int err, numa_node = bpf_map_attr_numa_node(attr);
        struct reuseport_array *array;
+       struct bpf_map_memory mem;
        u64 cost, array_size;
 
        if (!capable(CAP_SYS_ADMIN))
                return ERR_PTR(-ENOMEM);
        cost = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
 
-       err = bpf_map_precharge_memlock(cost);
+       err = bpf_map_charge_init(&mem, cost);
        if (err)
                return ERR_PTR(err);
 
        /* allocate all map elements and zero-initialize them */
        array = bpf_map_area_alloc(array_size, numa_node);
-       if (!array)
+       if (!array) {
+               bpf_map_charge_finish(&mem);
                return ERR_PTR(-ENOMEM);
+       }
 
        /* copy mandatory map attributes */
        bpf_map_init_from_attr(&array->map, attr);
-       array->map.memory.pages = cost;
+       bpf_map_charge_move(&array->map.memory, &mem);
 
        return &array->map;
 }
 
 {
        u32 value_size = attr->value_size;
        struct bpf_stack_map *smap;
+       struct bpf_map_memory mem;
        u64 cost, n_buckets;
        int err;
 
        n_buckets = roundup_pow_of_two(attr->max_entries);
 
        cost = n_buckets * sizeof(struct stack_map_bucket *) + sizeof(*smap);
+       if (cost >= U32_MAX - PAGE_SIZE)
+               return ERR_PTR(-E2BIG);
+       cost += n_buckets * (value_size + sizeof(struct stack_map_bucket));
        if (cost >= U32_MAX - PAGE_SIZE)
                return ERR_PTR(-E2BIG);
 
+       err = bpf_map_charge_init(&mem,
+                                 round_up(cost, PAGE_SIZE) >> PAGE_SHIFT);
+       if (err)
+               return ERR_PTR(err);
+
        smap = bpf_map_area_alloc(cost, bpf_map_attr_numa_node(attr));
-       if (!smap)
+       if (!smap) {
+               bpf_map_charge_finish(&mem);
                return ERR_PTR(-ENOMEM);
-
-       err = -E2BIG;
-       cost += n_buckets * (value_size + sizeof(struct stack_map_bucket));
-       if (cost >= U32_MAX - PAGE_SIZE)
-               goto free_smap;
+       }
 
        bpf_map_init_from_attr(&smap->map, attr);
        smap->map.value_size = value_size;
        smap->n_buckets = n_buckets;
-       smap->map.memory.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
-
-       err = bpf_map_precharge_memlock(smap->map.memory.pages);
-       if (err)
-               goto free_smap;
 
        err = get_callchain_buffers(sysctl_perf_event_max_stack);
        if (err)
-               goto free_smap;
+               goto free_charge;
 
        err = prealloc_elems_and_freelist(smap);
        if (err)
                goto put_buffers;
 
+       bpf_map_charge_move(&smap->map.memory, &mem);
+
        return &smap->map;
 
 put_buffers:
        put_callchain_buffers();
-free_smap:
+free_charge:
+       bpf_map_charge_finish(&mem);
        bpf_map_area_free(smap);
        return ERR_PTR(err);
 }
 
        map->numa_node = bpf_map_attr_numa_node(attr);
 }
 
-int bpf_map_precharge_memlock(u32 pages)
-{
-       struct user_struct *user = get_current_user();
-       unsigned long memlock_limit, cur;
-
-       memlock_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
-       cur = atomic_long_read(&user->locked_vm);
-       free_uid(user);
-       if (cur + pages > memlock_limit)
-               return -EPERM;
-       return 0;
-}
-
 static int bpf_charge_memlock(struct user_struct *user, u32 pages)
 {
        unsigned long memlock_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
 
 static void bpf_uncharge_memlock(struct user_struct *user, u32 pages)
 {
-       atomic_long_sub(pages, &user->locked_vm);
+       if (user)
+               atomic_long_sub(pages, &user->locked_vm);
 }
 
-static int bpf_map_init_memlock(struct bpf_map *map)
+int bpf_map_charge_init(struct bpf_map_memory *mem, u32 pages)
 {
        struct user_struct *user = get_current_user();
        int ret;
 
-       ret = bpf_charge_memlock(user, map->memory.pages);
+       ret = bpf_charge_memlock(user, pages);
        if (ret) {
                free_uid(user);
                return ret;
        }
-       map->memory.user = user;
-       return ret;
+
+       mem->pages = pages;
+       mem->user = user;
+
+       return 0;
 }
 
-static void bpf_map_release_memlock(struct bpf_map *map)
+void bpf_map_charge_finish(struct bpf_map_memory *mem)
 {
-       struct user_struct *user = map->memory.user;
+       bpf_uncharge_memlock(mem->user, mem->pages);
+       free_uid(mem->user);
+}
 
-       bpf_uncharge_memlock(user, map->memory.pages);
-       free_uid(user);
+void bpf_map_charge_move(struct bpf_map_memory *dst,
+                        struct bpf_map_memory *src)
+{
+       *dst = *src;
+
+       /* Make sure src will not be used for the redundant uncharging. */
+       memset(src, 0, sizeof(struct bpf_map_memory));
 }
 
 int bpf_map_charge_memlock(struct bpf_map *map, u32 pages)
 static void bpf_map_free_deferred(struct work_struct *work)
 {
        struct bpf_map *map = container_of(work, struct bpf_map, work);
+       struct bpf_map_memory mem;
 
-       bpf_map_release_memlock(map);
+       bpf_map_charge_move(&mem, &map->memory);
        security_bpf_map_free(map);
        /* implementation dependent freeing */
        map->ops->map_free(map);
+       bpf_map_charge_finish(&mem);
 }
 
 static void bpf_map_put_uref(struct bpf_map *map)
 static int map_create(union bpf_attr *attr)
 {
        int numa_node = bpf_map_attr_numa_node(attr);
+       struct bpf_map_memory mem;
        struct bpf_map *map;
        int f_flags;
        int err;
 
        err = bpf_obj_name_cpy(map->name, attr->map_name);
        if (err)
-               goto free_map_nouncharge;
+               goto free_map;
 
        atomic_set(&map->refcnt, 1);
        atomic_set(&map->usercnt, 1);
 
                if (!attr->btf_value_type_id) {
                        err = -EINVAL;
-                       goto free_map_nouncharge;
+                       goto free_map;
                }
 
                btf = btf_get_by_fd(attr->btf_fd);
                if (IS_ERR(btf)) {
                        err = PTR_ERR(btf);
-                       goto free_map_nouncharge;
+                       goto free_map;
                }
 
                err = map_check_btf(map, btf, attr->btf_key_type_id,
                                    attr->btf_value_type_id);
                if (err) {
                        btf_put(btf);
-                       goto free_map_nouncharge;
+                       goto free_map;
                }
 
                map->btf = btf;
 
        err = security_bpf_map_alloc(map);
        if (err)
-               goto free_map_nouncharge;
-
-       err = bpf_map_init_memlock(map);
-       if (err)
-               goto free_map_sec;
+               goto free_map;
 
        err = bpf_map_alloc_id(map);
        if (err)
-               goto free_map;
+               goto free_map_sec;
 
        err = bpf_map_new_fd(map, f_flags);
        if (err < 0) {
 
        return err;
 
-free_map:
-       bpf_map_release_memlock(map);
 free_map_sec:
        security_bpf_map_free(map);
-free_map_nouncharge:
+free_map:
        btf_put(map->btf);
+       bpf_map_charge_move(&mem, &map->memory);
        map->ops->map_free(map);
+       bpf_map_charge_finish(&mem);
        return err;
 }
 
 
        if (cost >= U32_MAX - PAGE_SIZE)
                goto free_m;
 
-       m->map.memory.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
-
        /* Notice returns -EPERM on if map size is larger than memlock limit */
-       err = bpf_map_precharge_memlock(m->map.memory.pages);
+       err = bpf_map_charge_init(&m->map.memory,
+                                 round_up(cost, PAGE_SIZE) >> PAGE_SHIFT);
        if (err)
                goto free_m;
 
 
        m->flush_list = alloc_percpu(struct list_head);
        if (!m->flush_list)
-               goto free_m;
+               goto free_charge;
 
        for_each_possible_cpu(cpu)
                INIT_LIST_HEAD(per_cpu_ptr(m->flush_list, cpu));
 
 free_percpu:
        free_percpu(m->flush_list);
+free_charge:
+       bpf_map_charge_finish(&m->map.memory);
 free_m:
        kfree(m);
        return ERR_PTR(err);
 
        cost = sizeof(*smap->buckets) * nbuckets + sizeof(*smap);
        pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
 
-       ret = bpf_map_precharge_memlock(pages);
-       if (ret < 0)
+       ret = bpf_map_charge_init(&smap->map.memory, pages);
+       if (ret < 0) {
+               kfree(smap);
                return ERR_PTR(ret);
+       }
 
        smap->buckets = kvcalloc(sizeof(*smap->buckets), nbuckets,
                                 GFP_USER | __GFP_NOWARN);
        if (!smap->buckets) {
+               bpf_map_charge_finish(&smap->map.memory);
                kfree(smap);
                return ERR_PTR(-ENOMEM);
        }
        smap->elem_size = sizeof(struct bpf_sk_storage_elem) + attr->value_size;
        smap->cache_idx = (unsigned int)atomic_inc_return(&cache_idx) %
                BPF_SK_STORAGE_CACHE_SIZE;
-       smap->map.memory.pages = pages;
 
        return &smap->map;
 }
 
                goto free_stab;
        }
 
-       stab->map.memory.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
-       err = bpf_map_precharge_memlock(stab->map.memory.pages);
+       err = bpf_map_charge_init(&stab->map.memory,
+                                 round_up(cost, PAGE_SIZE) >> PAGE_SHIFT);
        if (err)
                goto free_stab;
 
        if (stab->sks)
                return &stab->map;
        err = -ENOMEM;
+       bpf_map_charge_finish(&stab->map.memory);
 free_stab:
        kfree(stab);
        return ERR_PTR(err);