stream: Make sure no watcher callback is active while changing stream callbacks
[strongswan.git] / src / libstrongswan / networking / streams / stream.c
index 3c782cc..f6fec0b 100644 (file)
@@ -16,8 +16,7 @@
 #include <library.h>
 #include <errno.h>
 #include <unistd.h>
-#include <sys/socket.h>
-#include <sys/un.h>
+#include <limits.h>
 
 typedef struct private_stream_t private_stream_t;
 
@@ -37,11 +36,6 @@ struct private_stream_t {
        int fd;
 
        /**
-        * FILE* for convenience functions, or NULL
-        */
-       FILE *file;
-
-       /**
         * Callback if data is ready to read
         */
        stream_cb_t read_cb;
@@ -60,8 +54,6 @@ struct private_stream_t {
         * Data for write-ready callback
         */
        void *write_data;
-
-
 };
 
 METHOD(stream_t, read_, ssize_t,
@@ -92,6 +84,29 @@ METHOD(stream_t, read_, ssize_t,
        }
 }
 
+METHOD(stream_t, read_all, bool,
+       private_stream_t *this, void *buf, size_t len)
+{
+       ssize_t ret;
+
+       while (len)
+       {
+               ret = read_(this, buf, len, TRUE);
+               if (ret < 0)
+               {
+                       return FALSE;
+               }
+               if (ret == 0)
+               {
+                       errno = ECONNRESET;
+                       return FALSE;
+               }
+               len -= ret;
+               buf += ret;
+       }
+       return TRUE;
+}
+
 METHOD(stream_t, write_, ssize_t,
        private_stream_t *this, void *buf, size_t len, bool block)
 {
@@ -120,15 +135,27 @@ METHOD(stream_t, write_, ssize_t,
        }
 }
 
-/**
- * Remove a registered watcher
- */
-static void remove_watcher(private_stream_t *this)
+METHOD(stream_t, write_all, bool,
+       private_stream_t *this, void *buf, size_t len)
 {
-       if (this->read_cb || this->write_cb)
+       ssize_t ret;
+
+       while (len)
        {
-               lib->watcher->remove(lib->watcher, this->fd);
+               ret = write_(this, buf, len, TRUE);
+               if (ret < 0)
+               {
+                       return FALSE;
+               }
+               if (ret == 0)
+               {
+                       errno = ECONNRESET;
+                       return FALSE;
+               }
+               len -= ret;
+               buf += ret;
        }
+       return TRUE;
 }
 
 /**
@@ -137,21 +164,26 @@ static void remove_watcher(private_stream_t *this)
 static bool watch(private_stream_t *this, int fd, watcher_event_t event)
 {
        bool keep = FALSE;
+       stream_cb_t cb;
 
        switch (event)
        {
                case WATCHER_READ:
-                       keep = this->read_cb(this->read_data, &this->public);
-                       if (!keep)
+                       cb = this->read_cb;
+                       this->read_cb = NULL;
+                       keep = cb(this->read_data, &this->public);
+                       if (keep)
                        {
-                               this->read_cb = NULL;
+                               this->read_cb = cb;
                        }
                        break;
                case WATCHER_WRITE:
-                       keep = this->write_cb(this->write_data, &this->public);
-                       if (!keep)
+                       cb = this->write_cb;
+                       this->write_cb = NULL;
+                       keep = cb(this->write_data, &this->public);
+                       if (keep)
                        {
-                               this->write_cb = NULL;
+                               this->write_cb = cb;
                        }
                        break;
                case WATCHER_EXCEPT:
@@ -185,7 +217,7 @@ static void add_watcher(private_stream_t *this)
 METHOD(stream_t, on_read, void,
        private_stream_t *this, stream_cb_t cb, void *data)
 {
-       remove_watcher(this);
+       lib->watcher->remove(lib->watcher, this->fd);
 
        this->read_cb = cb;
        this->read_data = data;
@@ -196,7 +228,7 @@ METHOD(stream_t, on_read, void,
 METHOD(stream_t, on_write, void,
        private_stream_t *this, stream_cb_t cb, void *data)
 {
-       remove_watcher(this);
+       lib->watcher->remove(lib->watcher, this->fd);
 
        this->write_cb = cb;
        this->write_data = data;
@@ -204,45 +236,31 @@ METHOD(stream_t, on_write, void,
        add_watcher(this);
 }
 
-METHOD(stream_t, vprint, int,
-       private_stream_t *this, char *format, va_list ap)
+METHOD(stream_t, get_file, FILE*,
+       private_stream_t *this)
 {
-       if (!this->file)
+       FILE *file;
+       int fd;
+
+       /* fclose() closes the FD passed to fdopen(), so dup() it */
+       fd = dup(this->fd);
+       if (fd == -1)
        {
-               this->file = fdopen(this->fd, "w+");
-               if (!this->file)
-               {
-                       return -1;
-               }
+               return NULL;
        }
-       return vfprintf(this->file, format, ap);
-}
-
-METHOD(stream_t, print, int,
-       private_stream_t *this, char *format, ...)
-{
-       va_list ap;
-       int ret;
-
-       va_start(ap, format);
-       ret = vprint(this, format, ap);
-       va_end(ap);
-
-       return ret;
+       file = fdopen(fd, "w+");
+       if (!file)
+       {
+               close(fd);
+       }
+       return file;
 }
 
 METHOD(stream_t, destroy, void,
        private_stream_t *this)
 {
-       remove_watcher(this);
-       if (this->file)
-       {
-               fclose(this->file);
-       }
-       else
-       {
-               close(this->fd);
-       }
+       lib->watcher->remove(lib->watcher, this->fd);
+       close(this->fd);
        free(this);
 }
 
@@ -256,11 +274,12 @@ stream_t *stream_create_from_fd(int fd)
        INIT(this,
                .public = {
                        .read = _read_,
+                       .read_all = _read_all,
                        .on_read = _on_read,
                        .write = _write_,
+                       .write_all = _write_all,
                        .on_write = _on_write,
-                       .print = _print,
-                       .vprint = _vprint,
+                       .get_file = _get_file,
                        .destroy = _destroy,
                },
                .fd = fd,
@@ -283,6 +302,7 @@ int stream_parse_uri_unix(char *uri, struct sockaddr_un *addr)
        memset(addr, 0, sizeof(*addr));
        addr->sun_family = AF_UNIX;
        strncpy(addr->sun_path, uri, sizeof(addr->sun_path));
+       addr->sun_path[sizeof(addr->sun_path)-1] = '\0';
 
        return offsetof(struct sockaddr_un, sun_path) + strlen(addr->sun_path);
 }
@@ -315,3 +335,81 @@ stream_t *stream_create_unix(char *uri)
        }
        return stream_create_from_fd(fd);
 }
+
+/**
+ * See header.
+ */
+int stream_parse_uri_tcp(char *uri, struct sockaddr *addr)
+{
+       char *pos, buf[128];
+       host_t *host;
+       u_long port;
+       int len;
+
+       if (!strpfx(uri, "tcp://"))
+       {
+               return -1;
+       }
+       uri += strlen("tcp://");
+       pos = strrchr(uri, ':');
+       if (!pos)
+       {
+               return -1;
+       }
+       if (*uri == '[' && pos > uri && *(pos - 1) == ']')
+       {
+               /* IPv6 URI */
+               snprintf(buf, sizeof(buf), "%.*s", (int)(pos - uri - 2), uri + 1);
+       }
+       else
+       {
+               snprintf(buf, sizeof(buf), "%.*s", (int)(pos - uri), uri);
+       }
+       port = strtoul(pos + 1, &pos, 10);
+       if (port == ULONG_MAX || *pos || port > 65535)
+       {
+               return -1;
+       }
+       host = host_create_from_dns(buf, AF_UNSPEC, port);
+       if (!host)
+       {
+               return -1;
+       }
+       len = *host->get_sockaddr_len(host);
+       memcpy(addr, host->get_sockaddr(host), len);
+       host->destroy(host);
+       return len;
+}
+
+/**
+ * See header
+ */
+stream_t *stream_create_tcp(char *uri)
+{
+       union {
+               struct sockaddr_in in;
+               struct sockaddr_in6 in6;
+               struct sockaddr sa;
+       } addr;
+       int fd, len;
+
+       len = stream_parse_uri_tcp(uri, &addr.sa);
+       if (len == -1)
+       {
+               DBG1(DBG_NET, "invalid stream URI: '%s'", uri);
+               return NULL;
+       }
+       fd = socket(addr.sa.sa_family, SOCK_STREAM, 0);
+       if (fd < 0)
+       {
+               DBG1(DBG_NET, "opening socket '%s' failed: %s", uri, strerror(errno));
+               return NULL;
+       }
+       if (connect(fd, &addr.sa, len))
+       {
+               DBG1(DBG_NET, "connecting to '%s' failed: %s", uri, strerror(errno));
+               close(fd);
+               return NULL;
+       }
+       return stream_create_from_fd(fd);
+}