struct mt76_vif_link *mlink)
 {
        struct mt7996_vif_link *link = container_of(mlink, struct mt7996_vif_link, mt76);
+       struct mt7996_vif *mvif = (struct mt7996_vif *)vif->drv_priv;
        struct mt7996_sta_link *msta_link = &link->msta_link;
        struct mt7996_phy *phy = mphy->priv;
        struct mt7996_dev *dev = phy->dev;
 
        ieee80211_iter_keys(mphy->hw, vif, mt7996_key_iter, NULL);
 
+       if (mvif->mt76.deflink_id == IEEE80211_LINK_UNSPECIFIED)
+               mvif->mt76.deflink_id = link_conf->link_id;
+
        return 0;
 }
 
                            struct mt76_vif_link *mlink)
 {
        struct mt7996_vif_link *link = container_of(mlink, struct mt7996_vif_link, mt76);
+       struct mt7996_vif *mvif = (struct mt7996_vif *)vif->drv_priv;
        struct mt7996_sta_link *msta_link = &link->msta_link;
        struct mt7996_phy *phy = mphy->priv;
        struct mt7996_dev *dev = phy->dev;
 
        rcu_assign_pointer(dev->mt76.wcid[idx], NULL);
 
+       if (mvif->mt76.deflink_id == link_conf->link_id) {
+               struct ieee80211_bss_conf *iter;
+               unsigned int link_id;
+
+               mvif->mt76.deflink_id = IEEE80211_LINK_UNSPECIFIED;
+               for_each_vif_active_link(vif, iter, link_id) {
+                       if (link_id != IEEE80211_LINK_UNSPECIFIED) {
+                               mvif->mt76.deflink_id = link_id;
+                               break;
+                       }
+               }
+       }
+
        dev->mt76.vif_mask &= ~BIT_ULL(mlink->idx);
        phy->omac_mask &= ~BIT_ULL(mlink->omac_idx);
 
        mt76_vif_init(vif, &mvif->mt76);
 
        vif->offload_flags |= IEEE80211_OFFLOAD_ENCAP_4ADDR;
+       mvif->mt76.deflink_id = IEEE80211_LINK_UNSPECIFIED;
 
 out:
        mutex_unlock(&dev->mt76.mutex);