/**
  * struct iwl_mvm_baid_data - BA session data
- * @sta_id: station id
+ * @sta_mask: current station mask for the BAID
  * @tid: tid of the session
  * @baid baid of the session
  * @timeout: the timeout set in the addba request
  */
 struct iwl_mvm_baid_data {
        struct rcu_head rcu_head;
-       u8 sta_id;
+       u32 sta_mask;
        u8 tid;
        u8 baid;
        u16 timeout;
 
        if (expired) {
                struct ieee80211_sta *sta;
                struct iwl_mvm_sta *mvmsta;
-               u8 sta_id = baid_data->sta_id;
+               u8 sta_id = ffs(baid_data->sta_mask) - 1;
 
                rcu_read_lock();
                sta = rcu_dereference(buf->mvm->fw_id_to_mac_id[sta_id]);
        struct ieee80211_sta *sta;
        struct iwl_mvm_reorder_buffer *reorder_buf;
        u8 baid = data->baid;
+       u32 sta_id;
 
        if (WARN_ONCE(baid >= IWL_MAX_BAID, "invalid BAID: %x\n", baid))
                return;
        if (WARN_ON_ONCE(!ba_data))
                goto out;
 
-       sta = rcu_dereference(mvm->fw_id_to_mac_id[ba_data->sta_id]);
+       /* pick any STA ID to find the pointer */
+       sta_id = ffs(ba_data->sta_mask) - 1;
+       sta = rcu_dereference(mvm->fw_id_to_mac_id[sta_id]);
        if (WARN_ON_ONCE(IS_ERR_OR_NULL(sta)))
                goto out;
 
        struct ieee80211_sta *sta;
        struct iwl_mvm_reorder_buffer *reorder_buf;
        struct iwl_mvm_baid_data *ba_data;
+       u32 sta_id;
 
        IWL_DEBUG_HT(mvm, "Frame release notification for BAID %u, NSSN %d\n",
                     baid, nssn);
                goto out;
        }
 
-       sta = rcu_dereference(mvm->fw_id_to_mac_id[ba_data->sta_id]);
+       /* pick any STA ID to find the pointer */
+       sta_id = ffs(ba_data->sta_mask) - 1;
+       sta = rcu_dereference(mvm->fw_id_to_mac_id[sta_id]);
        if (WARN_ON_ONCE(IS_ERR_OR_NULL(sta)))
                goto out;
 
 {
        struct ieee80211_rx_status *rx_status = IEEE80211_SKB_RXCB(skb);
        struct ieee80211_hdr *hdr = (void *)skb_mac_header(skb);
-       struct iwl_mvm_sta *mvm_sta;
        struct iwl_mvm_baid_data *baid_data;
        struct iwl_mvm_reorder_buffer *buffer;
        struct sk_buff *tail;
        u8 sub_frame_idx = desc->amsdu_info &
                           IWL_RX_MPDU_AMSDU_SUBFRAME_IDX_MASK;
        struct iwl_mvm_reorder_buf_entry *entries;
+       u32 sta_mask;
        int index;
        u16 nssn, sn;
        u8 baid;
                      "Got valid BAID without a valid station assigned\n"))
                return false;
 
-       mvm_sta = iwl_mvm_sta_from_mac80211(sta);
-
        /* not a data packet or a bar */
        if (!ieee80211_is_back_req(hdr->frame_control) &&
            (!ieee80211_is_data_qos(hdr->frame_control) ||
                return false;
        }
 
+       rcu_read_lock();
+       sta_mask = iwl_mvm_sta_fw_id_mask(mvm, sta, -1);
+       rcu_read_unlock();
+
        if (WARN(tid != baid_data->tid ||
-                mvm_sta->deflink.sta_id != baid_data->sta_id,
-                "baid 0x%x is mapped to sta:%d tid:%d, but was received for sta:%d tid:%d\n",
-                baid, baid_data->sta_id, baid_data->tid, mvm_sta->deflink.sta_id,
-                tid))
+                !(sta_mask & baid_data->sta_mask),
+                "baid 0x%x is mapped to sta_mask:0x%x tid:%d, but was received for sta_mask:0x%x tid:%d\n",
+                baid, baid_data->sta_mask, baid_data->tid, sta_mask, tid))
                return false;
 
        nssn = reorder & IWL_RX_MPDU_REORDER_NSSN_MASK;
                goto out;
        }
 
-       if (WARN(tid != baid_data->tid || sta_id != baid_data->sta_id,
-                "baid 0x%x is mapped to sta:%d tid:%d, but BAR release received for sta:%d tid:%d\n",
-                baid, baid_data->sta_id, baid_data->tid, sta_id,
+       if (WARN(tid != baid_data->tid || sta_id > IWL_MVM_STATION_COUNT_MAX ||
+                !(baid_data->sta_mask & BIT(sta_id)),
+                "baid 0x%x is mapped to sta_mask:0x%x tid:%d, but BAR release received for sta:%d tid:%d\n",
+                baid, baid_data->sta_mask, baid_data->tid, sta_id,
                 tid))
                goto out;
 
 
        struct ieee80211_sta *sta;
        struct iwl_mvm_sta *mvm_sta;
        unsigned long timeout;
+       unsigned int sta_id;
 
        rcu_read_lock();
 
        }
 
        /* Timer expired */
-       sta = rcu_dereference(ba_data->mvm->fw_id_to_mac_id[ba_data->sta_id]);
+       sta_id = ffs(ba_data->sta_mask) - 1; /* don't care which one */
+       sta = rcu_dereference(ba_data->mvm->fw_id_to_mac_id[sta_id]);
 
        /*
         * sta should be valid unless the following happens:
 }
 
 static int iwl_mvm_fw_baid_op_sta(struct iwl_mvm *mvm,
-                                 struct iwl_mvm_sta *mvm_sta,
+                                 struct ieee80211_sta *sta,
                                  bool start, int tid, u16 ssn,
                                  u16 buf_size)
 {
+       struct iwl_mvm_sta *mvm_sta = iwl_mvm_sta_from_mac80211(sta);
        struct iwl_mvm_add_sta_cmd cmd = {
                .mac_id_n_color = cpu_to_le32(mvm_sta->mac_id_n_color),
                .sta_id = mvm_sta->deflink.sta_id,
 }
 
 static int iwl_mvm_fw_baid_op_cmd(struct iwl_mvm *mvm,
-                                 struct iwl_mvm_sta *mvm_sta,
+                                 struct ieee80211_sta *sta,
                                  bool start, int tid, u16 ssn,
                                  u16 buf_size, int baid)
 {
 
        if (start) {
                cmd.alloc.sta_id_mask =
-                       cpu_to_le32(BIT(mvm_sta->deflink.sta_id));
+                       cpu_to_le32(iwl_mvm_sta_fw_id_mask(mvm, sta, -1));
                cmd.alloc.tid = tid;
                cmd.alloc.ssn = cpu_to_le16(ssn);
                cmd.alloc.win_size = cpu_to_le16(buf_size);
                BUILD_BUG_ON(sizeof(cmd.remove_v1) > sizeof(cmd.remove));
        } else {
                cmd.remove.sta_id_mask =
-                       cpu_to_le32(BIT(mvm_sta->deflink.sta_id));
+                       cpu_to_le32(iwl_mvm_sta_fw_id_mask(mvm, sta, -1));
                cmd.remove.tid = cpu_to_le32(tid);
        }
 
        return baid;
 }
 
-static int iwl_mvm_fw_baid_op(struct iwl_mvm *mvm, struct iwl_mvm_sta *mvm_sta,
+static int iwl_mvm_fw_baid_op(struct iwl_mvm *mvm, struct ieee80211_sta *sta,
                              bool start, int tid, u16 ssn, u16 buf_size,
                              int baid)
 {
        if (fw_has_capa(&mvm->fw->ucode_capa,
                        IWL_UCODE_TLV_CAPA_BAID_ML_SUPPORT))
-               return iwl_mvm_fw_baid_op_cmd(mvm, mvm_sta, start,
+               return iwl_mvm_fw_baid_op_cmd(mvm, sta, start,
                                              tid, ssn, buf_size, baid);
 
-       return iwl_mvm_fw_baid_op_sta(mvm, mvm_sta, start,
+       return iwl_mvm_fw_baid_op_sta(mvm, sta, start,
                                      tid, ssn, buf_size);
 }
 
 
        /* Don't send command to remove (start=0) BAID during restart */
        if (start || !test_bit(IWL_MVM_STATUS_IN_HW_RESTART, &mvm->status))
-               baid = iwl_mvm_fw_baid_op(mvm, mvm_sta, start, tid, ssn, buf_size,
+               baid = iwl_mvm_fw_baid_op(mvm, sta, start, tid, ssn, buf_size,
                                          baid);
 
        if (baid < 0) {
                            iwl_mvm_rx_agg_session_expired, 0);
                baid_data->mvm = mvm;
                baid_data->tid = tid;
-               baid_data->sta_id = mvm_sta->deflink.sta_id;
+               baid_data->sta_mask = iwl_mvm_sta_fw_id_mask(mvm, sta, -1);
 
                mvm_sta->tid_to_baid[tid] = baid;
                if (timeout)