.write = ath9k_iowrite32,
 };
 
+static int count_streams(unsigned int chainmask, int max)
+{
+       int streams = 0;
+
+       do {
+               if (++streams == max)
+                       break;
+       } while ((chainmask = chainmask & (chainmask - 1)));
+
+       return streams;
+}
+
 /**************************/
 /*     Initialization     */
 /**************************/
 static void setup_ht_cap(struct ath_softc *sc,
                         struct ieee80211_sta_ht_cap *ht_info)
 {
-       struct ath_common *common = ath9k_hw_common(sc->sc_ah);
+       struct ath_hw *ah = sc->sc_ah;
+       struct ath_common *common = ath9k_hw_common(ah);
        u8 tx_streams, rx_streams;
+       int i, max_streams;
 
        ht_info->ht_supported = true;
        ht_info->cap = IEEE80211_HT_CAP_SUP_WIDTH_20_40 |
        ht_info->ampdu_factor = IEEE80211_HT_MAX_AMPDU_64K;
        ht_info->ampdu_density = IEEE80211_HT_MPDU_DENSITY_8;
 
+       if (AR_SREV_9300_20_OR_LATER(ah))
+               max_streams = 3;
+       else
+               max_streams = 2;
+
        /* set up supported mcs set */
        memset(&ht_info->mcs, 0, sizeof(ht_info->mcs));
-       tx_streams = !(common->tx_chainmask & (common->tx_chainmask - 1)) ?
-                    1 : 2;
-       rx_streams = !(common->rx_chainmask & (common->rx_chainmask - 1)) ?
-                    1 : 2;
+       tx_streams = count_streams(common->tx_chainmask, max_streams);
+       rx_streams = count_streams(common->rx_chainmask, max_streams);
+
+       ath_print(common, ATH_DBG_CONFIG,
+                 "TX streams %d, RX streams: %d\n",
+                 tx_streams, rx_streams);
 
        if (tx_streams != rx_streams) {
-               ath_print(common, ATH_DBG_CONFIG,
-                         "TX streams %d, RX streams: %d\n",
-                         tx_streams, rx_streams);
                ht_info->mcs.tx_params |= IEEE80211_HT_MCS_TX_RX_DIFF;
                ht_info->mcs.tx_params |= ((tx_streams - 1) <<
                                IEEE80211_HT_MCS_TX_MAX_STREAMS_SHIFT);
        }
 
-       ht_info->mcs.rx_mask[0] = 0xff;
-       if (rx_streams >= 2)
-               ht_info->mcs.rx_mask[1] = 0xff;
+       for (i = 0; i < rx_streams; i++)
+               ht_info->mcs.rx_mask[i] = 0xff;
 
        ht_info->mcs.tx_params |= IEEE80211_HT_MCS_TX_DEFINED;
 }