mvmvif->csa_countdown = false;
 
+               /* Set CS bit on all the stations */
+               iwl_mvm_modify_all_sta_disable_tx(mvm, mvmvif, true);
+
                /* Save blocked iface, the timeout is set on the next beacon */
                rcu_assign_pointer(mvm->csa_tx_blocked_vif, vif);
 
        if (vif->type == NL80211_IFTYPE_MONITOR)
                iwl_mvm_rm_snif_sta(mvm, vif);
 
-       if (vif->type == NL80211_IFTYPE_AP)
-               /* Set CS bit on all the stations */
-               iwl_mvm_modify_all_sta_disable_tx(mvm, mvmvif, true);
 
        if (vif->type == NL80211_IFTYPE_STATION && switching_chanctx) {
                disabled_vif = vif;
 
        if (vif->type == NL80211_IFTYPE_MONITOR)
                iwl_mvm_mld_rm_snif_sta(mvm, vif);
 
-       if (vif->type == NL80211_IFTYPE_AP)
-               /* Set CS bit on all the stations */
-               iwl_mvm_mld_modify_all_sta_disable_tx(mvm, mvmvif, true);
-
        /* Link needs to be deactivated before removal */
        iwl_mvm_link_changed(mvm, vif, LINK_CONTEXT_MODIFY_ACTIVE, false);
        iwl_mvm_remove_link(mvm, vif);
 
        return ret;
 }
 
-static void iwl_mvm_mld_sta_modify_disable_tx(struct iwl_mvm *mvm,
-                                             struct ieee80211_sta *sta,
-                                             bool disable)
+void iwl_mvm_mld_sta_modify_disable_tx(struct iwl_mvm *mvm,
+                                      struct iwl_mvm_sta *mvmsta,
+                                      bool disable)
 {
-       struct iwl_mvm_sta *mvm_sta = iwl_mvm_sta_from_mac80211(sta);
        struct iwl_mvm_sta_disable_tx_cmd cmd;
        int ret;
 
-       spin_lock_bh(&mvm_sta->lock);
-
-       if (mvm_sta->disable_tx == disable) {
-               spin_unlock_bh(&mvm_sta->lock);
-               return;
-       }
-
-       mvm_sta->disable_tx = disable;
-
-       cmd.sta_id = cpu_to_le32(mvm_sta->deflink.sta_id);
+       cmd.sta_id = cpu_to_le32(mvmsta->deflink.sta_id);
        cmd.disable = cpu_to_le32(disable);
 
        ret = iwl_mvm_send_cmd_pdu(mvm,
                IWL_ERR(mvm,
                        "Failed to send STA_DISABLE_TX_CMD command (%d)\n",
                        ret);
+}
+
+void iwl_mvm_mld_sta_modify_disable_tx_ap(struct iwl_mvm *mvm,
+                                         struct ieee80211_sta *sta,
+                                         bool disable)
+{
+       struct iwl_mvm_sta *mvm_sta = iwl_mvm_sta_from_mac80211(sta);
+
+       spin_lock_bh(&mvm_sta->lock);
+
+       if (mvm_sta->disable_tx == disable) {
+               spin_unlock_bh(&mvm_sta->lock);
+               return;
+       }
+
+       iwl_mvm_mld_sta_modify_disable_tx(mvm, mvm_sta, disable);
 
        spin_unlock_bh(&mvm_sta->lock);
 }
                    FW_CMD_ID_AND_COLOR(mvmvif->id, mvmvif->color))
                        continue;
 
-               iwl_mvm_mld_sta_modify_disable_tx(mvm, sta, disable);
+               iwl_mvm_mld_sta_modify_disable_tx(mvm, mvm_sta, disable);
        }
 
        rcu_read_unlock();
 
 }
 
 void iwl_mvm_sta_modify_disable_tx(struct iwl_mvm *mvm,
-                                  struct iwl_mvm_sta *mvmsta, bool disable)
+                                  struct iwl_mvm_sta *mvmsta,
+                                  bool disable)
 {
        struct iwl_mvm_add_sta_cmd cmd = {
                .add_modify = STA_MODE_MODIFY,
        };
        int ret;
 
+       if (mvm->mld_api_is_used) {
+               iwl_mvm_mld_sta_modify_disable_tx(mvm, mvmsta, disable);
+               return;
+       }
+
        ret = iwl_mvm_send_cmd_pdu(mvm, ADD_STA, CMD_ASYNC,
                                   iwl_mvm_add_sta_cmd_size(mvm), &cmd);
        if (ret)
 {
        struct iwl_mvm_sta *mvm_sta = iwl_mvm_sta_from_mac80211(sta);
 
+       if (mvm->mld_api_is_used) {
+               iwl_mvm_mld_sta_modify_disable_tx_ap(mvm, sta, disable);
+               return;
+       }
+
        spin_lock_bh(&mvm_sta->lock);
 
        if (mvm_sta->disable_tx == disable) {
        struct iwl_mvm_sta *mvm_sta;
        int i;
 
+       if (mvm->mld_api_is_used) {
+               iwl_mvm_mld_modify_all_sta_disable_tx(mvm, mvmvif, disable);
+               return;
+       }
+
        rcu_read_lock();
 
        /* Block/unblock all the stations of the given mvmvif */
 
 void iwl_mvm_mld_modify_all_sta_disable_tx(struct iwl_mvm *mvm,
                                           struct iwl_mvm_vif *mvmvif,
                                           bool disable);
+void iwl_mvm_mld_sta_modify_disable_tx(struct iwl_mvm *mvm,
+                                      struct iwl_mvm_sta *mvm_sta,
+                                      bool disable);
+void iwl_mvm_mld_sta_modify_disable_tx_ap(struct iwl_mvm *mvm,
+                                         struct ieee80211_sta *sta,
+                                         bool disable);
 #endif /* __sta_h__ */