u32 n_buckets;
        u32 elem_size;
        struct bpf_sock_progs progs;
+       struct rcu_head rcu;
 };
 
 struct htab_elem {
 struct smap_psock_map_entry {
        struct list_head list;
        struct sock **entry;
-       struct htab_elem *hash_link;
-       struct bpf_htab *htab;
+       struct htab_elem __rcu *hash_link;
+       struct bpf_htab __rcu *htab;
 };
 
 struct smap_psock {
        struct bpf_prog *bpf_parse;
        struct bpf_prog *bpf_verdict;
        struct list_head maps;
+       spinlock_t maps_lock;
 
        /* Back reference used when sock callback trigger sockmap operations */
        struct sock *sock;
        rcu_read_unlock();
 }
 
+static struct htab_elem *lookup_elem_raw(struct hlist_head *head,
+                                        u32 hash, void *key, u32 key_size)
+{
+       struct htab_elem *l;
+
+       hlist_for_each_entry_rcu(l, head, hash_node) {
+               if (l->hash == hash && !memcmp(&l->key, key, key_size))
+                       return l;
+       }
+
+       return NULL;
+}
+
+static inline struct bucket *__select_bucket(struct bpf_htab *htab, u32 hash)
+{
+       return &htab->buckets[hash & (htab->n_buckets - 1)];
+}
+
+static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash)
+{
+       return &__select_bucket(htab, hash)->head;
+}
+
 static void free_htab_elem(struct bpf_htab *htab, struct htab_elem *l)
 {
        atomic_dec(&htab->count);
        kfree_rcu(l, rcu);
 }
 
+static struct smap_psock_map_entry *psock_map_pop(struct sock *sk,
+                                                 struct smap_psock *psock)
+{
+       struct smap_psock_map_entry *e;
+
+       spin_lock_bh(&psock->maps_lock);
+       e = list_first_entry_or_null(&psock->maps,
+                                    struct smap_psock_map_entry,
+                                    list);
+       if (e)
+               list_del(&e->list);
+       spin_unlock_bh(&psock->maps_lock);
+       return e;
+}
+
 static void bpf_tcp_close(struct sock *sk, long timeout)
 {
        void (*close_fun)(struct sock *sk, long timeout);
-       struct smap_psock_map_entry *e, *tmp;
+       struct smap_psock_map_entry *e;
        struct sk_msg_buff *md, *mtmp;
        struct smap_psock *psock;
        struct sock *osk;
         */
        close_fun = psock->save_close;
 
-       write_lock_bh(&sk->sk_callback_lock);
        if (psock->cork) {
                free_start_sg(psock->sock, psock->cork);
                kfree(psock->cork);
                kfree(md);
        }
 
-       list_for_each_entry_safe(e, tmp, &psock->maps, list) {
+       e = psock_map_pop(sk, psock);
+       while (e) {
                if (e->entry) {
                        osk = cmpxchg(e->entry, sk, NULL);
                        if (osk == sk) {
-                               list_del(&e->list);
                                smap_release_sock(psock, sk);
                        }
                } else {
-                       hlist_del_rcu(&e->hash_link->hash_node);
-                       smap_release_sock(psock, e->hash_link->sk);
-                       free_htab_elem(e->htab, e->hash_link);
+                       struct htab_elem *link = rcu_dereference(e->hash_link);
+                       struct bpf_htab *htab = rcu_dereference(e->htab);
+                       struct hlist_head *head;
+                       struct htab_elem *l;
+                       struct bucket *b;
+
+                       b = __select_bucket(htab, link->hash);
+                       head = &b->head;
+                       raw_spin_lock_bh(&b->lock);
+                       l = lookup_elem_raw(head,
+                                           link->hash, link->key,
+                                           htab->map.key_size);
+                       /* If another thread deleted this object skip deletion.
+                        * The refcnt on psock may or may not be zero.
+                        */
+                       if (l) {
+                               hlist_del_rcu(&link->hash_node);
+                               smap_release_sock(psock, link->sk);
+                               free_htab_elem(htab, link);
+                       }
+                       raw_spin_unlock_bh(&b->lock);
                }
+               e = psock_map_pop(sk, psock);
        }
-       write_unlock_bh(&sk->sk_callback_lock);
        rcu_read_unlock();
        close_fun(sk, timeout);
 }
 {
        if (refcount_dec_and_test(&psock->refcnt)) {
                tcp_cleanup_ulp(sock);
+               write_lock_bh(&sock->sk_callback_lock);
                smap_stop_sock(psock, sock);
+               write_unlock_bh(&sock->sk_callback_lock);
                clear_bit(SMAP_TX_RUNNING, &psock->state);
                rcu_assign_sk_user_data(sock, NULL);
                call_rcu_sched(&psock->rcu, smap_destroy_psock);
        INIT_LIST_HEAD(&psock->maps);
        INIT_LIST_HEAD(&psock->ingress);
        refcount_set(&psock->refcnt, 1);
+       spin_lock_init(&psock->maps_lock);
 
        rcu_assign_sk_user_data(sock, psock);
        sock_hold(sock);
 {
        struct smap_psock_map_entry *e, *tmp;
 
+       spin_lock_bh(&psock->maps_lock);
        list_for_each_entry_safe(e, tmp, &psock->maps, list) {
                if (e->entry == entry)
                        list_del(&e->list);
        }
+       spin_unlock_bh(&psock->maps_lock);
 }
 
 static void smap_list_hash_remove(struct smap_psock *psock,
 {
        struct smap_psock_map_entry *e, *tmp;
 
+       spin_lock_bh(&psock->maps_lock);
        list_for_each_entry_safe(e, tmp, &psock->maps, list) {
-               struct htab_elem *c = e->hash_link;
+               struct htab_elem *c = rcu_dereference(e->hash_link);
 
                if (c == hash_link)
                        list_del(&e->list);
        }
+       spin_unlock_bh(&psock->maps_lock);
 }
 
 static void sock_map_free(struct bpf_map *map)
                if (!sock)
                        continue;
 
-               write_lock_bh(&sock->sk_callback_lock);
                psock = smap_psock_sk(sock);
                /* This check handles a racing sock event that can get the
                 * sk_callback_lock before this case but after xchg happens
                        smap_list_map_remove(psock, &stab->sock_map[i]);
                        smap_release_sock(psock, sock);
                }
-               write_unlock_bh(&sock->sk_callback_lock);
        }
        rcu_read_unlock();
 
        if (!sock)
                return -EINVAL;
 
-       write_lock_bh(&sock->sk_callback_lock);
        psock = smap_psock_sk(sock);
        if (!psock)
                goto out;
        smap_list_map_remove(psock, &stab->sock_map[k]);
        smap_release_sock(psock, sock);
 out:
-       write_unlock_bh(&sock->sk_callback_lock);
        return 0;
 }
 
                }
        }
 
-       write_lock_bh(&sock->sk_callback_lock);
        psock = smap_psock_sk(sock);
 
        /* 2. Do not allow inheriting programs if psock exists and has
                if (err)
                        goto out_free;
                smap_init_progs(psock, verdict, parse);
+               write_lock_bh(&sock->sk_callback_lock);
                smap_start_sock(psock, sock);
+               write_unlock_bh(&sock->sk_callback_lock);
        }
 
        /* 4. Place psock in sockmap for use and stop any programs on
         */
        if (map_link) {
                e->entry = map_link;
+               spin_lock_bh(&psock->maps_lock);
                list_add_tail(&e->list, &psock->maps);
+               spin_unlock_bh(&psock->maps_lock);
        }
-       write_unlock_bh(&sock->sk_callback_lock);
        return err;
 out_free:
        smap_release_sock(psock, sock);
        }
        if (tx_msg)
                bpf_prog_put(tx_msg);
-       write_unlock_bh(&sock->sk_callback_lock);
        kfree(e);
        return err;
 }
        if (osock) {
                struct smap_psock *opsock = smap_psock_sk(osock);
 
-               write_lock_bh(&osock->sk_callback_lock);
                smap_list_map_remove(opsock, &stab->sock_map[i]);
                smap_release_sock(opsock, osock);
-               write_unlock_bh(&osock->sk_callback_lock);
        }
 out:
        return err;
        return ERR_PTR(err);
 }
 
-static inline struct bucket *__select_bucket(struct bpf_htab *htab, u32 hash)
+static void __bpf_htab_free(struct rcu_head *rcu)
 {
-       return &htab->buckets[hash & (htab->n_buckets - 1)];
-}
+       struct bpf_htab *htab;
 
-static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash)
-{
-       return &__select_bucket(htab, hash)->head;
+       htab = container_of(rcu, struct bpf_htab, rcu);
+       bpf_map_area_free(htab->buckets);
+       kfree(htab);
 }
 
 static void sock_hash_free(struct bpf_map *map)
         */
        rcu_read_lock();
        for (i = 0; i < htab->n_buckets; i++) {
-               struct hlist_head *head = select_bucket(htab, i);
+               struct bucket *b = __select_bucket(htab, i);
+               struct hlist_head *head;
                struct hlist_node *n;
                struct htab_elem *l;
 
+               raw_spin_lock_bh(&b->lock);
+               head = &b->head;
                hlist_for_each_entry_safe(l, n, head, hash_node) {
                        struct sock *sock = l->sk;
                        struct smap_psock *psock;
 
                        hlist_del_rcu(&l->hash_node);
-                       write_lock_bh(&sock->sk_callback_lock);
                        psock = smap_psock_sk(sock);
                        /* This check handles a racing sock event that can get
                         * the sk_callback_lock before this case but after xchg
                                smap_list_hash_remove(psock, l);
                                smap_release_sock(psock, sock);
                        }
-                       write_unlock_bh(&sock->sk_callback_lock);
-                       kfree(l);
+                       free_htab_elem(htab, l);
                }
+               raw_spin_unlock_bh(&b->lock);
        }
        rcu_read_unlock();
-       bpf_map_area_free(htab->buckets);
-       kfree(htab);
+       call_rcu(&htab->rcu, __bpf_htab_free);
 }
 
 static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab,
        return l_new;
 }
 
-static struct htab_elem *lookup_elem_raw(struct hlist_head *head,
-                                        u32 hash, void *key, u32 key_size)
-{
-       struct htab_elem *l;
-
-       hlist_for_each_entry_rcu(l, head, hash_node) {
-               if (l->hash == hash && !memcmp(&l->key, key, key_size))
-                       return l;
-       }
-
-       return NULL;
-}
-
 static inline u32 htab_map_hash(const void *key, u32 key_len)
 {
        return jhash(key, key_len, 0);
                goto bucket_err;
        }
 
-       e->hash_link = l_new;
-       e->htab = container_of(map, struct bpf_htab, map);
+       rcu_assign_pointer(e->hash_link, l_new);
+       rcu_assign_pointer(e->htab,
+                          container_of(map, struct bpf_htab, map));
+       spin_lock_bh(&psock->maps_lock);
        list_add_tail(&e->list, &psock->maps);
+       spin_unlock_bh(&psock->maps_lock);
 
        /* add new element to the head of the list, so that
         * concurrent search will find it before old elem
                struct smap_psock *psock;
 
                hlist_del_rcu(&l->hash_node);
-               write_lock_bh(&sock->sk_callback_lock);
                psock = smap_psock_sk(sock);
                /* This check handles a racing sock event that can get the
                 * sk_callback_lock before this case but after xchg happens
                        smap_list_hash_remove(psock, l);
                        smap_release_sock(psock, sock);
                }
-               write_unlock_bh(&sock->sk_callback_lock);
                free_htab_elem(htab, l);
                ret = 0;
        }