bool have_sta = false;
        bool mlo;
        int err;
+       u16 new_links;
 
        if (link_id >= 0) {
                mlo = true;
                if (WARN_ON(!ap_mld_addr))
                        return -EINVAL;
-               err = ieee80211_vif_set_links(sdata, BIT(link_id), 0);
+               new_links = BIT(link_id);
        } else {
                if (WARN_ON(ap_mld_addr))
                        return -EINVAL;
                ap_mld_addr = cbss->bssid;
-               err = ieee80211_vif_set_links(sdata, 0, 0);
+               new_links = 0;
                link_id = 0;
                mlo = false;
        }
 
+       if (assoc) {
+               rcu_read_lock();
+               have_sta = sta_info_get(sdata, ap_mld_addr);
+               rcu_read_unlock();
+       }
+
+       if (mlo && !have_sta &&
+           WARN_ON(sdata->vif.valid_links || sdata->vif.active_links))
+               return -EINVAL;
+
+       err = ieee80211_vif_set_links(sdata, new_links, 0);
        if (err)
                return err;
 
                goto out_err;
        }
 
-       if (assoc) {
-               rcu_read_lock();
-               have_sta = sta_info_get(sdata, ap_mld_addr);
-               rcu_read_unlock();
-       }
-
        if (!have_sta) {
                if (mlo)
                        new_sta = sta_info_alloc_with_link(sdata, ap_mld_addr,