aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
        mem_size = aead_size + (nsg * sizeof(struct scatterlist));
-       mem_size = mem_size + prot->aad_size;
+       mem_size = mem_size + TLS_MAX_AAD_SIZE;
        mem_size = mem_size + MAX_IV_SIZE;
        mem_size = mem_size + prot->tail_size;
 
        sgin = (struct scatterlist *)(mem + aead_size);
        sgout = sgin + n_sgin;
        aad = (u8 *)(sgout + n_sgout);
-       iv = aad + prot->aad_size;
+       iv = aad + TLS_MAX_AAD_SIZE;
        tail = iv + MAX_IV_SIZE;
 
        /* For CCM based ciphers, first byte of nonce+iv is a constant */
                goto free_priv;
        }
 
-       /* Sanity-check the sizes for stack allocations. */
-       if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
-           rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE) {
-               rc = -EINVAL;
-               goto free_priv;
-       }
-
        if (crypto_info->version == TLS_1_3_VERSION) {
                nonce_size = 0;
                prot->aad_size = TLS_HEADER_SIZE;
                prot->tail_size = 0;
        }
 
+       /* Sanity-check the sizes for stack allocations. */
+       if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
+           rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE ||
+           prot->aad_size > TLS_MAX_AAD_SIZE) {
+               rc = -EINVAL;
+               goto free_priv;
+       }
+
        prot->version = crypto_info->version;
        prot->cipher_type = crypto_info->cipher_type;
        prot->prepend_size = TLS_HEADER_SIZE + nonce_size;