return ret;
 }
 
-static void
-mt76_sta_remove(struct mt76_dev *dev, struct ieee80211_vif *vif,
-               struct ieee80211_sta *sta)
+void __mt76_sta_remove(struct mt76_dev *dev, struct ieee80211_vif *vif,
+                      struct ieee80211_sta *sta)
 {
        struct mt76_wcid *wcid = (struct mt76_wcid *)sta->drv_priv;
-       int idx = wcid->idx;
-       int i;
+       int i, idx = wcid->idx;
 
        rcu_assign_pointer(dev->wcid[idx], NULL);
        synchronize_rcu();
 
-       mutex_lock(&dev->mutex);
-
        if (dev->drv->sta_remove)
                dev->drv->sta_remove(dev, vif, sta);
 
        for (i = 0; i < ARRAY_SIZE(sta->txq); i++)
                mt76_txq_remove(dev, sta->txq[i]);
        mt76_wcid_free(dev->wcid_mask, idx);
+}
+EXPORT_SYMBOL_GPL(__mt76_sta_remove);
 
+static void
+mt76_sta_remove(struct mt76_dev *dev, struct ieee80211_vif *vif,
+               struct ieee80211_sta *sta)
+{
+       mutex_lock(&dev->mutex);
+       __mt76_sta_remove(dev, vif, sta);
        mutex_unlock(&dev->mutex);
 }
 
 
                   struct ieee80211_sta *sta,
                   enum ieee80211_sta_state old_state,
                   enum ieee80211_sta_state new_state);
+void __mt76_sta_remove(struct mt76_dev *dev, struct ieee80211_vif *vif,
+                      struct ieee80211_sta *sta);
 
 struct ieee80211_sta *mt76_rx_convert(struct sk_buff *skb);
 
 
 {
        int i;
 
+       lockdep_assert_held(&dev->mt76.mutex);
+
        clear_bit(MT76_STATE_RUNNING, &dev->mt76.state);
 
        rcu_read_lock();
-
        ieee80211_iter_keys_rcu(dev->mt76.hw, NULL, mt76x02_key_sync, NULL);
+       rcu_read_unlock();
 
        for (i = 0; i < ARRAY_SIZE(dev->mt76.wcid); i++) {
-               struct mt76_wcid *wcid = rcu_dereference(dev->mt76.wcid[i]);
-               struct mt76x02_sta *msta;
                struct ieee80211_sta *sta;
                struct ieee80211_vif *vif;
+               struct mt76x02_sta *msta;
+               struct mt76_wcid *wcid;
                void *priv;
 
+               wcid = rcu_dereference_protected(dev->mt76.wcid[i],
+                                       lockdep_is_held(&dev->mt76.mutex));
                if (!wcid)
                        continue;
 
                priv = msta->vif;
                vif = container_of(priv, struct ieee80211_vif, drv_priv);
 
-               mt76_sta_state(dev->mt76.hw, vif, sta,
-                              IEEE80211_STA_NONE, IEEE80211_STA_NOTEXIST);
+               __mt76_sta_remove(&dev->mt76, vif, sta);
                memset(msta, 0, sizeof(*msta));
        }
 
-       rcu_read_unlock();
-
        dev->vif_mask = 0;
        dev->beacon_mask = 0;
 }
        for (i = 0; i < ARRAY_SIZE(dev->mt76.napi); i++)
                napi_disable(&dev->mt76.napi[i]);
 
+       mutex_lock(&dev->mt76.mutex);
+
        if (restart)
                mt76x02_reset_state(dev);
 
-       mutex_lock(&dev->mt76.mutex);
-
        if (dev->beacon_mask)
                mt76_clear(dev, MT_BEACON_TIME_CFG,
                           MT_BEACON_TIME_CFG_BEACON_TX |