#include <linux/vmalloc.h>
 #include <linux/stacktrace.h>
 #include <linux/perf_event.h>
+#include "percpu_freelist.h"
 
 struct stack_map_bucket {
-       struct rcu_head rcu;
+       struct pcpu_freelist_node fnode;
        u32 hash;
        u32 nr;
        u64 ip[];
 
 struct bpf_stack_map {
        struct bpf_map map;
+       void *elems;
+       struct pcpu_freelist freelist;
        u32 n_buckets;
-       struct stack_map_bucket __rcu *buckets[];
+       struct stack_map_bucket *buckets[];
 };
 
+static int prealloc_elems_and_freelist(struct bpf_stack_map *smap)
+{
+       u32 elem_size = sizeof(struct stack_map_bucket) + smap->map.value_size;
+       int err;
+
+       smap->elems = vzalloc(elem_size * smap->map.max_entries);
+       if (!smap->elems)
+               return -ENOMEM;
+
+       err = pcpu_freelist_init(&smap->freelist);
+       if (err)
+               goto free_elems;
+
+       pcpu_freelist_populate(&smap->freelist, smap->elems, elem_size,
+                              smap->map.max_entries);
+       return 0;
+
+free_elems:
+       vfree(smap->elems);
+       return err;
+}
+
 /* Called from syscall */
 static struct bpf_map *stack_map_alloc(union bpf_attr *attr)
 {
        smap->n_buckets = n_buckets;
        smap->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
 
+       err = bpf_map_precharge_memlock(smap->map.pages);
+       if (err)
+               goto free_smap;
+
        err = get_callchain_buffers();
        if (err)
                goto free_smap;
 
+       err = prealloc_elems_and_freelist(smap);
+       if (err)
+               goto put_buffers;
+
        return &smap->map;
 
+put_buffers:
+       put_callchain_buffers();
 free_smap:
        kvfree(smap);
        return ERR_PTR(err);
        ips = trace->ip + skip + init_nr;
        hash = jhash2((u32 *)ips, trace_len / sizeof(u32), 0);
        id = hash & (smap->n_buckets - 1);
-       bucket = rcu_dereference(smap->buckets[id]);
+       bucket = READ_ONCE(smap->buckets[id]);
 
        if (bucket && bucket->hash == hash) {
                if (flags & BPF_F_FAST_STACK_CMP)
        if (bucket && !(flags & BPF_F_REUSE_STACKID))
                return -EEXIST;
 
-       new_bucket = kmalloc(sizeof(struct stack_map_bucket) + map->value_size,
-                            GFP_ATOMIC | __GFP_NOWARN);
+       new_bucket = (struct stack_map_bucket *)
+               pcpu_freelist_pop(&smap->freelist);
        if (unlikely(!new_bucket))
                return -ENOMEM;
 
        memcpy(new_bucket->ip, ips, trace_len);
-       memset(new_bucket->ip + trace_len / 8, 0, map->value_size - trace_len);
        new_bucket->hash = hash;
        new_bucket->nr = trace_nr;
 
        old_bucket = xchg(&smap->buckets[id], new_bucket);
        if (old_bucket)
-               kfree_rcu(old_bucket, rcu);
+               pcpu_freelist_push(&smap->freelist, &old_bucket->fnode);
        return id;
 }
 
        .arg3_type      = ARG_ANYTHING,
 };
 
-/* Called from syscall or from eBPF program */
+/* Called from eBPF program */
 static void *stack_map_lookup_elem(struct bpf_map *map, void *key)
+{
+       return NULL;
+}
+
+/* Called from syscall */
+int bpf_stackmap_copy(struct bpf_map *map, void *key, void *value)
 {
        struct bpf_stack_map *smap = container_of(map, struct bpf_stack_map, map);
-       struct stack_map_bucket *bucket;
-       u32 id = *(u32 *)key;
+       struct stack_map_bucket *bucket, *old_bucket;
+       u32 id = *(u32 *)key, trace_len;
 
        if (unlikely(id >= smap->n_buckets))
-               return NULL;
-       bucket = rcu_dereference(smap->buckets[id]);
-       return bucket ? bucket->ip : NULL;
+               return -ENOENT;
+
+       bucket = xchg(&smap->buckets[id], NULL);
+       if (!bucket)
+               return -ENOENT;
+
+       trace_len = bucket->nr * sizeof(u64);
+       memcpy(value, bucket->ip, trace_len);
+       memset(value + trace_len, 0, map->value_size - trace_len);
+
+       old_bucket = xchg(&smap->buckets[id], bucket);
+       if (old_bucket)
+               pcpu_freelist_push(&smap->freelist, &old_bucket->fnode);
+       return 0;
 }
 
 static int stack_map_get_next_key(struct bpf_map *map, void *key, void *next_key)
 
        old_bucket = xchg(&smap->buckets[id], NULL);
        if (old_bucket) {
-               kfree_rcu(old_bucket, rcu);
+               pcpu_freelist_push(&smap->freelist, &old_bucket->fnode);
                return 0;
        } else {
                return -ENOENT;
 static void stack_map_free(struct bpf_map *map)
 {
        struct bpf_stack_map *smap = container_of(map, struct bpf_stack_map, map);
-       int i;
 
+       /* wait for bpf programs to complete before freeing stack map */
        synchronize_rcu();
 
-       for (i = 0; i < smap->n_buckets; i++)
-               if (smap->buckets[i])
-                       kfree_rcu(smap->buckets[i], rcu);
+       vfree(smap->elems);
+       pcpu_freelist_destroy(&smap->freelist);
        kvfree(smap);
        put_callchain_buffers();
 }