]> www.infradead.org Git - users/dwmw2/openconnect.git/commitdiff
Implement RSA-PSS padding for TPMv2
authorDavid Woodhouse <dwmw2@infradead.org>
Wed, 12 May 2021 15:19:07 +0000 (16:19 +0100)
committerDavid Woodhouse <dwmw2@infradead.org>
Wed, 12 May 2021 22:01:09 +0000 (23:01 +0100)
Now I can connect using TLSv1.3 using a TPMv2 RSA key. And my list of
"stuff I should never have had to do for myself in the application,
just to ask the crypto library to use the key that the user pointed
it at" has *really* jumped the shark now.

Signed-off-by: David Woodhouse <dwmw2@infradead.org>
gnutls.h
gnutls_tpm2.c
gnutls_tpm2_esys.c
gnutls_tpm2_ibm.c

index d6f75dc6ecf3f444a253b851d4bc99d092d6a9fa..3ce7f8777f6e05682f832727391199d24289dff2 100644 (file)
--- a/gnutls.h
+++ b/gnutls.h
@@ -42,9 +42,10 @@ int tpm2_rsa_sign_hash_fn(gnutls_privkey_t key, gnutls_sign_algorithm_t algo,
 int tpm2_ec_sign_hash_fn(gnutls_privkey_t key, gnutls_sign_algorithm_t algo,
                         void *_certinfo, unsigned int flags,
                         const gnutls_datum_t *data, gnutls_datum_t *sig);
-int oc_pkcs1_pad(struct openconnect_info *vpninfo,
-                unsigned char *buf, int size, const gnutls_datum_t *data);
+int oc_pad_rsasig(struct openconnect_info *vpninfo, gnutls_sign_algorithm_t algo,
+                 unsigned char *buf, int size, const gnutls_datum_t *data, int keybits);
 uint16_t tpm2_key_curve(struct openconnect_info *vpninfo, struct cert_info *certinfo);
+int tpm2_rsa_key_bits(struct openconnect_info *vpninfo, struct cert_info *certinfo);
 
 /* GnuTLS 3.6.0+ provides this. We have our own for older GnuTLS. There is
  * also _gnutls_encode_ber_rs_raw() in some older versions, but there were
index 839e1b833542eaf6d458f351206392961be479ce..1f0f09462b097615242116a34465d67d8bba6198 100644 (file)
@@ -102,9 +102,15 @@ static int tpm2_ec_sign_fn(gnutls_privkey_t key, void *_certinfo,
 #if GNUTLS_VERSION_NUMBER >= 0x030600
 static int rsa_key_info(gnutls_privkey_t key, unsigned int flags, void *_certinfo)
 {
+       struct cert_info *certinfo = _certinfo;
+       struct openconnect_info *vpninfo = certinfo->vpninfo;
+
        if (flags & GNUTLS_PRIVKEY_INFO_PK_ALGO)
                return GNUTLS_PK_RSA;
 
+       if (flags & GNUTLS_PRIVKEY_INFO_PK_ALGO_BITS)
+               return tpm2_rsa_key_bits(vpninfo, certinfo);
+
        if (flags & GNUTLS_PRIVKEY_INFO_HAVE_SIGN_ALGO) {
                gnutls_sign_algorithm_t algo = GNUTLS_FLAGS_TO_SIGN_ALGO(flags);
                switch (algo) {
@@ -115,7 +121,18 @@ static int rsa_key_info(gnutls_privkey_t key, unsigned int flags, void *_certinf
                case GNUTLS_SIGN_RSA_SHA512:
                        return 1;
 
+               case GNUTLS_SIGN_RSA_PSS_SHA256:
+               case GNUTLS_SIGN_RSA_PSS_RSAE_SHA256:
+               case GNUTLS_SIGN_RSA_PSS_SHA384:
+               case GNUTLS_SIGN_RSA_PSS_RSAE_SHA384:
+               case GNUTLS_SIGN_RSA_PSS_SHA512:
+               case GNUTLS_SIGN_RSA_PSS_RSAE_SHA512:
+                       return 1;
+
                default:
+                       vpn_progress(vpninfo, PRG_DEBUG,
+                                    _("Not supporting EC sign algo %s\n"),
+                                    gnutls_sign_get_name(algo));
                        return 0;
                }
        }
@@ -130,16 +147,17 @@ static int rsa_key_info(gnutls_privkey_t key, unsigned int flags, void *_certinf
 #if GNUTLS_VERSION_NUMBER >= 0x030400
 static int ec_key_info(gnutls_privkey_t key, unsigned int flags, void *_certinfo)
 {
-       struct cert_info *certinfo = _certinfo;
-       struct openconnect_info *vpninfo = certinfo->vpninfo;
-
        if (flags & GNUTLS_PRIVKEY_INFO_PK_ALGO)
                return GNUTLS_PK_EC;
 
 #ifdef GNUTLS_PRIVKEY_INFO_HAVE_SIGN_ALGO
        if (flags & GNUTLS_PRIVKEY_INFO_HAVE_SIGN_ALGO) {
+               struct cert_info *certinfo = _certinfo;
+               struct openconnect_info *vpninfo = certinfo->vpninfo;
+
                uint16_t tpm2_curve = tpm2_key_curve(vpninfo, certinfo);
                gnutls_sign_algorithm_t algo = GNUTLS_FLAGS_TO_SIGN_ALGO(flags);
+
                switch (algo) {
                case GNUTLS_SIGN_ECDSA_SHA1:
                case GNUTLS_SIGN_ECDSA_SHA256:
@@ -393,8 +411,8 @@ int oc_gnutls_encode_rs_value(gnutls_datum_t *sig, const gnutls_datum_t *sig_r,
 
 /* EMSA-PKCS1-v1_5 padding in accordance with RFC3447 §9.2 */
 #define PKCS1_PAD_OVERHEAD 11
-int oc_pkcs1_pad(struct openconnect_info *vpninfo,
-                unsigned char *buf, int size, const gnutls_datum_t *data)
+static int oc_pkcs1_pad(struct openconnect_info *vpninfo,
+                       unsigned char *buf, int size, const gnutls_datum_t *data)
 {
        if (data->size + PKCS1_PAD_OVERHEAD > size) {
                vpn_progress(vpninfo, PRG_ERR,
@@ -411,4 +429,160 @@ int oc_pkcs1_pad(struct openconnect_info *vpninfo,
 
        return 0;
 }
+
+#if GNUTLS_VERSION_NUMBER >= 0x030600
+/* EMSA-PSS encoding in accordance with RFC3447 §9.1 */
+static int oc_pss_mgf1_pad(struct openconnect_info *vpninfo, gnutls_digest_algorithm_t dig,
+                          unsigned char *emBuf, int emLen, const gnutls_datum_t *mHash, int keybits)
+{
+       gnutls_hash_hd_t hashctx = NULL;
+       int err = GNUTLS_E_PK_SIGN_FAILED;
+
+       /* The emBits for EMSA-PSS encoding is actually one *fewer* bit than
+        * the RSA modulus. As RFC3447 §8.1.1 points out, "the octet length
+        * of EM will be one less than k if modBits - 1 is divisible by 8
+        * and equal to k otherwise". Where k is the input emLen, which we
+        * thus need to adjust before using it as emLen for the following
+        * operations. Not that it matters much since I don't think the TPM
+        * can cope with RSA keys whose modulus isn't a multiple of 8 bits
+        * anyway. */
+       int msbits = (keybits - 1) & 7;
+       if (!msbits) {
+               *(emBuf++) = 0;
+               emLen--;
+       }
+
+       /* GnuTLS gives us a predigested mHash from which we create M' and
+        * continue the process. Can we infer all the PSS parameters from
+        * the digest size, including the salt size? Or does GnuTLS need
+        * a gnutls_privkey_import_ext5() which lets us have the params too?
+        * Better still, could GnuTLS just do this all for us and we only
+        * do a raw signature — really raw, unlike GNUTLS_SIGN_RSA_RAW
+        * which AIUI is actually padded. */
+       if (mHash->size > emLen - 2) {
+               vpn_progress(vpninfo, PRG_ERR,
+                            _("PSS encoding failed; hash size %d too large for RSA key %d\n"),
+                            mHash->size, emLen);
+               return GNUTLS_E_PK_SIGN_FAILED;
+       }
+
+       int sLen = mHash->size;
+       if (sLen + mHash->size > emLen - 2)
+               sLen = emLen - 2 - mHash->size;
+
+       char salt[SHA512_SIZE];
+       if (sLen) {
+               err = gnutls_rnd(GNUTLS_RND_NONCE, salt, sLen);
+               if (err)
+                       goto out;
+       }
+
+       /* Hash M' (8 zeroes || mHash || salt) into its place in EM */
+       if ((err = gnutls_hash_init(&hashctx, dig)) ||
+           (err = gnutls_hash(hashctx, "\0\0\0\0\0\0\0\0", 8)) ||
+           (err = gnutls_hash(hashctx, mHash->data, mHash->size)) ||
+           (sLen && (err = gnutls_hash(hashctx, salt, sLen))))
+               goto out;
+
+       int maskedDBLen = emLen - mHash->size - 1;
+       gnutls_hash_output(hashctx, emBuf + maskedDBLen);
+
+       emBuf[emLen - 1] = 0xbc;
+
+       /* Now the MGF1 function as definsed in RFC3447 Appendix B, although
+        * it's somewhat easier to read in NIST SP 800-56B §7.2.2.2.
+        *
+        * We repeatedly hash (M' || C) where C is an incrementing 32-bit
+        * counter, so hash M' first and then use gnutls_hash_copy() each
+        * time to add C to the copy. */
+       err = gnutls_hash(hashctx, emBuf + maskedDBLen, mHash->size);
+       if (err)
+               goto out;
+
+       int mgflen = 0, mgf_count = 0;
+       while (mgflen < maskedDBLen) {
+               gnutls_hash_hd_t ctx2 = gnutls_hash_copy(hashctx);
+               if (!ctx2) {
+                       err = GNUTLS_E_PK_SIGN_FAILED;
+                       goto out;
+               }
+               uint32_t be_count = htonl(mgf_count++);
+               err = gnutls_hash(ctx2, &be_count, sizeof(be_count));
+               if (err) {
+                       gnutls_hash_deinit(ctx2, NULL);
+                       goto out;
+               }
+               if (mgflen + mHash->size <= maskedDBLen) {
+                       gnutls_hash_deinit(ctx2, emBuf + mgflen);
+                       mgflen += mHash->size;
+               } else {
+                       char md[SHA512_SIZE];
+                       gnutls_hash_deinit(ctx2, md);
+                       memcpy(emBuf + mgflen, md, maskedDBLen - mgflen);
+                       mgflen = maskedDBLen;
+               }
+       }
+
+       /* Back to EMSA-PSS-ENCODE step 10. The MGF result was directly placed
+        * into emBuf, so now XOR with DB, which is (zeroes || 0x01 || salt) */
+       int dst = maskedDBLen - 1;
+       while (sLen--)
+               emBuf[dst--] ^= salt[sLen];
+       emBuf[dst] ^= 0x01;
+
+       /* Now mask out the high bits. In the case where msbits is zero, we
+        * skipped the entire first byte so do nothing. */
+       if (msbits)
+               emBuf[0] &= 0xFF >> (8 - msbits);
+
+       err = 0;
+ out:
+       if (hashctx)
+               gnutls_hash_deinit(hashctx, NULL);
+
+       return err;
+}
+#endif
+
+int oc_pad_rsasig(struct openconnect_info *vpninfo, gnutls_sign_algorithm_t algo,
+                 unsigned char *buf, int size, const gnutls_datum_t *data, int keybits)
+{
+       switch(algo) {
+       case GNUTLS_SIGN_UNKNOWN:
+       case GNUTLS_SIGN_RSA_SHA1:
+       case GNUTLS_SIGN_RSA_SHA256:
+       case GNUTLS_SIGN_RSA_SHA384:
+       case GNUTLS_SIGN_RSA_SHA512:
+               return oc_pkcs1_pad(vpninfo, buf, size, data);
+
+#if GNUTLS_VERSION_NUMBER >= 0x030600
+               /* Really PKCS#1.5 padding, yes. */
+       case GNUTLS_SIGN_RSA_RAW:
+               return oc_pkcs1_pad(vpninfo, buf, size, data);
+
+       case GNUTLS_SIGN_RSA_PSS_SHA256:
+       case GNUTLS_SIGN_RSA_PSS_RSAE_SHA256:
+               if (data->size != SHA256_SIZE)
+                       return GNUTLS_E_PK_SIGN_FAILED;
+               return oc_pss_mgf1_pad(vpninfo, GNUTLS_DIG_SHA256, buf, size, data, keybits);
+
+       case GNUTLS_SIGN_RSA_PSS_SHA384:
+       case GNUTLS_SIGN_RSA_PSS_RSAE_SHA384:
+               if (data->size != SHA384_SIZE)
+                       return GNUTLS_E_PK_SIGN_FAILED;
+               return oc_pss_mgf1_pad(vpninfo, GNUTLS_DIG_SHA384, buf, size, data, keybits);
+
+       case GNUTLS_SIGN_RSA_PSS_SHA512:
+       case GNUTLS_SIGN_RSA_PSS_RSAE_SHA512:
+               if (data->size != SHA512_SIZE)
+                       return GNUTLS_E_PK_SIGN_FAILED;
+               return oc_pss_mgf1_pad(vpninfo, GNUTLS_DIG_SHA512, buf, size, data, keybits);
+#endif /* 3.6.0+ */
+       default:
+               vpn_progress(vpninfo, PRG_ERR,
+                            _("TPMv2 RSA sign called for unknown algorithm %s\n"),
+                            gnutls_sign_get_name(algo));
+               return GNUTLS_E_PK_SIGN_FAILED;
+       }
+}
 #endif /* HAVE_TSS2 */
index 26e9b05d9bc08ba7c7e1b5b7fa3877272c575630..8fa71f641e093dbf7b7814fe8359e10f07763366 100644 (file)
@@ -409,12 +409,13 @@ int tpm2_rsa_sign_hash_fn(gnutls_privkey_t key, gnutls_sign_algorithm_t algo,
        TSS2_RC r;
 
        vpn_progress(vpninfo, PRG_DEBUG,
-                    _("TPM2 RSA sign function called for %d bytes.\n"),
-                    data->size);
+                    _("TPM2 RSA sign function called for %d bytes, algo %s\n"),
+                    data->size, gnutls_sign_get_name(algo));
 
        digest.size = certinfo->tpm2->pub.publicArea.unique.rsa.size;
 
-       if (oc_pkcs1_pad(vpninfo, digest.buffer, digest.size, data))
+       if (oc_pad_rsasig(vpninfo, algo, digest.buffer, digest.size, data,
+                         certinfo->tpm2->pub.publicArea.parameters.rsaDetail.keyBits))
                return GNUTLS_E_PK_SIGN_FAILED;
 
        if (init_tpm2_key(&ectx, &key_handle, vpninfo, certinfo))
@@ -608,6 +609,11 @@ uint16_t tpm2_key_curve(struct openconnect_info *vpninfo, struct cert_info *cert
        return certinfo->tpm2->pub.publicArea.parameters.eccDetail.curveID;
 }
 
+int tpm2_rsa_key_bits(struct openconnect_info *vpninfo, struct cert_info *certinfo)
+{
+       return certinfo->tpm2->pub.publicArea.parameters.rsaDetail.keyBits;
+}
+
 void release_tpm2_ctx(struct openconnect_info *vpninfo, struct cert_info *certinfo)
 {
        if (certinfo->tpm2) {
index 16365dff734f9b853cbe2c0a34e90bac7798daa6..7ff6670c9a66bb0aea1cf2611e075154f7fe859d 100644 (file)
@@ -344,7 +344,8 @@ int tpm2_rsa_sign_hash_fn(gnutls_privkey_t key, gnutls_sign_algorithm_t algo,
 
        in.cipherText.t.size = certinfo->tpm2->pub.publicArea.unique.rsa.t.size;
 
-       if (oc_pkcs1_pad(vpninfo, in.cipherText.t.buffer, in.cipherText.t.size, data))
+       if (oc_pad_rsasig(vpninfo, algo, in.cipherText.t.buffer, in.cipherText.t.size, data,
+                         certinfo->tpm2->pub.publicArea.parameters.rsaDetail.keyBits))
                return GNUTLS_E_PK_SIGN_FAILED;
 
        in.inScheme.scheme = TPM_ALG_NULL;
@@ -558,6 +559,11 @@ uint16_t tpm2_key_curve(struct openconnect_info *vpninfo, struct cert_info *cert
        return certinfo->tpm2->pub.publicArea.parameters.eccDetail.curveID;
 }
 
+int tpm2_rsa_key_bits(struct openconnect_info *vpninfo, struct cert_info *certinfo)
+{
+       return certinfo->tpm2->pub.publicArea.parameters.rsaDetail.keyBits;
+}
+
 void release_tpm2_ctx(struct openconnect_info *vpninfo, struct cert_info *certinfo)
 {
        if (certinfo->tpm2) {