#include "dp_tx.h"
 #include "debug.h"
 #include "hw.h"
+#include "peer.h"
+#include "mac.h"
 
 static enum hal_tcl_encap_type
 ath12k_dp_tx_get_encap_type(struct ath12k_link_vif *arvif, struct sk_buff *skb)
        }
 }
 
+static void ath12k_dp_tx_update_txcompl(struct ath12k *ar, struct hal_tx_status *ts)
+{
+       struct ath12k_base *ab = ar->ab;
+       struct ath12k_peer *peer;
+       struct ieee80211_sta *sta;
+       struct ath12k_sta *ahsta;
+       struct ath12k_link_sta *arsta;
+       struct rate_info txrate = {0};
+       u16 rate, ru_tones;
+       u8 rate_idx = 0;
+       int ret;
+
+       spin_lock_bh(&ab->base_lock);
+       peer = ath12k_peer_find_by_id(ab, ts->peer_id);
+       if (!peer || !peer->sta) {
+               ath12k_dbg(ab, ATH12K_DBG_DP_TX,
+                          "failed to find the peer by id %u\n", ts->peer_id);
+               spin_unlock_bh(&ab->base_lock);
+               return;
+       }
+       sta = peer->sta;
+       ahsta = ath12k_sta_to_ahsta(sta);
+       arsta = &ahsta->deflink;
+
+       /* This is to prefer choose the real NSS value arsta->last_txrate.nss,
+        * if it is invalid, then choose the NSS value while assoc.
+        */
+       if (arsta->last_txrate.nss)
+               txrate.nss = arsta->last_txrate.nss;
+       else
+               txrate.nss = arsta->peer_nss;
+       spin_unlock_bh(&ab->base_lock);
+
+       switch (ts->pkt_type) {
+       case HAL_TX_RATE_STATS_PKT_TYPE_11A:
+       case HAL_TX_RATE_STATS_PKT_TYPE_11B:
+               ret = ath12k_mac_hw_ratecode_to_legacy_rate(ts->mcs,
+                                                           ts->pkt_type,
+                                                           &rate_idx,
+                                                           &rate);
+               if (ret < 0) {
+                       ath12k_warn(ab, "Invalid tx legacy rate %d\n", ret);
+                       return;
+               }
+
+               txrate.legacy = rate;
+               break;
+       case HAL_TX_RATE_STATS_PKT_TYPE_11N:
+               if (ts->mcs > ATH12K_HT_MCS_MAX) {
+                       ath12k_warn(ab, "Invalid HT mcs index %d\n", ts->mcs);
+                       return;
+               }
+
+               if (txrate.nss != 0)
+                       txrate.mcs = ts->mcs + 8 * (txrate.nss - 1);
+
+               txrate.flags = RATE_INFO_FLAGS_MCS;
+
+               if (ts->sgi)
+                       txrate.flags |= RATE_INFO_FLAGS_SHORT_GI;
+               break;
+       case HAL_TX_RATE_STATS_PKT_TYPE_11AC:
+               if (ts->mcs > ATH12K_VHT_MCS_MAX) {
+                       ath12k_warn(ab, "Invalid VHT mcs index %d\n", ts->mcs);
+                       return;
+               }
+
+               txrate.mcs = ts->mcs;
+               txrate.flags = RATE_INFO_FLAGS_VHT_MCS;
+
+               if (ts->sgi)
+                       txrate.flags |= RATE_INFO_FLAGS_SHORT_GI;
+               break;
+       case HAL_TX_RATE_STATS_PKT_TYPE_11AX:
+               if (ts->mcs > ATH12K_HE_MCS_MAX) {
+                       ath12k_warn(ab, "Invalid HE mcs index %d\n", ts->mcs);
+                       return;
+               }
+
+               txrate.mcs = ts->mcs;
+               txrate.flags = RATE_INFO_FLAGS_HE_MCS;
+               txrate.he_gi = ath12k_he_gi_to_nl80211_he_gi(ts->sgi);
+               break;
+       case HAL_TX_RATE_STATS_PKT_TYPE_11BE:
+               if (ts->mcs > ATH12K_EHT_MCS_MAX) {
+                       ath12k_warn(ab, "Invalid EHT mcs index %d\n", ts->mcs);
+                       return;
+               }
+
+               txrate.mcs = ts->mcs;
+               txrate.flags = RATE_INFO_FLAGS_EHT_MCS;
+               txrate.eht_gi = ath12k_mac_eht_gi_to_nl80211_eht_gi(ts->sgi);
+               break;
+       default:
+               ath12k_warn(ab, "Invalid tx pkt type: %d\n", ts->pkt_type);
+               return;
+       }
+
+       txrate.bw = ath12k_mac_bw_to_mac80211_bw(ts->bw);
+
+       if (ts->ofdma && ts->pkt_type == HAL_TX_RATE_STATS_PKT_TYPE_11AX) {
+               txrate.bw = RATE_INFO_BW_HE_RU;
+               ru_tones = ath12k_mac_he_convert_tones_to_ru_tones(ts->tones);
+               txrate.he_ru_alloc =
+                       ath12k_he_ru_tones_to_nl80211_he_ru_alloc(ru_tones);
+       }
+
+       if (ts->ofdma && ts->pkt_type == HAL_TX_RATE_STATS_PKT_TYPE_11BE) {
+               txrate.bw = RATE_INFO_BW_EHT_RU;
+               txrate.eht_ru_alloc =
+                       ath12k_mac_eht_ru_tones_to_nl80211_eht_ru_alloc(ts->tones);
+       }
+
+       spin_lock_bh(&ab->base_lock);
+       arsta->txrate = txrate;
+       spin_unlock_bh(&ab->base_lock);
+}
+
 static void ath12k_dp_tx_complete_msdu(struct ath12k *ar,
                                       struct sk_buff *msdu,
                                       struct hal_tx_status *ts)
         * Might end up reporting it out-of-band from HTT stats.
         */
 
+       ath12k_dp_tx_update_txcompl(ar, ts);
+
        ieee80211_tx_status_skb(ath12k_ar_to_hw(ar), msdu);
 
 exit:
                                      struct hal_wbm_completion_ring_tx *desc,
                                      struct hal_tx_status *ts)
 {
+       u32 info0 = le32_to_cpu(desc->rate_stats.info0);
+
        ts->buf_rel_source =
                le32_get_bits(desc->info0, HAL_WBM_COMPL_TX_INFO0_REL_SRC_MODULE);
        if (ts->buf_rel_source != HAL_WBM_REL_SRC_MODULE_FW &&
 
        ts->ppdu_id = le32_get_bits(desc->info1,
                                    HAL_WBM_COMPL_TX_INFO1_TQM_STATUS_NUMBER);
-       if (le32_to_cpu(desc->rate_stats.info0) & HAL_TX_RATE_STATS_INFO0_VALID)
-               ts->rate_stats = le32_to_cpu(desc->rate_stats.info0);
-       else
-               ts->rate_stats = 0;
+
+       ts->peer_id = le32_get_bits(desc->info3, HAL_WBM_COMPL_TX_INFO3_PEER_ID);
+
+       if (info0 & HAL_TX_RATE_STATS_INFO0_VALID) {
+               ts->pkt_type = u32_get_bits(info0, HAL_TX_RATE_STATS_INFO0_PKT_TYPE);
+               ts->mcs = u32_get_bits(info0, HAL_TX_RATE_STATS_INFO0_MCS);
+               ts->sgi = u32_get_bits(info0, HAL_TX_RATE_STATS_INFO0_SGI);
+               ts->bw = u32_get_bits(info0, HAL_TX_RATE_STATS_INFO0_BW);
+               ts->tones = u32_get_bits(info0, HAL_TX_RATE_STATS_INFO0_TONES_IN_RU);
+               ts->ofdma = u32_get_bits(info0, HAL_TX_RATE_STATS_INFO0_OFDMA_TX);
+       }
 }
 
 void ath12k_dp_tx_completion_handler(struct ath12k_base *ab, int ring_id)
 
        return "<unknown>";
 }
 
+u16 ath12k_mac_he_convert_tones_to_ru_tones(u16 tones)
+{
+       switch (tones) {
+       case 26:
+               return RU_26;
+       case 52:
+               return RU_52;
+       case 106:
+               return RU_106;
+       case 242:
+               return RU_242;
+       case 484:
+               return RU_484;
+       case 996:
+               return RU_996;
+       case (996 * 2):
+               return RU_2X996;
+       default:
+               return RU_26;
+       }
+}
+
+enum nl80211_eht_gi ath12k_mac_eht_gi_to_nl80211_eht_gi(u8 sgi)
+{
+       switch (sgi) {
+       case RX_MSDU_START_SGI_0_8_US:
+               return NL80211_RATE_INFO_EHT_GI_0_8;
+       case RX_MSDU_START_SGI_1_6_US:
+               return NL80211_RATE_INFO_EHT_GI_1_6;
+       case RX_MSDU_START_SGI_3_2_US:
+               return NL80211_RATE_INFO_EHT_GI_3_2;
+       default:
+               return NL80211_RATE_INFO_EHT_GI_0_8;
+       }
+}
+
+enum nl80211_eht_ru_alloc ath12k_mac_eht_ru_tones_to_nl80211_eht_ru_alloc(u16 ru_tones)
+{
+       switch (ru_tones) {
+       case 26:
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_26;
+       case 52:
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_52;
+       case (52 + 26):
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_52P26;
+       case 106:
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_106;
+       case (106 + 26):
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_106P26;
+       case 242:
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_242;
+       case 484:
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_484;
+       case (484 + 242):
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_484P242;
+       case 996:
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_996;
+       case (996 + 484):
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_996P484;
+       case (996 + 484 + 242):
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_996P484P242;
+       case (2 * 996):
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_2x996;
+       case (2 * 996 + 484):
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_2x996P484;
+       case (3 * 996):
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_3x996;
+       case (3 * 996 + 484):
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_3x996P484;
+       case (4 * 996):
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_4x996;
+       default:
+               return NL80211_RATE_INFO_EHT_RU_ALLOC_26;
+       }
+}
+
 enum rate_info_bw
 ath12k_mac_bw_to_mac80211_bw(enum ath12k_supported_bw bw)
 {
        ath12k_peer_assoc_h_smps(arsta, arg);
        ath12k_peer_assoc_h_mlo(arsta, arg);
 
+       arsta->peer_nss = arg->peer_nss;
        /* TODO: amsdu_disable req? */
 }
 
                sinfo->txrate.he_gi = arsta->txrate.he_gi;
                sinfo->txrate.he_dcm = arsta->txrate.he_dcm;
                sinfo->txrate.he_ru_alloc = arsta->txrate.he_ru_alloc;
+               sinfo->txrate.eht_gi = arsta->txrate.eht_gi;
+               sinfo->txrate.eht_ru_alloc = arsta->txrate.eht_ru_alloc;
        }
        sinfo->txrate.flags = arsta->txrate.flags;
        sinfo->filled |= BIT_ULL(NL80211_STA_INFO_TX_BITRATE);