From: David Woodhouse Date: Mon, 12 Apr 2021 11:02:36 +0000 (+0100) Subject: Add DTLS support to ssl_nonblock_read() / ssl_nonblock_write() X-Git-Tag: v8.20~283 X-Git-Url: https://www.infradead.org/git/?a=commitdiff_plain;h=32dd02ea88256feb1f9ff60ead34b478d48c44f8;p=users%2Fdwmw2%2Fopenconnect.git Add DTLS support to ssl_nonblock_read() / ssl_nonblock_write() We previously lived with #ifdef hacks for OpenSSL vs. GnuTLS in dtls.c because there were only a few call sites and only one each for send/recv where the -EAGAIN handling actually matters. Now that we're going to want to send/receive DTLS frames for PPP too, clean it up a bit. While we're at it, fix up the OpenSSL version of ssl_nonblock_read() so that it only ignores SSL_ERROR_WANT_READ and SSL_ERROR_WANT_WRITE. Previously it ignored *every* error except SSL_ERROR_ZERO_RETURN and SSL_ERROR_SYSCALL. Signed-off-by: David Woodhouse --- diff --git a/cstp.c b/cstp.c index 59841be6..6c1f9373 100644 --- a/cstp.c +++ b/cstp.c @@ -972,7 +972,7 @@ int cstp_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable) } } - len = ssl_nonblock_read(vpninfo, vpninfo->cstp_pkt->cstp.hdr, receive_mtu + 8); + len = ssl_nonblock_read(vpninfo, 0, vpninfo->cstp_pkt->cstp.hdr, receive_mtu + 8); if (!len) break; if (len < 0) @@ -1083,7 +1083,7 @@ int cstp_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable) vpninfo->ssl_times.last_tx = time(NULL); unmonitor_write_fd(vpninfo, ssl); - ret = ssl_nonblock_write(vpninfo, + ret = ssl_nonblock_write(vpninfo, 0, vpninfo->current_ssl_pkt->cstp.hdr, vpninfo->current_ssl_pkt->len + 8); if (ret < 0) @@ -1264,7 +1264,7 @@ int cstp_bye(struct openconnect_info *vpninfo, const char *reason) vpn_progress(vpninfo, PRG_INFO, _("Send BYE packet: %s\n"), reason); - ret = ssl_nonblock_write(vpninfo, bye_pkt, reason_len + 9); + ret = ssl_nonblock_write(vpninfo, 0, bye_pkt, reason_len + 9); if (ret == reason_len + 9) { ret = 0; } else if (ret >= 0) { diff --git a/dtls.c b/dtls.c index 670e558b..d28f0078 100644 --- a/dtls.c +++ b/dtls.c @@ -58,14 +58,6 @@ * their clients use anyway. */ -#if defined(OPENCONNECT_OPENSSL) -#define DTLS_SEND SSL_write -#define DTLS_RECV SSL_read -#elif defined(OPENCONNECT_GNUTLS) -#define DTLS_SEND gnutls_record_send -#define DTLS_RECV gnutls_record_recv -#endif - char *openconnect_bin2hex(const char *prefix, const uint8_t *data, unsigned len) { struct oc_text_buf *buf; @@ -283,7 +275,7 @@ int dtls_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable) } buf = vpninfo->dtls_pkt->data - 1; - len = DTLS_RECV(vpninfo->dtls_ssl, buf, len + 1); + len = ssl_nonblock_read(vpninfo, 1, buf, len + 1); if (len <= 0) break; @@ -306,7 +298,7 @@ int dtls_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable) /* FIXME: What if the packet doesn't get through? */ magic_pkt = AC_PKT_DPD_RESP; - if (DTLS_SEND(vpninfo->dtls_ssl, &magic_pkt, 1) != 1) + if (ssl_nonblock_write(vpninfo, 1, &magic_pkt, 1) != 1) vpn_progress(vpninfo, PRG_ERR, _("Failed to send DPD response. Expect disconnect\n")); continue; @@ -377,7 +369,7 @@ int dtls_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable) vpn_progress(vpninfo, PRG_DEBUG, _("Send DTLS DPD\n")); magic_pkt = AC_PKT_DPD_OUT; - if (DTLS_SEND(vpninfo->dtls_ssl, &magic_pkt, 1) != 1) + if (ssl_nonblock_write(vpninfo, 1, &magic_pkt, 1) != 1) vpn_progress(vpninfo, PRG_ERR, _("Failed to send DPD request. Expect disconnect\n")); @@ -395,7 +387,7 @@ int dtls_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable) vpn_progress(vpninfo, PRG_DEBUG, _("Send DTLS Keepalive\n")); magic_pkt = AC_PKT_KEEPALIVE; - if (DTLS_SEND(vpninfo->dtls_ssl, &magic_pkt, 1) != 1) + if (ssl_nonblock_write(vpninfo, 1, &magic_pkt, 1) != 1) vpn_progress(vpninfo, PRG_ERR, _("Failed to send keepalive request. Expect disconnect\n")); time(&vpninfo->dtls_times.last_tx); @@ -431,45 +423,19 @@ int dtls_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable) send_pkt->cstp.hdr[7] = AC_PKT_COMPRESSED; } -#ifdef OPENCONNECT_OPENSSL - ret = SSL_write(vpninfo->dtls_ssl, &send_pkt->cstp.hdr[7], send_pkt->len + 1); + ret = ssl_nonblock_write(vpninfo, 1, &send_pkt->cstp.hdr[7], send_pkt->len + 1); if (ret <= 0) { - ret = SSL_get_error(vpninfo->dtls_ssl, ret); - - if (ret == SSL_ERROR_WANT_WRITE) { - monitor_write_fd(vpninfo, dtls); - requeue_packet(&vpninfo->outgoing_queue, this); - } else if (ret != SSL_ERROR_WANT_READ) { - /* If it's a real error, kill the DTLS connection and - requeue the packet to be sent over SSL */ - vpn_progress(vpninfo, PRG_ERR, - _("DTLS got write error %d. Falling back to SSL\n"), - ret); - openconnect_report_ssl_errors(vpninfo); - dtls_reconnect(vpninfo); - requeue_packet(&vpninfo->outgoing_queue, this); - work_done = 1; - } - return work_done; - } -#else /* GnuTLS */ - ret = gnutls_record_send(vpninfo->dtls_ssl, &send_pkt->cstp.hdr[7], send_pkt->len + 1); - if (ret <= 0) { - if (ret != GNUTLS_E_AGAIN && ret != GNUTLS_E_INTERRUPTED) { - vpn_progress(vpninfo, PRG_ERR, - _("DTLS got write error: %s. Falling back to SSL\n"), - gnutls_strerror(ret)); + /* Zero is -EAGAIN; just requeue. dtls_nonblock_write() + * will have added the socket to the poll wfd list. */ + requeue_packet(&vpninfo->outgoing_queue, this); + if (ret < 0) { + /* If it's a real error, kill the DTLS connection so + the requeued packet will be sent over SSL */ dtls_reconnect(vpninfo); work_done = 1; - } else { - /* Wake me up when it becomes writeable */ - monitor_write_fd(vpninfo, dtls); } - - requeue_packet(&vpninfo->outgoing_queue, this); return work_done; } -#endif time(&vpninfo->dtls_times.last_tx); vpn_progress(vpninfo, PRG_TRACE, _("Sent DTLS packet of %d bytes; DTLS send returned %d\n"), diff --git a/gnutls.c b/gnutls.c index 61e5efcd..a124763b 100644 --- a/gnutls.c +++ b/gnutls.c @@ -296,28 +296,43 @@ static int openconnect_gnutls_gets(struct openconnect_info *vpninfo, char *buf, return i ?: ret; } -int ssl_nonblock_read(struct openconnect_info *vpninfo, void *buf, int maxlen) +int ssl_nonblock_read(struct openconnect_info *vpninfo, int dtls, void *buf, int maxlen) { + gnutls_session_t sess = dtls ? vpninfo->dtls_ssl : vpninfo->https_sess; int ret; - ret = gnutls_record_recv(vpninfo->https_sess, buf, maxlen); + if (!sess) { + vpn_progress(vpninfo, PRG_ERR, + _("Attempted to read from non-existent %s session"), + dtls ? "DTLS" : "TLS"); + return -1; + } + + ret = gnutls_record_recv(sess, buf, maxlen); if (ret > 0) return ret; - if (ret != GNUTLS_E_AGAIN && ret != GNUTLS_E_INTERRUPTED) { - vpn_progress(vpninfo, PRG_ERR, - _("SSL read error: %s; reconnecting.\n"), - gnutls_strerror(ret)); - return -EIO; - } - return 0; + if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) + return 0; + + vpn_progress(vpninfo, PRG_ERR, _("Read error on %s session: %s\n"), + dtls ? "DTLS" : "SSL", gnutls_strerror(ret)); + return -1; } -int ssl_nonblock_write(struct openconnect_info *vpninfo, void *buf, int buflen) +int ssl_nonblock_write(struct openconnect_info *vpninfo, int dtls, void *buf, int buflen) { + gnutls_session_t sess = dtls ? vpninfo->dtls_ssl : vpninfo->https_sess; int ret; - ret = gnutls_record_send(vpninfo->https_sess, buf, buflen); + if (!sess) { + vpn_progress(vpninfo, PRG_ERR, + _("Attempted to write to non-existent %s session"), + dtls ? "DTLS" : "TLS"); + return -1; + } + + ret = gnutls_record_send(sess, buf, buflen); if (ret > 0) return ret; @@ -335,15 +350,19 @@ int ssl_nonblock_write(struct openconnect_info *vpninfo, void *buf, int buflen) * that it's waiting for does arrive. */ if (GNUTLS_VERSION_NUMBER < 0x03030d || - gnutls_record_get_direction(vpninfo->https_sess)) { + gnutls_record_get_direction(sess)) { /* Waiting for the socket to become writable — it's probably stalled, and/or the buffers are full */ - monitor_write_fd(vpninfo, ssl); + if (dtls) + monitor_write_fd(vpninfo, dtls); + else + monitor_write_fd(vpninfo, ssl); } return 0; } - vpn_progress(vpninfo, PRG_ERR, _("SSL send failed: %s\n"), - gnutls_strerror(ret)); + + vpn_progress(vpninfo, PRG_ERR, _("Write error on %s session: %s\n"), + dtls ? "DTLS" : "SSL", gnutls_strerror(ret)); return -1; } diff --git a/gpst.c b/gpst.c index 08e9a0d5..6b61ff76 100644 --- a/gpst.c +++ b/gpst.c @@ -1107,7 +1107,7 @@ int gpst_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable) } } - len = ssl_nonblock_read(vpninfo, vpninfo->cstp_pkt->gpst.hdr, receive_mtu + 16); + len = ssl_nonblock_read(vpninfo, 0, vpninfo->cstp_pkt->gpst.hdr, receive_mtu + 16); if (!len) break; if (len < 0) { @@ -1186,7 +1186,7 @@ int gpst_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable) vpninfo->ssl_times.last_tx = time(NULL); unmonitor_write_fd(vpninfo, ssl); - ret = ssl_nonblock_write(vpninfo, + ret = ssl_nonblock_write(vpninfo, 0, vpninfo->current_ssl_pkt->gpst.hdr, vpninfo->current_ssl_pkt->len + 16); if (ret < 0) diff --git a/oncp.c b/oncp.c index 759327d2..75f2be18 100644 --- a/oncp.c +++ b/oncp.c @@ -771,7 +771,7 @@ static int oncp_record_read(struct openconnect_info *vpninfo, void *buf, int len if (!vpninfo->oncp_rec_size) { unsigned char lenbuf[2]; - ret = ssl_nonblock_read(vpninfo, lenbuf, 2); + ret = ssl_nonblock_read(vpninfo, 0, lenbuf, 2); if (ret <= 0) return ret; if (ret == 1) { @@ -783,7 +783,7 @@ static int oncp_record_read(struct openconnect_info *vpninfo, void *buf, int len } vpninfo->oncp_rec_size = load_le16(lenbuf); if (!vpninfo->oncp_rec_size) { - ret = ssl_nonblock_read(vpninfo, lenbuf, 1); + ret = ssl_nonblock_read(vpninfo, 0, lenbuf, 1); if (ret == 1) { if (lenbuf[0] == 1) { vpn_progress(vpninfo, PRG_ERR, @@ -806,7 +806,7 @@ static int oncp_record_read(struct openconnect_info *vpninfo, void *buf, int len if (len > vpninfo->oncp_rec_size) len = vpninfo->oncp_rec_size; - ret = ssl_nonblock_read(vpninfo, buf, len); + ret = ssl_nonblock_read(vpninfo, 0, buf, len); if (ret > 0) vpninfo->oncp_rec_size -= ret; return ret; @@ -1013,7 +1013,7 @@ int oncp_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable) vpninfo->current_ssl_pkt->oncp.rec, vpninfo->current_ssl_pkt->len + 22); - ret = ssl_nonblock_write(vpninfo, + ret = ssl_nonblock_write(vpninfo, 0, vpninfo->current_ssl_pkt->oncp.rec, vpninfo->current_ssl_pkt->len + 22); if (ret < 0) { diff --git a/openconnect-internal.h b/openconnect-internal.h index 3ee725f2..0c87b738 100644 --- a/openconnect-internal.h +++ b/openconnect-internal.h @@ -1123,8 +1123,8 @@ int encrypt_esp_packet(struct openconnect_info *vpninfo, struct pkt *pkt, int cr /* {gnutls,openssl}.c */ const char *openconnect_get_tls_library_version(); int can_enable_insecure_crypto(); -int ssl_nonblock_read(struct openconnect_info *vpninfo, void *buf, int maxlen); -int ssl_nonblock_write(struct openconnect_info *vpninfo, void *buf, int buflen); +int ssl_nonblock_read(struct openconnect_info *vpninfo, int dtls, void *buf, int maxlen); +int ssl_nonblock_write(struct openconnect_info *vpninfo, int dtls, void *buf, int buflen); int openconnect_open_https(struct openconnect_info *vpninfo); void openconnect_close_https(struct openconnect_info *vpninfo, int final); int cstp_handshake(struct openconnect_info *vpninfo, unsigned init); diff --git a/openssl.c b/openssl.c index d7601721..1bd7451a 100644 --- a/openssl.c +++ b/openssl.c @@ -311,43 +311,63 @@ static int openconnect_openssl_gets(struct openconnect_info *vpninfo, char *buf, return i ?: ret; } -int ssl_nonblock_read(struct openconnect_info *vpninfo, void *buf, int maxlen) +int ssl_nonblock_read(struct openconnect_info *vpninfo, int dtls, void *buf, int maxlen) { + SSL *ssl = dtls ? vpninfo->dtls_ssl : vpninfo->https_ssl; int len, ret; - len = SSL_read(vpninfo->https_ssl, buf, maxlen); + if (!ssl) { + vpn_progress(vpninfo, PRG_ERR, + _("Attempted to read from non-existent %s session"), + dtls ? "DTLS" : "TLS"); + return -1; + } + + len = SSL_read(ssl, buf, maxlen); if (len > 0) return len; - ret = SSL_get_error(vpninfo->https_ssl, len); - if (ret == SSL_ERROR_SYSCALL || ret == SSL_ERROR_ZERO_RETURN) { - vpn_progress(vpninfo, PRG_ERR, - _("SSL read error %d (server probably closed connection); reconnecting.\n"), - ret); - return -EIO; - } - return 0; + ret = SSL_get_error(ssl, len); + if (ret == SSL_ERROR_WANT_WRITE || ret == SSL_ERROR_WANT_READ) + return 0; + + vpn_progress(vpninfo, PRG_ERR, _("Read error on %s session: %d\n"), + dtls ? "DTLS" : "TLS", ret); + return -EIO; } -int ssl_nonblock_write(struct openconnect_info *vpninfo, void *buf, int buflen) +int ssl_nonblock_write(struct openconnect_info *vpninfo, int dtls, void *buf, int buflen) { + SSL *ssl = dtls ? vpninfo->dtls_ssl : vpninfo->https_ssl; int ret; - ret = SSL_write(vpninfo->https_ssl, buf, buflen); + if (!ssl) { + vpn_progress(vpninfo, PRG_ERR, + _("Attempted to write to non-existent %s session"), + dtls ? "DTLS" : "TLS"); + return -1; + } + + ret = SSL_write(ssl, buf, buflen); if (ret > 0) return ret; - ret = SSL_get_error(vpninfo->https_ssl, ret); + ret = SSL_get_error(ssl, ret); switch (ret) { case SSL_ERROR_WANT_WRITE: /* Waiting for the socket to become writable -- it's probably stalled, and/or the buffers are full */ - monitor_write_fd(vpninfo, ssl); + if (dtls) + monitor_write_fd(vpninfo, dtls); + else + monitor_write_fd(vpninfo, ssl); + /* Fall through */ case SSL_ERROR_WANT_READ: return 0; default: - vpn_progress(vpninfo, PRG_ERR, _("SSL_write failed: %d\n"), ret); + vpn_progress(vpninfo, PRG_ERR, _("Write error on %s session: %d\n"), + dtls ? "DTLS" : "TLS", ret); openconnect_report_ssl_errors(vpninfo); return -1; } diff --git a/ppp.c b/ppp.c index c116e68f..ba1b9b0e 100644 --- a/ppp.c +++ b/ppp.c @@ -1021,7 +1021,7 @@ int ppp_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable) /* Load the encap header to end up with the payload where we expect it */ eh = this->data - rsv_hdr_size; - len = ssl_nonblock_read(vpninfo, eh, receive_mtu + rsv_hdr_size); + len = ssl_nonblock_read(vpninfo, 0, eh, receive_mtu + rsv_hdr_size); if (!len) break; if (len < 0) @@ -1246,7 +1246,7 @@ int ppp_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable) vpninfo->ssl_times.last_tx = time(NULL); unmonitor_write_fd(vpninfo, ssl); - ret = ssl_nonblock_write(vpninfo, this->data - this->ppp.hlen, this->len + this->ppp.hlen); + ret = ssl_nonblock_write(vpninfo, 0, this->data - this->ppp.hlen, this->len + this->ppp.hlen); if (ret < 0) goto do_reconnect; else if (!ret) { diff --git a/pulse.c b/pulse.c index 4181896c..82d4fd15 100644 --- a/pulse.c +++ b/pulse.c @@ -2603,7 +2603,7 @@ int pulse_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable) } /* Receive packet header, if there's anything there... */ - len = ssl_nonblock_read(vpninfo, &pkt->pulse.vendor, 16); + len = ssl_nonblock_read(vpninfo, 0, &pkt->pulse.vendor, 16); if (!len) break; if (len < 0) @@ -2624,7 +2624,7 @@ int pulse_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable) } else len = load_be32(&pkt->pulse.len) - 0x10; - payload_len = ssl_nonblock_read(vpninfo, &pkt->data, len); + payload_len = ssl_nonblock_read(vpninfo, 0, &pkt->data, len); if (payload_len != load_be32(&pkt->pulse.len) - 0x10) { if (payload_len < 0) len = 0x10; @@ -2709,7 +2709,7 @@ int pulse_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable) (void *)&vpninfo->current_ssl_pkt->pulse.vendor, vpninfo->current_ssl_pkt->len + 16); - ret = ssl_nonblock_write(vpninfo, + ret = ssl_nonblock_write(vpninfo, 0, &vpninfo->current_ssl_pkt->pulse.vendor, vpninfo->current_ssl_pkt->len + 16); if (ret < 0) {