Added a tls_socket_t.splice method to wrap a file descriptor into TLS
[strongswan.git] / src / libtls / tls_socket.c
index b6ebdfb..3abff59 100644 (file)
 #include "tls_socket.h"
 
 #include <unistd.h>
+#include <errno.h>
 
 #include <debug.h>
+#include <threading/thread.h>
+
+/**
+ * Buffer size for plain side I/O
+ */
+#define PLAIN_BUF_SIZE 4096
+
+/**
+ * Buffer size for encrypted side I/O
+ */
+#define CRYPTO_BUF_SIZE 4096
 
 typedef struct private_tls_socket_t private_tls_socket_t;
 typedef struct private_tls_application_t private_tls_application_t;
@@ -96,8 +108,8 @@ METHOD(tls_application_t, build, status_t,
  */
 static bool exchange(private_tls_socket_t *this, bool wr)
 {
-       char buf[1024];
-       ssize_t len;
+       char buf[CRYPTO_BUF_SIZE], *pos;
+       ssize_t len, out;
        int round = 0;
 
        for (round = 0; TRUE; round++)
@@ -109,10 +121,18 @@ static bool exchange(private_tls_socket_t *this, bool wr)
                        {
                                case NEED_MORE:
                                case ALREADY_DONE:
-                                       len = write(this->fd, buf, len);
-                                       if (len == -1)
+                                       pos = buf;
+                                       while (len)
                                        {
-                                               return FALSE;
+                                               out = write(this->fd, pos, len);
+                                               if (out == -1)
+                                               {
+                                                       DBG1(DBG_TLS, "TLS crypto write error: %s",
+                                                                strerror(errno));
+                                                       return FALSE;
+                                               }
+                                               len -= out;
+                                               pos += out;
                                        }
                                        continue;
                                case INVALID_STATE:
@@ -175,6 +195,75 @@ METHOD(tls_socket_t, write_, bool,
        return FALSE;
 }
 
+METHOD(tls_socket_t, splice, bool,
+       private_tls_socket_t *this, int rfd, int wfd)
+{
+       char buf[PLAIN_BUF_SIZE], *pos;
+       fd_set set;
+       chunk_t data;
+       ssize_t len;
+       bool old;
+
+       while (TRUE)
+       {
+               FD_ZERO(&set);
+               FD_SET(rfd, &set);
+               FD_SET(this->fd, &set);
+
+               old = thread_cancelability(TRUE);
+               len = select(max(rfd, this->fd) + 1, &set, NULL, NULL, NULL);
+               thread_cancelability(old);
+               if (len == -1)
+               {
+                       DBG1(DBG_TLS, "TLS select error: %s", strerror(errno));
+                       return FALSE;
+               }
+               if (FD_ISSET(this->fd, &set))
+               {
+                       if (!read_(this, &data))
+                       {
+                               DBG2(DBG_TLS, "TLS read error/disconnect");
+                               return TRUE;
+                       }
+                       pos = data.ptr;
+                       while (data.len)
+                       {
+                               len = write(wfd, pos, data.len);
+                               if (len == -1)
+                               {
+                                       free(data.ptr);
+                                       DBG1(DBG_TLS, "TLS plain write error: %s", strerror(errno));
+                                       return FALSE;
+                               }
+                               data.len -= len;
+                               pos += len;
+                       }
+                       free(data.ptr);
+               }
+               if (FD_ISSET(rfd, &set))
+               {
+                       len = read(rfd, buf, sizeof(buf));
+                       if (len > 0)
+                       {
+                               if (!write_(this, chunk_create(buf, len)))
+                               {
+                                       DBG1(DBG_TLS, "TLS write error");
+                                       return FALSE;
+                               }
+                       }
+                       else
+                       {
+                               if (len < 0)
+                               {
+                                       DBG1(DBG_TLS, "TLS plain read error: %s", strerror(errno));
+                                       return FALSE;
+                               }
+                               return TRUE;
+                       }
+               }
+       }
+}
+
 METHOD(tls_socket_t, get_fd, int,
        private_tls_socket_t *this)
 {
@@ -201,6 +290,7 @@ tls_socket_t *tls_socket_create(bool is_server, identification_t *server,
                .public = {
                        .read = _read_,
                        .write = _write_,
+                       .splice = _splice,
                        .get_fd = _get_fd,
                        .destroy = _destroy,
                },