stream: Make sure no watcher callback is active while changing stream callbacks
[strongswan.git] / src / libstrongswan / networking / streams / stream.c
index bc6bbc2..f6fec0b 100644 (file)
@@ -36,11 +36,6 @@ struct private_stream_t {
        int fd;
 
        /**
-        * FILE* for convenience functions, or NULL
-        */
-       FILE *file;
-
-       /**
         * Callback if data is ready to read
         */
        stream_cb_t read_cb;
@@ -59,8 +54,6 @@ struct private_stream_t {
         * Data for write-ready callback
         */
        void *write_data;
-
-
 };
 
 METHOD(stream_t, read_, ssize_t,
@@ -91,6 +84,29 @@ METHOD(stream_t, read_, ssize_t,
        }
 }
 
+METHOD(stream_t, read_all, bool,
+       private_stream_t *this, void *buf, size_t len)
+{
+       ssize_t ret;
+
+       while (len)
+       {
+               ret = read_(this, buf, len, TRUE);
+               if (ret < 0)
+               {
+                       return FALSE;
+               }
+               if (ret == 0)
+               {
+                       errno = ECONNRESET;
+                       return FALSE;
+               }
+               len -= ret;
+               buf += ret;
+       }
+       return TRUE;
+}
+
 METHOD(stream_t, write_, ssize_t,
        private_stream_t *this, void *buf, size_t len, bool block)
 {
@@ -119,15 +135,27 @@ METHOD(stream_t, write_, ssize_t,
        }
 }
 
-/**
- * Remove a registered watcher
- */
-static void remove_watcher(private_stream_t *this)
+METHOD(stream_t, write_all, bool,
+       private_stream_t *this, void *buf, size_t len)
 {
-       if (this->read_cb || this->write_cb)
+       ssize_t ret;
+
+       while (len)
        {
-               lib->watcher->remove(lib->watcher, this->fd);
+               ret = write_(this, buf, len, TRUE);
+               if (ret < 0)
+               {
+                       return FALSE;
+               }
+               if (ret == 0)
+               {
+                       errno = ECONNRESET;
+                       return FALSE;
+               }
+               len -= ret;
+               buf += ret;
        }
+       return TRUE;
 }
 
 /**
@@ -136,21 +164,26 @@ static void remove_watcher(private_stream_t *this)
 static bool watch(private_stream_t *this, int fd, watcher_event_t event)
 {
        bool keep = FALSE;
+       stream_cb_t cb;
 
        switch (event)
        {
                case WATCHER_READ:
-                       keep = this->read_cb(this->read_data, &this->public);
-                       if (!keep)
+                       cb = this->read_cb;
+                       this->read_cb = NULL;
+                       keep = cb(this->read_data, &this->public);
+                       if (keep)
                        {
-                               this->read_cb = NULL;
+                               this->read_cb = cb;
                        }
                        break;
                case WATCHER_WRITE:
-                       keep = this->write_cb(this->write_data, &this->public);
-                       if (!keep)
+                       cb = this->write_cb;
+                       this->write_cb = NULL;
+                       keep = cb(this->write_data, &this->public);
+                       if (keep)
                        {
-                               this->write_cb = NULL;
+                               this->write_cb = cb;
                        }
                        break;
                case WATCHER_EXCEPT:
@@ -184,7 +217,7 @@ static void add_watcher(private_stream_t *this)
 METHOD(stream_t, on_read, void,
        private_stream_t *this, stream_cb_t cb, void *data)
 {
-       remove_watcher(this);
+       lib->watcher->remove(lib->watcher, this->fd);
 
        this->read_cb = cb;
        this->read_data = data;
@@ -195,7 +228,7 @@ METHOD(stream_t, on_read, void,
 METHOD(stream_t, on_write, void,
        private_stream_t *this, stream_cb_t cb, void *data)
 {
-       remove_watcher(this);
+       lib->watcher->remove(lib->watcher, this->fd);
 
        this->write_cb = cb;
        this->write_data = data;
@@ -203,45 +236,31 @@ METHOD(stream_t, on_write, void,
        add_watcher(this);
 }
 
-METHOD(stream_t, vprint, int,
-       private_stream_t *this, char *format, va_list ap)
+METHOD(stream_t, get_file, FILE*,
+       private_stream_t *this)
 {
-       if (!this->file)
+       FILE *file;
+       int fd;
+
+       /* fclose() closes the FD passed to fdopen(), so dup() it */
+       fd = dup(this->fd);
+       if (fd == -1)
        {
-               this->file = fdopen(this->fd, "w+");
-               if (!this->file)
-               {
-                       return -1;
-               }
+               return NULL;
        }
-       return vfprintf(this->file, format, ap);
-}
-
-METHOD(stream_t, print, int,
-       private_stream_t *this, char *format, ...)
-{
-       va_list ap;
-       int ret;
-
-       va_start(ap, format);
-       ret = vprint(this, format, ap);
-       va_end(ap);
-
-       return ret;
+       file = fdopen(fd, "w+");
+       if (!file)
+       {
+               close(fd);
+       }
+       return file;
 }
 
 METHOD(stream_t, destroy, void,
        private_stream_t *this)
 {
-       remove_watcher(this);
-       if (this->file)
-       {
-               fclose(this->file);
-       }
-       else
-       {
-               close(this->fd);
-       }
+       lib->watcher->remove(lib->watcher, this->fd);
+       close(this->fd);
        free(this);
 }
 
@@ -255,11 +274,12 @@ stream_t *stream_create_from_fd(int fd)
        INIT(this,
                .public = {
                        .read = _read_,
+                       .read_all = _read_all,
                        .on_read = _on_read,
                        .write = _write_,
+                       .write_all = _write_all,
                        .on_write = _on_write,
-                       .print = _print,
-                       .vprint = _vprint,
+                       .get_file = _get_file,
                        .destroy = _destroy,
                },
                .fd = fd,
@@ -282,6 +302,7 @@ int stream_parse_uri_unix(char *uri, struct sockaddr_un *addr)
        memset(addr, 0, sizeof(*addr));
        addr->sun_family = AF_UNIX;
        strncpy(addr->sun_path, uri, sizeof(addr->sun_path));
+       addr->sun_path[sizeof(addr->sun_path)-1] = '\0';
 
        return offsetof(struct sockaddr_un, sun_path) + strlen(addr->sun_path);
 }