#include <linux/bpf.h>
 #include <linux/jhash.h>
 #include <linux/filter.h>
+#include <linux/rculist_nulls.h>
 #include "percpu_freelist.h"
 #include "bpf_lru_list.h"
 
 struct bucket {
-       struct hlist_head head;
+       struct hlist_nulls_head head;
        raw_spinlock_t lock;
 };
 
 /* each htab element is struct htab_elem + key + value */
 struct htab_elem {
        union {
-               struct hlist_node hash_node;
+               struct hlist_nulls_node hash_node;
                struct {
                        void *padding;
                        union {
                goto free_htab;
 
        for (i = 0; i < htab->n_buckets; i++) {
-               INIT_HLIST_HEAD(&htab->buckets[i].head);
+               INIT_HLIST_NULLS_HEAD(&htab->buckets[i].head, i);
                raw_spin_lock_init(&htab->buckets[i].lock);
        }
 
        return &htab->buckets[hash & (htab->n_buckets - 1)];
 }
 
-static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash)
+static inline struct hlist_nulls_head *select_bucket(struct bpf_htab *htab, u32 hash)
 {
        return &__select_bucket(htab, hash)->head;
 }
 
-static struct htab_elem *lookup_elem_raw(struct hlist_head *head, u32 hash,
+/* this lookup function can only be called with bucket lock taken */
+static struct htab_elem *lookup_elem_raw(struct hlist_nulls_head *head, u32 hash,
                                         void *key, u32 key_size)
 {
+       struct hlist_nulls_node *n;
        struct htab_elem *l;
 
-       hlist_for_each_entry_rcu(l, head, hash_node)
+       hlist_nulls_for_each_entry_rcu(l, n, head, hash_node)
                if (l->hash == hash && !memcmp(&l->key, key, key_size))
                        return l;
 
        return NULL;
 }
 
+/* can be called without bucket lock. it will repeat the loop in
+ * the unlikely event when elements moved from one bucket into another
+ * while link list is being walked
+ */
+static struct htab_elem *lookup_nulls_elem_raw(struct hlist_nulls_head *head,
+                                              u32 hash, void *key,
+                                              u32 key_size, u32 n_buckets)
+{
+       struct hlist_nulls_node *n;
+       struct htab_elem *l;
+
+again:
+       hlist_nulls_for_each_entry_rcu(l, n, head, hash_node)
+               if (l->hash == hash && !memcmp(&l->key, key, key_size))
+                       return l;
+
+       if (unlikely(get_nulls_value(n) != (hash & (n_buckets - 1))))
+               goto again;
+
+       return NULL;
+}
+
 /* Called from syscall or from eBPF program */
 static void *__htab_map_lookup_elem(struct bpf_map *map, void *key)
 {
        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
-       struct hlist_head *head;
+       struct hlist_nulls_head *head;
        struct htab_elem *l;
        u32 hash, key_size;
 
 
        head = select_bucket(htab, hash);
 
-       l = lookup_elem_raw(head, hash, key, key_size);
+       l = lookup_nulls_elem_raw(head, hash, key, key_size, htab->n_buckets);
 
        return l;
 }
 static bool htab_lru_map_delete_node(void *arg, struct bpf_lru_node *node)
 {
        struct bpf_htab *htab = (struct bpf_htab *)arg;
-       struct htab_elem *l, *tgt_l;
-       struct hlist_head *head;
+       struct htab_elem *l = NULL, *tgt_l;
+       struct hlist_nulls_head *head;
+       struct hlist_nulls_node *n;
        unsigned long flags;
        struct bucket *b;
 
 
        raw_spin_lock_irqsave(&b->lock, flags);
 
-       hlist_for_each_entry_rcu(l, head, hash_node)
+       hlist_nulls_for_each_entry_rcu(l, n, head, hash_node)
                if (l == tgt_l) {
-                       hlist_del_rcu(&l->hash_node);
+                       hlist_nulls_del_rcu(&l->hash_node);
                        break;
                }
 
 static int htab_map_get_next_key(struct bpf_map *map, void *key, void *next_key)
 {
        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
-       struct hlist_head *head;
+       struct hlist_nulls_head *head;
        struct htab_elem *l, *next_l;
        u32 hash, key_size;
        int i;
        head = select_bucket(htab, hash);
 
        /* lookup the key */
-       l = lookup_elem_raw(head, hash, key, key_size);
+       l = lookup_nulls_elem_raw(head, hash, key, key_size, htab->n_buckets);
 
        if (!l) {
                i = 0;
        }
 
        /* key was found, get next key in the same bucket */
-       next_l = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&l->hash_node)),
+       next_l = hlist_nulls_entry_safe(rcu_dereference_raw(hlist_nulls_next_rcu(&l->hash_node)),
                                  struct htab_elem, hash_node);
 
        if (next_l) {
                head = select_bucket(htab, i);
 
                /* pick first element in the bucket */
-               next_l = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)),
+               next_l = hlist_nulls_entry_safe(rcu_dereference_raw(hlist_nulls_first_rcu(head)),
                                          struct htab_elem, hash_node);
                if (next_l) {
                        /* if it's not empty, just return it */
 {
        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
        struct htab_elem *l_new = NULL, *l_old;
-       struct hlist_head *head;
+       struct hlist_nulls_head *head;
        unsigned long flags;
        struct bucket *b;
        u32 key_size, hash;
        /* add new element to the head of the list, so that
         * concurrent search will find it before old elem
         */
-       hlist_add_head_rcu(&l_new->hash_node, head);
+       hlist_nulls_add_head_rcu(&l_new->hash_node, head);
        if (l_old) {
-               hlist_del_rcu(&l_old->hash_node);
+               hlist_nulls_del_rcu(&l_old->hash_node);
                free_htab_elem(htab, l_old);
        }
        ret = 0;
 {
        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
        struct htab_elem *l_new, *l_old = NULL;
-       struct hlist_head *head;
+       struct hlist_nulls_head *head;
        unsigned long flags;
        struct bucket *b;
        u32 key_size, hash;
        /* add new element to the head of the list, so that
         * concurrent search will find it before old elem
         */
-       hlist_add_head_rcu(&l_new->hash_node, head);
+       hlist_nulls_add_head_rcu(&l_new->hash_node, head);
        if (l_old) {
                bpf_lru_node_set_ref(&l_new->lru_node);
-               hlist_del_rcu(&l_old->hash_node);
+               hlist_nulls_del_rcu(&l_old->hash_node);
        }
        ret = 0;
 
 {
        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
        struct htab_elem *l_new = NULL, *l_old;
-       struct hlist_head *head;
+       struct hlist_nulls_head *head;
        unsigned long flags;
        struct bucket *b;
        u32 key_size, hash;
                        ret = PTR_ERR(l_new);
                        goto err;
                }
-               hlist_add_head_rcu(&l_new->hash_node, head);
+               hlist_nulls_add_head_rcu(&l_new->hash_node, head);
        }
        ret = 0;
 err:
 {
        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
        struct htab_elem *l_new = NULL, *l_old;
-       struct hlist_head *head;
+       struct hlist_nulls_head *head;
        unsigned long flags;
        struct bucket *b;
        u32 key_size, hash;
        } else {
                pcpu_copy_value(htab, htab_elem_get_ptr(l_new, key_size),
                                value, onallcpus);
-               hlist_add_head_rcu(&l_new->hash_node, head);
+               hlist_nulls_add_head_rcu(&l_new->hash_node, head);
                l_new = NULL;
        }
        ret = 0;
 static int htab_map_delete_elem(struct bpf_map *map, void *key)
 {
        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
-       struct hlist_head *head;
+       struct hlist_nulls_head *head;
        struct bucket *b;
        struct htab_elem *l;
        unsigned long flags;
        l = lookup_elem_raw(head, hash, key, key_size);
 
        if (l) {
-               hlist_del_rcu(&l->hash_node);
+               hlist_nulls_del_rcu(&l->hash_node);
                free_htab_elem(htab, l);
                ret = 0;
        }
 static int htab_lru_map_delete_elem(struct bpf_map *map, void *key)
 {
        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
-       struct hlist_head *head;
+       struct hlist_nulls_head *head;
        struct bucket *b;
        struct htab_elem *l;
        unsigned long flags;
        l = lookup_elem_raw(head, hash, key, key_size);
 
        if (l) {
-               hlist_del_rcu(&l->hash_node);
+               hlist_nulls_del_rcu(&l->hash_node);
                ret = 0;
        }
 
        int i;
 
        for (i = 0; i < htab->n_buckets; i++) {
-               struct hlist_head *head = select_bucket(htab, i);
-               struct hlist_node *n;
+               struct hlist_nulls_head *head = select_bucket(htab, i);
+               struct hlist_nulls_node *n;
                struct htab_elem *l;
 
-               hlist_for_each_entry_safe(l, n, head, hash_node) {
-                       hlist_del_rcu(&l->hash_node);
+               hlist_nulls_for_each_entry_safe(l, n, head, hash_node) {
+                       hlist_nulls_del_rcu(&l->hash_node);
                        if (l->state != HTAB_EXTRA_ELEM_USED)
                                htab_elem_free(htab, l);
                }