From: Johannes Berg <johannes.berg@intel.com>
Date: Wed, 15 Jun 2022 07:20:45 +0000 (+0200)
Subject: wifi: mac80211: RCU-ify link STA pointers
X-Git-Url: https://www.infradead.org/git/?a=commitdiff_plain;h=c71420db653aba30a234d1e4cf86dde376e604fa;p=linux.git

wifi: mac80211: RCU-ify link STA pointers

We need to be able to access these in a race-free way under
traffic while adding/removing them, so RCU-ify the pointers.
This requires passing a link_sta to a lot of functions so
we don't have to do the RCU handling everywhere.

Signed-off-by: Johannes Berg <johannes.berg@intel.com>
---

diff --git a/include/net/mac80211.h b/include/net/mac80211.h
index 06545ef1d41c..27f24ac0426d 100644
--- a/include/net/mac80211.h
+++ b/include/net/mac80211.h
@@ -2201,7 +2201,7 @@ struct ieee80211_sta {
 
 	u16 valid_links;
 	struct ieee80211_link_sta deflink;
-	struct ieee80211_link_sta *link[IEEE80211_MLD_MAX_NUM_LINKS];
+	struct ieee80211_link_sta __rcu *link[IEEE80211_MLD_MAX_NUM_LINKS];
 
 	/* must be last */
 	u8 drv_priv[] __aligned(sizeof(void *));
diff --git a/net/mac80211/cfg.c b/net/mac80211/cfg.c
index 009f1723c990..b387f5f4fef0 100644
--- a/net/mac80211/cfg.c
+++ b/net/mac80211/cfg.c
@@ -1771,19 +1771,21 @@ static int sta_apply_parameters(struct ieee80211_local *local,
 
 	if (params->ht_capa)
 		ieee80211_ht_cap_ie_to_sta_ht_cap(sdata, sband,
-						  params->ht_capa, sta, 0);
+						  params->ht_capa,
+						  &sta->deflink);
 
 	/* VHT can override some HT caps such as the A-MSDU max length */
 	if (params->vht_capa)
 		ieee80211_vht_cap_ie_to_sta_vht_cap(sdata, sband,
-						    params->vht_capa, sta, 0);
+						    params->vht_capa,
+						    &sta->deflink);
 
 	if (params->he_capa)
 		ieee80211_he_cap_ie_to_sta_he_cap(sdata, sband,
 						  (void *)params->he_capa,
 						  params->he_capa_len,
 						  (void *)params->he_6ghz_capa,
-						  sta, 0);
+						  &sta->deflink);
 
 	if (params->eht_capa)
 		ieee80211_eht_cap_ie_to_sta_eht_cap(sdata, sband,
@@ -1791,13 +1793,13 @@ static int sta_apply_parameters(struct ieee80211_local *local,
 						    params->he_capa_len,
 						    params->eht_capa,
 						    params->eht_capa_len,
-						    sta, 0);
+						    &sta->deflink);
 
 	if (params->opmode_notif_used) {
 		/* returned value is only needed for rc update, but the
 		 * rc isn't initialized here yet, so ignore it
 		 */
-		__ieee80211_vht_handle_opmode(sdata, sta, 0,
+		__ieee80211_vht_handle_opmode(sdata, &sta->deflink,
 					      params->opmode_notif,
 					      sband->band);
 	}
diff --git a/net/mac80211/chan.c b/net/mac80211/chan.c
index 4f25660d0eeb..6853b563fb6c 100644
--- a/net/mac80211/chan.c
+++ b/net/mac80211/chan.c
@@ -207,16 +207,20 @@ ieee80211_find_reservation_chanctx(struct ieee80211_local *local,
 static enum nl80211_chan_width ieee80211_get_sta_bw(struct sta_info *sta,
 						    unsigned int link_id)
 {
-	enum ieee80211_sta_rx_bandwidth width =
-		ieee80211_sta_cap_rx_bw(sta, link_id);
+	enum ieee80211_sta_rx_bandwidth width;
+	struct link_sta_info *link_sta;
+
+	link_sta = rcu_dereference(sta->link[link_id]);
 
 	/* no effect if this STA has no presence on this link */
-	if (!sta->sta.link[link_id])
+	if (!link_sta)
 		return NL80211_CHAN_WIDTH_20_NOHT;
 
+	width = ieee80211_sta_cap_rx_bw(link_sta);
+
 	switch (width) {
 	case IEEE80211_STA_RX_BW_20:
-		if (sta->sta.link[link_id]->ht_cap.ht_supported)
+		if (link_sta->pub->ht_cap.ht_supported)
 			return NL80211_CHAN_WIDTH_20;
 		else
 			return NL80211_CHAN_WIDTH_20_NOHT;
@@ -416,6 +420,7 @@ static void ieee80211_chan_bw_change(struct ieee80211_local *local,
 		for (link_id = 0; link_id < ARRAY_SIZE(sta->sdata->link); link_id++) {
 			struct ieee80211_bss_conf *link_conf =
 				sdata->vif.link_conf[link_id];
+			struct link_sta_info *link_sta;
 
 			if (!link_conf)
 				continue;
@@ -423,18 +428,22 @@ static void ieee80211_chan_bw_change(struct ieee80211_local *local,
 			if (rcu_access_pointer(link_conf->chanctx_conf) != &ctx->conf)
 				continue;
 
-			new_sta_bw = ieee80211_sta_cur_vht_bw(sta, link_id);
+			link_sta = rcu_dereference(sta->link[link_id]);
+			if (!link_sta)
+				continue;
+
+			new_sta_bw = ieee80211_sta_cur_vht_bw(link_sta);
 
 			/* nothing change */
-			if (new_sta_bw == sta->sta.link[link_id]->bandwidth)
+			if (new_sta_bw == link_sta->pub->bandwidth)
 				continue;
 
 			/* vif changed to narrow BW and narrow BW for station wasn't
 			 * requested or vise versa */
-			if ((new_sta_bw < sta->sta.link[link_id]->bandwidth) == !narrowed)
+			if ((new_sta_bw < link_sta->pub->bandwidth) == !narrowed)
 				continue;
 
-			sta->sta.link[link_id]->bandwidth = new_sta_bw;
+			link_sta->pub->bandwidth = new_sta_bw;
 			rate_control_rate_update(local, sband, sta, link_id,
 						 IEEE80211_RC_BW_CHANGED);
 		}
diff --git a/net/mac80211/eht.c b/net/mac80211/eht.c
index de762a803c38..31e20a342f21 100644
--- a/net/mac80211/eht.c
+++ b/net/mac80211/eht.c
@@ -12,11 +12,10 @@ ieee80211_eht_cap_ie_to_sta_eht_cap(struct ieee80211_sub_if_data *sdata,
 				    struct ieee80211_supported_band *sband,
 				    const u8 *he_cap_ie, u8 he_cap_len,
 				    const struct ieee80211_eht_cap_elem *eht_cap_ie_elem,
-				    u8 eht_cap_len, struct sta_info *sta,
-				    unsigned int link_id)
+				    u8 eht_cap_len,
+				    struct link_sta_info *link_sta)
 {
-	struct ieee80211_sta_eht_cap *eht_cap =
-		&sta->sta.link[link_id]->eht_cap;
+	struct ieee80211_sta_eht_cap *eht_cap = &link_sta->pub->eht_cap;
 	struct ieee80211_he_cap_elem *he_cap_ie_elem = (void *)he_cap_ie;
 	u8 eht_ppe_size = 0;
 	u8 mcs_nss_size;
@@ -73,8 +72,6 @@ ieee80211_eht_cap_ie_to_sta_eht_cap(struct ieee80211_sub_if_data *sdata,
 
 	eht_cap->has_eht = true;
 
-	sta->link[link_id]->cur_max_bandwidth =
-		ieee80211_sta_cap_rx_bw(sta, link_id);
-	sta->sta.link[link_id]->bandwidth =
-		ieee80211_sta_cur_vht_bw(sta, link_id);
+	link_sta->cur_max_bandwidth = ieee80211_sta_cap_rx_bw(link_sta);
+	link_sta->pub->bandwidth = ieee80211_sta_cur_vht_bw(link_sta);
 }
diff --git a/net/mac80211/he.c b/net/mac80211/he.c
index e48e9a021622..d9228fd3f77a 100644
--- a/net/mac80211/he.c
+++ b/net/mac80211/he.c
@@ -10,8 +10,9 @@
 
 static void
 ieee80211_update_from_he_6ghz_capa(const struct ieee80211_he_6ghz_capa *he_6ghz_capa,
-				   struct sta_info *sta, unsigned int link_id)
+				   struct link_sta_info *link_sta)
 {
+	struct sta_info *sta = link_sta->sta;
 	enum ieee80211_smps_mode smps_mode;
 
 	if (sta->sdata->vif.type == NL80211_IFTYPE_AP ||
@@ -49,7 +50,7 @@ ieee80211_update_from_he_6ghz_capa(const struct ieee80211_he_6ghz_capa *he_6ghz_
 		break;
 	}
 
-	sta->sta.link[link_id]->he_6ghz_capa = *he_6ghz_capa;
+	link_sta->pub->he_6ghz_capa = *he_6ghz_capa;
 }
 
 static void ieee80211_he_mcs_disable(__le16 *he_mcs)
@@ -108,9 +109,9 @@ ieee80211_he_cap_ie_to_sta_he_cap(struct ieee80211_sub_if_data *sdata,
 				  struct ieee80211_supported_band *sband,
 				  const u8 *he_cap_ie, u8 he_cap_len,
 				  const struct ieee80211_he_6ghz_capa *he_6ghz_capa,
-				  struct sta_info *sta, unsigned int link_id)
+				  struct link_sta_info *link_sta)
 {
-	struct ieee80211_sta_he_cap *he_cap = &sta->sta.link[link_id]->he_cap;
+	struct ieee80211_sta_he_cap *he_cap = &link_sta->pub->he_cap;
 	struct ieee80211_sta_he_cap own_he_cap;
 	struct ieee80211_he_cap_elem *he_cap_ie_elem = (void *)he_cap_ie;
 	u8 he_ppe_size;
@@ -153,13 +154,11 @@ ieee80211_he_cap_ie_to_sta_he_cap(struct ieee80211_sub_if_data *sdata,
 
 	he_cap->has_he = true;
 
-	sta->link[link_id]->cur_max_bandwidth =
-		ieee80211_sta_cap_rx_bw(sta, link_id);
-	sta->sta.link[link_id]->bandwidth =
-		ieee80211_sta_cur_vht_bw(sta, link_id);
+	link_sta->cur_max_bandwidth = ieee80211_sta_cap_rx_bw(link_sta);
+	link_sta->pub->bandwidth = ieee80211_sta_cur_vht_bw(link_sta);
 
 	if (sband->band == NL80211_BAND_6GHZ && he_6ghz_capa)
-		ieee80211_update_from_he_6ghz_capa(he_6ghz_capa, sta, link_id);
+		ieee80211_update_from_he_6ghz_capa(he_6ghz_capa, link_sta);
 
 	ieee80211_he_mcs_intersection(&own_he_cap.he_mcs_nss_supp.rx_mcs_80,
 				      &he_cap->he_mcs_nss_supp.rx_mcs_80,
diff --git a/net/mac80211/ht.c b/net/mac80211/ht.c
index 111828b559a0..2eb3a409b70f 100644
--- a/net/mac80211/ht.c
+++ b/net/mac80211/ht.c
@@ -138,8 +138,9 @@ void ieee80211_apply_htcap_overrides(struct ieee80211_sub_if_data *sdata,
 bool ieee80211_ht_cap_ie_to_sta_ht_cap(struct ieee80211_sub_if_data *sdata,
 				       struct ieee80211_supported_band *sband,
 				       const struct ieee80211_ht_cap *ht_cap_ie,
-				       struct sta_info *sta, unsigned int link_id)
+				       struct link_sta_info *link_sta)
 {
+	struct sta_info *sta = link_sta->sta;
 	struct ieee80211_sta_ht_cap ht_cap, own_cap;
 	u8 ampdu_info, tx_mcs_set_cap;
 	int i, max_tx_streams;
@@ -243,12 +244,11 @@ bool ieee80211_ht_cap_ie_to_sta_ht_cap(struct ieee80211_sub_if_data *sdata,
 		sta->sta.max_amsdu_len = IEEE80211_MAX_MPDU_LEN_HT_3839;
 
  apply:
-	changed = memcmp(&sta->sta.link[link_id]->ht_cap,
-			 &ht_cap, sizeof(ht_cap));
+	changed = memcmp(&link_sta->pub->ht_cap, &ht_cap, sizeof(ht_cap));
 
-	memcpy(&sta->sta.link[link_id]->ht_cap, &ht_cap, sizeof(ht_cap));
+	memcpy(&link_sta->pub->ht_cap, &ht_cap, sizeof(ht_cap));
 
-	switch (sdata->vif.link_conf[link_id]->chandef.width) {
+	switch (sdata->vif.link_conf[link_sta->link_id]->chandef.width) {
 	default:
 		WARN_ON_ONCE(1);
 		fallthrough;
@@ -265,9 +265,9 @@ bool ieee80211_ht_cap_ie_to_sta_ht_cap(struct ieee80211_sub_if_data *sdata,
 		break;
 	}
 
-	sta->sta.link[link_id]->bandwidth = bw;
+	link_sta->pub->bandwidth = bw;
 
-	sta->link[link_id]->cur_max_bandwidth =
+	link_sta->cur_max_bandwidth =
 		ht_cap.cap & IEEE80211_HT_CAP_SUP_WIDTH_20_40 ?
 				IEEE80211_STA_RX_BW_40 : IEEE80211_STA_RX_BW_20;
 
diff --git a/net/mac80211/ibss.c b/net/mac80211/ibss.c
index 65a3808dc92a..65b6255c6747 100644
--- a/net/mac80211/ibss.c
+++ b/net/mac80211/ibss.c
@@ -1051,7 +1051,7 @@ static void ieee80211_update_sta_info(struct ieee80211_sub_if_data *sdata,
 		memcpy(&htcap_ie, elems->ht_cap_elem, sizeof(htcap_ie));
 		rates_updated |= ieee80211_ht_cap_ie_to_sta_ht_cap(sdata, sband,
 								   &htcap_ie,
-								   sta, 0);
+								   &sta->deflink);
 
 		if (elems->vht_operation && elems->vht_cap_elem &&
 		    sdata->u.ibss.chandef.width != NL80211_CHAN_WIDTH_20 &&
@@ -1068,7 +1068,8 @@ static void ieee80211_update_sta_info(struct ieee80211_sub_if_data *sdata,
 						   &chandef);
 			memcpy(&cap_ie, elems->vht_cap_elem, sizeof(cap_ie));
 			ieee80211_vht_cap_ie_to_sta_vht_cap(sdata, sband,
-							    &cap_ie, sta, 0);
+							    &cap_ie,
+							    &sta->deflink);
 			if (memcmp(&cap, &sta->sta.deflink.vht_cap, sizeof(cap)))
 				rates_updated |= true;
 		}
diff --git a/net/mac80211/ieee80211_i.h b/net/mac80211/ieee80211_i.h
index aa5e4c2734b7..46f4e89825a0 100644
--- a/net/mac80211/ieee80211_i.h
+++ b/net/mac80211/ieee80211_i.h
@@ -2061,7 +2061,7 @@ void ieee80211_apply_htcap_overrides(struct ieee80211_sub_if_data *sdata,
 bool ieee80211_ht_cap_ie_to_sta_ht_cap(struct ieee80211_sub_if_data *sdata,
 				       struct ieee80211_supported_band *sband,
 				       const struct ieee80211_ht_cap *ht_cap_ie,
-				       struct sta_info *sta, unsigned int link_id);
+				       struct link_sta_info *link_sta);
 void ieee80211_send_delba(struct ieee80211_sub_if_data *sdata,
 			  const u8 *da, u16 tid,
 			  u16 initiator, u16 reason_code);
@@ -2117,31 +2117,31 @@ void
 ieee80211_vht_cap_ie_to_sta_vht_cap(struct ieee80211_sub_if_data *sdata,
 				    struct ieee80211_supported_band *sband,
 				    const struct ieee80211_vht_cap *vht_cap_ie,
-				    struct sta_info *sta, unsigned int link_id);
-enum ieee80211_sta_rx_bandwidth ieee80211_sta_cap_rx_bw(struct sta_info *sta,
-							unsigned int link_id);
-enum ieee80211_sta_rx_bandwidth ieee80211_sta_cur_vht_bw(struct sta_info *sta,
-							 unsigned int link_id);
-void ieee80211_sta_set_rx_nss(struct sta_info *sta, unsigned int link_id);
+				    struct link_sta_info *link_sta);
+enum ieee80211_sta_rx_bandwidth
+ieee80211_sta_cap_rx_bw(struct link_sta_info *link_sta);
+enum ieee80211_sta_rx_bandwidth
+ieee80211_sta_cur_vht_bw(struct link_sta_info *link_sta);
+void ieee80211_sta_set_rx_nss(struct link_sta_info *link_sta);
 enum ieee80211_sta_rx_bandwidth
 ieee80211_chan_width_to_rx_bw(enum nl80211_chan_width width);
-enum nl80211_chan_width ieee80211_sta_cap_chan_bw(struct sta_info *sta,
-						  unsigned int link_id);
+enum nl80211_chan_width
+ieee80211_sta_cap_chan_bw(struct link_sta_info *link_sta);
 void ieee80211_process_mu_groups(struct ieee80211_sub_if_data *sdata,
 				 unsigned int link_id,
 				 struct ieee80211_mgmt *mgmt);
 u32 __ieee80211_vht_handle_opmode(struct ieee80211_sub_if_data *sdata,
-				  struct sta_info *sta, unsigned int link_id,
+				  struct link_sta_info *sta,
 				  u8 opmode, enum nl80211_band band);
 void ieee80211_vht_handle_opmode(struct ieee80211_sub_if_data *sdata,
-				 struct sta_info *sta, unsigned int link_id,
+				 struct link_sta_info *sta,
 				 u8 opmode, enum nl80211_band band);
 void ieee80211_apply_vhtcap_overrides(struct ieee80211_sub_if_data *sdata,
 				      struct ieee80211_sta_vht_cap *vht_cap);
 void ieee80211_get_vht_mask_from_cap(__le16 vht_cap,
 				     u16 vht_mask[NL80211_VHT_NSS_MAX]);
 enum nl80211_chan_width
-ieee80211_sta_rx_bw_to_chan_width(struct sta_info *sta, unsigned int link_id);
+ieee80211_sta_rx_bw_to_chan_width(struct link_sta_info *sta);
 
 /* HE */
 void
@@ -2149,7 +2149,7 @@ ieee80211_he_cap_ie_to_sta_he_cap(struct ieee80211_sub_if_data *sdata,
 				  struct ieee80211_supported_band *sband,
 				  const u8 *he_cap_ie, u8 he_cap_len,
 				  const struct ieee80211_he_6ghz_capa *he_6ghz_capa,
-				  struct sta_info *sta, unsigned int link_id);
+				  struct link_sta_info *link_sta);
 void
 ieee80211_he_spr_ie_to_bss_conf(struct ieee80211_vif *vif,
 				const struct ieee80211_he_spr *he_spr_ie_elem);
@@ -2560,6 +2560,6 @@ ieee80211_eht_cap_ie_to_sta_eht_cap(struct ieee80211_sub_if_data *sdata,
 				    struct ieee80211_supported_band *sband,
 				    const u8 *he_cap_ie, u8 he_cap_len,
 				    const struct ieee80211_eht_cap_elem *eht_cap_ie_elem,
-				    u8 eht_cap_len, struct sta_info *sta,
-				    unsigned int link_id);
+				    u8 eht_cap_len,
+				    struct link_sta_info *link_sta);
 #endif /* IEEE80211_I_H */
diff --git a/net/mac80211/iface.c b/net/mac80211/iface.c
index f118e7710fb1..0cf5a395d925 100644
--- a/net/mac80211/iface.c
+++ b/net/mac80211/iface.c
@@ -1525,7 +1525,8 @@ static void ieee80211_iface_process_skb(struct ieee80211_local *local,
 			sta = sta_info_get_bss(sdata, mgmt->sa);
 
 			if (sta)
-				ieee80211_vht_handle_opmode(sdata, sta, 0,
+				ieee80211_vht_handle_opmode(sdata,
+							    &sta->deflink,
 							    opmode, band);
 
 			mutex_unlock(&local->sta_mtx);
diff --git a/net/mac80211/mesh_plink.c b/net/mac80211/mesh_plink.c
index fe54ac431202..55bed9ce98fe 100644
--- a/net/mac80211/mesh_plink.c
+++ b/net/mac80211/mesh_plink.c
@@ -438,16 +438,18 @@ static void mesh_sta_info_init(struct ieee80211_sub_if_data *sdata,
 	sta->sta.deflink.supp_rates[sband->band] = rates;
 
 	if (ieee80211_ht_cap_ie_to_sta_ht_cap(sdata, sband,
-					      elems->ht_cap_elem, sta, 0))
+					      elems->ht_cap_elem,
+					      &sta->deflink))
 		changed |= IEEE80211_RC_BW_CHANGED;
 
 	ieee80211_vht_cap_ie_to_sta_vht_cap(sdata, sband,
-					    elems->vht_cap_elem, sta, 0);
+					    elems->vht_cap_elem,
+					    &sta->deflink);
 
 	ieee80211_he_cap_ie_to_sta_he_cap(sdata, sband, elems->he_cap,
 					  elems->he_cap_len,
 					  elems->he_6ghz_capa,
-					  sta, 0);
+					  &sta->deflink);
 
 	if (bw != sta->sta.deflink.bandwidth)
 		changed |= IEEE80211_RC_BW_CHANGED;
diff --git a/net/mac80211/mlme.c b/net/mac80211/mlme.c
index c2c997086553..01a72d1fcfcc 100644
--- a/net/mac80211/mlme.c
+++ b/net/mac80211/mlme.c
@@ -3566,12 +3566,13 @@ static bool ieee80211_assoc_success(struct ieee80211_sub_if_data *sdata,
 	/* Set up internal HT/VHT capabilities */
 	if (elems->ht_cap_elem && !(ifmgd->flags & IEEE80211_STA_DISABLE_HT))
 		ieee80211_ht_cap_ie_to_sta_ht_cap(sdata, sband,
-						  elems->ht_cap_elem, sta, 0);
+						  elems->ht_cap_elem,
+						  &sta->deflink);
 
 	if (elems->vht_cap_elem && !(ifmgd->flags & IEEE80211_STA_DISABLE_VHT))
 		ieee80211_vht_cap_ie_to_sta_vht_cap(sdata, sband,
 						    elems->vht_cap_elem,
-						    sta, 0);
+						    &sta->deflink);
 
 	if (elems->he_operation && !(ifmgd->flags & IEEE80211_STA_DISABLE_HE) &&
 	    elems->he_cap) {
@@ -3579,7 +3580,7 @@ static bool ieee80211_assoc_success(struct ieee80211_sub_if_data *sdata,
 						  elems->he_cap,
 						  elems->he_cap_len,
 						  elems->he_6ghz_capa,
-						  sta, 0);
+						  &sta->deflink);
 
 		bss_conf->he_support = sta->sta.deflink.he_cap.has_he;
 		if (elems->rsnx && elems->rsnx_len &&
@@ -3599,7 +3600,7 @@ static bool ieee80211_assoc_success(struct ieee80211_sub_if_data *sdata,
 							    elems->he_cap_len,
 							    elems->eht_cap,
 							    elems->eht_cap_len,
-							    sta, 0);
+							    &sta->deflink);
 
 			bss_conf->eht_support = sta->sta.deflink.eht_cap.has_eht;
 		} else {
@@ -4378,7 +4379,8 @@ static void ieee80211_rx_mgmt_beacon(struct ieee80211_sub_if_data *sdata,
 	}
 
 	if (sta && elems->opmode_notif)
-		ieee80211_vht_handle_opmode(sdata, sta, 0,
+		ieee80211_vht_handle_opmode(sdata,
+					    &sta->deflink,
 					    *elems->opmode_notif,
 					    rx_status->band);
 	mutex_unlock(&local->sta_mtx);
diff --git a/net/mac80211/rate.c b/net/mac80211/rate.c
index c58d9689f51f..7947e9a162a9 100644
--- a/net/mac80211/rate.c
+++ b/net/mac80211/rate.c
@@ -37,7 +37,7 @@ void rate_control_rate_init(struct sta_info *sta)
 	struct ieee80211_supported_band *sband;
 	struct ieee80211_chanctx_conf *chanctx_conf;
 
-	ieee80211_sta_set_rx_nss(sta, 0);
+	ieee80211_sta_set_rx_nss(&sta->deflink);
 
 	if (!ref)
 		return;
diff --git a/net/mac80211/rx.c b/net/mac80211/rx.c
index 0b5c34d00c14..d017ad14d7db 100644
--- a/net/mac80211/rx.c
+++ b/net/mac80211/rx.c
@@ -3391,11 +3391,11 @@ ieee80211_rx_h_action(struct ieee80211_rx_data *rx)
 			if (chanwidth == IEEE80211_HT_CHANWIDTH_20MHZ)
 				max_bw = IEEE80211_STA_RX_BW_20;
 			else
-				max_bw = ieee80211_sta_cap_rx_bw(rx->sta, 0);
+				max_bw = ieee80211_sta_cap_rx_bw(&rx->sta->deflink);
 
 			/* set cur_max_bandwidth and recalc sta bw */
 			rx->sta->deflink.cur_max_bandwidth = max_bw;
-			new_bw = ieee80211_sta_cur_vht_bw(rx->sta, 0);
+			new_bw = ieee80211_sta_cur_vht_bw(&rx->sta->deflink);
 
 			if (rx->sta->sta.deflink.bandwidth == new_bw)
 				goto handled;
@@ -3403,7 +3403,7 @@ ieee80211_rx_h_action(struct ieee80211_rx_data *rx)
 			rx->sta->sta.deflink.bandwidth = new_bw;
 			sband = rx->local->hw.wiphy->bands[status->band];
 			sta_opmode.bw =
-				ieee80211_sta_rx_bw_to_chan_width(rx->sta, 0);
+				ieee80211_sta_rx_bw_to_chan_width(&rx->sta->deflink);
 			sta_opmode.changed = STA_OPMODE_MAX_BW_CHANGED;
 
 			rate_control_rate_update(local, sband, rx->sta, 0,
diff --git a/net/mac80211/sta_info.c b/net/mac80211/sta_info.c
index b1426a2459e8..1f4189e08675 100644
--- a/net/mac80211/sta_info.c
+++ b/net/mac80211/sta_info.c
@@ -67,6 +67,7 @@
 struct sta_link_alloc {
 	struct link_sta_info info;
 	struct ieee80211_link_sta sta;
+	struct rcu_head rcu_head;
 };
 
 static const struct rhashtable_params sta_rht_params = {
@@ -258,19 +259,23 @@ static void sta_info_free_link(struct link_sta_info *link_sta)
 static void sta_remove_link(struct sta_info *sta, unsigned int link_id)
 {
 	struct sta_link_alloc *alloc = NULL;
+	struct link_sta_info *link_sta;
 
-	if (WARN_ON(!sta->link[link_id]))
+	link_sta = rcu_dereference_protected(sta->link[link_id],
+					     lockdep_is_held(&sta->local->sta_mtx));
+
+	if (WARN_ON(!link_sta))
 		return;
 
-	if (sta->link[link_id] != &sta->deflink)
-		alloc = container_of(sta->link[link_id], typeof(*alloc), info);
+	if (link_sta != &sta->deflink)
+		alloc = container_of(link_sta, typeof(*alloc), info);
 
 	sta->sta.valid_links &= ~BIT(link_id);
-	sta->link[link_id] = NULL;
-	sta->sta.link[link_id] = NULL;
+	RCU_INIT_POINTER(sta->link[link_id], NULL);
+	RCU_INIT_POINTER(sta->sta.link[link_id], NULL);
 	if (alloc) {
 		sta_info_free_link(&alloc->info);
-		kfree(alloc);
+		kfree_rcu(alloc, rcu_head);
 	}
 }
 
@@ -404,8 +409,9 @@ static void sta_info_add_link(struct sta_info *sta,
 {
 	link_info->sta = sta;
 	link_info->link_id = link_id;
-	sta->link[link_id] = link_info;
-	sta->sta.link[link_id] = link_sta;
+	link_info->pub = link_sta;
+	rcu_assign_pointer(sta->link[link_id], link_info);
+	rcu_assign_pointer(sta->sta.link[link_id], link_sta);
 }
 
 struct sta_info *sta_info_alloc(struct ieee80211_sub_if_data *sdata,
@@ -2682,13 +2688,15 @@ int ieee80211_sta_allocate_link(struct sta_info *sta, unsigned int link_id)
 int ieee80211_sta_activate_link(struct sta_info *sta, unsigned int link_id)
 {
 	struct ieee80211_sub_if_data *sdata = sta->sdata;
+	struct link_sta_info *link_sta;
 	u16 old_links = sta->sta.valid_links;
 	u16 new_links = old_links | BIT(link_id);
 	int ret;
 
-	lockdep_assert_held(&sdata->local->sta_mtx);
+	link_sta = rcu_dereference_protected(sta->link[link_id],
+					     lockdep_is_held(&sdata->local->sta_mtx));
 
-	if (WARN_ON(old_links == new_links || !sta->link[link_id]))
+	if (WARN_ON(old_links == new_links || !link_sta))
 		return -EINVAL;
 
 	sta->sta.valid_links = new_links;
diff --git a/net/mac80211/sta_info.h b/net/mac80211/sta_info.h
index 8ec65bb7d13e..27c96c04b13f 100644
--- a/net/mac80211/sta_info.h
+++ b/net/mac80211/sta_info.h
@@ -516,6 +516,7 @@ struct ieee80211_fragment_cache {
  * @status_stats.last_ack_signal: last ACK signal
  * @status_stats.ack_signal_filled: last ACK signal validity
  * @status_stats.avg_ack_signal: average ACK signal
+ * @pub: public (driver visible) link STA data
  * TODO Move other link params from sta_info as required for MLD operation
  */
 struct link_sta_info {
@@ -561,6 +562,8 @@ struct link_sta_info {
 	} tx_stats;
 
 	enum ieee80211_sta_rx_bandwidth cur_max_bandwidth;
+
+	struct ieee80211_link_sta *pub;
 };
 
 /**
@@ -708,7 +711,7 @@ struct sta_info {
 	struct ieee80211_fragment_cache frags;
 
 	struct link_sta_info deflink;
-	struct link_sta_info *link[IEEE80211_MLD_MAX_NUM_LINKS];
+	struct link_sta_info __rcu *link[IEEE80211_MLD_MAX_NUM_LINKS];
 
 	/* keep last! */
 	struct ieee80211_sta sta;
diff --git a/net/mac80211/tdls.c b/net/mac80211/tdls.c
index 4fc120c64022..c531fa17f426 100644
--- a/net/mac80211/tdls.c
+++ b/net/mac80211/tdls.c
@@ -308,7 +308,8 @@ ieee80211_tdls_chandef_vht_upgrade(struct ieee80211_sub_if_data *sdata,
 	/* IEEE802.11ac-2013 Table E-4 */
 	u16 centers_80mhz[] = { 5210, 5290, 5530, 5610, 5690, 5775 };
 	struct cfg80211_chan_def uc = sta->tdls_chandef;
-	enum nl80211_chan_width max_width = ieee80211_sta_cap_chan_bw(sta, 0);
+	enum nl80211_chan_width max_width =
+		ieee80211_sta_cap_chan_bw(&sta->deflink);
 	int i;
 
 	/* only support upgrading non-narrow channels up to 80Mhz */
@@ -1268,7 +1269,7 @@ static void iee80211_tdls_recalc_chanctx(struct ieee80211_sub_if_data *sdata,
 			enum ieee80211_sta_rx_bandwidth bw;
 
 			bw = ieee80211_chan_width_to_rx_bw(conf->def.width);
-			bw = min(bw, ieee80211_sta_cap_rx_bw(sta, 0));
+			bw = min(bw, ieee80211_sta_cap_rx_bw(&sta->deflink));
 			if (bw != sta->sta.deflink.bandwidth) {
 				sta->sta.deflink.bandwidth = bw;
 				rate_control_rate_update(local, sband, sta, 0,
diff --git a/net/mac80211/vht.c b/net/mac80211/vht.c
index acfe1459535f..fa14627b499a 100644
--- a/net/mac80211/vht.c
+++ b/net/mac80211/vht.c
@@ -116,16 +116,16 @@ void
 ieee80211_vht_cap_ie_to_sta_vht_cap(struct ieee80211_sub_if_data *sdata,
 				    struct ieee80211_supported_band *sband,
 				    const struct ieee80211_vht_cap *vht_cap_ie,
-				    struct sta_info *sta, unsigned int link_id)
+				    struct link_sta_info *link_sta)
 {
-	struct ieee80211_sta_vht_cap *vht_cap = &sta->sta.link[link_id]->vht_cap;
+	struct ieee80211_sta_vht_cap *vht_cap = &link_sta->pub->vht_cap;
 	struct ieee80211_sta_vht_cap own_cap;
 	u32 cap_info, i;
 	bool have_80mhz;
 
 	memset(vht_cap, 0, sizeof(*vht_cap));
 
-	if (!sta->sta.link[link_id]->ht_cap.ht_supported)
+	if (!link_sta->pub->ht_cap.ht_supported)
 		return;
 
 	if (!vht_cap_ie || !sband->vht_cap.vht_supported)
@@ -162,7 +162,7 @@ ieee80211_vht_cap_ie_to_sta_vht_cap(struct ieee80211_sub_if_data *sdata,
 	 * our own capabilities and then use those below.
 	 */
 	if (sdata->vif.type == NL80211_IFTYPE_STATION &&
-	    !test_sta_flag(sta, WLAN_STA_TDLS_PEER))
+	    !test_sta_flag(link_sta->sta, WLAN_STA_TDLS_PEER))
 		ieee80211_apply_vhtcap_overrides(sdata, &own_cap);
 
 	/* take some capabilities as-is */
@@ -286,8 +286,9 @@ ieee80211_vht_cap_ie_to_sta_vht_cap(struct ieee80211_sub_if_data *sdata,
 	 */
 	if (vht_cap->vht_mcs.rx_mcs_map == cpu_to_le16(0xFFFF)) {
 		vht_cap->vht_supported = false;
-		sdata_info(sdata, "Ignoring VHT IE from %pM due to invalid rx_mcs_map\n",
-			   sta->addr);
+		sdata_info(sdata,
+			   "Ignoring VHT IE from %pM (link:%pM) due to invalid rx_mcs_map\n",
+			   link_sta->sta->addr, link_sta->addr);
 		return;
 	}
 
@@ -295,10 +296,10 @@ ieee80211_vht_cap_ie_to_sta_vht_cap(struct ieee80211_sub_if_data *sdata,
 	switch (vht_cap->cap & IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_MASK) {
 	case IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160MHZ:
 	case IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160_80PLUS80MHZ:
-		sta->link[link_id]->cur_max_bandwidth = IEEE80211_STA_RX_BW_160;
+		link_sta->cur_max_bandwidth = IEEE80211_STA_RX_BW_160;
 		break;
 	default:
-		sta->link[link_id]->cur_max_bandwidth = IEEE80211_STA_RX_BW_80;
+		link_sta->cur_max_bandwidth = IEEE80211_STA_RX_BW_80;
 
 		if (!(vht_cap->vht_mcs.tx_highest &
 				cpu_to_le16(IEEE80211_VHT_EXT_NSS_BW_CAPABLE)))
@@ -310,35 +311,40 @@ ieee80211_vht_cap_ie_to_sta_vht_cap(struct ieee80211_sub_if_data *sdata,
 		 * above) between 160 and 80+80 yet.
 		 */
 		if (cap_info & IEEE80211_VHT_CAP_EXT_NSS_BW_MASK)
-			sta->link[link_id]->cur_max_bandwidth =
+			link_sta->cur_max_bandwidth =
 				IEEE80211_STA_RX_BW_160;
 	}
 
-	sta->sta.link[link_id]->bandwidth = ieee80211_sta_cur_vht_bw(sta, link_id);
+	link_sta->pub->bandwidth = ieee80211_sta_cur_vht_bw(link_sta);
 
+	/*
+	 * FIXME - should the amsdu len be per link? store per link
+	 * and maintain a minimum?
+	 */
 	switch (vht_cap->cap & IEEE80211_VHT_CAP_MAX_MPDU_MASK) {
 	case IEEE80211_VHT_CAP_MAX_MPDU_LENGTH_11454:
-		sta->sta.max_amsdu_len = IEEE80211_MAX_MPDU_LEN_VHT_11454;
+		link_sta->sta->sta.max_amsdu_len = IEEE80211_MAX_MPDU_LEN_VHT_11454;
 		break;
 	case IEEE80211_VHT_CAP_MAX_MPDU_LENGTH_7991:
-		sta->sta.max_amsdu_len = IEEE80211_MAX_MPDU_LEN_VHT_7991;
+		link_sta->sta->sta.max_amsdu_len = IEEE80211_MAX_MPDU_LEN_VHT_7991;
 		break;
 	case IEEE80211_VHT_CAP_MAX_MPDU_LENGTH_3895:
 	default:
-		sta->sta.max_amsdu_len = IEEE80211_MAX_MPDU_LEN_VHT_3895;
+		link_sta->sta->sta.max_amsdu_len = IEEE80211_MAX_MPDU_LEN_VHT_3895;
 		break;
 	}
 }
 
 /* FIXME: move this to some better location - parses HE/EHT now */
-enum ieee80211_sta_rx_bandwidth ieee80211_sta_cap_rx_bw(struct sta_info *sta,
-							unsigned int link_id)
+enum ieee80211_sta_rx_bandwidth
+ieee80211_sta_cap_rx_bw(struct link_sta_info *link_sta)
 {
-	struct ieee80211_bss_conf *link_conf = sta->sdata->vif.link_conf[link_id];
-	struct ieee80211_link_sta *link_sta = sta->sta.link[link_id];
-	struct ieee80211_sta_vht_cap *vht_cap = &link_sta->vht_cap;
-	struct ieee80211_sta_he_cap *he_cap = &link_sta->he_cap;
-	struct ieee80211_sta_eht_cap *eht_cap = &link_sta->eht_cap;
+	unsigned int link_id = link_sta->link_id;
+	struct ieee80211_sub_if_data *sdata = link_sta->sta->sdata;
+	struct ieee80211_bss_conf *link_conf = sdata->vif.link_conf[link_id];
+	struct ieee80211_sta_vht_cap *vht_cap = &link_sta->pub->vht_cap;
+	struct ieee80211_sta_he_cap *he_cap = &link_sta->pub->he_cap;
+	struct ieee80211_sta_eht_cap *eht_cap = &link_sta->pub->eht_cap;
 	u32 cap_width;
 
 	if (he_cap->has_he) {
@@ -371,7 +377,7 @@ enum ieee80211_sta_rx_bandwidth ieee80211_sta_cap_rx_bw(struct sta_info *sta,
 	}
 
 	if (!vht_cap->vht_supported)
-		return link_sta->ht_cap.cap & IEEE80211_HT_CAP_SUP_WIDTH_20_40 ?
+		return link_sta->pub->ht_cap.cap & IEEE80211_HT_CAP_SUP_WIDTH_20_40 ?
 				IEEE80211_STA_RX_BW_40 :
 				IEEE80211_STA_RX_BW_20;
 
@@ -392,17 +398,17 @@ enum ieee80211_sta_rx_bandwidth ieee80211_sta_cap_rx_bw(struct sta_info *sta,
 	return IEEE80211_STA_RX_BW_80;
 }
 
-enum nl80211_chan_width ieee80211_sta_cap_chan_bw(struct sta_info *sta,
-						  unsigned int link_id)
+enum nl80211_chan_width
+ieee80211_sta_cap_chan_bw(struct link_sta_info *link_sta)
 {
-	struct ieee80211_sta_vht_cap *vht_cap = &sta->sta.link[link_id]->vht_cap;
+	struct ieee80211_sta_vht_cap *vht_cap = &link_sta->pub->vht_cap;
 	u32 cap_width;
 
 	if (!vht_cap->vht_supported) {
-		if (!sta->sta.link[link_id]->ht_cap.ht_supported)
+		if (!link_sta->pub->ht_cap.ht_supported)
 			return NL80211_CHAN_WIDTH_20_NOHT;
 
-		return sta->sta.link[link_id]->ht_cap.cap & IEEE80211_HT_CAP_SUP_WIDTH_20_40 ?
+		return link_sta->pub->ht_cap.cap & IEEE80211_HT_CAP_SUP_WIDTH_20_40 ?
 				NL80211_CHAN_WIDTH_40 : NL80211_CHAN_WIDTH_20;
 	}
 
@@ -417,17 +423,17 @@ enum nl80211_chan_width ieee80211_sta_cap_chan_bw(struct sta_info *sta,
 }
 
 enum nl80211_chan_width
-ieee80211_sta_rx_bw_to_chan_width(struct sta_info *sta, unsigned int link_id)
+ieee80211_sta_rx_bw_to_chan_width(struct link_sta_info *link_sta)
 {
 	enum ieee80211_sta_rx_bandwidth cur_bw =
-		sta->sta.link[link_id]->bandwidth;
+		link_sta->pub->bandwidth;
 	struct ieee80211_sta_vht_cap *vht_cap =
-		&sta->sta.link[link_id]->vht_cap;
+		&link_sta->pub->vht_cap;
 	u32 cap_width;
 
 	switch (cur_bw) {
 	case IEEE80211_STA_RX_BW_20:
-		if (!sta->sta.link[link_id]->ht_cap.ht_supported)
+		if (!link_sta->pub->ht_cap.ht_supported)
 			return NL80211_CHAN_WIDTH_20_NOHT;
 		else
 			return NL80211_CHAN_WIDTH_20;
@@ -471,15 +477,17 @@ ieee80211_chan_width_to_rx_bw(enum nl80211_chan_width width)
 }
 
 /* FIXME: rename/move - this deals with everything not just VHT */
-enum ieee80211_sta_rx_bandwidth ieee80211_sta_cur_vht_bw(struct sta_info *sta,
-							 unsigned int link_id)
+enum ieee80211_sta_rx_bandwidth
+ieee80211_sta_cur_vht_bw(struct link_sta_info *link_sta)
 {
-	struct ieee80211_bss_conf *link_conf = sta->sdata->vif.link_conf[link_id];
+	struct sta_info *sta = link_sta->sta;
+	struct ieee80211_bss_conf *link_conf =
+		sta->sdata->vif.link_conf[link_sta->link_id];
 	enum nl80211_chan_width bss_width = link_conf->chandef.width;
 	enum ieee80211_sta_rx_bandwidth bw;
 
-	bw = ieee80211_sta_cap_rx_bw(sta, link_id);
-	bw = min(bw, sta->link[link_id]->cur_max_bandwidth);
+	bw = ieee80211_sta_cap_rx_bw(link_sta);
+	bw = min(bw, link_sta->cur_max_bandwidth);
 
 	/* Don't consider AP's bandwidth for TDLS peers, section 11.23.1 of
 	 * IEEE80211-2016 specification makes higher bandwidth operation
@@ -501,18 +509,18 @@ enum ieee80211_sta_rx_bandwidth ieee80211_sta_cur_vht_bw(struct sta_info *sta,
 	return bw;
 }
 
-void ieee80211_sta_set_rx_nss(struct sta_info *sta, unsigned int link_id)
+void ieee80211_sta_set_rx_nss(struct link_sta_info *link_sta)
 {
 	u8 ht_rx_nss = 0, vht_rx_nss = 0, he_rx_nss = 0, eht_rx_nss = 0, rx_nss;
 	bool support_160;
 
 	/* if we received a notification already don't overwrite it */
-	if (sta->sta.link[link_id]->rx_nss)
+	if (link_sta->pub->rx_nss)
 		return;
 
-	if (sta->sta.link[link_id]->eht_cap.has_eht) {
+	if (link_sta->pub->eht_cap.has_eht) {
 		int i;
-		const u8 *rx_nss_mcs = (void *)&sta->sta.link[link_id]->eht_cap.eht_mcs_nss_supp;
+		const u8 *rx_nss_mcs = (void *)&link_sta->pub->eht_cap.eht_mcs_nss_supp;
 
 		/* get the max nss for EHT over all possible bandwidths and mcs */
 		for (i = 0; i < sizeof(struct ieee80211_eht_mcs_nss_supp); i++)
@@ -521,10 +529,10 @@ void ieee80211_sta_set_rx_nss(struct sta_info *sta, unsigned int link_id)
 						       IEEE80211_EHT_MCS_NSS_RX));
 	}
 
-	if (sta->sta.link[link_id]->he_cap.has_he) {
+	if (link_sta->pub->he_cap.has_he) {
 		int i;
 		u8 rx_mcs_80 = 0, rx_mcs_160 = 0;
-		const struct ieee80211_sta_he_cap *he_cap = &sta->sta.link[link_id]->he_cap;
+		const struct ieee80211_sta_he_cap *he_cap = &link_sta->pub->he_cap;
 		u16 mcs_160_map =
 			le16_to_cpu(he_cap->he_mcs_nss_supp.rx_mcs_160);
 		u16 mcs_80_map = le16_to_cpu(he_cap->he_mcs_nss_supp.rx_mcs_80);
@@ -555,23 +563,23 @@ void ieee80211_sta_set_rx_nss(struct sta_info *sta, unsigned int link_id)
 			he_rx_nss = rx_mcs_80;
 	}
 
-	if (sta->sta.link[link_id]->ht_cap.ht_supported) {
-		if (sta->sta.link[link_id]->ht_cap.mcs.rx_mask[0])
+	if (link_sta->pub->ht_cap.ht_supported) {
+		if (link_sta->pub->ht_cap.mcs.rx_mask[0])
 			ht_rx_nss++;
-		if (sta->sta.link[link_id]->ht_cap.mcs.rx_mask[1])
+		if (link_sta->pub->ht_cap.mcs.rx_mask[1])
 			ht_rx_nss++;
-		if (sta->sta.link[link_id]->ht_cap.mcs.rx_mask[2])
+		if (link_sta->pub->ht_cap.mcs.rx_mask[2])
 			ht_rx_nss++;
-		if (sta->sta.link[link_id]->ht_cap.mcs.rx_mask[3])
+		if (link_sta->pub->ht_cap.mcs.rx_mask[3])
 			ht_rx_nss++;
 		/* FIXME: consider rx_highest? */
 	}
 
-	if (sta->sta.link[link_id]->vht_cap.vht_supported) {
+	if (link_sta->pub->vht_cap.vht_supported) {
 		int i;
 		u16 rx_mcs_map;
 
-		rx_mcs_map = le16_to_cpu(sta->sta.link[link_id]->vht_cap.vht_mcs.rx_mcs_map);
+		rx_mcs_map = le16_to_cpu(link_sta->pub->vht_cap.vht_mcs.rx_mcs_map);
 
 		for (i = 7; i >= 0; i--) {
 			u8 mcs = (rx_mcs_map >> (2 * i)) & 3;
@@ -587,11 +595,11 @@ void ieee80211_sta_set_rx_nss(struct sta_info *sta, unsigned int link_id)
 	rx_nss = max(vht_rx_nss, ht_rx_nss);
 	rx_nss = max(he_rx_nss, rx_nss);
 	rx_nss = max(eht_rx_nss, rx_nss);
-	sta->sta.link[link_id]->rx_nss = max_t(u8, 1, rx_nss);
+	link_sta->pub->rx_nss = max_t(u8, 1, rx_nss);
 }
 
 u32 __ieee80211_vht_handle_opmode(struct ieee80211_sub_if_data *sdata,
-				  struct sta_info *sta, unsigned int link_id,
+				  struct link_sta_info *link_sta,
 				  u8 opmode, enum nl80211_band band)
 {
 	enum ieee80211_sta_rx_bandwidth new_bw;
@@ -607,8 +615,8 @@ u32 __ieee80211_vht_handle_opmode(struct ieee80211_sub_if_data *sdata,
 	nss >>= IEEE80211_OPMODE_NOTIF_RX_NSS_SHIFT;
 	nss += 1;
 
-	if (sta->sta.link[link_id]->rx_nss != nss) {
-		sta->sta.link[link_id]->rx_nss = nss;
+	if (link_sta->pub->rx_nss != nss) {
+		link_sta->pub->rx_nss = nss;
 		sta_opmode.rx_nss = nss;
 		changed |= IEEE80211_RC_NSS_CHANGED;
 		sta_opmode.changed |= STA_OPMODE_N_SS_CHANGED;
@@ -617,34 +625,34 @@ u32 __ieee80211_vht_handle_opmode(struct ieee80211_sub_if_data *sdata,
 	switch (opmode & IEEE80211_OPMODE_NOTIF_CHANWIDTH_MASK) {
 	case IEEE80211_OPMODE_NOTIF_CHANWIDTH_20MHZ:
 		/* ignore IEEE80211_OPMODE_NOTIF_BW_160_80P80 must not be set */
-		sta->link[link_id]->cur_max_bandwidth = IEEE80211_STA_RX_BW_20;
+		link_sta->cur_max_bandwidth = IEEE80211_STA_RX_BW_20;
 		break;
 	case IEEE80211_OPMODE_NOTIF_CHANWIDTH_40MHZ:
 		/* ignore IEEE80211_OPMODE_NOTIF_BW_160_80P80 must not be set */
-		sta->link[link_id]->cur_max_bandwidth = IEEE80211_STA_RX_BW_40;
+		link_sta->cur_max_bandwidth = IEEE80211_STA_RX_BW_40;
 		break;
 	case IEEE80211_OPMODE_NOTIF_CHANWIDTH_80MHZ:
 		if (opmode & IEEE80211_OPMODE_NOTIF_BW_160_80P80)
-			sta->link[link_id]->cur_max_bandwidth = IEEE80211_STA_RX_BW_160;
+			link_sta->cur_max_bandwidth = IEEE80211_STA_RX_BW_160;
 		else
-			sta->link[link_id]->cur_max_bandwidth = IEEE80211_STA_RX_BW_80;
+			link_sta->cur_max_bandwidth = IEEE80211_STA_RX_BW_80;
 		break;
 	case IEEE80211_OPMODE_NOTIF_CHANWIDTH_160MHZ:
 		/* legacy only, no longer used by newer spec */
-		sta->link[link_id]->cur_max_bandwidth = IEEE80211_STA_RX_BW_160;
+		link_sta->cur_max_bandwidth = IEEE80211_STA_RX_BW_160;
 		break;
 	}
 
-	new_bw = ieee80211_sta_cur_vht_bw(sta, link_id);
-	if (new_bw != sta->sta.link[link_id]->bandwidth) {
-		sta->sta.link[link_id]->bandwidth = new_bw;
-		sta_opmode.bw = ieee80211_sta_rx_bw_to_chan_width(sta, link_id);
+	new_bw = ieee80211_sta_cur_vht_bw(link_sta);
+	if (new_bw != link_sta->pub->bandwidth) {
+		link_sta->pub->bandwidth = new_bw;
+		sta_opmode.bw = ieee80211_sta_rx_bw_to_chan_width(link_sta);
 		changed |= IEEE80211_RC_BW_CHANGED;
 		sta_opmode.changed |= STA_OPMODE_MAX_BW_CHANGED;
 	}
 
 	if (sta_opmode.changed)
-		cfg80211_sta_opmode_change_notify(sdata->dev, sta->addr,
+		cfg80211_sta_opmode_change_notify(sdata->dev, link_sta->addr,
 						  &sta_opmode, GFP_KERNEL);
 
 	return changed;
@@ -689,18 +697,19 @@ void ieee80211_update_mu_groups(struct ieee80211_vif *vif, unsigned int link_id,
 EXPORT_SYMBOL_GPL(ieee80211_update_mu_groups);
 
 void ieee80211_vht_handle_opmode(struct ieee80211_sub_if_data *sdata,
-				 struct sta_info *sta, unsigned int link_id,
+				 struct link_sta_info *link_sta,
 				 u8 opmode, enum nl80211_band band)
 {
 	struct ieee80211_local *local = sdata->local;
 	struct ieee80211_supported_band *sband = local->hw.wiphy->bands[band];
 
-	u32 changed = __ieee80211_vht_handle_opmode(sdata, sta, link_id,
+	u32 changed = __ieee80211_vht_handle_opmode(sdata, link_sta,
 						    opmode, band);
 
 	if (changed > 0) {
 		ieee80211_recalc_min_chandef(sdata);
-		rate_control_rate_update(local, sband, sta, link_id, changed);
+		rate_control_rate_update(local, sband, link_sta->sta,
+					 link_sta->link_id, changed);
 	}
 }