stream: Make sure no watcher callback is active while changing stream callbacks
[strongswan.git] / src / libstrongswan / networking / streams / stream.c
index 43a6bd4..f6fec0b 100644 (file)
@@ -16,6 +16,7 @@
 #include <library.h>
 #include <errno.h>
 #include <unistd.h>
+#include <limits.h>
 
 typedef struct private_stream_t private_stream_t;
 
@@ -33,6 +34,26 @@ struct private_stream_t {
         * Underlying socket
         */
        int fd;
+
+       /**
+        * Callback if data is ready to read
+        */
+       stream_cb_t read_cb;
+
+       /**
+        * Data for read-ready callback
+        */
+       void *read_data;
+
+       /**
+        * Callback if write is non-blocking
+        */
+       stream_cb_t write_cb;
+
+       /**
+        * Data for write-ready callback
+        */
+       void *write_data;
 };
 
 METHOD(stream_t, read_, ssize_t,
@@ -63,6 +84,29 @@ METHOD(stream_t, read_, ssize_t,
        }
 }
 
+METHOD(stream_t, read_all, bool,
+       private_stream_t *this, void *buf, size_t len)
+{
+       ssize_t ret;
+
+       while (len)
+       {
+               ret = read_(this, buf, len, TRUE);
+               if (ret < 0)
+               {
+                       return FALSE;
+               }
+               if (ret == 0)
+               {
+                       errno = ECONNRESET;
+                       return FALSE;
+               }
+               len -= ret;
+               buf += ret;
+       }
+       return TRUE;
+}
+
 METHOD(stream_t, write_, ssize_t,
        private_stream_t *this, void *buf, size_t len, bool block)
 {
@@ -91,9 +135,131 @@ METHOD(stream_t, write_, ssize_t,
        }
 }
 
+METHOD(stream_t, write_all, bool,
+       private_stream_t *this, void *buf, size_t len)
+{
+       ssize_t ret;
+
+       while (len)
+       {
+               ret = write_(this, buf, len, TRUE);
+               if (ret < 0)
+               {
+                       return FALSE;
+               }
+               if (ret == 0)
+               {
+                       errno = ECONNRESET;
+                       return FALSE;
+               }
+               len -= ret;
+               buf += ret;
+       }
+       return TRUE;
+}
+
+/**
+ * Watcher callback
+ */
+static bool watch(private_stream_t *this, int fd, watcher_event_t event)
+{
+       bool keep = FALSE;
+       stream_cb_t cb;
+
+       switch (event)
+       {
+               case WATCHER_READ:
+                       cb = this->read_cb;
+                       this->read_cb = NULL;
+                       keep = cb(this->read_data, &this->public);
+                       if (keep)
+                       {
+                               this->read_cb = cb;
+                       }
+                       break;
+               case WATCHER_WRITE:
+                       cb = this->write_cb;
+                       this->write_cb = NULL;
+                       keep = cb(this->write_data, &this->public);
+                       if (keep)
+                       {
+                               this->write_cb = cb;
+                       }
+                       break;
+               case WATCHER_EXCEPT:
+                       break;
+       }
+       return keep;
+}
+
+/**
+ * Register watcher for stream callbacks
+ */
+static void add_watcher(private_stream_t *this)
+{
+       watcher_event_t events = 0;
+
+       if (this->read_cb)
+       {
+               events |= WATCHER_READ;
+       }
+       if (this->write_cb)
+       {
+               events |= WATCHER_WRITE;
+       }
+       if (events)
+       {
+               lib->watcher->add(lib->watcher, this->fd, events,
+                                                 (watcher_cb_t)watch, this);
+       }
+}
+
+METHOD(stream_t, on_read, void,
+       private_stream_t *this, stream_cb_t cb, void *data)
+{
+       lib->watcher->remove(lib->watcher, this->fd);
+
+       this->read_cb = cb;
+       this->read_data = data;
+
+       add_watcher(this);
+}
+
+METHOD(stream_t, on_write, void,
+       private_stream_t *this, stream_cb_t cb, void *data)
+{
+       lib->watcher->remove(lib->watcher, this->fd);
+
+       this->write_cb = cb;
+       this->write_data = data;
+
+       add_watcher(this);
+}
+
+METHOD(stream_t, get_file, FILE*,
+       private_stream_t *this)
+{
+       FILE *file;
+       int fd;
+
+       /* fclose() closes the FD passed to fdopen(), so dup() it */
+       fd = dup(this->fd);
+       if (fd == -1)
+       {
+               return NULL;
+       }
+       file = fdopen(fd, "w+");
+       if (!file)
+       {
+               close(fd);
+       }
+       return file;
+}
+
 METHOD(stream_t, destroy, void,
        private_stream_t *this)
 {
+       lib->watcher->remove(lib->watcher, this->fd);
        close(this->fd);
        free(this);
 }
@@ -108,7 +274,12 @@ stream_t *stream_create_from_fd(int fd)
        INIT(this,
                .public = {
                        .read = _read_,
+                       .read_all = _read_all,
+                       .on_read = _on_read,
                        .write = _write_,
+                       .write_all = _write_all,
+                       .on_write = _on_write,
+                       .get_file = _get_file,
                        .destroy = _destroy,
                },
                .fd = fd,
@@ -116,3 +287,129 @@ stream_t *stream_create_from_fd(int fd)
 
        return &this->public;
 }
+
+/**
+ * See header
+ */
+int stream_parse_uri_unix(char *uri, struct sockaddr_un *addr)
+{
+       if (!strpfx(uri, "unix://"))
+       {
+               return -1;
+       }
+       uri += strlen("unix://");
+
+       memset(addr, 0, sizeof(*addr));
+       addr->sun_family = AF_UNIX;
+       strncpy(addr->sun_path, uri, sizeof(addr->sun_path));
+       addr->sun_path[sizeof(addr->sun_path)-1] = '\0';
+
+       return offsetof(struct sockaddr_un, sun_path) + strlen(addr->sun_path);
+}
+
+/**
+ * See header
+ */
+stream_t *stream_create_unix(char *uri)
+{
+       struct sockaddr_un addr;
+       int len, fd;
+
+       len = stream_parse_uri_unix(uri, &addr);
+       if (len == -1)
+       {
+               DBG1(DBG_NET, "invalid stream URI: '%s'", uri);
+               return NULL;
+       }
+       fd = socket(AF_UNIX, SOCK_STREAM, 0);
+       if (fd < 0)
+       {
+               DBG1(DBG_NET, "opening socket '%s' failed: %s", uri, strerror(errno));
+               return NULL;
+       }
+       if (connect(fd, (struct sockaddr*)&addr, len) < 0)
+       {
+               DBG1(DBG_NET, "connecting to '%s' failed: %s", uri, strerror(errno));
+               close(fd);
+               return NULL;
+       }
+       return stream_create_from_fd(fd);
+}
+
+/**
+ * See header.
+ */
+int stream_parse_uri_tcp(char *uri, struct sockaddr *addr)
+{
+       char *pos, buf[128];
+       host_t *host;
+       u_long port;
+       int len;
+
+       if (!strpfx(uri, "tcp://"))
+       {
+               return -1;
+       }
+       uri += strlen("tcp://");
+       pos = strrchr(uri, ':');
+       if (!pos)
+       {
+               return -1;
+       }
+       if (*uri == '[' && pos > uri && *(pos - 1) == ']')
+       {
+               /* IPv6 URI */
+               snprintf(buf, sizeof(buf), "%.*s", (int)(pos - uri - 2), uri + 1);
+       }
+       else
+       {
+               snprintf(buf, sizeof(buf), "%.*s", (int)(pos - uri), uri);
+       }
+       port = strtoul(pos + 1, &pos, 10);
+       if (port == ULONG_MAX || *pos || port > 65535)
+       {
+               return -1;
+       }
+       host = host_create_from_dns(buf, AF_UNSPEC, port);
+       if (!host)
+       {
+               return -1;
+       }
+       len = *host->get_sockaddr_len(host);
+       memcpy(addr, host->get_sockaddr(host), len);
+       host->destroy(host);
+       return len;
+}
+
+/**
+ * See header
+ */
+stream_t *stream_create_tcp(char *uri)
+{
+       union {
+               struct sockaddr_in in;
+               struct sockaddr_in6 in6;
+               struct sockaddr sa;
+       } addr;
+       int fd, len;
+
+       len = stream_parse_uri_tcp(uri, &addr.sa);
+       if (len == -1)
+       {
+               DBG1(DBG_NET, "invalid stream URI: '%s'", uri);
+               return NULL;
+       }
+       fd = socket(addr.sa.sa_family, SOCK_STREAM, 0);
+       if (fd < 0)
+       {
+               DBG1(DBG_NET, "opening socket '%s' failed: %s", uri, strerror(errno));
+               return NULL;
+       }
+       if (connect(fd, &addr.sa, len))
+       {
+               DBG1(DBG_NET, "connecting to '%s' failed: %s", uri, strerror(errno));
+               close(fd);
+               return NULL;
+       }
+       return stream_create_from_fd(fd);
+}