return link_conf;
 }
 
+static struct ieee80211_link_sta *ath12k_mac_get_link_sta(struct ath12k_link_sta *arsta)
+{
+       struct ath12k_sta *ahsta = arsta->ahsta;
+       struct ieee80211_sta *sta = ath12k_ahsta_to_sta(ahsta);
+       struct ieee80211_link_sta *link_sta;
+
+       lockdep_assert_wiphy(ahsta->ahvif->ah->hw->wiphy);
+
+       if (arsta->link_id >= IEEE80211_MLD_MAX_NUM_LINKS)
+               return NULL;
+
+       link_sta = wiphy_dereference(ahsta->ahvif->ah->hw->wiphy,
+                                    sta->link[arsta->link_id]);
+
+       return link_sta;
+}
+
 static bool ath12k_mac_bitrate_is_cck(int bitrate)
 {
        switch (bitrate) {
        struct ieee80211_vif *vif = ath12k_ahvif_to_vif(arvif->ahvif);
        struct ieee80211_sta *sta = ath12k_ahsta_to_sta(arsta->ahsta);
        struct wmi_rate_set_arg *rateset = &arg->peer_legacy_rates;
+       struct ieee80211_link_sta *link_sta;
        struct cfg80211_chan_def def;
        const struct ieee80211_supported_band *sband;
        const struct ieee80211_rate *rates;
        if (WARN_ON(ath12k_mac_vif_link_chan(vif, arvif->link_id, &def)))
                return;
 
+       link_sta = ath12k_mac_get_link_sta(arsta);
+       if (!link_sta) {
+               ath12k_warn(ar->ab, "unable to access link sta in peer assoc rates for sta %pM link %u\n",
+                           sta->addr, arsta->link_id);
+               return;
+       }
+
        band = def.chan->band;
        sband = hw->wiphy->bands[band];
-       ratemask = sta->deflink.supp_rates[band];
+       ratemask = link_sta->supp_rates[band];
        ratemask &= arvif->bitrate_mask.control[band].legacy;
        rates = sband->bitrates;
 
 {
        struct ieee80211_vif *vif = ath12k_ahvif_to_vif(arvif->ahvif);
        struct ieee80211_sta *sta = ath12k_ahsta_to_sta(arsta->ahsta);
-       const struct ieee80211_sta_ht_cap *ht_cap = &sta->deflink.ht_cap;
+       const struct ieee80211_sta_ht_cap *ht_cap;
+       struct ieee80211_link_sta *link_sta;
        struct cfg80211_chan_def def;
        enum nl80211_band band;
        const u8 *ht_mcs_mask;
        if (WARN_ON(ath12k_mac_vif_link_chan(vif, arvif->link_id, &def)))
                return;
 
+       link_sta = ath12k_mac_get_link_sta(arsta);
+       if (!link_sta) {
+               ath12k_warn(ar->ab, "unable to access link sta in peer assoc ht for sta %pM link %u\n",
+                           sta->addr, arsta->link_id);
+               return;
+       }
+
+       ht_cap = &link_sta->ht_cap;
        if (!ht_cap->ht_supported)
                return;
 
        if (ht_cap->cap & IEEE80211_HT_CAP_LDPC_CODING)
                arg->ldpc_flag = true;
 
-       if (sta->deflink.bandwidth >= IEEE80211_STA_RX_BW_40) {
+       if (link_sta->bandwidth >= IEEE80211_STA_RX_BW_40) {
                arg->bw_40 = true;
                arg->peer_rate_caps |= WMI_HOST_RC_CW40_FLAG;
        }
                        arg->peer_ht_rates.rates[i] = i;
        } else {
                arg->peer_ht_rates.num_rates = n;
-               arg->peer_nss = min(sta->deflink.rx_nss, max_nss);
+               arg->peer_nss = min(link_sta->rx_nss, max_nss);
        }
 
        ath12k_dbg(ar->ab, ATH12K_DBG_MAC, "mac ht peer %pM mcs cnt %d nss %d\n",
 {
        struct ieee80211_vif *vif = ath12k_ahvif_to_vif(arvif->ahvif);
        struct ieee80211_sta *sta = ath12k_ahsta_to_sta(arsta->ahsta);
-       const struct ieee80211_sta_vht_cap *vht_cap = &sta->deflink.vht_cap;
+       const struct ieee80211_sta_vht_cap *vht_cap;
+       struct ieee80211_link_sta *link_sta;
        struct cfg80211_chan_def def;
        enum nl80211_band band;
        const u16 *vht_mcs_mask;
        if (WARN_ON(ath12k_mac_vif_link_chan(vif, arvif->link_id, &def)))
                return;
 
+       link_sta = ath12k_mac_get_link_sta(arsta);
+       if (!link_sta) {
+               ath12k_warn(ar->ab, "unable to access link sta in peer assoc vht for sta %pM link %u\n",
+                           sta->addr, arsta->link_id);
+               return;
+       }
+
+       vht_cap = &link_sta->vht_cap;
        if (!vht_cap->vht_supported)
                return;
 
                                 (1U << (IEEE80211_HT_MAX_AMPDU_FACTOR +
                                        ampdu_factor)) - 1);
 
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_80)
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_80)
                arg->bw_80 = true;
 
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_160)
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_160)
                arg->bw_160 = true;
 
        /* Calculate peer NSS capability from VHT capabilities if STA
                    vht_mcs_mask[i])
                        max_nss = i + 1;
        }
-       arg->peer_nss = min(sta->deflink.rx_nss, max_nss);
+       arg->peer_nss = min(link_sta->rx_nss, max_nss);
        arg->rx_max_rate = __le16_to_cpu(vht_cap->vht_mcs.rx_highest);
        arg->rx_mcs_set = __le16_to_cpu(vht_cap->vht_mcs.rx_mcs_map);
        arg->tx_max_rate = __le16_to_cpu(vht_cap->vht_mcs.tx_highest);
 {
        struct ieee80211_vif *vif = ath12k_ahvif_to_vif(arvif->ahvif);
        struct ieee80211_sta *sta = ath12k_ahsta_to_sta(arsta->ahsta);
-       const struct ieee80211_sta_he_cap *he_cap = &sta->deflink.he_cap;
+       const struct ieee80211_sta_he_cap *he_cap;
        struct ieee80211_bss_conf *link_conf;
+       struct ieee80211_link_sta *link_sta;
        int i;
        u8 ampdu_factor, max_nss;
        u8 rx_mcs_80 = IEEE80211_HE_MCS_NOT_SUPPORTED;
                return;
        }
 
+       link_sta = ath12k_mac_get_link_sta(arsta);
+       if (!link_sta) {
+               ath12k_warn(ar->ab, "unable to access link sta in peer assoc he for sta %pM link %u\n",
+                           sta->addr, arsta->link_id);
+               return;
+       }
+
+       he_cap = &link_sta->he_cap;
        if (!he_cap->has_he)
                return;
 
        else
                max_nss = rx_mcs_80;
 
-       arg->peer_nss = min(sta->deflink.rx_nss, max_nss);
+       arg->peer_nss = min(link_sta->rx_nss, max_nss);
 
        memcpy(&arg->peer_he_cap_macinfo, he_cap->he_cap_elem.mac_cap_info,
               sizeof(he_cap->he_cap_elem.mac_cap_info));
                                   IEEE80211_HE_MAC_CAP3_MAX_AMPDU_LEN_EXP_MASK);
 
        if (ampdu_factor) {
-               if (sta->deflink.vht_cap.vht_supported)
+               if (link_sta->vht_cap.vht_supported)
                        arg->peer_max_mpdu = (1 << (IEEE80211_HE_VHT_MAX_AMPDU_FACTOR +
                                                    ampdu_factor)) - 1;
-               else if (sta->deflink.ht_cap.ht_supported)
+               else if (link_sta->ht_cap.ht_supported)
                        arg->peer_max_mpdu = (1 << (IEEE80211_HE_HT_MAX_AMPDU_FACTOR +
                                                    ampdu_factor)) - 1;
        }
        if (he_cap->he_cap_elem.mac_cap_info[0] & IEEE80211_HE_MAC_CAP0_TWT_REQ)
                arg->twt_requester = true;
 
-       switch (sta->deflink.bandwidth) {
+       switch (link_sta->bandwidth) {
        case IEEE80211_STA_RX_BW_160:
                if (he_cap->he_cap_elem.phy_cap_info[0] &
                    IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_80PLUS80_MHZ_IN_5G) {
 {
        struct ieee80211_vif *vif = ath12k_ahvif_to_vif(arvif->ahvif);
        struct ieee80211_sta *sta = ath12k_ahsta_to_sta(arsta->ahsta);
-       const struct ieee80211_sta_he_cap *he_cap = &sta->deflink.he_cap;
+       const struct ieee80211_sta_he_cap *he_cap;
+       struct ieee80211_link_sta *link_sta;
        struct cfg80211_chan_def def;
        enum nl80211_band band;
        u8 ampdu_factor, mpdu_density;
 
        band = def.chan->band;
 
-       if (!arg->he_flag || band != NL80211_BAND_6GHZ || !sta->deflink.he_6ghz_capa.capa)
+       link_sta = ath12k_mac_get_link_sta(arsta);
+       if (!link_sta) {
+               ath12k_warn(ar->ab, "unable to access link sta in peer assoc he 6ghz for sta %pM link %u\n",
+                           sta->addr, arsta->link_id);
+               return;
+       }
+
+       he_cap = &link_sta->he_cap;
+
+       if (!arg->he_flag || band != NL80211_BAND_6GHZ || !link_sta->he_6ghz_capa.capa)
                return;
 
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_40)
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_40)
                arg->bw_40 = true;
 
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_80)
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_80)
                arg->bw_80 = true;
 
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_160)
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_160)
                arg->bw_160 = true;
 
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_320)
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_320)
                arg->bw_320 = true;
 
-       arg->peer_he_caps_6ghz = le16_to_cpu(sta->deflink.he_6ghz_capa.capa);
+       arg->peer_he_caps_6ghz = le16_to_cpu(link_sta->he_6ghz_capa.capa);
 
        mpdu_density = u32_get_bits(arg->peer_he_caps_6ghz,
                                    IEEE80211_HE_6GHZ_CAP_MIN_MPDU_START);
                                     struct ath12k_wmi_peer_assoc_arg *arg)
 {
        struct ieee80211_sta *sta = ath12k_ahsta_to_sta(arsta->ahsta);
-       const struct ieee80211_he_6ghz_capa *he_6ghz_capa = &sta->deflink.he_6ghz_capa;
-       const struct ieee80211_sta_ht_cap *ht_cap = &sta->deflink.ht_cap;
+       const struct ieee80211_he_6ghz_capa *he_6ghz_capa;
+       struct ath12k_link_vif *arvif = arsta->arvif;
+       const struct ieee80211_sta_ht_cap *ht_cap;
+       struct ieee80211_link_sta *link_sta;
+       struct ath12k *ar = arvif->ar;
        int smps;
 
+       link_sta = ath12k_mac_get_link_sta(arsta);
+       if (!link_sta) {
+               ath12k_warn(ar->ab, "unable to access link sta in peer assoc he for sta %pM link %u\n",
+                           sta->addr, arsta->link_id);
+               return;
+       }
+
+       he_6ghz_capa = &link_sta->he_6ghz_capa;
+       ht_cap = &link_sta->ht_cap;
+
        if (!ht_cap->ht_supported && !he_6ghz_capa->capa)
                return;
 
        return ret;
 }
 
-static bool ath12k_mac_sta_has_ofdm_only(struct ieee80211_sta *sta)
+static bool ath12k_mac_sta_has_ofdm_only(struct ieee80211_link_sta *sta)
 {
-       return sta->deflink.supp_rates[NL80211_BAND_2GHZ] >>
+       return sta->supp_rates[NL80211_BAND_2GHZ] >>
               ATH12K_MAC_FIRST_OFDM_RATE_IDX;
 }
 
 static enum wmi_phy_mode ath12k_mac_get_phymode_vht(struct ath12k *ar,
-                                                   struct ieee80211_sta *sta)
+                                                   struct ieee80211_link_sta *link_sta)
 {
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_160) {
-               switch (sta->deflink.vht_cap.cap &
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_160) {
+               switch (link_sta->vht_cap.cap &
                        IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_MASK) {
                case IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160MHZ:
                        return MODE_11AC_VHT160;
                }
        }
 
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_80)
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_80)
                return MODE_11AC_VHT80;
 
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_40)
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_40)
                return MODE_11AC_VHT40;
 
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_20)
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_20)
                return MODE_11AC_VHT20;
 
        return MODE_UNKNOWN;
 }
 
 static enum wmi_phy_mode ath12k_mac_get_phymode_he(struct ath12k *ar,
-                                                  struct ieee80211_sta *sta)
+                                                  struct ieee80211_link_sta *link_sta)
 {
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_160) {
-               if (sta->deflink.he_cap.he_cap_elem.phy_cap_info[0] &
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_160) {
+               if (link_sta->he_cap.he_cap_elem.phy_cap_info[0] &
                     IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_160MHZ_IN_5G)
                        return MODE_11AX_HE160;
-               else if (sta->deflink.he_cap.he_cap_elem.phy_cap_info[0] &
+               else if (link_sta->he_cap.he_cap_elem.phy_cap_info[0] &
                     IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_80PLUS80_MHZ_IN_5G)
                        return MODE_11AX_HE80_80;
                /* not sure if this is a valid case? */
                return MODE_11AX_HE160;
        }
 
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_80)
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_80)
                return MODE_11AX_HE80;
 
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_40)
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_40)
                return MODE_11AX_HE40;
 
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_20)
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_20)
                return MODE_11AX_HE20;
 
        return MODE_UNKNOWN;
 }
 
 static enum wmi_phy_mode ath12k_mac_get_phymode_eht(struct ath12k *ar,
-                                                   struct ieee80211_sta *sta)
+                                                   struct ieee80211_link_sta *link_sta)
 {
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_320)
-               if (sta->deflink.eht_cap.eht_cap_elem.phy_cap_info[0] &
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_320)
+               if (link_sta->eht_cap.eht_cap_elem.phy_cap_info[0] &
                    IEEE80211_EHT_PHY_CAP0_320MHZ_IN_6GHZ)
                        return MODE_11BE_EHT320;
 
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_160) {
-               if (sta->deflink.he_cap.he_cap_elem.phy_cap_info[0] &
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_160) {
+               if (link_sta->he_cap.he_cap_elem.phy_cap_info[0] &
                    IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_160MHZ_IN_5G)
                        return MODE_11BE_EHT160;
 
-               if (sta->deflink.he_cap.he_cap_elem.phy_cap_info[0] &
+               if (link_sta->he_cap.he_cap_elem.phy_cap_info[0] &
                         IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_80PLUS80_MHZ_IN_5G)
                        return MODE_11BE_EHT80_80;
 
                ath12k_warn(ar->ab, "invalid EHT PHY capability info for 160 Mhz: %d\n",
-                           sta->deflink.he_cap.he_cap_elem.phy_cap_info[0]);
+                           link_sta->he_cap.he_cap_elem.phy_cap_info[0]);
 
                return MODE_11BE_EHT160;
        }
 
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_80)
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_80)
                return MODE_11BE_EHT80;
 
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_40)
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_40)
                return MODE_11BE_EHT40;
 
-       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_20)
+       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_20)
                return MODE_11BE_EHT20;
 
        return MODE_UNKNOWN;
                                        struct ath12k_link_sta *arsta,
                                        struct ath12k_wmi_peer_assoc_arg *arg)
 {
+       struct ieee80211_link_sta *link_sta;
        struct cfg80211_chan_def def;
        enum nl80211_band band;
        const u8 *ht_mcs_mask;
        ht_mcs_mask = arvif->bitrate_mask.control[band].ht_mcs;
        vht_mcs_mask = arvif->bitrate_mask.control[band].vht_mcs;
 
+       link_sta = ath12k_mac_get_link_sta(arsta);
+       if (!link_sta) {
+               ath12k_warn(ar->ab, "unable to access link sta in peer assoc he for sta %pM link %u\n",
+                           sta->addr, arsta->link_id);
+               return;
+       }
+
        switch (band) {
        case NL80211_BAND_2GHZ:
-               if (sta->deflink.eht_cap.has_eht) {
-                       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_40)
+               if (link_sta->eht_cap.has_eht) {
+                       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_40)
                                phymode = MODE_11BE_EHT40_2G;
                        else
                                phymode = MODE_11BE_EHT20_2G;
-               } else if (sta->deflink.he_cap.has_he) {
-                       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_80)
+               } else if (link_sta->he_cap.has_he) {
+                       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_80)
                                phymode = MODE_11AX_HE80_2G;
-                       else if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_40)
+                       else if (link_sta->bandwidth == IEEE80211_STA_RX_BW_40)
                                phymode = MODE_11AX_HE40_2G;
                        else
                                phymode = MODE_11AX_HE20_2G;
-               } else if (sta->deflink.vht_cap.vht_supported &&
+               } else if (link_sta->vht_cap.vht_supported &&
                    !ath12k_peer_assoc_h_vht_masked(vht_mcs_mask)) {
-                       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_40)
+                       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_40)
                                phymode = MODE_11AC_VHT40;
                        else
                                phymode = MODE_11AC_VHT20;
-               } else if (sta->deflink.ht_cap.ht_supported &&
+               } else if (link_sta->ht_cap.ht_supported &&
                           !ath12k_peer_assoc_h_ht_masked(ht_mcs_mask)) {
-                       if (sta->deflink.bandwidth == IEEE80211_STA_RX_BW_40)
+                       if (link_sta->bandwidth == IEEE80211_STA_RX_BW_40)
                                phymode = MODE_11NG_HT40;
                        else
                                phymode = MODE_11NG_HT20;
-               } else if (ath12k_mac_sta_has_ofdm_only(sta)) {
+               } else if (ath12k_mac_sta_has_ofdm_only(link_sta)) {
                        phymode = MODE_11G;
                } else {
                        phymode = MODE_11B;
        case NL80211_BAND_5GHZ:
        case NL80211_BAND_6GHZ:
                /* Check EHT first */
-               if (sta->deflink.eht_cap.has_eht) {
-                       phymode = ath12k_mac_get_phymode_eht(ar, sta);
-               } else if (sta->deflink.he_cap.has_he) {
-                       phymode = ath12k_mac_get_phymode_he(ar, sta);
-               } else if (sta->deflink.vht_cap.vht_supported &&
+               if (link_sta->eht_cap.has_eht) {
+                       phymode = ath12k_mac_get_phymode_eht(ar, link_sta);
+               } else if (link_sta->he_cap.has_he) {
+                       phymode = ath12k_mac_get_phymode_he(ar, link_sta);
+               } else if (link_sta->vht_cap.vht_supported &&
                    !ath12k_peer_assoc_h_vht_masked(vht_mcs_mask)) {
-                       phymode = ath12k_mac_get_phymode_vht(ar, sta);
-               } else if (sta->deflink.ht_cap.ht_supported &&
+                       phymode = ath12k_mac_get_phymode_vht(ar, link_sta);
+               } else if (link_sta->ht_cap.ht_supported &&
                           !ath12k_peer_assoc_h_ht_masked(ht_mcs_mask)) {
-                       if (sta->deflink.bandwidth >= IEEE80211_STA_RX_BW_40)
+                       if (link_sta->bandwidth >= IEEE80211_STA_RX_BW_40)
                                phymode = MODE_11NA_HT40;
                        else
                                phymode = MODE_11NA_HT20;
                                    struct ath12k_wmi_peer_assoc_arg *arg)
 {
        struct ieee80211_sta *sta = ath12k_ahsta_to_sta(arsta->ahsta);
-       const struct ieee80211_sta_eht_cap *eht_cap = &sta->deflink.eht_cap;
-       const struct ieee80211_sta_he_cap *he_cap = &sta->deflink.he_cap;
        const struct ieee80211_eht_mcs_nss_supp_20mhz_only *bw_20;
        const struct ieee80211_eht_mcs_nss_supp_bw *bw;
+       const struct ieee80211_sta_eht_cap *eht_cap;
+       const struct ieee80211_sta_he_cap *he_cap;
+       struct ieee80211_link_sta *link_sta;
        u32 *rx_mcs, *tx_mcs;
 
        lockdep_assert_wiphy(ath12k_ar_to_hw(ar)->wiphy);
 
-       if (!sta->deflink.he_cap.has_he || !eht_cap->has_eht)
+       link_sta = ath12k_mac_get_link_sta(arsta);
+       if (!link_sta) {
+               ath12k_warn(ar->ab, "unable to access link sta in peer assoc eht for sta %pM link %u\n",
+                           sta->addr, arsta->link_id);
+               return;
+       }
+
+       eht_cap = &link_sta->eht_cap;
+       he_cap = &link_sta->he_cap;
+       if (!he_cap->has_he || !eht_cap->has_eht)
                return;
 
        arg->eht_flag = true;
        rx_mcs = arg->peer_eht_rx_mcs_set;
        tx_mcs = arg->peer_eht_tx_mcs_set;
 
-       switch (sta->deflink.bandwidth) {
+       switch (link_sta->bandwidth) {
        case IEEE80211_STA_RX_BW_320:
                bw = &eht_cap->eht_mcs_nss_supp.bw._320;
                ath12k_mac_set_eht_mcs(bw->rx_tx_mcs9_max_nss,
        struct ieee80211_vif *vif = ath12k_ahvif_to_vif(arvif->ahvif);
        struct ieee80211_sta *sta = ath12k_ahsta_to_sta(arsta->ahsta);
        struct ath12k_wmi_peer_assoc_arg peer_arg;
+       struct ieee80211_link_sta *link_sta;
        int ret;
        struct cfg80211_chan_def def;
        enum nl80211_band band;
         * fixed param.
         * Note that all other rates and NSS will be disabled for this peer.
         */
-       if (sta->deflink.vht_cap.vht_supported && num_vht_rates == 1) {
+       link_sta = ath12k_mac_get_link_sta(arsta);
+       if (!link_sta) {
+               ath12k_warn(ar->ab, "unable to access link sta in station assoc\n");
+               return -EINVAL;
+       }
+
+       if (link_sta->vht_cap.vht_supported && num_vht_rates == 1) {
                ret = ath12k_mac_set_peer_vht_fixed_rate(arvif, arsta, mask,
                                                         band);
                if (ret)
                return 0;
 
        ret = ath12k_setup_peer_smps(ar, arvif, arsta->addr,
-                                    &sta->deflink.ht_cap,
-                                    &sta->deflink.he_6ghz_capa);
+                                    &link_sta->ht_cap, &link_sta->he_6ghz_capa);
        if (ret) {
                ath12k_warn(ar->ab, "failed to setup peer SMPS for vdev %d: %d\n",
                            arvif->vdev_id, ret);
 
 static void ath12k_sta_rc_update_wk(struct wiphy *wiphy, struct wiphy_work *wk)
 {
+       struct ieee80211_link_sta *link_sta;
        struct ath12k *ar;
        struct ath12k_link_vif *arvif;
        struct ieee80211_sta *sta;
                 * TODO: Check RATEMASK_CMDID to support auto rates selection
                 * across HT/VHT and for multiple VHT MCS support.
                 */
-               if (sta->deflink.vht_cap.vht_supported && num_vht_rates == 1) {
+               link_sta = ath12k_mac_get_link_sta(arsta);
+               if (!link_sta) {
+                       ath12k_warn(ar->ab, "unable to access link sta in peer assoc he for sta %pM link %u\n",
+                                   sta->addr, arsta->link_id);
+                       return;
+               }
+
+               if (link_sta->vht_cap.vht_supported && num_vht_rates == 1) {
                        ath12k_mac_set_peer_vht_fixed_rate(arvif, arsta, mask,
                                                           band);
                } else {
 
        spin_unlock_bh(&ar->ab->base_lock);
 
+       if (arsta->link_id >= IEEE80211_MLD_MAX_NUM_LINKS) {
+               rcu_read_unlock();
+               return;
+       }
+
+       link_sta = rcu_dereference(sta->link[arsta->link_id]);
+       if (!link_sta) {
+               rcu_read_unlock();
+               ath12k_warn(ar->ab, "unable to access link sta in rc update for sta %pM link %u\n",
+                           sta->addr, arsta->link_id);
+               return;
+       }
+
        ath12k_dbg(ar->ab, ATH12K_DBG_MAC,
                   "mac sta rc update for %pM changed %08x bw %d nss %d smps %d\n",
-                  arsta->addr, changed, sta->deflink.bandwidth, sta->deflink.rx_nss,
-                  sta->deflink.smps_mode);
+                  arsta->addr, changed, link_sta->bandwidth, link_sta->rx_nss,
+                  link_sta->smps_mode);
 
        spin_lock_bh(&ar->data_lock);
 
        }
 
        if (changed & IEEE80211_RC_NSS_CHANGED)
-               arsta->nss = sta->deflink.rx_nss;
+               arsta->nss = link_sta->rx_nss;
 
        if (changed & IEEE80211_RC_SMPS_CHANGED) {
                smps = WMI_PEER_SMPS_PS_NONE;
 
-               switch (sta->deflink.smps_mode) {
+               switch (link_sta->smps_mode) {
                case IEEE80211_SMPS_AUTOMATIC:
                case IEEE80211_SMPS_OFF:
                        smps = WMI_PEER_SMPS_PS_NONE;
                        smps = WMI_PEER_SMPS_DYNAMIC;
                        break;
                default:
-                       ath12k_warn(ar->ab, "Invalid smps %d in sta rc update for %pM\n",
-                                   sta->deflink.smps_mode, arsta->addr);
+                       ath12k_warn(ar->ab, "Invalid smps %d in sta rc update for %pM link %u\n",
+                                   link_sta->smps_mode, arsta->addr, link_sta->link_id);
                        smps = WMI_PEER_SMPS_PS_NONE;
                        break;
                }
 {
        struct ath12k_link_vif *arvif = data;
        struct ath12k_sta *ahsta = ath12k_sta_to_ahsta(sta);
-       struct ath12k_link_sta *arsta = &ahsta->deflink;
+       struct ath12k_link_sta *arsta;
        struct ath12k *ar = arvif->ar;
 
-       if (arsta->arvif != arvif)
+       arsta = rcu_dereference(ahsta->link[arvif->link_id]);
+       if (!arsta || arsta->arvif != arvif)
                return;
 
        spin_lock_bh(&ar->data_lock);
 {
        struct ath12k_link_vif *arvif = data;
        struct ath12k_sta *ahsta = ath12k_sta_to_ahsta(sta);
-       struct ath12k_link_sta *arsta = &ahsta->deflink;
+       struct ath12k_link_sta *arsta;
        struct ath12k *ar = arvif->ar;
        int ret;
 
-       if (arsta->arvif != arvif)
+       lockdep_assert_wiphy(ath12k_ar_to_hw(ar)->wiphy);
+
+       arsta = wiphy_dereference(ath12k_ar_to_hw(ar)->wiphy,
+                                 ahsta->link[arvif->link_id]);
+
+       if (!arsta || arsta->arvif != arvif)
                return;
 
        ret = ath12k_wmi_set_peer_param(ar, arsta->addr,