kernel-netlink: Support parallel Netlink queries
[strongswan.git] / src / libhydra / plugins / kernel_netlink / kernel_netlink_shared.c
index b4cece7..6e1dd8c 100644 (file)
@@ -1,4 +1,6 @@
 /*
+ * Copyright (C) 2014 Martin Willi
+ * Copyright (C) 2014 revosec AG
  * Copyright (C) 2008 Tobias Brunner
  * Hochschule fuer Technik Rapperswil
  *
@@ -23,6 +25,9 @@
 
 #include <utils/debug.h>
 #include <threading/mutex.h>
+#include <threading/condvar.h>
+#include <collections/array.h>
+#include <collections/hashtable.h>
 
 typedef struct private_netlink_socket_t private_netlink_socket_t;
 
@@ -30,20 +35,26 @@ typedef struct private_netlink_socket_t private_netlink_socket_t;
  * Private variables and functions of netlink_socket_t class.
  */
 struct private_netlink_socket_t {
+
        /**
         * public part of the netlink_socket_t object.
         */
        netlink_socket_t public;
 
        /**
-        * mutex to lock access to netlink socket
+        * mutex to lock access entries
         */
        mutex_t *mutex;
 
        /**
-        * current sequence number for netlink request
+        * Netlink request entries currently active, uintptr_t seq => entry_t
+        */
+       hashtable_t *entries;
+
+       /**
+        * Current sequence number for Netlink requests
         */
-       int seq;
+       refcount_t seq;
 
        /**
         * netlink socket
@@ -57,110 +68,212 @@ struct private_netlink_socket_t {
 };
 
 /**
- * Imported from kernel_netlink_ipsec.c
+ * Request entry the answer for a waiting thread is collected in
  */
-extern enum_name_t *xfrm_msg_names;
+typedef struct {
+       /** Condition variable thread is waiting */
+       condvar_t *condvar;
+       /** Array of hdrs in a multi-message response, as struct nlmsghdr* */
+       array_t *hdrs;
+       /** All response messages received? */
+       bool complete;
+} entry_t;
 
-METHOD(netlink_socket_t, netlink_send, status_t,
-       private_netlink_socket_t *this, struct nlmsghdr *in, struct nlmsghdr **out,
-       size_t *out_len)
+/**
+ * Clean up a thread waiting entry
+ */
+static void destroy_entry(entry_t *entry)
 {
-       union {
-               struct nlmsghdr hdr;
-               u_char bytes[4096];
-       } response;
-       struct sockaddr_nl addr;
-       chunk_t result = chunk_empty;
-       int len;
-
-       this->mutex->lock(this->mutex);
-
-       in->nlmsg_seq = ++this->seq;
-       in->nlmsg_pid = getpid();
+       entry->condvar->destroy(entry->condvar);
+       array_destroy_function(entry->hdrs, (void*)free, NULL);
+       free(entry);
+}
 
-       memset(&addr, 0, sizeof(addr));
-       addr.nl_family = AF_NETLINK;
-       addr.nl_pid = 0;
-       addr.nl_groups = 0;
+/**
+ * Write a Netlink message to socket
+ */
+static bool write_msg(private_netlink_socket_t *this, struct nlmsghdr *msg)
+{
+       struct sockaddr_nl addr = {
+               .nl_family = AF_NETLINK,
+       };
+       int len;
 
-       if (this->names)
-       {
-               DBG3(DBG_KNL, "sending %N: %b",
-                        this->names, in->nlmsg_type, in, in->nlmsg_len);
-       }
        while (TRUE)
        {
-               len = sendto(this->socket, in, in->nlmsg_len, 0,
+               len = sendto(this->socket, msg, msg->nlmsg_len, 0,
                                         (struct sockaddr*)&addr, sizeof(addr));
-
-               if (len != in->nlmsg_len)
+               if (len != msg->nlmsg_len)
                {
                        if (errno == EINTR)
                        {
-                               /* interrupted, try again */
                                continue;
                        }
-                       this->mutex->unlock(this->mutex);
-                       DBG1(DBG_KNL, "error sending to netlink socket: %s", strerror(errno));
-                       return FAILED;
+                       DBG1(DBG_KNL, "netlink write error: %s", strerror(errno));
+                       return FALSE;
                }
-               break;
+               return TRUE;
        }
+}
 
-       while (TRUE)
+/**
+ * Read a single Netlink message from socket
+ */
+static size_t read_msg(private_netlink_socket_t *this,
+                                          char buf[4096], size_t buflen, bool block)
+{
+       ssize_t len;
+
+       len = recv(this->socket, buf, buflen, block ? 0 : MSG_DONTWAIT);
+       if (len == buflen)
+       {
+               DBG1(DBG_KNL, "netlink response exceeds buffer size");
+               return 0;
+       }
+       if (len < 0)
        {
-               len = recv(this->socket, &response, sizeof(response), 0);
-               if (len < 0)
+               if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR)
                {
-                       if (errno == EINTR)
-                       {
-                               DBG1(DBG_KNL, "got interrupted");
-                               /* interrupted, try again */
-                               continue;
-                       }
-                       DBG1(DBG_KNL, "error reading from netlink socket: %s", strerror(errno));
-                       this->mutex->unlock(this->mutex);
-                       free(result.ptr);
-                       return FAILED;
+                       DBG1(DBG_KNL, "netlink read error: %s", strerror(errno));
                }
-               if (!NLMSG_OK(&response.hdr, len))
+               return 0;
+       }
+       return len;
+}
+
+/**
+ * Queue received response message
+ */
+static bool queue(private_netlink_socket_t *this, struct nlmsghdr *buf)
+{
+       struct nlmsghdr *hdr;
+       entry_t *entry;
+       uintptr_t seq;
+
+       seq = (uintptr_t)buf->nlmsg_seq;
+
+       this->mutex->lock(this->mutex);
+       entry = this->entries->get(this->entries, (void*)seq);
+       if (entry)
+       {
+               hdr = malloc(buf->nlmsg_len);
+               memcpy(hdr, buf, buf->nlmsg_len);
+               array_insert(entry->hdrs, ARRAY_TAIL, hdr);
+               if (hdr->nlmsg_type == NLMSG_DONE || !(hdr->nlmsg_flags & NLM_F_MULTI))
                {
-                       DBG1(DBG_KNL, "received corrupted netlink message");
-                       this->mutex->unlock(this->mutex);
-                       free(result.ptr);
-                       return FAILED;
+                       entry->complete = TRUE;
+                       entry->condvar->signal(entry->condvar);
                }
-               if (response.hdr.nlmsg_seq != this->seq)
+       }
+       else
+       {
+               DBG1(DBG_KNL, "received unknown netlink seq %u, ignored", seq);
+       }
+       this->mutex->unlock(this->mutex);
+
+       return entry != NULL;
+}
+
+/**
+ * Read and queue response message, optionally blocking
+ */
+static void read_and_queue(private_netlink_socket_t *this, bool block)
+{
+       struct nlmsghdr *hdr;
+       union {
+               struct nlmsghdr hdr;
+               char bytes[4096];
+       } buf;
+       size_t len;
+
+       len = read_msg(this, buf.bytes, sizeof(buf.bytes), block);
+       if (len)
+       {
+               hdr = &buf.hdr;
+               while (NLMSG_OK(hdr, len))
                {
-                       DBG1(DBG_KNL, "received invalid netlink sequence number");
-                       if (response.hdr.nlmsg_seq < this->seq)
+                       if (!queue(this, hdr))
                        {
-                               continue;
+                               break;
                        }
-                       this->mutex->unlock(this->mutex);
-                       free(result.ptr);
-                       return FAILED;
+                       hdr = NLMSG_NEXT(hdr, len);
+               }
+       }
+}
+
+CALLBACK(watch, bool,
+       private_netlink_socket_t *this, int fd, watcher_event_t event)
+{
+       if (event == WATCHER_READ)
+       {
+               read_and_queue(this, FALSE);
+       }
+       return TRUE;
+}
+
+METHOD(netlink_socket_t, netlink_send, status_t,
+       private_netlink_socket_t *this, struct nlmsghdr *in, struct nlmsghdr **out,
+       size_t *out_len)
+{
+       struct nlmsghdr *hdr;
+       chunk_t result = {};
+       entry_t *entry;
+       uintptr_t seq;
+
+       seq = ref_get(&this->seq);
+       in->nlmsg_seq = seq;
+       in->nlmsg_pid = getpid();
+
+       if (this->names)
+       {
+               DBG3(DBG_KNL, "sending %N %u: %b", this->names, in->nlmsg_type,
+                        (u_int)seq, in, in->nlmsg_len);
+       }
+
+       this->mutex->lock(this->mutex);
+       if (!write_msg(this, in))
+       {
+               this->mutex->unlock(this->mutex);
+               return FAILED;
+       }
+
+       INIT(entry,
+               .condvar = condvar_create(CONDVAR_TYPE_DEFAULT),
+               .hdrs = array_create(0, 0),
+       );
+       this->entries->put(this->entries, (void*)seq, entry);
+
+       while (!entry->complete)
+       {
+               if (lib->watcher->get_state(lib->watcher) == WATCHER_RUNNING)
+               {
+                       entry->condvar->wait(entry->condvar, this->mutex);
+               }
+               else
+               {       /* During (de-)initialization, no watcher thread is active.
+                        * collect responses ourselves. */
+                       read_and_queue(this, TRUE);
                }
+       }
+       this->entries->remove(this->entries, (void*)seq);
 
-               result = chunk_cat("mc", result, chunk_create(response.bytes, len));
+       this->mutex->unlock(this->mutex);
 
-               /* NLM_F_MULTI flag does not seem to be set correctly, we use sequence
-                * numbers to detect multi header messages */
-               len = recv(this->socket, &response.hdr, sizeof(response.hdr),
-                                  MSG_PEEK | MSG_DONTWAIT);
-               if (len == sizeof(response.hdr) && response.hdr.nlmsg_seq == this->seq)
+       while (array_remove(entry->hdrs, ARRAY_HEAD, &hdr))
+       {
+               if (this->names)
                {
-                       /* seems to be multipart */
-                       continue;
+                       DBG3(DBG_KNL, "received %N %u: %b", this->names, hdr->nlmsg_type,
+                                hdr->nlmsg_seq, hdr, hdr->nlmsg_len);
                }
-               break;
+               result = chunk_cat("mm", result,
+                                                  chunk_create((char*)hdr, hdr->nlmsg_len));
        }
+       destroy_entry(entry);
 
        *out_len = result.len;
        *out = (struct nlmsghdr*)result.ptr;
 
-       this->mutex->unlock(this->mutex);
-
        return SUCCESS;
 }
 
@@ -221,8 +334,10 @@ METHOD(netlink_socket_t, destroy, void,
 {
        if (this->socket != -1)
        {
+               lib->watcher->remove(lib->watcher, this->socket);
                close(this->socket);
        }
+       this->entries->destroy(this->entries);
        this->mutex->destroy(this->mutex);
        free(this);
 }
@@ -244,8 +359,9 @@ netlink_socket_t *netlink_socket_create(int protocol, enum_name_t *names)
                        .destroy = _destroy,
                },
                .seq = 200,
-               .mutex = mutex_create(MUTEX_TYPE_DEFAULT),
+               .mutex = mutex_create(MUTEX_TYPE_RECURSIVE),
                .socket = socket(AF_NETLINK, SOCK_RAW, protocol),
+               .entries = hashtable_create(hashtable_hash_ptr, hashtable_equals_ptr, 4),
                .names = names,
        );
 
@@ -262,6 +378,8 @@ netlink_socket_t *netlink_socket_create(int protocol, enum_name_t *names)
                return NULL;
        }
 
+       lib->watcher->add(lib->watcher, this->socket, WATCHER_READ, watch, this);
+
        return &this->public;
 }