vici: Refactor socket to clean up locking
authorMartin Willi <martin@revosec.ch>
Wed, 12 Feb 2014 16:55:38 +0000 (17:55 +0100)
committerMartin Willi <martin@revosec.ch>
Wed, 7 May 2014 12:13:36 +0000 (14:13 +0200)
Uses separate locks for socket read and write operations. While holding the
socket reader lock, a different thread can still claim the socket write lock.
This allows to asynchronously send event messages while holding the read
lock.

src/libcharon/plugins/vici/vici_socket.c

index b225198..fa44c70 100644 (file)
@@ -17,7 +17,7 @@
 
 #include <daemon.h>
 #include <threading/mutex.h>
-#include <threading/rwlock.h>
+#include <threading/condvar.h>
 #include <threading/thread.h>
 #include <collections/array.h>
 #include <collections/linked_list.h>
@@ -74,9 +74,9 @@ struct private_vici_socket_t {
        linked_list_t *connections;
 
        /**
-        * rwlock for client connection list
+        * mutex for client connections
         */
-       rwlock_t *lock;
+       mutex_t *mutex;
 };
 
 /**
@@ -85,9 +85,9 @@ struct private_vici_socket_t {
 typedef struct {
        /* reference to socket instance */
        private_vici_socket_t *this;
-       /** connection identifier to disconnect */
+       /** connection identifier of entry */
        u_int id;
-} entry_data_t;
+} entry_selector_t;
 
 /**
  * Partially processed message
@@ -109,16 +109,24 @@ typedef struct {
 typedef struct {
        /** reference to socket */
        private_vici_socket_t *this;
-       /** mutex to lock this entry in/out buffers */
-       mutex_t *mutex;
        /** associated stream */
        stream_t *stream;
        /** queued messages to send, as msg_buf_t pointers */
        array_t *out;
        /** input message buffer */
        msg_buf_t in;
+       /** queued input messages to process, as chunk_t */
+       array_t *queue;
+       /** do we have job processing input queue? */
+       bool has_processor;
        /** client connection identifier */
        u_int id;
+       /** any users reading over this connection? */
+       int readers;
+       /** any users writing over this connection? */
+       int writers;
+       /** condvar to wait for usage  */
+       condvar_t *cond;
 } entry_t;
 
 /**
@@ -128,59 +136,148 @@ CALLBACK(destroy_entry, void,
        entry_t *entry)
 {
        msg_buf_t *out;
+       chunk_t chunk;
 
        entry->stream->destroy(entry->stream);
-
        entry->this->disconnect(entry->this->user, entry->id);
+       entry->cond->destroy(entry->cond);
 
-       entry->mutex->destroy(entry->mutex);
        while (array_remove(entry->out, ARRAY_TAIL, &out))
        {
                chunk_clear(&out->buf);
                free(out);
        }
        array_destroy(entry->out);
+       while (array_remove(entry->queue, ARRAY_TAIL, &chunk))
+       {
+               chunk_clear(&chunk);
+       }
+       array_destroy(entry->queue);
        chunk_clear(&entry->in.buf);
        free(entry);
 }
 
 /**
- * Find/remove entry by id, requires proper locking
+ * Find entry by stream (if given) or id, claim use
  */
-static entry_t* find_entry(private_vici_socket_t *this, u_int id, bool remove)
+static entry_t* find_entry(private_vici_socket_t *this, stream_t *stream,
+                                                  u_int id, bool reader, bool writer)
 {
        enumerator_t *enumerator;
        entry_t *entry, *found = NULL;
+       bool candidate = TRUE;
 
-       enumerator = this->connections->create_enumerator(this->connections);
-       while (enumerator->enumerate(enumerator, &entry))
+       this->mutex->lock(this->mutex);
+       while (candidate && !found)
        {
-               if (entry->id == id)
+               candidate = FALSE;
+               enumerator = this->connections->create_enumerator(this->connections);
+               while (enumerator->enumerate(enumerator, &entry))
                {
-                       if (remove)
+                       if (stream)
                        {
-                               this->connections->remove_at(this->connections, enumerator);
+                               if (entry->stream != stream)
+                               {
+                                       continue;
+                               }
+                       }
+                       else
+                       {
+                               if (entry->id != id)
+                               {
+                                       continue;
+                               }
+                       }
+                       candidate = TRUE;
+
+                       if ((reader && entry->readers) ||
+                               (writer && entry->writers))
+                       {
+                               entry->cond->wait(entry->cond, this->mutex);
+                               break;
+                       }
+                       if (reader)
+                       {
+                               entry->readers++;
+                       }
+                       if (writer)
+                       {
+                               entry->writers++;
                        }
                        found = entry;
                        break;
                }
+               enumerator->destroy(enumerator);
        }
-       enumerator->destroy(enumerator);
+       this->mutex->unlock(this->mutex);
 
        return found;
 }
 
 /**
+ * Remove entry by id, claim use
+ */
+static entry_t* remove_entry(private_vici_socket_t *this, u_int id)
+{
+       enumerator_t *enumerator;
+       entry_t *entry, *found = NULL;
+       bool candidate = TRUE;
+
+       this->mutex->lock(this->mutex);
+       while (candidate && !found)
+       {
+               candidate = FALSE;
+               enumerator = this->connections->create_enumerator(this->connections);
+               while (enumerator->enumerate(enumerator, &entry))
+               {
+                       if (entry->id == id)
+                       {
+                               candidate = TRUE;
+                               if (entry->readers || entry->writers)
+                               {
+                                       entry->cond->wait(entry->cond, this->mutex);
+                                       break;
+                               }
+                               this->connections->remove_at(this->connections, enumerator);
+                               found = entry;
+                               break;
+                       }
+               }
+               enumerator->destroy(enumerator);
+       }
+       this->mutex->unlock(this->mutex);
+
+       return found;
+}
+
+/**
+ * Release a claimed entry
+ */
+static void put_entry(private_vici_socket_t *this, entry_t *entry,
+                                         bool reader, bool writer)
+{
+       this->mutex->lock(this->mutex);
+       if (reader)
+       {
+               entry->readers--;
+       }
+       if (writer)
+       {
+               entry->writers--;
+       }
+       entry->cond->signal(entry->cond);
+       this->mutex->unlock(this->mutex);
+}
+
+/**
  * Asynchronous callback to disconnect client
  */
 CALLBACK(disconnect_async, job_requeue_t,
-       entry_data_t *data)
+       entry_selector_t *sel)
 {
        entry_t *entry;
 
-       data->this->lock->write_lock(data->this->lock);
-       entry = find_entry(data->this, data->id, TRUE);
-       data->this->lock->unlock(data->this->lock);
+       entry = remove_entry(sel->this, sel->id);
        if (entry)
        {
                destroy_entry(entry);
@@ -193,15 +290,15 @@ CALLBACK(disconnect_async, job_requeue_t,
  */
 static void disconnect(private_vici_socket_t *this, u_int id)
 {
-       entry_data_t *data;
+       entry_selector_t *sel;
 
-       INIT(data,
+       INIT(sel,
                .this = this,
                .id = id,
        );
 
        lib->processor->queue_job(lib->processor,
-                       (job_t*)callback_job_create(disconnect_async, data, free, NULL));
+                       (job_t*)callback_job_create(disconnect_async, sel, free, NULL));
 }
 
 /**
@@ -271,22 +368,26 @@ static bool do_write(private_vici_socket_t *this, entry_t *entry,
  * Send pending messages
  */
 CALLBACK(on_write, bool,
-       entry_t *entry, stream_t *stream)
+       private_vici_socket_t *this, stream_t *stream)
 {
-       bool ret;
+       entry_t *entry;
+       bool ret = FALSE;
 
-       entry->mutex->lock(entry->mutex);
-       ret = do_write(entry->this, entry, stream);
-       if (ret)
-       {
-               /* unregister if we have no more messages to send */
-               ret = array_count(entry->out) != 0;
-       }
-       else
+       entry = find_entry(this, stream, 0, FALSE, TRUE);
+       if (entry)
        {
-               disconnect(entry->this, entry->id);
+               ret = do_write(this, entry, stream);
+               if (ret)
+               {
+                       /* unregister if we have no more messages to send */
+                       ret = array_count(entry->out) != 0;
+               }
+               else
+               {
+                       disconnect(entry->this, entry->id);
+               }
+               put_entry(this, entry, FALSE, TRUE);
        }
-       entry->mutex->unlock(entry->mutex);
 
        return ret;
 }
@@ -351,33 +452,80 @@ static bool do_read(private_vici_socket_t *this, entry_t *entry,
 }
 
 /**
- * Process incoming messages
+ * Callback processing incoming requestes in strict order
  */
-CALLBACK(on_read, bool,
-       entry_t *entry, stream_t *stream)
+CALLBACK(process_queue, job_requeue_t,
+       entry_selector_t *sel)
 {
-       chunk_t data = chunk_empty;
-       bool ret;
+       entry_t *entry;
+       chunk_t chunk;
+       bool found;
+       u_int id;
 
-       entry->mutex->lock(entry->mutex);
-       ret = do_read(entry->this, entry, stream);
-       if (!ret)
+       while (TRUE)
        {
-               disconnect(entry->this, entry->id);
-       }
-       if (entry->in.buf.len == entry->in.done)
-       {
-               data = entry->in.buf;
-               entry->in.buf = chunk_empty;
-               entry->in.hdrlen = entry->in.done = 0;
+               entry = find_entry(sel->this, NULL, sel->id, TRUE, FALSE);
+               if (!entry)
+               {
+                       break;
+               }
+
+               found = array_remove(entry->queue, ARRAY_HEAD, &chunk);
+               if (!found)
+               {
+                       entry->has_processor = FALSE;
+               }
+               id = entry->id;
+               put_entry(sel->this, entry, TRUE, FALSE);
+               if (!found)
+               {
+                       break;
+               }
+
+               thread_cleanup_push(free, chunk.ptr);
+               sel->this->inbound(sel->this->user, id, chunk);
+               thread_cleanup_pop(TRUE);
        }
-       entry->mutex->unlock(entry->mutex);
+       return JOB_REQUEUE_NONE;
+}
 
-       if (data.len)
+/**
+ * Process incoming messages
+ */
+CALLBACK(on_read, bool,
+       private_vici_socket_t *this, stream_t *stream)
+{
+       entry_selector_t *sel;
+       entry_t *entry;
+       bool ret = FALSE;
+
+       entry = find_entry(this, stream, 0, TRUE, FALSE);
+       if (entry)
        {
-               thread_cleanup_push(free, data.ptr);
-               entry->this->inbound(entry->this->user, entry->id, data);
-               thread_cleanup_pop(TRUE);
+               ret = do_read(this, entry, stream);
+               if (!ret)
+               {
+                       disconnect(this, entry->id);
+               }
+               else if (entry->in.buf.len == entry->in.done)
+               {
+                       array_insert(entry->queue, ARRAY_TAIL, &entry->in.buf);
+                       entry->in.buf = chunk_empty;
+                       entry->in.hdrlen = entry->in.done = 0;
+
+                       if (!entry->has_processor)
+                       {
+                               INIT(sel,
+                                       .this = this,
+                                       .id = entry->id,
+                               );
+                               lib->processor->queue_job(lib->processor,
+                                                       (job_t*)callback_job_create(process_queue,
+                                                                                                               sel, free, NULL));
+                               entry->has_processor = TRUE;
+                       }
+               }
+               put_entry(this, entry, TRUE, FALSE);
        }
 
        return ret;
@@ -386,7 +534,8 @@ CALLBACK(on_read, bool,
 /**
  * Process connection request
  */
-static bool on_accept(private_vici_socket_t *this, stream_t *stream)
+CALLBACK(on_accept, bool,
+       private_vici_socket_t *this, stream_t *stream)
 {
        entry_t *entry;
        u_int id;
@@ -398,13 +547,18 @@ static bool on_accept(private_vici_socket_t *this, stream_t *stream)
                .stream = stream,
                .id = id,
                .out = array_create(0, 0),
-               .mutex = mutex_create(MUTEX_TYPE_RECURSIVE),
+               .queue = array_create(sizeof(chunk_t), 0),
+               .cond = condvar_create(CONDVAR_TYPE_DEFAULT),
+               .readers = 1,
        );
 
-       this->lock->write_lock(this->lock);
+       this->mutex->lock(this->mutex);
        this->connections->insert_last(this->connections, entry);
-       stream->on_read(stream, on_read, entry);
-       this->lock->unlock(this->lock);
+       this->mutex->unlock(this->mutex);
+
+       stream->on_read(stream, on_read, this);
+
+       put_entry(this, entry, TRUE, FALSE);
 
        this->connect(this->user, id);
 
@@ -412,22 +566,19 @@ static bool on_accept(private_vici_socket_t *this, stream_t *stream)
 }
 
 /**
- * Enable on_write callback to send data
+ * Async callback to enable writer
  */
-CALLBACK(on_write_async, job_requeue_t,
-       entry_data_t *data)
+CALLBACK(enable_writer, job_requeue_t,
+       entry_selector_t *sel)
 {
-       private_vici_socket_t *this = data->this;
        entry_t *entry;
 
-       this->lock->read_lock(this->lock);
-       entry = find_entry(this, data->id, FALSE);
+       entry = find_entry(sel->this, NULL, sel->id, FALSE, TRUE);
        if (entry)
        {
-               entry->stream->on_write(entry->stream, on_write, entry);
+               entry->stream->on_write(entry->stream, on_write, sel->this);
+               put_entry(sel->this, entry, FALSE, TRUE);
        }
-       this->lock->unlock(this->lock);
-
        return JOB_REQUEUE_NONE;
 }
 
@@ -436,12 +587,11 @@ METHOD(vici_socket_t, send_, void,
 {
        if (msg.len <= (u_int16_t)~0)
        {
-               entry_data_t *data;
+               entry_selector_t *sel;
                msg_buf_t *out;
                entry_t *entry;
 
-               this->lock->read_lock(this->lock);
-               entry = find_entry(this, id, FALSE);
+               entry = find_entry(this, NULL, id, FALSE, TRUE);
                if (entry)
                {
                        INIT(out,
@@ -449,28 +599,24 @@ METHOD(vici_socket_t, send_, void,
                        );
                        htoun16(out->hdr, msg.len);
 
-                       entry->mutex->lock(entry->mutex);
                        array_insert(entry->out, ARRAY_TAIL, out);
-                       entry->mutex->unlock(entry->mutex);
-
                        if (array_count(entry->out) == 1)
-                       {
-                               INIT(data,
+                       {       /* asynchronously re-enable on_write callback when we get data */
+                               INIT(sel,
                                        .this = this,
                                        .id = entry->id,
                                );
-                               /* asynchronously enable writing, as this might be called
-                                * from the on_read() callback. */
                                lib->processor->queue_job(lib->processor,
-                                                       (job_t*)callback_job_create(on_write_async,
-                                                                                                               data, free, NULL));
+                                                       (job_t*)callback_job_create(enable_writer,
+                                                                                                               sel, free, NULL));
                        }
+                       put_entry(this, entry, FALSE, TRUE);
                }
                else
                {
                        DBG1(DBG_CFG, "vici connection %u unknown", id);
+                       chunk_clear(&msg);
                }
-               this->lock->unlock(this->lock);
        }
        else
        {
@@ -484,7 +630,7 @@ METHOD(vici_socket_t, destroy, void,
 {
        DESTROY_IF(this->service);
        this->connections->destroy_function(this->connections, destroy_entry);
-       this->lock->destroy(this->lock);
+       this->mutex->destroy(this->mutex);
        free(this);
 }
 
@@ -502,7 +648,7 @@ vici_socket_t *vici_socket_create(char *uri, vici_inbound_cb_t inbound,
                        .send = _send_,
                        .destroy = _destroy,
                },
-               .lock = rwlock_create(RWLOCK_TYPE_DEFAULT),
+               .mutex = mutex_create(MUTEX_TYPE_DEFAULT),
                .connections = linked_list_create(),
                .inbound = inbound,
                .connect = connect,
@@ -517,8 +663,8 @@ vici_socket_t *vici_socket_create(char *uri, vici_inbound_cb_t inbound,
                destroy(this);
                return NULL;
        }
-       this->service->on_accept(this->service, (stream_service_cb_t)on_accept,
-                                                        this, JOB_PRIO_CRITICAL, 0);
+       this->service->on_accept(this->service, on_accept, this,
+                                                        JOB_PRIO_CRITICAL, 0);
 
        return &this->public;
 }