#include "wme.h"
 
 static struct ieee80211_link_data *
-ieee80211_link_or_deflink(struct ieee80211_sub_if_data *sdata, int link_id)
+ieee80211_link_or_deflink(struct ieee80211_sub_if_data *sdata, int link_id,
+                         bool require_valid)
 {
        struct ieee80211_link_data *link;
 
        if (link_id < 0) {
-               if (sdata->vif.valid_links)
+               /*
+                * For keys, if sdata is not an MLD, we might not use
+                * the return value at all (if it's not a pairwise key),
+                * so in that case (require_valid==false) don't error.
+                */
+               if (require_valid && sdata->vif.valid_links)
                        return ERR_PTR(-EINVAL);
 
                return &sdata->deflink;
                             const u8 *mac_addr, struct key_params *params)
 {
        struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev);
+       struct ieee80211_link_data *link =
+               ieee80211_link_or_deflink(sdata, link_id, false);
        struct ieee80211_local *local = sdata->local;
        struct sta_info *sta = NULL;
        struct ieee80211_key *key;
        if (!ieee80211_sdata_running(sdata))
                return -ENETDOWN;
 
+       if (IS_ERR(link))
+               return PTR_ERR(link);
+
        if (pairwise && params->mode == NL80211_KEY_SET_TX)
                return ieee80211_set_tx(sdata, mac_addr, key_idx);
 
        case WLAN_CIPHER_SUITE_WEP40:
        case WLAN_CIPHER_SUITE_TKIP:
        case WLAN_CIPHER_SUITE_WEP104:
+               if (link_id >= 0)
+                       return -EINVAL;
                if (WARN_ON_ONCE(fips_enabled))
                        return -EINVAL;
                break;
        if (IS_ERR(key))
                return PTR_ERR(key);
 
+       key->conf.link_id = link_id;
+
        if (pairwise)
                key->conf.flags |= IEEE80211_KEY_FLAG_PAIRWISE;
 
                break;
        }
 
-       err = ieee80211_key_link(key, sdata, sta);
+       err = ieee80211_key_link(key, link, sta);
 
  out_unlock:
        mutex_unlock(&local->sta_mtx);
 }
 
 static struct ieee80211_key *
-ieee80211_lookup_key(struct ieee80211_sub_if_data *sdata,
+ieee80211_lookup_key(struct ieee80211_sub_if_data *sdata, int link_id,
                     u8 key_idx, bool pairwise, const u8 *mac_addr)
 {
        struct ieee80211_local *local = sdata->local;
+       struct ieee80211_link_data *link = &sdata->deflink;
        struct ieee80211_key *key;
-       struct sta_info *sta;
+
+       if (link_id >= 0) {
+               link = rcu_dereference_check(sdata->link[link_id],
+                                            lockdep_is_held(&sdata->wdev.mtx));
+               if (!link)
+                       return NULL;
+       }
 
        if (mac_addr) {
+               struct sta_info *sta;
+               struct link_sta_info *link_sta;
+
                sta = sta_info_get_bss(sdata, mac_addr);
                if (!sta)
                        return NULL;
 
+               if (link_id >= 0) {
+                       link_sta = rcu_dereference_check(sta->link[link_id],
+                                                        lockdep_is_held(&local->sta_mtx));
+                       if (!link_sta)
+                               return NULL;
+               } else {
+                       link_sta = &sta->deflink;
+               }
+
                if (pairwise && key_idx < NUM_DEFAULT_KEYS)
                        return rcu_dereference_check_key_mtx(local,
                                                             sta->ptk[key_idx]);
                              NUM_DEFAULT_MGMT_KEYS +
                              NUM_DEFAULT_BEACON_KEYS)
                        return rcu_dereference_check_key_mtx(local,
-                                                            sta->deflink.gtk[key_idx]);
+                                                            link_sta->gtk[key_idx]);
 
                return NULL;
        }
                return rcu_dereference_check_key_mtx(local,
                                                     sdata->keys[key_idx]);
 
-       key = rcu_dereference_check_key_mtx(local, sdata->deflink.gtk[key_idx]);
+       key = rcu_dereference_check_key_mtx(local, link->gtk[key_idx]);
        if (key)
                return key;
 
        mutex_lock(&local->sta_mtx);
        mutex_lock(&local->key_mtx);
 
-       key = ieee80211_lookup_key(sdata, key_idx, pairwise, mac_addr);
+       key = ieee80211_lookup_key(sdata, link_id, key_idx, pairwise, mac_addr);
        if (!key) {
                ret = -ENOENT;
                goto out_unlock;
 
        rcu_read_lock();
 
-       key = ieee80211_lookup_key(sdata, key_idx, pairwise, mac_addr);
+       key = ieee80211_lookup_key(sdata, link_id, key_idx, pairwise, mac_addr);
        if (!key)
                goto out;
 
                                        bool multi)
 {
        struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev);
+       struct ieee80211_link_data *link =
+               ieee80211_link_or_deflink(sdata, link_id, false);
 
-       ieee80211_set_default_key(sdata, key_idx, uni, multi);
+       if (IS_ERR(link))
+               return PTR_ERR(link);
+
+       ieee80211_set_default_key(link, key_idx, uni, multi);
 
        return 0;
 }
                                             int link_id, u8 key_idx)
 {
        struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev);
+       struct ieee80211_link_data *link =
+               ieee80211_link_or_deflink(sdata, link_id, true);
 
-       ieee80211_set_default_mgmt_key(sdata, key_idx);
+       if (IS_ERR(link))
+               return PTR_ERR(link);
+
+       ieee80211_set_default_mgmt_key(link, key_idx);
 
        return 0;
 }
                                               int link_id, u8 key_idx)
 {
        struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev);
+       struct ieee80211_link_data *link =
+               ieee80211_link_or_deflink(sdata, link_id, true);
+
+       if (IS_ERR(link))
+               return PTR_ERR(link);
 
-       ieee80211_set_default_beacon_key(sdata, key_idx);
+       ieee80211_set_default_beacon_key(link, key_idx);
 
        return 0;
 }
        struct ieee80211_local *local = wiphy_priv(wiphy);
        struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev);
        struct ieee80211_link_data *link =
-               ieee80211_link_or_deflink(sdata, params->link_id);
+               ieee80211_link_or_deflink(sdata, params->link_id, true);
        struct ieee80211_tx_queue_params p;
 
        if (!local->ops->conf_tx)
 
        }
 }
 
-static void __ieee80211_set_default_key(struct ieee80211_sub_if_data *sdata,
+static void __ieee80211_set_default_key(struct ieee80211_link_data *link,
                                        int idx, bool uni, bool multi)
 {
+       struct ieee80211_sub_if_data *sdata = link->sdata;
        struct ieee80211_key *key = NULL;
 
        assert_key_lock(sdata->local);
        if (idx >= 0 && idx < NUM_DEFAULT_KEYS) {
                key = key_mtx_dereference(sdata->local, sdata->keys[idx]);
                if (!key)
-                       key = key_mtx_dereference(sdata->local, sdata->deflink.gtk[idx]);
+                       key = key_mtx_dereference(sdata->local, link->gtk[idx]);
        }
 
        if (uni) {
        }
 
        if (multi)
-               rcu_assign_pointer(sdata->deflink.default_multicast_key, key);
+               rcu_assign_pointer(link->default_multicast_key, key);
 
        ieee80211_debugfs_key_update_default(sdata);
 }
 
-void ieee80211_set_default_key(struct ieee80211_sub_if_data *sdata, int idx,
+void ieee80211_set_default_key(struct ieee80211_link_data *link, int idx,
                               bool uni, bool multi)
 {
-       mutex_lock(&sdata->local->key_mtx);
-       __ieee80211_set_default_key(sdata, idx, uni, multi);
-       mutex_unlock(&sdata->local->key_mtx);
+       mutex_lock(&link->sdata->local->key_mtx);
+       __ieee80211_set_default_key(link, idx, uni, multi);
+       mutex_unlock(&link->sdata->local->key_mtx);
 }
 
 static void
-__ieee80211_set_default_mgmt_key(struct ieee80211_sub_if_data *sdata, int idx)
+__ieee80211_set_default_mgmt_key(struct ieee80211_link_data *link, int idx)
 {
+       struct ieee80211_sub_if_data *sdata = link->sdata;
        struct ieee80211_key *key = NULL;
 
        assert_key_lock(sdata->local);
 
        if (idx >= NUM_DEFAULT_KEYS &&
            idx < NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS)
-               key = key_mtx_dereference(sdata->local,
-                                         sdata->deflink.gtk[idx]);
+               key = key_mtx_dereference(sdata->local, link->gtk[idx]);
 
-       rcu_assign_pointer(sdata->deflink.default_mgmt_key, key);
+       rcu_assign_pointer(link->default_mgmt_key, key);
 
        ieee80211_debugfs_key_update_default(sdata);
 }
 
-void ieee80211_set_default_mgmt_key(struct ieee80211_sub_if_data *sdata,
+void ieee80211_set_default_mgmt_key(struct ieee80211_link_data *link,
                                    int idx)
 {
-       mutex_lock(&sdata->local->key_mtx);
-       __ieee80211_set_default_mgmt_key(sdata, idx);
-       mutex_unlock(&sdata->local->key_mtx);
+       mutex_lock(&link->sdata->local->key_mtx);
+       __ieee80211_set_default_mgmt_key(link, idx);
+       mutex_unlock(&link->sdata->local->key_mtx);
 }
 
 static void
-__ieee80211_set_default_beacon_key(struct ieee80211_sub_if_data *sdata, int idx)
+__ieee80211_set_default_beacon_key(struct ieee80211_link_data *link, int idx)
 {
+       struct ieee80211_sub_if_data *sdata = link->sdata;
        struct ieee80211_key *key = NULL;
 
        assert_key_lock(sdata->local);
        if (idx >= NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS &&
            idx < NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS +
            NUM_DEFAULT_BEACON_KEYS)
-               key = key_mtx_dereference(sdata->local,
-                                         sdata->deflink.gtk[idx]);
+               key = key_mtx_dereference(sdata->local, link->gtk[idx]);
 
-       rcu_assign_pointer(sdata->deflink.default_beacon_key, key);
+       rcu_assign_pointer(link->default_beacon_key, key);
 
        ieee80211_debugfs_key_update_default(sdata);
 }
 
-void ieee80211_set_default_beacon_key(struct ieee80211_sub_if_data *sdata,
+void ieee80211_set_default_beacon_key(struct ieee80211_link_data *link,
                                      int idx)
 {
-       mutex_lock(&sdata->local->key_mtx);
-       __ieee80211_set_default_beacon_key(sdata, idx);
-       mutex_unlock(&sdata->local->key_mtx);
+       mutex_lock(&link->sdata->local->key_mtx);
+       __ieee80211_set_default_beacon_key(link, idx);
+       mutex_unlock(&link->sdata->local->key_mtx);
 }
 
 static int ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
-                                 struct sta_info *sta,
-                                 bool pairwise,
-                                 struct ieee80211_key *old,
-                                 struct ieee80211_key *new)
+                                struct ieee80211_link_data *link,
+                                struct sta_info *sta,
+                                bool pairwise,
+                                struct ieee80211_key *old,
+                                struct ieee80211_key *new)
 {
+       struct link_sta_info *link_sta = sta ? &sta->deflink : NULL;
+       int link_id;
        int idx;
        int ret = 0;
        bool defunikey, defmultikey, defmgmtkey, defbeaconkey;
 
        if (new) {
                idx = new->conf.keyidx;
-               list_add_tail_rcu(&new->list, &sdata->key_list);
                is_wep = new->conf.cipher == WLAN_CIPHER_SUITE_WEP40 ||
                         new->conf.cipher == WLAN_CIPHER_SUITE_WEP104;
+               link_id = new->conf.link_id;
        } else {
                idx = old->conf.keyidx;
                is_wep = old->conf.cipher == WLAN_CIPHER_SUITE_WEP40 ||
                         old->conf.cipher == WLAN_CIPHER_SUITE_WEP104;
+               link_id = old->conf.link_id;
+       }
+
+       if (WARN(old && old->conf.link_id != link_id,
+                "old link ID %d doesn't match new link ID %d\n",
+                old->conf.link_id, link_id))
+               return -EINVAL;
+
+       if (link_id >= 0) {
+               if (!link) {
+                       link = sdata_dereference(sdata->link[link_id], sdata);
+                       if (!link)
+                               return -ENOLINK;
+               }
+
+               if (sta) {
+                       link_sta = rcu_dereference_protected(sta->link[link_id],
+                                                            lockdep_is_held(&sta->local->sta_mtx));
+                       if (!link_sta)
+                               return -ENOLINK;
+               }
+       } else {
+               link = &sdata->deflink;
        }
 
        if ((is_wep || pairwise) && idx >= NUM_DEFAULT_KEYS)
        if (ret)
                return ret;
 
+       if (new)
+               list_add_tail_rcu(&new->list, &sdata->key_list);
+
        if (sta) {
                if (pairwise) {
                        rcu_assign_pointer(sta->ptk[idx], new);
                            !(new->conf.flags & IEEE80211_KEY_FLAG_NO_AUTO_TX))
                                _ieee80211_set_tx_key(new, true);
                } else {
-                       rcu_assign_pointer(sta->deflink.gtk[idx], new);
+                       rcu_assign_pointer(link_sta->gtk[idx], new);
                }
                /* Only needed for transition from no key -> key.
                 * Still triggers unnecessary when using Extended Key ID
                                                sdata->default_unicast_key);
                defmultikey = old &&
                        old == key_mtx_dereference(sdata->local,
-                                               sdata->deflink.default_multicast_key);
+                                                  link->default_multicast_key);
                defmgmtkey = old &&
                        old == key_mtx_dereference(sdata->local,
-                                               sdata->deflink.default_mgmt_key);
+                                                  link->default_mgmt_key);
                defbeaconkey = old &&
                        old == key_mtx_dereference(sdata->local,
-                                                  sdata->deflink.default_beacon_key);
+                                                  link->default_beacon_key);
 
                if (defunikey && !new)
-                       __ieee80211_set_default_key(sdata, -1, true, false);
+                       __ieee80211_set_default_key(link, -1, true, false);
                if (defmultikey && !new)
-                       __ieee80211_set_default_key(sdata, -1, false, true);
+                       __ieee80211_set_default_key(link, -1, false, true);
                if (defmgmtkey && !new)
-                       __ieee80211_set_default_mgmt_key(sdata, -1);
+                       __ieee80211_set_default_mgmt_key(link, -1);
                if (defbeaconkey && !new)
-                       __ieee80211_set_default_beacon_key(sdata, -1);
+                       __ieee80211_set_default_beacon_key(link, -1);
 
                if (is_wep || pairwise)
                        rcu_assign_pointer(sdata->keys[idx], new);
                else
-                       rcu_assign_pointer(sdata->deflink.gtk[idx], new);
+                       rcu_assign_pointer(link->gtk[idx], new);
 
                if (defunikey && new)
-                       __ieee80211_set_default_key(sdata, new->conf.keyidx,
+                       __ieee80211_set_default_key(link, new->conf.keyidx,
                                                    true, false);
                if (defmultikey && new)
-                       __ieee80211_set_default_key(sdata, new->conf.keyidx,
+                       __ieee80211_set_default_key(link, new->conf.keyidx,
                                                    false, true);
                if (defmgmtkey && new)
-                       __ieee80211_set_default_mgmt_key(sdata,
+                       __ieee80211_set_default_mgmt_key(link,
                                                         new->conf.keyidx);
                if (defbeaconkey && new)
-                       __ieee80211_set_default_beacon_key(sdata,
+                       __ieee80211_set_default_beacon_key(link,
                                                           new->conf.keyidx);
        }
 
        key->conf.flags = 0;
        key->flags = 0;
 
+       key->conf.link_id = -1;
        key->conf.cipher = cipher;
        key->conf.keyidx = idx;
        key->conf.keylen = key_len;
 }
 
 int ieee80211_key_link(struct ieee80211_key *key,
-                      struct ieee80211_sub_if_data *sdata,
+                      struct ieee80211_link_data *link,
                       struct sta_info *sta)
 {
+       struct ieee80211_sub_if_data *sdata = link->sdata;
        static atomic_t key_color = ATOMIC_INIT(0);
        struct ieee80211_key *old_key = NULL;
        int idx = key->conf.keyidx;
                    (old_key && old_key->conf.cipher != key->conf.cipher))
                        goto out;
        } else if (sta) {
-               old_key = key_mtx_dereference(sdata->local,
-                                             sta->deflink.gtk[idx]);
+               struct link_sta_info *link_sta = &sta->deflink;
+               int link_id = key->conf.link_id;
+
+               if (link_id >= 0) {
+                       link_sta = rcu_dereference_protected(sta->link[link_id],
+                                                            lockdep_is_held(&sta->local->sta_mtx));
+                       if (!link_sta)
+                               return -ENOLINK;
+               }
+
+               old_key = key_mtx_dereference(sdata->local, link_sta->gtk[idx]);
        } else {
                if (idx < NUM_DEFAULT_KEYS)
                        old_key = key_mtx_dereference(sdata->local,
                                                      sdata->keys[idx]);
                if (!old_key)
                        old_key = key_mtx_dereference(sdata->local,
-                                                     sdata->deflink.gtk[idx]);
+                                                     link->gtk[idx]);
        }
 
        /* Non-pairwise keys must also not switch the cipher on rekey */
 
        increment_tailroom_need_count(sdata);
 
-       ret = ieee80211_key_replace(sdata, sta, pairwise, old_key, key);
+       ret = ieee80211_key_replace(sdata, link, sta, pairwise, old_key, key);
 
        if (!ret) {
                ieee80211_debugfs_key_add(key);
         * Replace key with nothingness if it was ever used.
         */
        if (key->sdata)
-               ieee80211_key_replace(key->sdata, key->sta,
-                               key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
-                               key, NULL);
+               ieee80211_key_replace(key->sdata, NULL, key->sta,
+                                     key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
+                                     key, NULL);
        ieee80211_key_destroy(key, delay_tailroom);
 }
 
        ieee80211_debugfs_key_remove_beacon_default(sdata);
 
        list_for_each_entry_safe(key, tmp, &sdata->key_list, list) {
-               ieee80211_key_replace(key->sdata, key->sta,
-                               key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
-                               key, NULL);
+               ieee80211_key_replace(key->sdata, NULL, key->sta,
+                                     key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
+                                     key, NULL);
                list_add_tail(&key->list, keys);
        }
 
        ieee80211_debugfs_key_update_default(sdata);
 }
 
+void ieee80211_remove_link_keys(struct ieee80211_link_data *link,
+                               struct list_head *keys)
+{
+       struct ieee80211_sub_if_data *sdata = link->sdata;
+       struct ieee80211_local *local = sdata->local;
+       struct ieee80211_key *key, *tmp;
+
+       mutex_lock(&local->key_mtx);
+       list_for_each_entry_safe(key, tmp, &sdata->key_list, list) {
+               if (key->conf.link_id != link->link_id)
+                       continue;
+               ieee80211_key_replace(key->sdata, link, key->sta,
+                                     key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
+                                     key, NULL);
+               list_add_tail(&key->list, keys);
+       }
+       mutex_unlock(&local->key_mtx);
+}
+
+void ieee80211_free_key_list(struct ieee80211_local *local,
+                            struct list_head *keys)
+{
+       struct ieee80211_key *key, *tmp;
+
+       mutex_lock(&local->key_mtx);
+       list_for_each_entry_safe(key, tmp, keys, list)
+               __ieee80211_key_destroy(key, false);
+       mutex_unlock(&local->key_mtx);
+}
+
 void ieee80211_free_keys(struct ieee80211_sub_if_data *sdata,
                         bool force_synchronize)
 {
                key = key_mtx_dereference(local, sta->deflink.gtk[i]);
                if (!key)
                        continue;
-               ieee80211_key_replace(key->sdata, key->sta,
-                               key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
-                               key, NULL);
+               ieee80211_key_replace(key->sdata, NULL, key->sta,
+                                     key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
+                                     key, NULL);
                __ieee80211_key_destroy(key, key->sdata->vif.type ==
                                        NL80211_IFTYPE_STATION);
        }
                key = key_mtx_dereference(local, sta->ptk[i]);
                if (!key)
                        continue;
-               ieee80211_key_replace(key->sdata, key->sta,
-                               key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
-                               key, NULL);
+               ieee80211_key_replace(key->sdata, NULL, key->sta,
+                                     key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
+                                     key, NULL);
                __ieee80211_key_destroy(key, key->sdata->vif.type ==
                                        NL80211_IFTYPE_STATION);
        }
        if (sdata->u.mgd.mfp != IEEE80211_MFP_DISABLED)
                key->conf.flags |= IEEE80211_KEY_FLAG_RX_MGMT;
 
-       err = ieee80211_key_link(key, sdata, NULL);
+       /* FIXME: this function needs to get a link ID */
+       err = ieee80211_key_link(key, &sdata->deflink, NULL);
        if (err)
                return ERR_PTR(err);
 
 
 ieee80211_rx_get_bigtk(struct ieee80211_rx_data *rx, int idx)
 {
        struct ieee80211_key *key = NULL;
-       struct ieee80211_sub_if_data *sdata = rx->sdata;
        int idx2;
 
        /* Make sure key gets set if either BIGTK key index is set so that
                        idx2 = idx - 1;
        }
 
-       if (rx->sta)
-               key = rcu_dereference(rx->sta->deflink.gtk[idx]);
+       if (rx->link_sta)
+               key = rcu_dereference(rx->link_sta->gtk[idx]);
        if (!key)
-               key = rcu_dereference(sdata->deflink.gtk[idx]);
-       if (!key && rx->sta)
-               key = rcu_dereference(rx->sta->deflink.gtk[idx2]);
+               key = rcu_dereference(rx->link->gtk[idx]);
+       if (!key && rx->link_sta)
+               key = rcu_dereference(rx->link_sta->gtk[idx2]);
        if (!key)
-               key = rcu_dereference(sdata->deflink.gtk[idx2]);
+               key = rcu_dereference(rx->link->gtk[idx2]);
 
        return key;
 }
                if (mmie_keyidx < NUM_DEFAULT_KEYS ||
                    mmie_keyidx >= NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS)
                        return RX_DROP_MONITOR; /* unexpected BIP keyidx */
-               if (rx->sta) {
+               if (rx->link_sta) {
                        if (ieee80211_is_group_privacy_action(skb) &&
                            test_sta_flag(rx->sta, WLAN_STA_MFP))
                                return RX_DROP_MONITOR;
 
-                       rx->key = rcu_dereference(rx->sta->deflink.gtk[mmie_keyidx]);
+                       rx->key = rcu_dereference(rx->link_sta->gtk[mmie_keyidx]);
                }
                if (!rx->key)
-                       rx->key = rcu_dereference(rx->sdata->deflink.gtk[mmie_keyidx]);
+                       rx->key = rcu_dereference(rx->link->gtk[mmie_keyidx]);
        } else if (!ieee80211_has_protected(fc)) {
                /*
                 * The frame was not protected, so skip decryption. However, we
                 * have been expected.
                 */
                struct ieee80211_key *key = NULL;
-               struct ieee80211_sub_if_data *sdata = rx->sdata;
                int i;
 
                if (ieee80211_is_beacon(fc)) {
                        key = ieee80211_rx_get_bigtk(rx, -1);
                } else if (ieee80211_is_mgmt(fc) &&
                           is_multicast_ether_addr(hdr->addr1)) {
-                       key = rcu_dereference(rx->sdata->deflink.default_mgmt_key);
+                       key = rcu_dereference(rx->link->default_mgmt_key);
                } else {
-                       if (rx->sta) {
+                       if (rx->link_sta) {
                                for (i = 0; i < NUM_DEFAULT_KEYS; i++) {
-                                       key = rcu_dereference(rx->sta->deflink.gtk[i]);
+                                       key = rcu_dereference(rx->link_sta->gtk[i]);
                                        if (key)
                                                break;
                                }
                        }
                        if (!key) {
                                for (i = 0; i < NUM_DEFAULT_KEYS; i++) {
-                                       key = rcu_dereference(sdata->deflink.gtk[i]);
+                                       key = rcu_dereference(rx->link->gtk[i]);
                                        if (key)
                                                break;
                                }
                        return RX_DROP_UNUSABLE;
 
                /* check per-station GTK first, if multicast packet */
-               if (is_multicast_ether_addr(hdr->addr1) && rx->sta)
-                       rx->key = rcu_dereference(rx->sta->deflink.gtk[keyidx]);
+               if (is_multicast_ether_addr(hdr->addr1) && rx->link_sta)
+                       rx->key = rcu_dereference(rx->link_sta->gtk[keyidx]);
 
                /* if not found, try default key */
                if (!rx->key) {
                        if (is_multicast_ether_addr(hdr->addr1))
-                               rx->key = rcu_dereference(rx->sdata->deflink.gtk[keyidx]);
+                               rx->key = rcu_dereference(rx->link->gtk[keyidx]);
                        if (!rx->key)
                                rx->key = rcu_dereference(rx->sdata->keys[keyidx]);
 
                if (!link)
                        return true;
                rx->link = link;
+
+               if (rx->sta) {
+                       rx->link_sta =
+                               rcu_dereference(rx->sta->link[rx->link_id]);
+                       if (!rx->link_sta)
+                               return true;
+               }
        } else {
+               if (rx->sta)
+                       rx->link_sta = &rx->sta->deflink;
+
                rx->link = &sdata->deflink;
        }