]> diplodocus.org Git - nmh/commitdiff
I think I've written all of these functions; now we just need to
authorKen Hornstein <kenh@pobox.com>
Sun, 18 Sep 2016 05:15:20 +0000 (01:15 -0400)
committerKen Hornstein <kenh@pobox.com>
Sun, 18 Sep 2016 05:15:20 +0000 (01:15 -0400)
have everyone else use them.

h/netsec.h
sbr/netsec.c

index f07c09f03be29284fa5700ea26366829839ade52..ad50d9f7464fb9ab38063d900ffc8b540bad758e 100644 (file)
@@ -73,6 +73,17 @@ int netsec_get_snoop(netsec_context *ns_context);
 
 void netsec_set_snoop(netsec_context *ns_context, int snoop);
 
+/*
+ * Set the read timeout for this connection.
+ *
+ * Arguments:
+ *
+ * ns_context  - Network security context
+ * timeout     - Read timeout, in seconds.
+ */
+
+void netsec_set_timeout(netsec_context *ns_context, int timeout);
+
 /*
  * Read a "line" from the network.  This reads one CR/LF terminated line.
  * Returns a pointer to a NUL-terminated string.  This memory is valid
index 6a09d14015f994c1104f301f3360ef38b03fe59a..cb297cc1b9714c61f16026c6c7956aada8b3cde4 100644 (file)
@@ -12,6 +12,7 @@
 #include <h/utils.h>
 #include <h/netsec.h>
 #include <stdarg.h>
+#include <sys/select.h>
 
 #ifdef CYRUS_SASL
 #include <sasl/sasl.h>
@@ -54,6 +55,7 @@ static SSL_CTX *sslctx = NULL;                /* SSL Context */
 struct _netsec_context {
     int ns_fd;                 /* Descriptor for network connection */
     int ns_snoop;              /* If true, display network data */
+    int ns_timeout;            /* Network read timeout, in seconds */
     char *ns_userid;           /* Userid for authentication */
     unsigned char *ns_inbuffer;        /* Our read input buffer */
     unsigned char *ns_inptr;   /* Our read buffer input pointer */
@@ -74,7 +76,8 @@ struct _netsec_context {
     sasl_secret_t *sasl_secret;        /* SASL password structure */
     char *sasl_chosen_mech;    /* Mechanism chosen by SASL */
     int sasl_seclayer;         /* If true, SASL security layer is enabled */
-    size_t sasl_maxoutsize;    /* Negotiated maximum size of output messages */
+    char *sasl_tmpbuf;         /* Temporary read buffer for decodes */
+    size_t sasl_maxbufsize;    /* Maximum negotiated SASL buffer size */
 #endif /* CYRUS_SASL */
 #ifdef TLS_SUPPORT
     BIO *ssl_io;               /* BIO used for connection I/O */
@@ -88,6 +91,12 @@ struct _netsec_context {
 
 static void netsec_err(char **errstr, const char *format, ...);
 
+/*
+ * Function to read data from the actual network socket
+ */
+
+static int netsec_fillread(netsec_context *ns_context, char **errstr);
+
 /*
  * How this code works, in general.
  *
@@ -116,6 +125,7 @@ netsec_init(void)
     nsc->ns_fd = -1;
     nsc->ns_snoop = 0;
     nsc->ns_userid = NULL;
+    nsc->ns_timeout = 60;      /* Our default */
     nsc->ns_inbufsize = NETSEC_BUFSIZE;
     nsc->ns_inbuffer = mh_xmalloc(nsc->ns_inbufsize);
     nsc->ns_inptr = nsc->ns_inbuffer;
@@ -132,7 +142,10 @@ netsec_init(void)
     nsc->sasl_creds = NULL;
     nsc->sasl_secret = NULL;
     nsc->sasl_chosen_mech = NULL;
-    nsc->sasl_maxoutsize = nsc->sasl_ssf = 0;
+    nsc->sasl_ssf = 0;
+    nsc->sasl_seclayer = 0;
+    nsc->sasl_tmpbuf = NULL;
+    nsc->sasl_maxbufsize = 0;
 #endif /* CYRUS_SASL */
 #ifdef TLS_SUPPORT
     nsc->ssl_io = NULL;
@@ -178,6 +191,8 @@ netsec_shutdown(netsec_context *nsc)
     }
     if (nsc->sasl_chosen_mech)
        free(nsc->sasl_chosen_mech);
+    if (nsc->sasl_tmpbuf)
+       free(nsc->sasl_tmpbuf);
 #endif /* CYRUS_SASL */
 #ifdef TLS_SUPPORT
     if (nsc->ssl_io)
@@ -232,9 +247,508 @@ netsec_set_snoop(netsec_context *nsc, int snoop)
 }
 
 /*
- * Write data to our network connection
+ * Set the read timeout for this connection
  */
 
+void
+netsec_set_timeout(netsec_context *nsc, int timeout)
+{
+    nsc->ns_timeout = timeout;
+}
+
+/*
+ * Read data from the network.  Basically, return anything in our buffer,
+ * otherwise fill from the network.
+ */
+
+ssize_t
+netsec_read(netsec_context *nsc, void *buffer, size_t size, char **errstr)
+{
+    int retlen;
+
+    /*
+     * If our buffer is empty, then we should fill it now
+     */
+
+    if (nsc->ns_inbuflen == 0) {
+       if (netsec_fillread(nsc, errstr) != OK)
+           return NOTOK;
+    }
+
+    /*
+     * netsec_fillread only returns if the buffer is full, so we can
+     * assume here that this has something in it.
+     */
+
+    retlen = size > nsc->ns_inbuflen ? nsc->ns_inbuflen : size;
+
+    memcpy(buffer, nsc->ns_inptr, retlen);
+
+    if (retlen == (int) nsc->ns_inbuflen) {
+       /*
+        * We've emptied our buffer, so reset everything.
+        */
+       nsc->ns_inptr = nsc->ns_inbuffer;
+       nsc->ns_inbuflen = 0;
+    } else {
+       nsc->ns_inptr += size;
+       nsc->ns_inbuflen -= size;
+    }
+
+    return OK;
+}
+
+/*
+ * Get a "line" (CR/LF) terminated from the network.
+ *
+ * Okay, we play some games here, so pay attention:
+ *
+ * - Unlike every other function, we return a pointer to the
+ *   existing buffer.  This pointer is valid until you call another
+ *   read functiona again.
+ * - We NUL-terminated the buffer right at the end, before the terminator.
+ * - Technically we look for a LF; if we find a CR right before it, then
+ *   we back up one.
+ * - If your data may contain embedded NULs, this won't work.
+ */
+
+char *
+netsec_readline(netsec_context *nsc, char **errstr)
+{
+    unsigned char *ptr = nsc->ns_inptr;
+    size_t count = 0, offset;
+
+retry:
+    /*
+     * Search through our existing buffer for a LF
+     */
+
+    while (count < nsc->ns_inbuflen) {
+       count++;
+       if (*ptr++ == '\n') {
+           char *sptr = nsc->ns_inptr;
+           if (count > 1 && *(ptr - 2) == '\r')
+               ptr--;
+           *--ptr = '\0';
+           nsc->ns_inptr += count;
+           nsc->ns_inbuflen -= count;
+           return sptr;
+       }
+    }
+
+    /*
+     * Hm, we didn't find a \n.  If we've already searched half of the input
+     * buffer, return an error.
+     */
+
+    if (count >= nsc->ns_inbufsize / 2) {
+       netsec_err(errstr, "Unable to find a line terminator after %d bytes",
+                  count);
+       return NULL;
+    }
+
+    /*
+     * Okay, get some more network data.  This may move inptr, so regenerate
+     * our ptr value;
+     */
+
+    offset = ptr - nsc->ns_inptr;
+
+    if (netsec_fillread(nsc, errstr) != OK)
+       return NOTOK;
+
+    ptr = nsc->ns_inptr + offset;
+
+    goto retry;
+
+    return NULL;       /* Should never reach this */
+}
+
+/*
+ * Fill our read buffer with some data from the network.
+ */
+
+static int
+netsec_fillread(netsec_context *nsc, char **errstr)
+{
+    unsigned char *end;
+    char *readbuf;
+    size_t readbufsize, remaining, startoffset;
+    int rc;
+
+    /*
+     * If inbuflen is zero, that means the buffer has been emptied
+     * completely.  In that case move inptr back to the start.
+     */
+
+    if (nsc->ns_inbuflen == 0) {
+       nsc->ns_inptr = nsc->ns_inbuffer;
+    }
+
+retry:
+    /*
+     * If we are using TLS and there's anything pending, then skip the
+     * select call
+     */
+#ifdef TLS_SUPPORT
+    if (!nsc->tls_active || BIO_pending(nsc->ssl_io) == 0)
+#endif /* TLS_SUPPORT */
+    {
+       struct timeval tv;
+       fd_set rfds;
+
+       FD_ZERO(&rfds);
+       FD_SET(nsc->ns_fd, &rfds);
+
+       tv.tv_sec = nsc->ns_timeout;
+       tv.tv_usec = 0;
+
+       rc = select(nsc->ns_fd + 1, &rfds, NULL, NULL, &tv);
+
+       if (rc == -1) {
+           netsec_err(errstr, "select() while reading failed: %s",
+                      strerror(errno));
+           return NOTOK;
+       }
+
+       if (rc == 0) {
+           netsec_err(errstr, "read() timed out after %d seconds",
+                      nsc->ns_timeout);
+           return NOTOK;
+       }
+
+       /*
+        * At this point, we know that rc is 1, so there's not even any
+        * point to check to see if our descriptor is set in rfds.
+        */
+    }
+
+    startoffset = nsc->ns_inptr - nsc->ns_inbuffer;
+    remaining = nsc->ns_inbufsize - (startoffset + nsc->ns_inbuflen);
+    end = nsc->ns_inptr + nsc->ns_inbuflen;
+
+    /*
+     * If we are using TLS, then just read via the BIO.  But we still
+     * use our local buffer.
+     */
+#ifdef TLS_SUPPORT
+    if (nsc->tls_active) {
+       rc = BIO_read(nsc->ssl_io, end, remaining);
+       if (rc == 0) {
+           /*
+            * Either EOF, or possibly an error.  Either way, it was probably
+            * unexpected, so treat as error.
+            */
+           netsec_err(errstr, "TLS peer aborted connection");
+           return NOTOK;
+       } else if (rc < 0) {
+           /* Definitely an error */
+           netsec_err(errstr, "Read on TLS connection failed: %s",
+                      ERR_error_string(ERR_get_error(), NULL));
+           return NOTOK;
+       }
+
+       nsc->ns_inbuflen += rc;
+
+       return OK;
+    }
+#endif /* TLS_SUPPORT */
+
+    /*
+     * Okay, time to read some data.  Either we're just doing it straight
+     * or we're passing it through sasl_decode() first.
+     */
+
+#ifdef CYRUS_SASL
+    if (nsc->sasl_seclayer) {
+       readbuf = nsc->sasl_tmpbuf;
+       readbufsize = nsc->sasl_maxbufsize;
+    } else
+#endif /* CYRUS_SASL */
+    {
+       readbuf = (char *) end;
+       readbufsize = remaining;
+    }
+
+    /*
+     * At this point, we should have active data on the connection (see
+     * select() above) so this read SHOULDN'T block.  Hopefully.
+     */
+
+    rc = read(nsc->ns_fd, readbuf, readbufsize);
+
+    if (rc == 0) {
+       netsec_err(errstr, "Received EOF on network read");
+       return NOTOK;
+    }
+
+    if (rc < 0) {
+       netsec_err(errstr, "Network read failed: %s", strerror(errno));
+       return NOTOK;
+    }
+
+    /*
+     * Okay, so we've had a successful read.  If we are doing SASL security
+     * layers, pass this through sasl_decode().  sasl_decode() can return
+     * 0 bytes decoded; if that happens, jump back to the beginning.  Otherwise
+     * we can just update our length pointer.
+     */
+
+#ifdef CYRUS_SASL
+    if (nsc->sasl_seclayer) {
+       const char *tmpout;
+       unsigned int tmpoutlen;
+
+       rc = sasl_decode(nsc->sasl_conn, nsc->sasl_tmpbuf, rc,
+                        &tmpout, &tmpoutlen);
+
+       if (rc != SASL_OK) {
+           netsec_err(errstr, "Unable to decode SASL network data: %s",
+                      sasl_errdetail(nsc->sasl_conn));
+           return NOTOK;
+       }
+
+       if (tmpoutlen == 0)
+           goto retry;
+
+       /*
+        * Just in case ...
+        */
+
+       if (tmpoutlen > remaining) {
+           netsec_err(errstr, "Internal error: SASL decode buffer overflow!");
+           return NOTOK;
+       }
+
+       memcpy(end, tmpout, tmpoutlen);
+
+       nsc->ns_inbuflen += tmpoutlen;
+    } else
+#endif /* CYRUS_SASL */
+       nsc->ns_inbuflen += rc;
+
+    /*
+     * If we're past the halfway point in our read buffers, shuffle everything
+     * back to the beginning.
+     */
+
+    if (startoffset > nsc->ns_inbufsize / 2) {
+       memmove(nsc->ns_inbuffer, nsc->ns_inptr, nsc->ns_inbuflen);
+       nsc->ns_inptr = nsc->ns_inbuffer;
+    }
+
+    return OK;
+}
+
+/*
+ * Write data to our network connection.  Really, fill up the buffer as
+ * much as we can, and flush it out if necessary.  netsec_flush() does
+ * the real work.
+ */
+
+int
+netsec_write(netsec_context *nsc, const void *buffer, size_t size,
+            char **errstr)
+{
+    const unsigned char *bufptr = buffer;
+    int rc, remaining;
+
+    /*
+     * If TLS is active, then bypass all of our buffering logic; just
+     * write it directly to our BIO.  We have a buffering BIO first in
+     * our stack, so buffering will take place there.
+     */
+#ifdef TLS_SUPPORT
+    if (nsc->tls_active) {
+       rc = BIO_write(nsc->ssl_io, buffer, size);
+
+       if (rc <= 0) {
+           netsec_err(errstr, "Error writing to TLS connection: %s",
+                      ERR_error_string(ERR_get_error(), NULL));
+           return NOTOK;
+       }
+
+       return OK;
+    }
+#endif /* TLS_SUPPORT */
+
+    /*
+     * Run a loop copying in data to our local buffer; when we're done with
+     * any buffer overflows then just copy any remaining data in.
+     */
+
+    while ((int) size >= (remaining = nsc->ns_outbufsize - nsc->ns_outbuflen)) {
+       memcpy(nsc->ns_outptr, bufptr, remaining);
+
+       /*
+        * In theory I should increment outptr, but netsec_flush just resets
+        * it anyway.
+        */
+       nsc->ns_outbuflen = nsc->ns_outbufsize;
+
+       rc = netsec_flush(nsc, errstr);
+
+       if (rc != OK)
+           return NOTOK;
+
+       bufptr += remaining;
+       size -= remaining;
+    }
+
+    /*
+     * Copy any leftover data into the buffer.
+     */
+
+    if (size > 0) {
+       memcpy(nsc->ns_outptr, bufptr, size);
+       nsc->ns_outptr += size;
+       nsc->ns_outbuflen += size;
+    }
+
+    return OK;
+}
+
+/*
+ * Write bytes to the network using printf()-style formatting.
+ *
+ * Again, for the most part copy stuff into our buffer to be flushed
+ * out later.
+ */
+
+int
+netsec_printf(netsec_context *nsc, char **errstr, const char *format, ...)
+{
+    va_list ap;
+    int rc;
+
+    /*
+     * Again, if we're using TLS, then bypass our local buffering
+     */
+#ifdef TLS_SUPPORT
+    if (nsc->tls_active) {
+       va_start(ap, format);
+       rc = BIO_vprintf(nsc->ssl_io, format, ap);
+       va_end(ap);
+
+       if (rc <= 0) {
+           netsec_err(errstr, "Error writing to TLS connection: %s",
+                      ERR_error_string(ERR_get_error(), NULL));
+           return NOTOK;
+       }
+
+       return OK;
+    }
+#endif /* TLS_SUPPORT */
+
+    /*
+     * Cheat a little.  If we can fit the data into our outgoing buffer,
+     * great!  If not, generate a flush and retry once.
+     */
+
+retry:
+    va_start(ap, format);
+    rc = vsnprintf((char *) nsc->ns_outptr,
+                  nsc->ns_outbufsize - nsc->ns_outbuflen, format, ap);
+    va_end(ap);
+
+    if (rc >= (int) (nsc->ns_outbufsize - nsc->ns_outbuflen)) {
+       /*
+        * This means we have an overflow.  Note that we don't actually
+        * make use of the terminating NUL, but according to the spec
+        * vsnprintf() won't write to the last byte in the string; that's
+        * why we have to use >= in the comparison above.
+        */
+       if (nsc->ns_outbuffer == nsc->ns_outptr) {
+           /*
+            * Whoops, if the buffer pointer was the same as the start of the
+            * buffer, that means we overflowed the internal buffer.
+            * At that point, just give up.
+            */
+           netsec_err(errstr, "Internal error: wanted to printf() a total of "
+                      "%d bytes, but our buffer size was only %d bytes",
+                      rc, nsc->ns_outbufsize);
+           return NOTOK;
+       } else {
+           /*
+            * Generate a flush (which may be inefficient, but hopefully
+            * it isn't) and then try again.
+            */
+           if (netsec_flush(nsc, errstr) != OK)
+               return NOTOK;
+           /*
+            * After this, outbuffer should == outptr, so we shouldn't
+            * hit this next time around.
+            */
+           goto retry;
+       }
+    }
+
+    nsc->ns_outptr += rc;
+    nsc->ns_outbuflen += rc;
+
+    return OK;
+}
+
+/*
+ * Flush out any buffered data in our output buffers.  This routine is
+ * actually where the real network writes take place.
+ */
+
+int
+netsec_flush(netsec_context *nsc, char **errstr)
+{
+    const char *netoutbuf = (const char *) nsc->ns_outbuffer;
+    unsigned int netoutlen = nsc->ns_outbuflen;
+    int rc;
+
+    /*
+     * For TLS connections, just call BIO_flush(); we'll let TLS handle
+     * all of our output buffering.
+     */
+#ifdef TLS_SUPPORT
+    if (nsc->tls_active) {
+       rc = BIO_flush(nsc->ssl_io);
+
+       if (rc <= 0) {
+           netsec_err(errstr, "Error flushing TLS connection: %s",
+                      ERR_error_string(ERR_get_error(), NULL));
+           return NOTOK;
+       }
+
+       return OK;
+    }
+#endif /* TLS_SUPPORT */
+
+    /*
+     * If SASL security layers are in effect, run the data through
+     * sasl_encode() first and then write it.
+     */
+#ifdef CYRUS_SASL
+    if (nsc->sasl_seclayer) {
+       rc = sasl_encode(nsc->sasl_conn, (const char *) nsc->ns_outbuffer,
+                        nsc->ns_outbuflen, &netoutbuf, &netoutlen);
+
+       if (rc != SASL_OK) {
+           netsec_err(errstr, "SASL data encoding failed: %s",
+                      sasl_errdetail(nsc->sasl_conn));
+           return NOTOK;
+       }
+
+    }
+#endif /* CYRUS_SASL */
+    rc = write(nsc->ns_fd, netoutbuf, netoutlen);
+
+    if (rc < 0) {
+       netsec_err(errstr, "write() failed: %s", strerror(errno));
+       return NOTOK;
+    }
+
+    nsc->ns_outptr = nsc->ns_outbuffer;
+    nsc->ns_outbuflen = 0;
+
+    return OK;
+}
+
 /*
  * Set various SASL protocol parameters
  */
@@ -583,13 +1097,49 @@ netsec_negotiate_sasl(netsec_context *nsc, const char *mechlist, char **errstr)
            return NOTOK;
        }
 
-       nsc->sasl_maxoutsize = *outbufmax;
+       /*
+        * If our output buffer isn't the same size as the input buffer,
+        * reallocate it and set the new size (since we won't encode any
+        * data larger than that).
+        */
+
+       nsc->sasl_maxbufsize = *outbufmax;
 
-       if (nsc->sasl_maxoutsize > nsc->ns_outbufsize) {
-           nsc->ns_outbufsize = nsc->sasl_maxoutsize;
+       if (nsc->ns_outbufsize != nsc->sasl_maxbufsize) {
+           nsc->ns_outbufsize = nsc->sasl_maxbufsize;
            nsc->ns_outbuffer = mh_xrealloc(nsc->ns_outbuffer,
                                            nsc->ns_outbufsize);
+           /*
+            * There shouldn't be any data in the buffer, but for
+            * consistency's sake discard it.
+            */
+           nsc->ns_outptr = nsc->ns_outbuffer;
+           nsc->ns_outbuflen = 0;
        }
+
+       /*
+        * Allocate a buffer to do temporary reads into, before we
+        * call sasl_decode()
+        */
+
+       nsc->sasl_tmpbuf = mh_xmalloc(nsc->sasl_maxbufsize);
+
+       /*
+        * Okay, this is a bit weird.  Make sure that the input buffer
+        * is at least TWICE the size of the max buffer size.  That's
+        * because if we're consuming data but want to extend the current
+        * buffer, we want to be sure there's room for another full buffer's
+        * worth of data.
+        */
+
+       if (nsc->ns_inbufsize < nsc->sasl_maxbufsize * 2) {
+           size_t offset = nsc->ns_inptr - nsc->ns_inbuffer;
+           nsc->ns_inbufsize = nsc->sasl_maxbufsize * 2;
+           nsc->ns_inbuffer = mh_xrealloc(nsc->ns_inbuffer, nsc->ns_inbufsize);
+           nsc->ns_inptr = nsc->ns_inbuffer + offset;
+       }
+
+       nsc->sasl_seclayer = 1;
     }
 
     return OK;
@@ -668,6 +1218,12 @@ netsec_set_tls(netsec_context *nsc, int tls, char **errstr)
            return NOTOK;
        }
 
+       /*
+        * Never bother us, since we are using blocking sockets.
+        */
+
+       SSL_set_mode(ssl, SSL_MODE_AUTO_RETRY);
+
        /*
         * This is a bit weird, so pay attention.
         *
@@ -767,6 +1323,16 @@ netsec_negotiate_tls(netsec_context *nsc, char **errstr)
 
     nsc->tls_active = 1;
 
+    /*
+     * At this point, TLS has been activated; we're not going to use
+     * the output buffer, so free it now to save a little bit of memory.
+     */
+
+    if (nsc->ns_outbuffer) {
+       free(nsc->ns_outbuffer);
+       nsc->ns_outbuffer = NULL;
+    }
+
     return OK;
 #else /* TLS_SUPPORT */
     netsec_err(errstr, "TLS not supported");