return cmd.handler_status;
 }
 
-static void iwl_set_ht_add_station(struct iwl_priv *priv, u8 index,
-                                  struct ieee80211_sta *sta,
-                                  struct iwl_rxon_context *ctx)
+static void iwl_sta_calc_ht_flags(struct iwl_priv *priv,
+                                 struct ieee80211_sta *sta,
+                                 struct iwl_rxon_context *ctx,
+                                 __le32 *flags, __le32 *mask)
 {
        struct ieee80211_sta_ht_cap *sta_ht_inf = &sta->ht_cap;
-       __le32 sta_flags;
        u8 mimo_ps_mode;
 
+       *mask = STA_FLG_RTS_MIMO_PROT_MSK |
+               STA_FLG_MIMO_DIS_MSK |
+               STA_FLG_HT40_EN_MSK |
+               STA_FLG_MAX_AGG_SIZE_MSK |
+               STA_FLG_AGG_MPDU_DENSITY_MSK;
+       *flags = 0;
+
        if (!sta || !sta_ht_inf->ht_supported)
-               goto done;
+               return;
 
        mimo_ps_mode = (sta_ht_inf->cap & IEEE80211_HT_CAP_SM_PS) >> 2;
-       IWL_DEBUG_ASSOC(priv, "spatial multiplexing power save mode: %s\n",
+
+       IWL_DEBUG_INFO(priv, "STA %pM SM PS mode: %s\n",
                        (mimo_ps_mode == WLAN_HT_CAP_SM_PS_STATIC) ?
                        "static" :
                        (mimo_ps_mode == WLAN_HT_CAP_SM_PS_DYNAMIC) ?
                        "dynamic" : "disabled");
 
-       sta_flags = priv->stations[index].sta.station_flags;
-
-       sta_flags &= ~(STA_FLG_RTS_MIMO_PROT_MSK | STA_FLG_MIMO_DIS_MSK);
-
        switch (mimo_ps_mode) {
        case WLAN_HT_CAP_SM_PS_STATIC:
-               sta_flags |= STA_FLG_MIMO_DIS_MSK;
+               *flags |= STA_FLG_MIMO_DIS_MSK;
                break;
        case WLAN_HT_CAP_SM_PS_DYNAMIC:
-               sta_flags |= STA_FLG_RTS_MIMO_PROT_MSK;
+               *flags |= STA_FLG_RTS_MIMO_PROT_MSK;
                break;
        case WLAN_HT_CAP_SM_PS_DISABLED:
                break;
                break;
        }
 
-       sta_flags |= cpu_to_le32(
-             (u32)sta_ht_inf->ampdu_factor << STA_FLG_MAX_AGG_SIZE_POS);
+       *flags |= cpu_to_le32(
+               (u32)sta_ht_inf->ampdu_factor << STA_FLG_MAX_AGG_SIZE_POS);
 
-       sta_flags |= cpu_to_le32(
-             (u32)sta_ht_inf->ampdu_density << STA_FLG_AGG_MPDU_DENSITY_POS);
+       *flags |= cpu_to_le32(
+               (u32)sta_ht_inf->ampdu_density << STA_FLG_AGG_MPDU_DENSITY_POS);
 
        if (iwl_is_ht40_tx_allowed(priv, ctx, &sta->ht_cap))
-               sta_flags |= STA_FLG_HT40_EN_MSK;
-       else
-               sta_flags &= ~STA_FLG_HT40_EN_MSK;
+               *flags |= STA_FLG_HT40_EN_MSK;
+}
+
+int iwl_sta_update_ht(struct iwl_priv *priv, struct iwl_rxon_context *ctx,
+                     struct ieee80211_sta *sta)
+{
+       u8 sta_id = iwl_sta_id(sta);
+       __le32 flags, mask;
+       struct iwl_addsta_cmd cmd;
+
+       if (WARN_ON_ONCE(sta_id == IWL_INVALID_STATION))
+               return -EINVAL;
 
-       priv->stations[index].sta.station_flags = sta_flags;
- done:
-       return;
+       iwl_sta_calc_ht_flags(priv, sta, ctx, &flags, &mask);
+
+       spin_lock_bh(&priv->sta_lock);
+       priv->stations[sta_id].sta.station_flags &= ~mask;
+       priv->stations[sta_id].sta.station_flags |= flags;
+       spin_unlock_bh(&priv->sta_lock);
+
+       memset(&cmd, 0, sizeof(cmd));
+       cmd.mode = STA_CONTROL_MODIFY_MSK;
+       cmd.station_flags_msk = mask;
+       cmd.station_flags = flags;
+       cmd.sta.sta_id = sta_id;
+
+       return iwl_send_add_sta(priv, &cmd, CMD_SYNC);
+}
+
+static void iwl_set_ht_add_station(struct iwl_priv *priv, u8 index,
+                                  struct ieee80211_sta *sta,
+                                  struct iwl_rxon_context *ctx)
+{
+       __le32 flags, mask;
+
+       iwl_sta_calc_ht_flags(priv, sta, ctx, &flags, &mask);
+
+       lockdep_assert_held(&priv->sta_lock);
+       priv->stations[index].sta.station_flags &= ~mask;
+       priv->stations[index].sta.station_flags |= flags;
 }
 
 /**
 
                                enum ieee80211_sta_state new_state)
 {
        struct iwl_priv *priv = IWL_MAC80211_GET_DVM(hw);
+       struct iwl_vif_priv *vif_priv = (void *)vif->drv_priv;
        enum {
-               NONE, ADD, REMOVE, RATE_INIT, ADD_RATE_INIT,
+               NONE, ADD, REMOVE, HT_RATE_INIT, ADD_RATE_INIT,
        } op = NONE;
        int ret;
 
                        op = REMOVE;
                else if (old_state == IEEE80211_STA_AUTH &&
                         new_state == IEEE80211_STA_ASSOC)
-                       op = RATE_INIT;
+                       op = HT_RATE_INIT;
        } else {
                if (old_state == IEEE80211_STA_AUTH &&
                    new_state == IEEE80211_STA_ASSOC)
                ret = iwlagn_mac_sta_add(hw, vif, sta);
                if (ret)
                        break;
-               /* fall through */
-       case RATE_INIT:
                /* Initialize rate scaling */
                IWL_DEBUG_INFO(priv,
                               "Initializing rate scaling for station %pM\n",
                iwl_rs_rate_init(priv, sta, iwl_sta_id(sta));
                ret = 0;
                break;
+       case HT_RATE_INIT:
+               /* Initialize rate scaling */
+               ret = iwl_sta_update_ht(priv, vif_priv->ctx, sta);
+               if (ret)
+                       break;
+               IWL_DEBUG_INFO(priv,
+                              "Initializing rate scaling for station %pM\n",
+                              sta->addr);
+               iwl_rs_rate_init(priv, sta, iwl_sta_id(sta));
+               ret = 0;
+               break;
        default:
                ret = 0;
                break;