return;
 
        wcid->rx_check_pn = true;
+
+       /* data frame */
        for (i = 0; i < IEEE80211_NUM_TIDS; i++) {
                ieee80211_get_key_rx_seq(key, i, &seq);
                memcpy(wcid->rx_key_pn[i], seq.ccmp.pn, sizeof(seq.ccmp.pn));
        }
+
+       /* robust management frame */
+       ieee80211_get_key_rx_seq(key, -1, &seq);
+       memcpy(wcid->rx_key_pn[i], seq.ccmp.pn, sizeof(seq.ccmp.pn));
+
 }
 EXPORT_SYMBOL(mt76_wcid_key_setup);
 
        struct mt76_rx_status *status = (struct mt76_rx_status *)skb->cb;
        struct mt76_wcid *wcid = status->wcid;
        struct ieee80211_hdr *hdr;
-       u8 tidno = status->qos_ctl & IEEE80211_QOS_CTL_TID_MASK;
+       int security_idx;
        int ret;
 
        if (!(status->flag & RX_FLAG_DECRYPTED))
        if (!wcid || !wcid->rx_check_pn)
                return 0;
 
+       hdr = mt76_skb_get_hdr(skb);
        if (!(status->flag & RX_FLAG_IV_STRIPPED)) {
                /*
                 * Validate the first fragment both here and in mac80211
                 * All further fragments will be validated by mac80211 only.
                 */
-               hdr = mt76_skb_get_hdr(skb);
                if (ieee80211_is_frag(hdr) &&
                    !ieee80211_is_first_frag(hdr->frame_control))
                        return 0;
        }
 
+       /* IEEE 802.11-2020, 12.5.3.4.4 "PN and replay detection" c):
+        *
+        * the recipient shall maintain a single replay counter for received
+        * individually addressed robust Management frames that are received
+        * with the To DS subfield equal to 0, [...]
+        */
+       if (ieee80211_is_mgmt(hdr->frame_control) &&
+           !ieee80211_has_tods(hdr->frame_control))
+               security_idx = IEEE80211_NUM_TIDS;
+       else
+               security_idx = status->qos_ctl & IEEE80211_QOS_CTL_TID_MASK;
+
        BUILD_BUG_ON(sizeof(status->iv) != sizeof(wcid->rx_key_pn[0]));
-       ret = memcmp(status->iv, wcid->rx_key_pn[tidno],
+       ret = memcmp(status->iv, wcid->rx_key_pn[security_idx],
                     sizeof(status->iv));
        if (ret <= 0)
                return -EINVAL; /* replay */
 
-       memcpy(wcid->rx_key_pn[tidno], status->iv, sizeof(status->iv));
+       memcpy(wcid->rx_key_pn[security_idx], status->iv, sizeof(status->iv));
 
        if (status->flag & RX_FLAG_IV_STRIPPED)
                status->flag |= RX_FLAG_PN_VALIDATED;