#include "allowedips.h"
 #include "peer.h"
 
+static struct kmem_cache *node_cache;
+
 static void swap_endian(u8 *dst, const u8 *src, u8 bits)
 {
        if (bits == 32) {
        }
 }
 
+static void node_free_rcu(struct rcu_head *rcu)
+{
+       kmem_cache_free(node_cache, container_of(rcu, struct allowedips_node, rcu));
+}
+
 static void root_free_rcu(struct rcu_head *rcu)
 {
        struct allowedips_node *node, *stack[128] = {
        while (len > 0 && (node = stack[--len])) {
                push_rcu(stack, node->bit[0], &len);
                push_rcu(stack, node->bit[1], &len);
-               kfree(node);
+               kmem_cache_free(node_cache, node);
        }
 }
 
                return -EINVAL;
 
        if (!rcu_access_pointer(*trie)) {
-               node = kzalloc(sizeof(*node), GFP_KERNEL);
+               node = kmem_cache_zalloc(node_cache, GFP_KERNEL);
                if (unlikely(!node))
                        return -ENOMEM;
                RCU_INIT_POINTER(node->peer, peer);
                return 0;
        }
 
-       newnode = kzalloc(sizeof(*newnode), GFP_KERNEL);
+       newnode = kmem_cache_zalloc(node_cache, GFP_KERNEL);
        if (unlikely(!newnode))
                return -ENOMEM;
        RCU_INIT_POINTER(newnode->peer, peer);
                return 0;
        }
 
-       node = kzalloc(sizeof(*node), GFP_KERNEL);
+       node = kmem_cache_zalloc(node_cache, GFP_KERNEL);
        if (unlikely(!node)) {
                list_del(&newnode->peer_list);
-               kfree(newnode);
+               kmem_cache_free(node_cache, newnode);
                return -ENOMEM;
        }
        INIT_LIST_HEAD(&node->peer_list);
                if (child)
                        child->parent_bit = node->parent_bit;
                *rcu_dereference_protected(node->parent_bit, lockdep_is_held(lock)) = child;
-               kfree_rcu(node, rcu);
+               call_rcu(&node->rcu, node_free_rcu);
 
                /* TODO: Note that we currently don't walk up and down in order to
                 * free any potential filler nodes. This means that this function
        return NULL;
 }
 
+int __init wg_allowedips_slab_init(void)
+{
+       node_cache = KMEM_CACHE(allowedips_node, 0);
+       return node_cache ? 0 : -ENOMEM;
+}
+
+void wg_allowedips_slab_uninit(void)
+{
+       rcu_barrier();
+       kmem_cache_destroy(node_cache);
+}
+
 #include "selftest/allowedips.c"
 
        u8 bits[16] __aligned(__alignof(u64));
 
        /* Keep rarely used members at bottom to be beyond cache line. */
-       struct allowedips_node *__rcu *parent_bit; /* XXX: this puts us at 68->128 bytes instead of 60->64 bytes!! */
+       struct allowedips_node *__rcu *parent_bit;
        union {
                struct list_head peer_list;
                struct rcu_head rcu;
 bool wg_allowedips_selftest(void);
 #endif
 
+int wg_allowedips_slab_init(void);
+void wg_allowedips_slab_uninit(void);
+
 #endif /* _WG_ALLOWEDIPS_H */