lq_sta->visited_columns = 0;
 }
 
+static int rs_get_max_allowed_rate(struct iwl_lq_sta *lq_sta,
+                                  const struct rs_tx_column *column)
+{
+       switch (column->mode) {
+       case RS_LEGACY:
+               return lq_sta->max_legacy_rate_idx;
+       case RS_SISO:
+               return lq_sta->max_siso_rate_idx;
+       case RS_MIMO2:
+               return lq_sta->max_mimo2_rate_idx;
+       default:
+               WARN_ON_ONCE(1);
+       }
+
+       return lq_sta->max_legacy_rate_idx;
+}
+
 static const u16 *rs_get_expected_tpt_table(struct iwl_lq_sta *lq_sta,
-                                     const struct rs_tx_column *column,
-                                     u32 bw)
+                                           const struct rs_tx_column *column,
+                                           u32 bw)
 {
        /* Used to choose among HT tables */
        const u16 (*ht_tbl_pointer)[IWL_RATE_COUNT];
                                         struct ieee80211_sta *sta,
                                         struct iwl_scale_tbl_info *tbl)
 {
-       int i, j, n;
+       int i, j, max_rate;
        enum rs_column next_col_id;
        const struct rs_tx_column *curr_col = &rs_tx_columns[tbl->column];
        const struct rs_tx_column *next_col;
        allow_column_func_t allow_func;
        u8 valid_ants = mvm->fw->valid_tx_ant;
        const u16 *expected_tpt_tbl;
-       s32 tpt, max_expected_tpt;
+       u16 tpt, max_expected_tpt;
 
        for (i = 0; i < MAX_NEXT_COLUMNS; i++) {
                next_col_id = curr_col->next_columns[i];
                if (WARN_ON_ONCE(!expected_tpt_tbl))
                        continue;
 
-               max_expected_tpt = 0;
-               for (n = 0; n < IWL_RATE_COUNT; n++)
-                       if (expected_tpt_tbl[n] > max_expected_tpt)
-                               max_expected_tpt = expected_tpt_tbl[n];
+               max_rate = rs_get_max_allowed_rate(lq_sta, next_col);
+               if (WARN_ON_ONCE(max_rate == IWL_RATE_INVALID))
+                       continue;
 
+               max_expected_tpt = expected_tpt_tbl[max_rate];
                if (tpt >= max_expected_tpt) {
                        IWL_DEBUG_RATE(mvm,
                                       "Skip column %d: can't beat current TPT. Max expected %d current %d\n",
                        continue;
                }
 
+               IWL_DEBUG_RATE(mvm,
+                              "Found potential column %d. Max expected %d current %d\n",
+                              next_col_id, max_expected_tpt, tpt);
                break;
        }
 
        if (i == MAX_NEXT_COLUMNS)
                return RS_COLUMN_INVALID;
 
-       IWL_DEBUG_RATE(mvm, "Found potential column %d\n", next_col_id);
-
        return next_col_id;
 }
 
                lq_sta->is_vht = true;
        }
 
-       IWL_DEBUG_RATE(mvm,
-                      "SISO-RATE=%X MIMO2-RATE=%X VHT=%d\n",
+       lq_sta->max_legacy_rate_idx = find_last_bit(&lq_sta->active_legacy_rate,
+                                                   BITS_PER_LONG);
+       lq_sta->max_siso_rate_idx = find_last_bit(&lq_sta->active_siso_rate,
+                                                 BITS_PER_LONG);
+       lq_sta->max_mimo2_rate_idx = find_last_bit(&lq_sta->active_mimo2_rate,
+                                                  BITS_PER_LONG);
+
+       IWL_DEBUG_RATE(mvm, "RATE MASK: LEGACY=%lX SISO=%lX MIMO2=%lX VHT=%d\n",
+                      lq_sta->active_legacy_rate,
                       lq_sta->active_siso_rate,
                       lq_sta->active_mimo2_rate,
                       lq_sta->is_vht);
+       IWL_DEBUG_RATE(mvm, "MAX RATE: LEGACY=%d SISO=%d MIMO2=%d\n",
+                      lq_sta->max_legacy_rate_idx,
+                      lq_sta->max_siso_rate_idx,
+                      lq_sta->max_mimo2_rate_idx);
 
        /* These values will be overridden later */
        lq_sta->lq.single_stream_ant_msk =
                return -ENOMEM;
 
        desc += sprintf(buff+desc, "sta_id %d\n", lq_sta->lq.sta_id);
-       desc += sprintf(buff+desc, "failed=%d success=%d rate=0%X\n",
+       desc += sprintf(buff+desc, "failed=%d success=%d rate=0%lX\n",
                        lq_sta->total_failed, lq_sta->total_success,
                        lq_sta->active_legacy_rate);
        desc += sprintf(buff+desc, "fixed rate 0x%X\n",