]> www.infradead.org Git - users/dwmw2/openconnect.git/commitdiff
Add DTLS support to ssl_nonblock_read() / ssl_nonblock_write()
authorDavid Woodhouse <dwmw2@infradead.org>
Mon, 12 Apr 2021 11:02:36 +0000 (12:02 +0100)
committerDavid Woodhouse <dwmw2@infradead.org>
Fri, 16 Apr 2021 15:06:36 +0000 (16:06 +0100)
We previously lived with #ifdef hacks for OpenSSL vs. GnuTLS in dtls.c
because there were only a few call sites and only one each for send/recv
where the -EAGAIN handling actually matters.

Now that we're going to want to send/receive DTLS frames for PPP too,
clean it up a bit.

While we're at it, fix up the OpenSSL version of ssl_nonblock_read()
so that it only ignores SSL_ERROR_WANT_READ and SSL_ERROR_WANT_WRITE.
Previously it ignored *every* error except SSL_ERROR_ZERO_RETURN and
SSL_ERROR_SYSCALL.

Signed-off-by: David Woodhouse <dwmw2@infradead.org>
cstp.c
dtls.c
gnutls.c
gpst.c
oncp.c
openconnect-internal.h
openssl.c
ppp.c
pulse.c

diff --git a/cstp.c b/cstp.c
index 59841be657304dff2207037fe07fd9c7086ffd39..6c1f9373047c9d0b34d8e366e443ff258020ba2a 100644 (file)
--- a/cstp.c
+++ b/cstp.c
@@ -972,7 +972,7 @@ int cstp_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable)
                        }
                }
 
-               len = ssl_nonblock_read(vpninfo, vpninfo->cstp_pkt->cstp.hdr, receive_mtu + 8);
+               len = ssl_nonblock_read(vpninfo, 0, vpninfo->cstp_pkt->cstp.hdr, receive_mtu + 8);
                if (!len)
                        break;
                if (len < 0)
@@ -1083,7 +1083,7 @@ int cstp_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable)
                vpninfo->ssl_times.last_tx = time(NULL);
                unmonitor_write_fd(vpninfo, ssl);
 
-               ret = ssl_nonblock_write(vpninfo,
+               ret = ssl_nonblock_write(vpninfo, 0,
                                         vpninfo->current_ssl_pkt->cstp.hdr,
                                         vpninfo->current_ssl_pkt->len + 8);
                if (ret < 0)
@@ -1264,7 +1264,7 @@ int cstp_bye(struct openconnect_info *vpninfo, const char *reason)
        vpn_progress(vpninfo, PRG_INFO,
                     _("Send BYE packet: %s\n"), reason);
 
-       ret = ssl_nonblock_write(vpninfo, bye_pkt, reason_len + 9);
+       ret = ssl_nonblock_write(vpninfo, 0, bye_pkt, reason_len + 9);
        if (ret == reason_len + 9) {
                ret = 0;
        } else if (ret >= 0) {
diff --git a/dtls.c b/dtls.c
index 670e558b3014fd61f83bc0fe6f78f54de374c0b4..d28f0078400933c847675d0e140248a2a1dca83f 100644 (file)
--- a/dtls.c
+++ b/dtls.c
  * their clients use anyway.
  */
 
-#if defined(OPENCONNECT_OPENSSL)
-#define DTLS_SEND SSL_write
-#define DTLS_RECV SSL_read
-#elif defined(OPENCONNECT_GNUTLS)
-#define DTLS_SEND gnutls_record_send
-#define DTLS_RECV gnutls_record_recv
-#endif
-
 char *openconnect_bin2hex(const char *prefix, const uint8_t *data, unsigned len)
 {
        struct oc_text_buf *buf;
@@ -283,7 +275,7 @@ int dtls_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable)
                }
 
                buf = vpninfo->dtls_pkt->data - 1;
-               len = DTLS_RECV(vpninfo->dtls_ssl, buf, len + 1);
+               len = ssl_nonblock_read(vpninfo, 1, buf, len + 1);
                if (len <= 0)
                        break;
 
@@ -306,7 +298,7 @@ int dtls_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable)
 
                        /* FIXME: What if the packet doesn't get through? */
                        magic_pkt = AC_PKT_DPD_RESP;
-                       if (DTLS_SEND(vpninfo->dtls_ssl, &magic_pkt, 1) != 1)
+                       if (ssl_nonblock_write(vpninfo, 1, &magic_pkt, 1) != 1)
                                vpn_progress(vpninfo, PRG_ERR,
                                             _("Failed to send DPD response. Expect disconnect\n"));
                        continue;
@@ -377,7 +369,7 @@ int dtls_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable)
                vpn_progress(vpninfo, PRG_DEBUG, _("Send DTLS DPD\n"));
 
                magic_pkt = AC_PKT_DPD_OUT;
-               if (DTLS_SEND(vpninfo->dtls_ssl, &magic_pkt, 1) != 1)
+               if (ssl_nonblock_write(vpninfo, 1, &magic_pkt, 1) != 1)
                        vpn_progress(vpninfo, PRG_ERR,
                                     _("Failed to send DPD request. Expect disconnect\n"));
 
@@ -395,7 +387,7 @@ int dtls_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable)
                vpn_progress(vpninfo, PRG_DEBUG, _("Send DTLS Keepalive\n"));
 
                magic_pkt = AC_PKT_KEEPALIVE;
-               if (DTLS_SEND(vpninfo->dtls_ssl, &magic_pkt, 1) != 1)
+               if (ssl_nonblock_write(vpninfo, 1, &magic_pkt, 1) != 1)
                        vpn_progress(vpninfo, PRG_ERR,
                                     _("Failed to send keepalive request. Expect disconnect\n"));
                time(&vpninfo->dtls_times.last_tx);
@@ -431,45 +423,19 @@ int dtls_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable)
                                send_pkt->cstp.hdr[7] = AC_PKT_COMPRESSED;
                }
 
-#ifdef OPENCONNECT_OPENSSL
-               ret = SSL_write(vpninfo->dtls_ssl, &send_pkt->cstp.hdr[7], send_pkt->len + 1);
+               ret = ssl_nonblock_write(vpninfo, 1, &send_pkt->cstp.hdr[7], send_pkt->len + 1);
                if (ret <= 0) {
-                       ret = SSL_get_error(vpninfo->dtls_ssl, ret);
-
-                       if (ret == SSL_ERROR_WANT_WRITE) {
-                               monitor_write_fd(vpninfo, dtls);
-                               requeue_packet(&vpninfo->outgoing_queue, this);
-                       } else if (ret != SSL_ERROR_WANT_READ) {
-                               /* If it's a real error, kill the DTLS connection and
-                                  requeue the packet to be sent over SSL */
-                               vpn_progress(vpninfo, PRG_ERR,
-                                            _("DTLS got write error %d. Falling back to SSL\n"),
-                                            ret);
-                               openconnect_report_ssl_errors(vpninfo);
-                               dtls_reconnect(vpninfo);
-                               requeue_packet(&vpninfo->outgoing_queue, this);
-                               work_done = 1;
-                       }
-                       return work_done;
-               }
-#else /* GnuTLS */
-               ret = gnutls_record_send(vpninfo->dtls_ssl, &send_pkt->cstp.hdr[7], send_pkt->len + 1);
-               if (ret <= 0) {
-                       if (ret != GNUTLS_E_AGAIN && ret != GNUTLS_E_INTERRUPTED) {
-                               vpn_progress(vpninfo, PRG_ERR,
-                                            _("DTLS got write error: %s. Falling back to SSL\n"),
-                                            gnutls_strerror(ret));
+                       /* Zero is -EAGAIN; just requeue. dtls_nonblock_write()
+                        * will have added the socket to the poll wfd list. */
+                       requeue_packet(&vpninfo->outgoing_queue, this);
+                       if (ret < 0) {
+                               /* If it's a real error, kill the DTLS connection so
+                                  the requeued packet will be sent over SSL */
                                dtls_reconnect(vpninfo);
                                work_done = 1;
-                       } else {
-                               /* Wake me up when it becomes writeable */
-                               monitor_write_fd(vpninfo, dtls);
                        }
-
-                       requeue_packet(&vpninfo->outgoing_queue, this);
                        return work_done;
                }
-#endif
                time(&vpninfo->dtls_times.last_tx);
                vpn_progress(vpninfo, PRG_TRACE,
                             _("Sent DTLS packet of %d bytes; DTLS send returned %d\n"),
index 61e5efcd03e0568f308f35f751d35576fb24fa91..a124763bb9e2bda4123d0fbdf85bc2334b36a051 100644 (file)
--- a/gnutls.c
+++ b/gnutls.c
@@ -296,28 +296,43 @@ static int openconnect_gnutls_gets(struct openconnect_info *vpninfo, char *buf,
        return i ?: ret;
 }
 
-int ssl_nonblock_read(struct openconnect_info *vpninfo, void *buf, int maxlen)
+int ssl_nonblock_read(struct openconnect_info *vpninfo, int dtls, void *buf, int maxlen)
 {
+       gnutls_session_t sess = dtls ? vpninfo->dtls_ssl : vpninfo->https_sess;
        int ret;
 
-       ret = gnutls_record_recv(vpninfo->https_sess, buf, maxlen);
+       if (!sess) {
+               vpn_progress(vpninfo, PRG_ERR,
+                            _("Attempted to read from non-existent %s session"),
+                            dtls ? "DTLS" : "TLS");
+               return -1;
+       }
+
+       ret = gnutls_record_recv(sess, buf, maxlen);
        if (ret > 0)
                return ret;
 
-       if (ret != GNUTLS_E_AGAIN && ret != GNUTLS_E_INTERRUPTED) {
-               vpn_progress(vpninfo, PRG_ERR,
-                            _("SSL read error: %s; reconnecting.\n"),
-                            gnutls_strerror(ret));
-               return -EIO;
-       }
-       return 0;
+       if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
+               return 0;
+
+       vpn_progress(vpninfo, PRG_ERR, _("Read error on %s session: %s\n"),
+                    dtls ? "DTLS" : "SSL", gnutls_strerror(ret));
+       return -1;
 }
 
-int ssl_nonblock_write(struct openconnect_info *vpninfo, void *buf, int buflen)
+int ssl_nonblock_write(struct openconnect_info *vpninfo, int dtls, void *buf, int buflen)
 {
+       gnutls_session_t sess = dtls ? vpninfo->dtls_ssl : vpninfo->https_sess;
        int ret;
 
-       ret = gnutls_record_send(vpninfo->https_sess, buf, buflen);
+       if (!sess) {
+               vpn_progress(vpninfo, PRG_ERR,
+                            _("Attempted to write to non-existent %s session"),
+                            dtls ? "DTLS" : "TLS");
+               return -1;
+       }
+
+       ret = gnutls_record_send(sess, buf, buflen);
        if (ret > 0)
                return ret;
 
@@ -335,15 +350,19 @@ int ssl_nonblock_write(struct openconnect_info *vpninfo, void *buf, int buflen)
                 * that it's waiting for does arrive.
                 */
                if (GNUTLS_VERSION_NUMBER < 0x03030d ||
-                   gnutls_record_get_direction(vpninfo->https_sess)) {
+                   gnutls_record_get_direction(sess)) {
                        /* Waiting for the socket to become writable — it's
                           probably stalled, and/or the buffers are full */
-                       monitor_write_fd(vpninfo, ssl);
+                       if (dtls)
+                               monitor_write_fd(vpninfo, dtls);
+                       else
+                               monitor_write_fd(vpninfo, ssl);
                }
                return 0;
        }
-       vpn_progress(vpninfo, PRG_ERR, _("SSL send failed: %s\n"),
-                    gnutls_strerror(ret));
+
+       vpn_progress(vpninfo, PRG_ERR, _("Write error on %s session: %s\n"),
+                    dtls ? "DTLS" : "SSL", gnutls_strerror(ret));
        return -1;
 }
 
diff --git a/gpst.c b/gpst.c
index 08e9a0d5a7e90feebfeffc925583c5c650a0cf65..6b61ff76342c8f78c82ab1c1c7686c74bc99ad02 100644 (file)
--- a/gpst.c
+++ b/gpst.c
@@ -1107,7 +1107,7 @@ int gpst_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable)
                        }
                }
 
-               len = ssl_nonblock_read(vpninfo, vpninfo->cstp_pkt->gpst.hdr, receive_mtu + 16);
+               len = ssl_nonblock_read(vpninfo, 0, vpninfo->cstp_pkt->gpst.hdr, receive_mtu + 16);
                if (!len)
                        break;
                if (len < 0) {
@@ -1186,7 +1186,7 @@ int gpst_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable)
                vpninfo->ssl_times.last_tx = time(NULL);
                unmonitor_write_fd(vpninfo, ssl);
 
-               ret = ssl_nonblock_write(vpninfo,
+               ret = ssl_nonblock_write(vpninfo, 0,
                                         vpninfo->current_ssl_pkt->gpst.hdr,
                                         vpninfo->current_ssl_pkt->len + 16);
                if (ret < 0)
diff --git a/oncp.c b/oncp.c
index 759327d2c49cce22340e00bfbeb63d81b83993ee..75f2be182219755e387a0f9183b5fbb14825b943 100644 (file)
--- a/oncp.c
+++ b/oncp.c
@@ -771,7 +771,7 @@ static int oncp_record_read(struct openconnect_info *vpninfo, void *buf, int len
        if (!vpninfo->oncp_rec_size) {
                unsigned char lenbuf[2];
 
-               ret = ssl_nonblock_read(vpninfo, lenbuf, 2);
+               ret = ssl_nonblock_read(vpninfo, 0, lenbuf, 2);
                if (ret <= 0)
                        return ret;
                if (ret == 1) {
@@ -783,7 +783,7 @@ static int oncp_record_read(struct openconnect_info *vpninfo, void *buf, int len
                }
                vpninfo->oncp_rec_size = load_le16(lenbuf);
                if (!vpninfo->oncp_rec_size) {
-                       ret = ssl_nonblock_read(vpninfo, lenbuf, 1);
+                       ret = ssl_nonblock_read(vpninfo, 0, lenbuf, 1);
                        if (ret == 1) {
                                if (lenbuf[0] == 1) {
                                        vpn_progress(vpninfo, PRG_ERR,
@@ -806,7 +806,7 @@ static int oncp_record_read(struct openconnect_info *vpninfo, void *buf, int len
        if (len > vpninfo->oncp_rec_size)
                len = vpninfo->oncp_rec_size;
 
-       ret = ssl_nonblock_read(vpninfo, buf, len);
+       ret = ssl_nonblock_read(vpninfo, 0, buf, len);
        if (ret > 0)
                vpninfo->oncp_rec_size -= ret;
        return ret;
@@ -1013,7 +1013,7 @@ int oncp_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable)
                             vpninfo->current_ssl_pkt->oncp.rec,
                             vpninfo->current_ssl_pkt->len + 22);
 
-               ret = ssl_nonblock_write(vpninfo,
+               ret = ssl_nonblock_write(vpninfo, 0,
                                         vpninfo->current_ssl_pkt->oncp.rec,
                                         vpninfo->current_ssl_pkt->len + 22);
                if (ret < 0) {
index 3ee725f2d0008a99a2a27178cfcc22d78f1eeec0..0c87b738fe2e0ebfc9f52859c0b8d7b059601c97 100644 (file)
@@ -1123,8 +1123,8 @@ int encrypt_esp_packet(struct openconnect_info *vpninfo, struct pkt *pkt, int cr
 /* {gnutls,openssl}.c */
 const char *openconnect_get_tls_library_version();
 int can_enable_insecure_crypto();
-int ssl_nonblock_read(struct openconnect_info *vpninfo, void *buf, int maxlen);
-int ssl_nonblock_write(struct openconnect_info *vpninfo, void *buf, int buflen);
+int ssl_nonblock_read(struct openconnect_info *vpninfo, int dtls, void *buf, int maxlen);
+int ssl_nonblock_write(struct openconnect_info *vpninfo, int dtls, void *buf, int buflen);
 int openconnect_open_https(struct openconnect_info *vpninfo);
 void openconnect_close_https(struct openconnect_info *vpninfo, int final);
 int cstp_handshake(struct openconnect_info *vpninfo, unsigned init);
index d760172154f681174a2ba8b56f3d9fd14b56bcc9..1bd7451a3e25f3086c80b60a78b6c79395f5fbcd 100644 (file)
--- a/openssl.c
+++ b/openssl.c
@@ -311,43 +311,63 @@ static int openconnect_openssl_gets(struct openconnect_info *vpninfo, char *buf,
        return i ?: ret;
 }
 
-int ssl_nonblock_read(struct openconnect_info *vpninfo, void *buf, int maxlen)
+int ssl_nonblock_read(struct openconnect_info *vpninfo, int dtls, void *buf, int maxlen)
 {
+       SSL *ssl = dtls ? vpninfo->dtls_ssl : vpninfo->https_ssl;
        int len, ret;
 
-       len = SSL_read(vpninfo->https_ssl, buf, maxlen);
+       if (!ssl) {
+               vpn_progress(vpninfo, PRG_ERR,
+                            _("Attempted to read from non-existent %s session"),
+                            dtls ? "DTLS" : "TLS");
+               return -1;
+       }
+
+       len = SSL_read(ssl, buf, maxlen);
        if (len > 0)
                return len;
 
-       ret = SSL_get_error(vpninfo->https_ssl, len);
-       if (ret == SSL_ERROR_SYSCALL || ret == SSL_ERROR_ZERO_RETURN) {
-               vpn_progress(vpninfo, PRG_ERR,
-                            _("SSL read error %d (server probably closed connection); reconnecting.\n"),
-                            ret);
-               return -EIO;
-       }
-       return 0;
+       ret = SSL_get_error(ssl, len);
+       if (ret == SSL_ERROR_WANT_WRITE || ret == SSL_ERROR_WANT_READ)
+               return 0;
+
+       vpn_progress(vpninfo, PRG_ERR, _("Read error on %s session: %d\n"),
+                      dtls ? "DTLS" : "TLS", ret);
+       return -EIO;
 }
 
-int ssl_nonblock_write(struct openconnect_info *vpninfo, void *buf, int buflen)
+int ssl_nonblock_write(struct openconnect_info *vpninfo, int dtls, void *buf, int buflen)
 {
+       SSL *ssl = dtls ? vpninfo->dtls_ssl : vpninfo->https_ssl;
        int ret;
 
-       ret = SSL_write(vpninfo->https_ssl, buf, buflen);
+       if (!ssl) {
+               vpn_progress(vpninfo, PRG_ERR,
+                            _("Attempted to write to non-existent %s session"),
+                            dtls ? "DTLS" : "TLS");
+               return -1;
+       }
+
+       ret = SSL_write(ssl, buf, buflen);
        if (ret > 0)
                return ret;
 
-       ret = SSL_get_error(vpninfo->https_ssl, ret);
+       ret = SSL_get_error(ssl, ret);
        switch (ret) {
        case SSL_ERROR_WANT_WRITE:
                /* Waiting for the socket to become writable -- it's
                   probably stalled, and/or the buffers are full */
-               monitor_write_fd(vpninfo, ssl);
+               if (dtls)
+                       monitor_write_fd(vpninfo, dtls);
+               else
+                       monitor_write_fd(vpninfo, ssl);
+               /* Fall through */
        case SSL_ERROR_WANT_READ:
                return 0;
 
        default:
-               vpn_progress(vpninfo, PRG_ERR, _("SSL_write failed: %d\n"), ret);
+               vpn_progress(vpninfo, PRG_ERR, _("Write error on %s session: %d\n"),
+                            dtls ? "DTLS" : "TLS", ret);
                openconnect_report_ssl_errors(vpninfo);
                return -1;
        }
diff --git a/ppp.c b/ppp.c
index c116e68f38c4fea84be4e578c3b83345b2bd8770..ba1b9b0e036c0de6d5c223eea98b8aeeecca2e8d 100644 (file)
--- a/ppp.c
+++ b/ppp.c
@@ -1021,7 +1021,7 @@ int ppp_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable)
 
                /* Load the encap header to end up with the payload where we expect it */
                eh = this->data - rsv_hdr_size;
-               len = ssl_nonblock_read(vpninfo, eh, receive_mtu + rsv_hdr_size);
+               len = ssl_nonblock_read(vpninfo, 0, eh, receive_mtu + rsv_hdr_size);
                if (!len)
                        break;
                if (len < 0)
@@ -1246,7 +1246,7 @@ int ppp_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable)
                vpninfo->ssl_times.last_tx = time(NULL);
                unmonitor_write_fd(vpninfo, ssl);
 
-               ret = ssl_nonblock_write(vpninfo, this->data - this->ppp.hlen, this->len + this->ppp.hlen);
+               ret = ssl_nonblock_write(vpninfo, 0, this->data - this->ppp.hlen, this->len + this->ppp.hlen);
                if (ret < 0)
                        goto do_reconnect;
                else if (!ret) {
diff --git a/pulse.c b/pulse.c
index 4181896cf86abd34c55c2d3d6ee2b10ba451afe2..82d4fd15a50e0626395d533e53760c3367d2f2f3 100644 (file)
--- a/pulse.c
+++ b/pulse.c
@@ -2603,7 +2603,7 @@ int pulse_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable)
                }
 
                /* Receive packet header, if there's anything there... */
-               len = ssl_nonblock_read(vpninfo, &pkt->pulse.vendor, 16);
+               len = ssl_nonblock_read(vpninfo, 0, &pkt->pulse.vendor, 16);
                if (!len)
                        break;
                if (len < 0)
@@ -2624,7 +2624,7 @@ int pulse_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable)
                } else
                        len = load_be32(&pkt->pulse.len) - 0x10;
 
-               payload_len = ssl_nonblock_read(vpninfo, &pkt->data, len);
+               payload_len = ssl_nonblock_read(vpninfo, 0, &pkt->data, len);
                if (payload_len != load_be32(&pkt->pulse.len) - 0x10) {
                        if (payload_len < 0)
                                len = 0x10;
@@ -2709,7 +2709,7 @@ int pulse_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable)
                             (void *)&vpninfo->current_ssl_pkt->pulse.vendor,
                             vpninfo->current_ssl_pkt->len + 16);
 
-               ret = ssl_nonblock_write(vpninfo,
+               ret = ssl_nonblock_write(vpninfo, 0,
                                         &vpninfo->current_ssl_pkt->pulse.vendor,
                                         vpninfo->current_ssl_pkt->len + 16);
                if (ret < 0) {