{ S8_MIN, IWL_RATE_MCS_0_INDEX },
 };
 
+#define IWL_RS_LOW_RSSI_THRESHOLD (-76) /* dBm */
+
 /* Init the optimal rate based on STA caps
  * This combined with rssi is used to report the last tx rate
  * to userspace when we haven't transmitted enough frames.
  * of last Rx
  */
 static void rs_get_initial_rate(struct iwl_mvm *mvm,
+                               struct ieee80211_sta *sta,
                                struct iwl_lq_sta *lq_sta,
                                enum ieee80211_band band,
                                struct rs_rate *rate)
 {
        int i, nentries;
+       unsigned long active_rate;
        s8 best_rssi = S8_MIN;
        u8 best_ant = ANT_NONE;
        u8 valid_tx_ant = iwl_mvm_get_valid_tx_ant(mvm);
                nentries = ARRAY_SIZE(rs_optimal_rates_24ghz_legacy);
        }
 
-       if (IWL_MVM_RS_RSSI_BASED_INIT_RATE) {
-               for (i = 0; i < nentries; i++) {
-                       int rate_idx = initial_rates[i].rate_idx;
-                       if ((best_rssi >= initial_rates[i].rssi) &&
-                           (BIT(rate_idx) & lq_sta->active_legacy_rate)) {
-                               rate->index = rate_idx;
-                               break;
-                       }
+       if (!IWL_MVM_RS_RSSI_BASED_INIT_RATE)
+               goto out;
+
+       /* Start from a higher rate if the corresponding debug capability
+        * is enabled. The rate is chosen according to AP capabilities.
+        * In case of VHT/HT when the rssi is low fallback to the case of
+        * legacy rates.
+        */
+       if (sta->vht_cap.vht_supported &&
+           best_rssi > IWL_RS_LOW_RSSI_THRESHOLD) {
+               if (sta->bandwidth >= IEEE80211_STA_RX_BW_40) {
+                       initial_rates = rs_optimal_rates_vht_40_80mhz;
+                       nentries = ARRAY_SIZE(rs_optimal_rates_vht_40_80mhz);
+                       if (sta->bandwidth >= IEEE80211_STA_RX_BW_80)
+                               rate->bw = RATE_MCS_CHAN_WIDTH_80;
+                       else
+                               rate->bw = RATE_MCS_CHAN_WIDTH_40;
+               } else if (sta->bandwidth == IEEE80211_STA_RX_BW_20) {
+                       initial_rates = rs_optimal_rates_vht_20mhz;
+                       nentries = ARRAY_SIZE(rs_optimal_rates_vht_20mhz);
+                       rate->bw = RATE_MCS_CHAN_WIDTH_20;
+               } else {
+                       IWL_ERR(mvm, "Invalid BW %d\n", sta->bandwidth);
+                       goto out;
                }
+               active_rate = lq_sta->active_siso_rate;
+               rate->type = LQ_VHT_SISO;
+       } else if (sta->ht_cap.ht_supported &&
+                  best_rssi > IWL_RS_LOW_RSSI_THRESHOLD) {
+               initial_rates = rs_optimal_rates_ht;
+               nentries = ARRAY_SIZE(rs_optimal_rates_ht);
+               active_rate = lq_sta->active_siso_rate;
+               rate->type = LQ_HT_SISO;
+       } else {
+               active_rate = lq_sta->active_legacy_rate;
        }
 
-       IWL_DEBUG_RATE(mvm, "rate_idx %d ANT %s\n", rate->index,
-                      rs_pretty_ant(rate->ant));
+       for (i = 0; i < nentries; i++) {
+               int rate_idx = initial_rates[i].rate_idx;
+
+               if ((best_rssi >= initial_rates[i].rssi) &&
+                   (BIT(rate_idx) & active_rate)) {
+                       rate->index = rate_idx;
+                       break;
+               }
+       }
+
+out:
+       rs_dump_rate(mvm, rate, "INITIAL");
 }
 
 /* Save info about RSSI of last Rx */
        tbl = &(lq_sta->lq_info[active_tbl]);
        rate = &tbl->rate;
 
-       rs_get_initial_rate(mvm, lq_sta, band, rate);
+       rs_get_initial_rate(mvm, sta, lq_sta, band, rate);
        rs_init_optimal_rate(mvm, sta, lq_sta);
 
        WARN_ON_ONCE(rate->ant != ANT_A && rate->ant != ANT_B);
-       if (rate->ant == ANT_A)
-               tbl->column = RS_COLUMN_LEGACY_ANT_A;
-       else
-               tbl->column = RS_COLUMN_LEGACY_ANT_B;
+       tbl->column = rs_get_column_from_rate(rate);
 
        rs_set_expected_tpt_table(lq_sta, tbl);
        rs_fill_lq_cmd(mvm, sta, lq_sta, rate);