if (unlikely(!keypair))
                return NULL;
+       spin_lock_init(&keypair->receiving_counter.lock);
        keypair->internal_id = atomic64_inc_return(&keypair_counter);
        keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
        keypair->entry.peer = peer;
        memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
 }
 
-static void symmetric_key_init(struct noise_symmetric_key *key)
-{
-       spin_lock_init(&key->counter.receive.lock);
-       atomic64_set(&key->counter.counter, 0);
-       memset(key->counter.receive.backtrack, 0,
-              sizeof(key->counter.receive.backtrack));
-       key->birthdate = ktime_get_coarse_boottime_ns();
-       key->is_valid = true;
-}
-
 static void derive_keys(struct noise_symmetric_key *first_dst,
                        struct noise_symmetric_key *second_dst,
                        const u8 chaining_key[NOISE_HASH_LEN])
 {
+       u64 birthdate = ktime_get_coarse_boottime_ns();
        kdf(first_dst->key, second_dst->key, NULL, NULL,
            NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
            chaining_key);
-       symmetric_key_init(first_dst);
-       symmetric_key_init(second_dst);
+       first_dst->birthdate = second_dst->birthdate = birthdate;
+       first_dst->is_valid = second_dst->is_valid = true;
 }
 
 static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
 
 #include <linux/mutex.h>
 #include <linux/kref.h>
 
-union noise_counter {
-       struct {
-               u64 counter;
-               unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG];
-               spinlock_t lock;
-       } receive;
-       atomic64_t counter;
+struct noise_replay_counter {
+       u64 counter;
+       spinlock_t lock;
+       unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG];
 };
 
 struct noise_symmetric_key {
        u8 key[NOISE_SYMMETRIC_KEY_LEN];
-       union noise_counter counter;
        u64 birthdate;
        bool is_valid;
 };
 struct noise_keypair {
        struct index_hashtable_entry entry;
        struct noise_symmetric_key sending;
+       atomic64_t sending_counter;
        struct noise_symmetric_key receiving;
+       struct noise_replay_counter receiving_counter;
        __le32 remote_index;
        bool i_am_the_initiator;
        struct kref refcount;
 
        }
 }
 
-static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
+static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
 {
        struct scatterlist sg[MAX_SKB_FRAGS + 8];
        struct sk_buff *trailer;
        unsigned int offset;
        int num_frags;
 
-       if (unlikely(!key))
+       if (unlikely(!keypair))
                return false;
 
-       if (unlikely(!READ_ONCE(key->is_valid) ||
-                 wg_birthdate_has_expired(key->birthdate, REJECT_AFTER_TIME) ||
-                 key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) {
-               WRITE_ONCE(key->is_valid, false);
+       if (unlikely(!READ_ONCE(keypair->receiving.is_valid) ||
+                 wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) ||
+                 keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) {
+               WRITE_ONCE(keypair->receiving.is_valid, false);
                return false;
        }
 
 
        if (!chacha20poly1305_decrypt_sg_inplace(sg, skb->len, NULL, 0,
                                                 PACKET_CB(skb)->nonce,
-                                                key->key))
+                                                keypair->receiving.key))
                return false;
 
        /* Another ugly situation of pushing and pulling the header so as to
 }
 
 /* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
-static bool counter_validate(union noise_counter *counter, u64 their_counter)
+static bool counter_validate(struct noise_replay_counter *counter, u64 their_counter)
 {
        unsigned long index, index_current, top, i;
        bool ret = false;
 
-       spin_lock_bh(&counter->receive.lock);
+       spin_lock_bh(&counter->lock);
 
-       if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 ||
+       if (unlikely(counter->counter >= REJECT_AFTER_MESSAGES + 1 ||
                     their_counter >= REJECT_AFTER_MESSAGES))
                goto out;
 
        ++their_counter;
 
        if (unlikely((COUNTER_WINDOW_SIZE + their_counter) <
-                    counter->receive.counter))
+                    counter->counter))
                goto out;
 
        index = their_counter >> ilog2(BITS_PER_LONG);
 
-       if (likely(their_counter > counter->receive.counter)) {
-               index_current = counter->receive.counter >> ilog2(BITS_PER_LONG);
+       if (likely(their_counter > counter->counter)) {
+               index_current = counter->counter >> ilog2(BITS_PER_LONG);
                top = min_t(unsigned long, index - index_current,
                            COUNTER_BITS_TOTAL / BITS_PER_LONG);
                for (i = 1; i <= top; ++i)
-                       counter->receive.backtrack[(i + index_current) &
+                       counter->backtrack[(i + index_current) &
                                ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
-               counter->receive.counter = their_counter;
+               counter->counter = their_counter;
        }
 
        index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
        ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1),
-                               &counter->receive.backtrack[index]);
+                               &counter->backtrack[index]);
 
 out:
-       spin_unlock_bh(&counter->receive.lock);
+       spin_unlock_bh(&counter->lock);
        return ret;
 }
 
                if (unlikely(state != PACKET_STATE_CRYPTED))
                        goto next;
 
-               if (unlikely(!counter_validate(&keypair->receiving.counter,
+               if (unlikely(!counter_validate(&keypair->receiving_counter,
                                               PACKET_CB(skb)->nonce))) {
                        net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
                                            peer->device->dev->name,
                                            PACKET_CB(skb)->nonce,
-                                           keypair->receiving.counter.receive.counter);
+                                           keypair->receiving_counter.counter);
                        goto next;
                }
 
        struct sk_buff *skb;
 
        while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) {
-               enum packet_state state = likely(decrypt_packet(skb,
-                               &PACKET_CB(skb)->keypair->receiving)) ?
+               enum packet_state state =
+                       likely(decrypt_packet(skb, PACKET_CB(skb)->keypair)) ?
                                PACKET_STATE_CRYPTED : PACKET_STATE_DEAD;
                wg_queue_enqueue_per_peer_napi(skb, state);
                if (need_resched())
 
 #ifdef DEBUG
 bool __init wg_packet_counter_selftest(void)
 {
+       struct noise_replay_counter *counter;
        unsigned int test_num = 0, i;
-       union noise_counter counter;
        bool success = true;
 
-#define T_INIT do {                                               \
-               memset(&counter, 0, sizeof(union noise_counter)); \
-               spin_lock_init(&counter.receive.lock);            \
+       counter = kmalloc(sizeof(*counter), GFP_KERNEL);
+       if (unlikely(!counter)) {
+               pr_err("nonce counter self-test malloc: FAIL\n");
+               return false;
+       }
+
+#define T_INIT do {                                    \
+               memset(counter, 0, sizeof(*counter));  \
+               spin_lock_init(&counter->lock);        \
        } while (0)
 #define T_LIM (COUNTER_WINDOW_SIZE + 1)
 #define T(n, v) do {                                                  \
                ++test_num;                                           \
-               if (counter_validate(&counter, n) != (v)) {           \
+               if (counter_validate(counter, n) != (v)) {            \
                        pr_err("nonce counter self-test %u: FAIL\n",  \
                               test_num);                             \
                        success = false;                              \
 
        if (success)
                pr_info("nonce counter self-tests: pass\n");
+       kfree(counter);
        return success;
 }
 #endif
 
        rcu_read_lock_bh();
        keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
        send = keypair && READ_ONCE(keypair->sending.is_valid) &&
-              (atomic64_read(&keypair->sending.counter.counter) > REKEY_AFTER_MESSAGES ||
+              (atomic64_read(&keypair->sending_counter) > REKEY_AFTER_MESSAGES ||
                (keypair->i_am_the_initiator &&
                 wg_birthdate_has_expired(keypair->sending.birthdate, REKEY_AFTER_TIME)));
        rcu_read_unlock_bh();
 
 void wg_packet_send_staged_packets(struct wg_peer *peer)
 {
-       struct noise_symmetric_key *key;
        struct noise_keypair *keypair;
        struct sk_buff_head packets;
        struct sk_buff *skb;
        rcu_read_unlock_bh();
        if (unlikely(!keypair))
                goto out_nokey;
-       key = &keypair->sending;
-       if (unlikely(!READ_ONCE(key->is_valid)))
+       if (unlikely(!READ_ONCE(keypair->sending.is_valid)))
                goto out_nokey;
-       if (unlikely(wg_birthdate_has_expired(key->birthdate,
+       if (unlikely(wg_birthdate_has_expired(keypair->sending.birthdate,
                                              REJECT_AFTER_TIME)))
                goto out_invalid;
 
                 */
                PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0, ip_hdr(skb), skb);
                PACKET_CB(skb)->nonce =
-                               atomic64_inc_return(&key->counter.counter) - 1;
+                               atomic64_inc_return(&keypair->sending_counter) - 1;
                if (unlikely(PACKET_CB(skb)->nonce >= REJECT_AFTER_MESSAGES))
                        goto out_invalid;
        }
        return;
 
 out_invalid:
-       WRITE_ONCE(key->is_valid, false);
+       WRITE_ONCE(keypair->sending.is_valid, false);
 out_nokey:
        wg_noise_keypair_put(keypair, false);