// SPDX-License-Identifier: BSD-3-Clause-Clear
 /*
  * Copyright (c) 2018-2019 The Linux Foundation. All rights reserved.
+ * Copyright (c) 2021-2022 Qualcomm Innovation Center, Inc. All rights reserved.
  */
 
 #include "core.h"
 #include "peer.h"
 #include "debug.h"
 
-struct ath11k_peer *ath11k_peer_find(struct ath11k_base *ab, int vdev_id,
-                                    const u8 *addr)
+static struct ath11k_peer *ath11k_peer_find_list_by_id(struct ath11k_base *ab,
+                                                      int peer_id)
 {
        struct ath11k_peer *peer;
 
        lockdep_assert_held(&ab->base_lock);
 
        list_for_each_entry(peer, &ab->peers, list) {
-               if (peer->vdev_id != vdev_id)
-                       continue;
-               if (!ether_addr_equal(peer->addr, addr))
+               if (peer->peer_id != peer_id)
                        continue;
 
                return peer;
        return NULL;
 }
 
-static struct ath11k_peer *ath11k_peer_find_by_pdev_idx(struct ath11k_base *ab,
-                                                       u8 pdev_idx, const u8 *addr)
+struct ath11k_peer *ath11k_peer_find(struct ath11k_base *ab, int vdev_id,
+                                    const u8 *addr)
 {
        struct ath11k_peer *peer;
 
        lockdep_assert_held(&ab->base_lock);
 
        list_for_each_entry(peer, &ab->peers, list) {
-               if (peer->pdev_idx != pdev_idx)
+               if (peer->vdev_id != vdev_id)
                        continue;
                if (!ether_addr_equal(peer->addr, addr))
                        continue;
 
        lockdep_assert_held(&ab->base_lock);
 
-       list_for_each_entry(peer, &ab->peers, list) {
-               if (!ether_addr_equal(peer->addr, addr))
-                       continue;
+       if (!ab->rhead_peer_addr)
+               return NULL;
 
-               return peer;
-       }
+       peer = rhashtable_lookup_fast(ab->rhead_peer_addr, addr,
+                                     ab->rhash_peer_addr_param);
 
-       return NULL;
+       return peer;
 }
 
 struct ath11k_peer *ath11k_peer_find_by_id(struct ath11k_base *ab,
 
        lockdep_assert_held(&ab->base_lock);
 
-       list_for_each_entry(peer, &ab->peers, list)
-               if (peer_id == peer->peer_id)
-                       return peer;
+       if (!ab->rhead_peer_id)
+               return NULL;
 
-       return NULL;
+       peer = rhashtable_lookup_fast(ab->rhead_peer_id, &peer_id,
+                                     ab->rhash_peer_id_param);
+
+       return peer;
 }
 
 struct ath11k_peer *ath11k_peer_find_by_vdev_id(struct ath11k_base *ab,
 
        spin_lock_bh(&ab->base_lock);
 
-       peer = ath11k_peer_find_by_id(ab, peer_id);
+       peer = ath11k_peer_find_list_by_id(ab, peer_id);
        if (!peer) {
                ath11k_warn(ab, "peer-unmap-event: unknown peer id %d\n",
                            peer_id);
        return 0;
 }
 
+static inline int ath11k_peer_rhash_insert(struct ath11k_base *ab,
+                                          struct rhashtable *rtbl,
+                                          struct rhash_head *rhead,
+                                          struct rhashtable_params *params,
+                                          void *key)
+{
+       struct ath11k_peer *tmp;
+
+       lockdep_assert_held(&ab->tbl_mtx_lock);
+
+       tmp = rhashtable_lookup_get_insert_fast(rtbl, rhead, *params);
+
+       if (!tmp)
+               return 0;
+       else if (IS_ERR(tmp))
+               return PTR_ERR(tmp);
+       else
+               return -EEXIST;
+}
+
+static inline int ath11k_peer_rhash_remove(struct ath11k_base *ab,
+                                          struct rhashtable *rtbl,
+                                          struct rhash_head *rhead,
+                                          struct rhashtable_params *params)
+{
+       int ret;
+
+       lockdep_assert_held(&ab->tbl_mtx_lock);
+
+       ret = rhashtable_remove_fast(rtbl, rhead, *params);
+       if (ret && ret != -ENOENT)
+               return ret;
+
+       return 0;
+}
+
+static int ath11k_peer_rhash_add(struct ath11k_base *ab, struct ath11k_peer *peer)
+{
+       int ret;
+
+       lockdep_assert_held(&ab->base_lock);
+       lockdep_assert_held(&ab->tbl_mtx_lock);
+
+       if (!ab->rhead_peer_id || !ab->rhead_peer_addr)
+               return -EPERM;
+
+       ret = ath11k_peer_rhash_insert(ab, ab->rhead_peer_id, &peer->rhash_id,
+                                      &ab->rhash_peer_id_param, &peer->peer_id);
+       if (ret) {
+               ath11k_warn(ab, "failed to add peer %pM with id %d in rhash_id ret %d\n",
+                           peer->addr, peer->peer_id, ret);
+               return ret;
+       }
+
+       ret = ath11k_peer_rhash_insert(ab, ab->rhead_peer_addr, &peer->rhash_addr,
+                                      &ab->rhash_peer_addr_param, &peer->addr);
+       if (ret) {
+               ath11k_warn(ab, "failed to add peer %pM with id %d in rhash_addr ret %d\n",
+                           peer->addr, peer->peer_id, ret);
+               goto err_clean;
+       }
+
+       return 0;
+
+err_clean:
+       ath11k_peer_rhash_remove(ab, ab->rhead_peer_id, &peer->rhash_id,
+                                &ab->rhash_peer_id_param);
+       return ret;
+}
+
 void ath11k_peer_cleanup(struct ath11k *ar, u32 vdev_id)
 {
        struct ath11k_peer *peer, *tmp;
 
        lockdep_assert_held(&ar->conf_mutex);
 
+       mutex_lock(&ab->tbl_mtx_lock);
        spin_lock_bh(&ab->base_lock);
        list_for_each_entry_safe(peer, tmp, &ab->peers, list) {
                if (peer->vdev_id != vdev_id)
                ath11k_warn(ab, "removing stale peer %pM from vdev_id %d\n",
                            peer->addr, vdev_id);
 
+               ath11k_peer_rhash_delete(ab, peer);
                list_del(&peer->list);
                kfree(peer);
                ar->num_peers--;
        }
 
        spin_unlock_bh(&ab->base_lock);
+       mutex_unlock(&ab->tbl_mtx_lock);
 }
 
 static int ath11k_wait_for_peer_deleted(struct ath11k *ar, int vdev_id, const u8 *addr)
 static int __ath11k_peer_delete(struct ath11k *ar, u32 vdev_id, const u8 *addr)
 {
        int ret;
+       struct ath11k_peer *peer;
+       struct ath11k_base *ab = ar->ab;
 
        lockdep_assert_held(&ar->conf_mutex);
 
+       mutex_lock(&ab->tbl_mtx_lock);
+       spin_lock_bh(&ab->base_lock);
+
+       peer = ath11k_peer_find_by_addr(ab, addr);
+       if (!peer) {
+               spin_unlock_bh(&ab->base_lock);
+               mutex_unlock(&ab->tbl_mtx_lock);
+
+               ath11k_warn(ab,
+                           "failed to find peer vdev_id %d addr %pM in delete\n",
+                           vdev_id, addr);
+               return -EINVAL;
+       }
+
+       ath11k_peer_rhash_delete(ab, peer);
+
+       spin_unlock_bh(&ab->base_lock);
+       mutex_unlock(&ab->tbl_mtx_lock);
+
        reinit_completion(&ar->peer_delete_done);
 
        ret = ath11k_wmi_send_peer_delete_cmd(ar, addr, vdev_id);
        if (ret) {
-               ath11k_warn(ar->ab,
+               ath11k_warn(ab,
                            "failed to delete peer vdev_id %d addr %pM ret %d\n",
                            vdev_id, addr, ret);
                return ret;
        }
 
        spin_lock_bh(&ar->ab->base_lock);
-       peer = ath11k_peer_find_by_pdev_idx(ar->ab, ar->pdev_idx, param->peer_addr);
+       peer = ath11k_peer_find_by_addr(ar->ab, param->peer_addr);
        if (peer) {
                spin_unlock_bh(&ar->ab->base_lock);
                return -EINVAL;
        if (ret)
                return ret;
 
+       mutex_lock(&ar->ab->tbl_mtx_lock);
        spin_lock_bh(&ar->ab->base_lock);
 
        peer = ath11k_peer_find(ar->ab, param->vdev_id, param->peer_addr);
        if (!peer) {
                spin_unlock_bh(&ar->ab->base_lock);
+               mutex_unlock(&ar->ab->tbl_mtx_lock);
                ath11k_warn(ar->ab, "failed to find peer %pM on vdev %i after creation\n",
                            param->peer_addr, param->vdev_id);
 
                goto cleanup;
        }
 
+       ret = ath11k_peer_rhash_add(ar->ab, peer);
+       if (ret) {
+               spin_unlock_bh(&ar->ab->base_lock);
+               mutex_unlock(&ar->ab->tbl_mtx_lock);
+               goto cleanup;
+       }
+
        peer->pdev_idx = ar->pdev_idx;
        peer->sta = sta;
 
        ar->num_peers++;
 
        spin_unlock_bh(&ar->ab->base_lock);
+       mutex_unlock(&ar->ab->tbl_mtx_lock);
 
        return 0;
 
 
        return ret;
 }
+
+int ath11k_peer_rhash_delete(struct ath11k_base *ab, struct ath11k_peer *peer)
+{
+       int ret;
+
+       lockdep_assert_held(&ab->base_lock);
+       lockdep_assert_held(&ab->tbl_mtx_lock);
+
+       if (!ab->rhead_peer_id || !ab->rhead_peer_addr)
+               return -EPERM;
+
+       ret = ath11k_peer_rhash_remove(ab, ab->rhead_peer_addr, &peer->rhash_addr,
+                                      &ab->rhash_peer_addr_param);
+       if (ret) {
+               ath11k_warn(ab, "failed to remove peer %pM id %d in rhash_addr ret %d\n",
+                           peer->addr, peer->peer_id, ret);
+               return ret;
+       }
+
+       ret = ath11k_peer_rhash_remove(ab, ab->rhead_peer_id, &peer->rhash_id,
+                                      &ab->rhash_peer_id_param);
+       if (ret) {
+               ath11k_warn(ab, "failed to remove peer %pM id %d in rhash_id ret %d\n",
+                           peer->addr, peer->peer_id, ret);
+               return ret;
+       }
+
+       return 0;
+}
+
+static int ath11k_peer_rhash_id_tbl_init(struct ath11k_base *ab)
+{
+       struct rhashtable_params *param;
+       struct rhashtable *rhash_id_tbl;
+       int ret;
+       size_t size;
+
+       lockdep_assert_held(&ab->tbl_mtx_lock);
+
+       if (ab->rhead_peer_id)
+               return 0;
+
+       size = sizeof(*ab->rhead_peer_id);
+       rhash_id_tbl = kzalloc(size, GFP_KERNEL);
+       if (!rhash_id_tbl) {
+               ath11k_warn(ab, "failed to init rhash id table due to no mem (size %zu)\n",
+                           size);
+               return -ENOMEM;
+       }
+
+       param = &ab->rhash_peer_id_param;
+
+       param->key_offset = offsetof(struct ath11k_peer, peer_id);
+       param->head_offset = offsetof(struct ath11k_peer, rhash_id);
+       param->key_len = sizeof_field(struct ath11k_peer, peer_id);
+       param->automatic_shrinking = true;
+       param->nelem_hint = ab->num_radios * TARGET_NUM_PEERS_PDEV(ab);
+
+       ret = rhashtable_init(rhash_id_tbl, param);
+       if (ret) {
+               ath11k_warn(ab, "failed to init peer id rhash table %d\n", ret);
+               goto err_free;
+       }
+
+       spin_lock_bh(&ab->base_lock);
+
+       if (!ab->rhead_peer_id) {
+               ab->rhead_peer_id = rhash_id_tbl;
+       } else {
+               spin_unlock_bh(&ab->base_lock);
+               goto cleanup_tbl;
+       }
+
+       spin_unlock_bh(&ab->base_lock);
+
+       return 0;
+
+cleanup_tbl:
+       rhashtable_destroy(rhash_id_tbl);
+err_free:
+       kfree(rhash_id_tbl);
+
+       return ret;
+}
+
+static int ath11k_peer_rhash_addr_tbl_init(struct ath11k_base *ab)
+{
+       struct rhashtable_params *param;
+       struct rhashtable *rhash_addr_tbl;
+       int ret;
+       size_t size;
+
+       lockdep_assert_held(&ab->tbl_mtx_lock);
+
+       if (ab->rhead_peer_addr)
+               return 0;
+
+       size = sizeof(*ab->rhead_peer_addr);
+       rhash_addr_tbl = kzalloc(size, GFP_KERNEL);
+       if (!rhash_addr_tbl) {
+               ath11k_warn(ab, "failed to init rhash addr table due to no mem (size %zu)\n",
+                           size);
+               return -ENOMEM;
+       }
+
+       param = &ab->rhash_peer_addr_param;
+
+       param->key_offset = offsetof(struct ath11k_peer, addr);
+       param->head_offset = offsetof(struct ath11k_peer, rhash_addr);
+       param->key_len = sizeof_field(struct ath11k_peer, addr);
+       param->automatic_shrinking = true;
+       param->nelem_hint = ab->num_radios * TARGET_NUM_PEERS_PDEV(ab);
+
+       ret = rhashtable_init(rhash_addr_tbl, param);
+       if (ret) {
+               ath11k_warn(ab, "failed to init peer addr rhash table %d\n", ret);
+               goto err_free;
+       }
+
+       spin_lock_bh(&ab->base_lock);
+
+       if (!ab->rhead_peer_addr) {
+               ab->rhead_peer_addr = rhash_addr_tbl;
+       } else {
+               spin_unlock_bh(&ab->base_lock);
+               goto cleanup_tbl;
+       }
+
+       spin_unlock_bh(&ab->base_lock);
+
+       return 0;
+
+cleanup_tbl:
+       rhashtable_destroy(rhash_addr_tbl);
+err_free:
+       kfree(rhash_addr_tbl);
+
+       return ret;
+}
+
+static inline void ath11k_peer_rhash_id_tbl_destroy(struct ath11k_base *ab)
+{
+       lockdep_assert_held(&ab->tbl_mtx_lock);
+
+       if (!ab->rhead_peer_id)
+               return;
+
+       rhashtable_destroy(ab->rhead_peer_id);
+       kfree(ab->rhead_peer_id);
+       ab->rhead_peer_id = NULL;
+}
+
+static inline void ath11k_peer_rhash_addr_tbl_destroy(struct ath11k_base *ab)
+{
+       lockdep_assert_held(&ab->tbl_mtx_lock);
+
+       if (!ab->rhead_peer_addr)
+               return;
+
+       rhashtable_destroy(ab->rhead_peer_addr);
+       kfree(ab->rhead_peer_addr);
+       ab->rhead_peer_addr = NULL;
+}
+
+int ath11k_peer_rhash_tbl_init(struct ath11k_base *ab)
+{
+       int ret;
+
+       mutex_lock(&ab->tbl_mtx_lock);
+
+       ret = ath11k_peer_rhash_id_tbl_init(ab);
+       if (ret)
+               goto out;
+
+       ret = ath11k_peer_rhash_addr_tbl_init(ab);
+       if (ret)
+               goto cleanup_tbl;
+
+       mutex_unlock(&ab->tbl_mtx_lock);
+
+       return 0;
+
+cleanup_tbl:
+       ath11k_peer_rhash_id_tbl_destroy(ab);
+out:
+       mutex_unlock(&ab->tbl_mtx_lock);
+       return ret;
+}
+
+void ath11k_peer_rhash_tbl_destroy(struct ath11k_base *ab)
+{
+       mutex_lock(&ab->tbl_mtx_lock);
+
+       ath11k_peer_rhash_addr_tbl_destroy(ab);
+       ath11k_peer_rhash_id_tbl_destroy(ab);
+
+       mutex_unlock(&ab->tbl_mtx_lock);
+}