Added a tls_socket_t.splice method to wrap a file descriptor into TLS
authorMartin Willi <martin@revosec.ch>
Sat, 31 Dec 2011 11:37:08 +0000 (12:37 +0100)
committerMartin Willi <martin@revosec.ch>
Sat, 31 Dec 2011 12:14:49 +0000 (13:14 +0100)
src/libtls/tls_socket.c
src/libtls/tls_socket.h

index b6ebdfb..3abff59 100644 (file)
 #include "tls_socket.h"
 
 #include <unistd.h>
+#include <errno.h>
 
 #include <debug.h>
+#include <threading/thread.h>
+
+/**
+ * Buffer size for plain side I/O
+ */
+#define PLAIN_BUF_SIZE 4096
+
+/**
+ * Buffer size for encrypted side I/O
+ */
+#define CRYPTO_BUF_SIZE 4096
 
 typedef struct private_tls_socket_t private_tls_socket_t;
 typedef struct private_tls_application_t private_tls_application_t;
@@ -96,8 +108,8 @@ METHOD(tls_application_t, build, status_t,
  */
 static bool exchange(private_tls_socket_t *this, bool wr)
 {
-       char buf[1024];
-       ssize_t len;
+       char buf[CRYPTO_BUF_SIZE], *pos;
+       ssize_t len, out;
        int round = 0;
 
        for (round = 0; TRUE; round++)
@@ -109,10 +121,18 @@ static bool exchange(private_tls_socket_t *this, bool wr)
                        {
                                case NEED_MORE:
                                case ALREADY_DONE:
-                                       len = write(this->fd, buf, len);
-                                       if (len == -1)
+                                       pos = buf;
+                                       while (len)
                                        {
-                                               return FALSE;
+                                               out = write(this->fd, pos, len);
+                                               if (out == -1)
+                                               {
+                                                       DBG1(DBG_TLS, "TLS crypto write error: %s",
+                                                                strerror(errno));
+                                                       return FALSE;
+                                               }
+                                               len -= out;
+                                               pos += out;
                                        }
                                        continue;
                                case INVALID_STATE:
@@ -175,6 +195,75 @@ METHOD(tls_socket_t, write_, bool,
        return FALSE;
 }
 
+METHOD(tls_socket_t, splice, bool,
+       private_tls_socket_t *this, int rfd, int wfd)
+{
+       char buf[PLAIN_BUF_SIZE], *pos;
+       fd_set set;
+       chunk_t data;
+       ssize_t len;
+       bool old;
+
+       while (TRUE)
+       {
+               FD_ZERO(&set);
+               FD_SET(rfd, &set);
+               FD_SET(this->fd, &set);
+
+               old = thread_cancelability(TRUE);
+               len = select(max(rfd, this->fd) + 1, &set, NULL, NULL, NULL);
+               thread_cancelability(old);
+               if (len == -1)
+               {
+                       DBG1(DBG_TLS, "TLS select error: %s", strerror(errno));
+                       return FALSE;
+               }
+               if (FD_ISSET(this->fd, &set))
+               {
+                       if (!read_(this, &data))
+                       {
+                               DBG2(DBG_TLS, "TLS read error/disconnect");
+                               return TRUE;
+                       }
+                       pos = data.ptr;
+                       while (data.len)
+                       {
+                               len = write(wfd, pos, data.len);
+                               if (len == -1)
+                               {
+                                       free(data.ptr);
+                                       DBG1(DBG_TLS, "TLS plain write error: %s", strerror(errno));
+                                       return FALSE;
+                               }
+                               data.len -= len;
+                               pos += len;
+                       }
+                       free(data.ptr);
+               }
+               if (FD_ISSET(rfd, &set))
+               {
+                       len = read(rfd, buf, sizeof(buf));
+                       if (len > 0)
+                       {
+                               if (!write_(this, chunk_create(buf, len)))
+                               {
+                                       DBG1(DBG_TLS, "TLS write error");
+                                       return FALSE;
+                               }
+                       }
+                       else
+                       {
+                               if (len < 0)
+                               {
+                                       DBG1(DBG_TLS, "TLS plain read error: %s", strerror(errno));
+                                       return FALSE;
+                               }
+                               return TRUE;
+                       }
+               }
+       }
+}
+
 METHOD(tls_socket_t, get_fd, int,
        private_tls_socket_t *this)
 {
@@ -201,6 +290,7 @@ tls_socket_t *tls_socket_create(bool is_server, identification_t *server,
                .public = {
                        .read = _read_,
                        .write = _write_,
+                       .splice = _splice,
                        .get_fd = _get_fd,
                        .destroy = _destroy,
                },
index 9f0e964..edd05fd 100644 (file)
@@ -55,6 +55,18 @@ struct tls_socket_t {
        bool (*write)(tls_socket_t *this, chunk_t data);
 
        /**
+        * Read/write plain data from file descriptor.
+        *
+        * This call is blocking, but a thread cancellation point. Data is
+        * exchanged until one of the sockets gets closed or an error occurs.
+        *
+        * @param rfd           file descriptor to read plain data from
+        * @param wfd           file descriptor to write plain data to
+        * @return                      TRUE if data exchanged successfully
+        */
+       bool (*splice)(tls_socket_t *this, int rfd, int wfd);
+
+       /**
         * Get the underlying file descriptor passed to the constructor.
         *
         * @return                      file descriptor