From: David Woodhouse Date: Wed, 12 May 2021 15:19:07 +0000 (+0100) Subject: Implement RSA-PSS padding for TPMv2 X-Git-Tag: v8.20~198 X-Git-Url: https://www.infradead.org/git/?a=commitdiff_plain;h=ff367965fcc13f6c1ba7fbda7a49d1467f1b39de;p=users%2Fdwmw2%2Fopenconnect.git Implement RSA-PSS padding for TPMv2 Now I can connect using TLSv1.3 using a TPMv2 RSA key. And my list of "stuff I should never have had to do for myself in the application, just to ask the crypto library to use the key that the user pointed it at" has *really* jumped the shark now. Signed-off-by: David Woodhouse --- diff --git a/gnutls.h b/gnutls.h index d6f75dc6..3ce7f877 100644 --- a/gnutls.h +++ b/gnutls.h @@ -42,9 +42,10 @@ int tpm2_rsa_sign_hash_fn(gnutls_privkey_t key, gnutls_sign_algorithm_t algo, int tpm2_ec_sign_hash_fn(gnutls_privkey_t key, gnutls_sign_algorithm_t algo, void *_certinfo, unsigned int flags, const gnutls_datum_t *data, gnutls_datum_t *sig); -int oc_pkcs1_pad(struct openconnect_info *vpninfo, - unsigned char *buf, int size, const gnutls_datum_t *data); +int oc_pad_rsasig(struct openconnect_info *vpninfo, gnutls_sign_algorithm_t algo, + unsigned char *buf, int size, const gnutls_datum_t *data, int keybits); uint16_t tpm2_key_curve(struct openconnect_info *vpninfo, struct cert_info *certinfo); +int tpm2_rsa_key_bits(struct openconnect_info *vpninfo, struct cert_info *certinfo); /* GnuTLS 3.6.0+ provides this. We have our own for older GnuTLS. There is * also _gnutls_encode_ber_rs_raw() in some older versions, but there were diff --git a/gnutls_tpm2.c b/gnutls_tpm2.c index 839e1b83..1f0f0946 100644 --- a/gnutls_tpm2.c +++ b/gnutls_tpm2.c @@ -102,9 +102,15 @@ static int tpm2_ec_sign_fn(gnutls_privkey_t key, void *_certinfo, #if GNUTLS_VERSION_NUMBER >= 0x030600 static int rsa_key_info(gnutls_privkey_t key, unsigned int flags, void *_certinfo) { + struct cert_info *certinfo = _certinfo; + struct openconnect_info *vpninfo = certinfo->vpninfo; + if (flags & GNUTLS_PRIVKEY_INFO_PK_ALGO) return GNUTLS_PK_RSA; + if (flags & GNUTLS_PRIVKEY_INFO_PK_ALGO_BITS) + return tpm2_rsa_key_bits(vpninfo, certinfo); + if (flags & GNUTLS_PRIVKEY_INFO_HAVE_SIGN_ALGO) { gnutls_sign_algorithm_t algo = GNUTLS_FLAGS_TO_SIGN_ALGO(flags); switch (algo) { @@ -115,7 +121,18 @@ static int rsa_key_info(gnutls_privkey_t key, unsigned int flags, void *_certinf case GNUTLS_SIGN_RSA_SHA512: return 1; + case GNUTLS_SIGN_RSA_PSS_SHA256: + case GNUTLS_SIGN_RSA_PSS_RSAE_SHA256: + case GNUTLS_SIGN_RSA_PSS_SHA384: + case GNUTLS_SIGN_RSA_PSS_RSAE_SHA384: + case GNUTLS_SIGN_RSA_PSS_SHA512: + case GNUTLS_SIGN_RSA_PSS_RSAE_SHA512: + return 1; + default: + vpn_progress(vpninfo, PRG_DEBUG, + _("Not supporting EC sign algo %s\n"), + gnutls_sign_get_name(algo)); return 0; } } @@ -130,16 +147,17 @@ static int rsa_key_info(gnutls_privkey_t key, unsigned int flags, void *_certinf #if GNUTLS_VERSION_NUMBER >= 0x030400 static int ec_key_info(gnutls_privkey_t key, unsigned int flags, void *_certinfo) { - struct cert_info *certinfo = _certinfo; - struct openconnect_info *vpninfo = certinfo->vpninfo; - if (flags & GNUTLS_PRIVKEY_INFO_PK_ALGO) return GNUTLS_PK_EC; #ifdef GNUTLS_PRIVKEY_INFO_HAVE_SIGN_ALGO if (flags & GNUTLS_PRIVKEY_INFO_HAVE_SIGN_ALGO) { + struct cert_info *certinfo = _certinfo; + struct openconnect_info *vpninfo = certinfo->vpninfo; + uint16_t tpm2_curve = tpm2_key_curve(vpninfo, certinfo); gnutls_sign_algorithm_t algo = GNUTLS_FLAGS_TO_SIGN_ALGO(flags); + switch (algo) { case GNUTLS_SIGN_ECDSA_SHA1: case GNUTLS_SIGN_ECDSA_SHA256: @@ -393,8 +411,8 @@ int oc_gnutls_encode_rs_value(gnutls_datum_t *sig, const gnutls_datum_t *sig_r, /* EMSA-PKCS1-v1_5 padding in accordance with RFC3447 §9.2 */ #define PKCS1_PAD_OVERHEAD 11 -int oc_pkcs1_pad(struct openconnect_info *vpninfo, - unsigned char *buf, int size, const gnutls_datum_t *data) +static int oc_pkcs1_pad(struct openconnect_info *vpninfo, + unsigned char *buf, int size, const gnutls_datum_t *data) { if (data->size + PKCS1_PAD_OVERHEAD > size) { vpn_progress(vpninfo, PRG_ERR, @@ -411,4 +429,160 @@ int oc_pkcs1_pad(struct openconnect_info *vpninfo, return 0; } + +#if GNUTLS_VERSION_NUMBER >= 0x030600 +/* EMSA-PSS encoding in accordance with RFC3447 §9.1 */ +static int oc_pss_mgf1_pad(struct openconnect_info *vpninfo, gnutls_digest_algorithm_t dig, + unsigned char *emBuf, int emLen, const gnutls_datum_t *mHash, int keybits) +{ + gnutls_hash_hd_t hashctx = NULL; + int err = GNUTLS_E_PK_SIGN_FAILED; + + /* The emBits for EMSA-PSS encoding is actually one *fewer* bit than + * the RSA modulus. As RFC3447 §8.1.1 points out, "the octet length + * of EM will be one less than k if modBits - 1 is divisible by 8 + * and equal to k otherwise". Where k is the input emLen, which we + * thus need to adjust before using it as emLen for the following + * operations. Not that it matters much since I don't think the TPM + * can cope with RSA keys whose modulus isn't a multiple of 8 bits + * anyway. */ + int msbits = (keybits - 1) & 7; + if (!msbits) { + *(emBuf++) = 0; + emLen--; + } + + /* GnuTLS gives us a predigested mHash from which we create M' and + * continue the process. Can we infer all the PSS parameters from + * the digest size, including the salt size? Or does GnuTLS need + * a gnutls_privkey_import_ext5() which lets us have the params too? + * Better still, could GnuTLS just do this all for us and we only + * do a raw signature — really raw, unlike GNUTLS_SIGN_RSA_RAW + * which AIUI is actually padded. */ + if (mHash->size > emLen - 2) { + vpn_progress(vpninfo, PRG_ERR, + _("PSS encoding failed; hash size %d too large for RSA key %d\n"), + mHash->size, emLen); + return GNUTLS_E_PK_SIGN_FAILED; + } + + int sLen = mHash->size; + if (sLen + mHash->size > emLen - 2) + sLen = emLen - 2 - mHash->size; + + char salt[SHA512_SIZE]; + if (sLen) { + err = gnutls_rnd(GNUTLS_RND_NONCE, salt, sLen); + if (err) + goto out; + } + + /* Hash M' (8 zeroes || mHash || salt) into its place in EM */ + if ((err = gnutls_hash_init(&hashctx, dig)) || + (err = gnutls_hash(hashctx, "\0\0\0\0\0\0\0\0", 8)) || + (err = gnutls_hash(hashctx, mHash->data, mHash->size)) || + (sLen && (err = gnutls_hash(hashctx, salt, sLen)))) + goto out; + + int maskedDBLen = emLen - mHash->size - 1; + gnutls_hash_output(hashctx, emBuf + maskedDBLen); + + emBuf[emLen - 1] = 0xbc; + + /* Now the MGF1 function as definsed in RFC3447 Appendix B, although + * it's somewhat easier to read in NIST SP 800-56B §7.2.2.2. + * + * We repeatedly hash (M' || C) where C is an incrementing 32-bit + * counter, so hash M' first and then use gnutls_hash_copy() each + * time to add C to the copy. */ + err = gnutls_hash(hashctx, emBuf + maskedDBLen, mHash->size); + if (err) + goto out; + + int mgflen = 0, mgf_count = 0; + while (mgflen < maskedDBLen) { + gnutls_hash_hd_t ctx2 = gnutls_hash_copy(hashctx); + if (!ctx2) { + err = GNUTLS_E_PK_SIGN_FAILED; + goto out; + } + uint32_t be_count = htonl(mgf_count++); + err = gnutls_hash(ctx2, &be_count, sizeof(be_count)); + if (err) { + gnutls_hash_deinit(ctx2, NULL); + goto out; + } + if (mgflen + mHash->size <= maskedDBLen) { + gnutls_hash_deinit(ctx2, emBuf + mgflen); + mgflen += mHash->size; + } else { + char md[SHA512_SIZE]; + gnutls_hash_deinit(ctx2, md); + memcpy(emBuf + mgflen, md, maskedDBLen - mgflen); + mgflen = maskedDBLen; + } + } + + /* Back to EMSA-PSS-ENCODE step 10. The MGF result was directly placed + * into emBuf, so now XOR with DB, which is (zeroes || 0x01 || salt) */ + int dst = maskedDBLen - 1; + while (sLen--) + emBuf[dst--] ^= salt[sLen]; + emBuf[dst] ^= 0x01; + + /* Now mask out the high bits. In the case where msbits is zero, we + * skipped the entire first byte so do nothing. */ + if (msbits) + emBuf[0] &= 0xFF >> (8 - msbits); + + err = 0; + out: + if (hashctx) + gnutls_hash_deinit(hashctx, NULL); + + return err; +} +#endif + +int oc_pad_rsasig(struct openconnect_info *vpninfo, gnutls_sign_algorithm_t algo, + unsigned char *buf, int size, const gnutls_datum_t *data, int keybits) +{ + switch(algo) { + case GNUTLS_SIGN_UNKNOWN: + case GNUTLS_SIGN_RSA_SHA1: + case GNUTLS_SIGN_RSA_SHA256: + case GNUTLS_SIGN_RSA_SHA384: + case GNUTLS_SIGN_RSA_SHA512: + return oc_pkcs1_pad(vpninfo, buf, size, data); + +#if GNUTLS_VERSION_NUMBER >= 0x030600 + /* Really PKCS#1.5 padding, yes. */ + case GNUTLS_SIGN_RSA_RAW: + return oc_pkcs1_pad(vpninfo, buf, size, data); + + case GNUTLS_SIGN_RSA_PSS_SHA256: + case GNUTLS_SIGN_RSA_PSS_RSAE_SHA256: + if (data->size != SHA256_SIZE) + return GNUTLS_E_PK_SIGN_FAILED; + return oc_pss_mgf1_pad(vpninfo, GNUTLS_DIG_SHA256, buf, size, data, keybits); + + case GNUTLS_SIGN_RSA_PSS_SHA384: + case GNUTLS_SIGN_RSA_PSS_RSAE_SHA384: + if (data->size != SHA384_SIZE) + return GNUTLS_E_PK_SIGN_FAILED; + return oc_pss_mgf1_pad(vpninfo, GNUTLS_DIG_SHA384, buf, size, data, keybits); + + case GNUTLS_SIGN_RSA_PSS_SHA512: + case GNUTLS_SIGN_RSA_PSS_RSAE_SHA512: + if (data->size != SHA512_SIZE) + return GNUTLS_E_PK_SIGN_FAILED; + return oc_pss_mgf1_pad(vpninfo, GNUTLS_DIG_SHA512, buf, size, data, keybits); +#endif /* 3.6.0+ */ + default: + vpn_progress(vpninfo, PRG_ERR, + _("TPMv2 RSA sign called for unknown algorithm %s\n"), + gnutls_sign_get_name(algo)); + return GNUTLS_E_PK_SIGN_FAILED; + } +} #endif /* HAVE_TSS2 */ diff --git a/gnutls_tpm2_esys.c b/gnutls_tpm2_esys.c index 26e9b05d..8fa71f64 100644 --- a/gnutls_tpm2_esys.c +++ b/gnutls_tpm2_esys.c @@ -409,12 +409,13 @@ int tpm2_rsa_sign_hash_fn(gnutls_privkey_t key, gnutls_sign_algorithm_t algo, TSS2_RC r; vpn_progress(vpninfo, PRG_DEBUG, - _("TPM2 RSA sign function called for %d bytes.\n"), - data->size); + _("TPM2 RSA sign function called for %d bytes, algo %s\n"), + data->size, gnutls_sign_get_name(algo)); digest.size = certinfo->tpm2->pub.publicArea.unique.rsa.size; - if (oc_pkcs1_pad(vpninfo, digest.buffer, digest.size, data)) + if (oc_pad_rsasig(vpninfo, algo, digest.buffer, digest.size, data, + certinfo->tpm2->pub.publicArea.parameters.rsaDetail.keyBits)) return GNUTLS_E_PK_SIGN_FAILED; if (init_tpm2_key(&ectx, &key_handle, vpninfo, certinfo)) @@ -608,6 +609,11 @@ uint16_t tpm2_key_curve(struct openconnect_info *vpninfo, struct cert_info *cert return certinfo->tpm2->pub.publicArea.parameters.eccDetail.curveID; } +int tpm2_rsa_key_bits(struct openconnect_info *vpninfo, struct cert_info *certinfo) +{ + return certinfo->tpm2->pub.publicArea.parameters.rsaDetail.keyBits; +} + void release_tpm2_ctx(struct openconnect_info *vpninfo, struct cert_info *certinfo) { if (certinfo->tpm2) { diff --git a/gnutls_tpm2_ibm.c b/gnutls_tpm2_ibm.c index 16365dff..7ff6670c 100644 --- a/gnutls_tpm2_ibm.c +++ b/gnutls_tpm2_ibm.c @@ -344,7 +344,8 @@ int tpm2_rsa_sign_hash_fn(gnutls_privkey_t key, gnutls_sign_algorithm_t algo, in.cipherText.t.size = certinfo->tpm2->pub.publicArea.unique.rsa.t.size; - if (oc_pkcs1_pad(vpninfo, in.cipherText.t.buffer, in.cipherText.t.size, data)) + if (oc_pad_rsasig(vpninfo, algo, in.cipherText.t.buffer, in.cipherText.t.size, data, + certinfo->tpm2->pub.publicArea.parameters.rsaDetail.keyBits)) return GNUTLS_E_PK_SIGN_FAILED; in.inScheme.scheme = TPM_ALG_NULL; @@ -558,6 +559,11 @@ uint16_t tpm2_key_curve(struct openconnect_info *vpninfo, struct cert_info *cert return certinfo->tpm2->pub.publicArea.parameters.eccDetail.curveID; } +int tpm2_rsa_key_bits(struct openconnect_info *vpninfo, struct cert_info *certinfo) +{ + return certinfo->tpm2->pub.publicArea.parameters.rsaDetail.keyBits; +} + void release_tpm2_ctx(struct openconnect_info *vpninfo, struct cert_info *certinfo) { if (certinfo->tpm2) {