static void
 mt7996_mcu_sta_mld_setup_tlv(struct mt7996_dev *dev, struct sk_buff *skb,
+                            struct ieee80211_vif *vif,
                             struct ieee80211_sta *sta)
 {
        struct mt7996_sta *msta = (struct mt7996_sta *)sta->drv_priv;
-       unsigned long links = sta->valid_links;
-       unsigned int nlinks = hweight16(links);
+       unsigned int nlinks = hweight16(sta->valid_links);
        struct mld_setup_link *mld_setup_link;
+       struct ieee80211_link_sta *link_sta;
        struct sta_rec_mld_setup *mld_setup;
        struct mt7996_sta_link *msta_link;
-       struct ieee80211_vif *vif;
        unsigned int link_id;
        struct tlv *tlv;
 
        mld_setup->primary_id = cpu_to_le16(msta_link->wcid.idx);
 
        if (nlinks > 1) {
-               link_id = __ffs(links & ~BIT(msta->deflink_id));
+               link_id = __ffs(sta->valid_links & ~BIT(msta->deflink_id));
                msta_link = mt76_dereference(msta->link[link_id], &dev->mt76);
                if (!msta_link)
                        return;
        mld_setup->seconed_id = cpu_to_le16(msta_link->wcid.idx);
        mld_setup->link_num = nlinks;
 
-       vif = container_of((void *)msta->vif, struct ieee80211_vif, drv_priv);
        mld_setup_link = (struct mld_setup_link *)mld_setup->link_info;
-       for_each_set_bit(link_id, &links, IEEE80211_MLD_MAX_NUM_LINKS) {
+       for_each_sta_active_link(vif, sta, link_sta, link_id) {
                struct mt7996_vif_link *link;
 
                msta_link = mt76_dereference(msta->link[link_id], &dev->mt76);
                mt7996_mcu_sta_muru_tlv(dev, skb, link_conf, link_sta);
 
                if (sta->mlo) {
-                       mt7996_mcu_sta_mld_setup_tlv(dev, skb, sta);
+                       mt7996_mcu_sta_mld_setup_tlv(dev, skb, link_conf->vif,
+                                                    sta);
                        mt7996_mcu_sta_eht_mld_tlv(dev, skb, sta);
                }
        }