stream: Make sure no watcher callback is active while changing stream callbacks
[strongswan.git] / src / libstrongswan / networking / streams / stream.c
index 9a4a3d3..f6fec0b 100644 (file)
@@ -54,8 +54,6 @@ struct private_stream_t {
         * Data for write-ready callback
         */
        void *write_data;
-
-
 };
 
 METHOD(stream_t, read_, ssize_t,
@@ -86,6 +84,29 @@ METHOD(stream_t, read_, ssize_t,
        }
 }
 
+METHOD(stream_t, read_all, bool,
+       private_stream_t *this, void *buf, size_t len)
+{
+       ssize_t ret;
+
+       while (len)
+       {
+               ret = read_(this, buf, len, TRUE);
+               if (ret < 0)
+               {
+                       return FALSE;
+               }
+               if (ret == 0)
+               {
+                       errno = ECONNRESET;
+                       return FALSE;
+               }
+               len -= ret;
+               buf += ret;
+       }
+       return TRUE;
+}
+
 METHOD(stream_t, write_, ssize_t,
        private_stream_t *this, void *buf, size_t len, bool block)
 {
@@ -114,15 +135,27 @@ METHOD(stream_t, write_, ssize_t,
        }
 }
 
-/**
- * Remove a registered watcher
- */
-static void remove_watcher(private_stream_t *this)
+METHOD(stream_t, write_all, bool,
+       private_stream_t *this, void *buf, size_t len)
 {
-       if (this->read_cb || this->write_cb)
+       ssize_t ret;
+
+       while (len)
        {
-               lib->watcher->remove(lib->watcher, this->fd);
+               ret = write_(this, buf, len, TRUE);
+               if (ret < 0)
+               {
+                       return FALSE;
+               }
+               if (ret == 0)
+               {
+                       errno = ECONNRESET;
+                       return FALSE;
+               }
+               len -= ret;
+               buf += ret;
        }
+       return TRUE;
 }
 
 /**
@@ -131,21 +164,26 @@ static void remove_watcher(private_stream_t *this)
 static bool watch(private_stream_t *this, int fd, watcher_event_t event)
 {
        bool keep = FALSE;
+       stream_cb_t cb;
 
        switch (event)
        {
                case WATCHER_READ:
-                       keep = this->read_cb(this->read_data, &this->public);
-                       if (!keep)
+                       cb = this->read_cb;
+                       this->read_cb = NULL;
+                       keep = cb(this->read_data, &this->public);
+                       if (keep)
                        {
-                               this->read_cb = NULL;
+                               this->read_cb = cb;
                        }
                        break;
                case WATCHER_WRITE:
-                       keep = this->write_cb(this->write_data, &this->public);
-                       if (!keep)
+                       cb = this->write_cb;
+                       this->write_cb = NULL;
+                       keep = cb(this->write_data, &this->public);
+                       if (keep)
                        {
-                               this->write_cb = NULL;
+                               this->write_cb = cb;
                        }
                        break;
                case WATCHER_EXCEPT:
@@ -179,7 +217,7 @@ static void add_watcher(private_stream_t *this)
 METHOD(stream_t, on_read, void,
        private_stream_t *this, stream_cb_t cb, void *data)
 {
-       remove_watcher(this);
+       lib->watcher->remove(lib->watcher, this->fd);
 
        this->read_cb = cb;
        this->read_data = data;
@@ -190,7 +228,7 @@ METHOD(stream_t, on_read, void,
 METHOD(stream_t, on_write, void,
        private_stream_t *this, stream_cb_t cb, void *data)
 {
-       remove_watcher(this);
+       lib->watcher->remove(lib->watcher, this->fd);
 
        this->write_cb = cb;
        this->write_data = data;
@@ -221,7 +259,7 @@ METHOD(stream_t, get_file, FILE*,
 METHOD(stream_t, destroy, void,
        private_stream_t *this)
 {
-       remove_watcher(this);
+       lib->watcher->remove(lib->watcher, this->fd);
        close(this->fd);
        free(this);
 }
@@ -236,8 +274,10 @@ stream_t *stream_create_from_fd(int fd)
        INIT(this,
                .public = {
                        .read = _read_,
+                       .read_all = _read_all,
                        .on_read = _on_read,
                        .write = _write_,
+                       .write_all = _write_all,
                        .on_write = _on_write,
                        .get_file = _get_file,
                        .destroy = _destroy,
@@ -262,6 +302,7 @@ int stream_parse_uri_unix(char *uri, struct sockaddr_un *addr)
        memset(addr, 0, sizeof(*addr));
        addr->sun_family = AF_UNIX;
        strncpy(addr->sun_path, uri, sizeof(addr->sun_path));
+       addr->sun_path[sizeof(addr->sun_path)-1] = '\0';
 
        return offsetof(struct sockaddr_un, sun_path) + strlen(addr->sun_path);
 }