kernel-netlink: Support parallel Netlink queries
authorMartin Willi <martin@revosec.ch>
Thu, 10 Jul 2014 14:28:44 +0000 (16:28 +0200)
committerMartin Willi <martin@revosec.ch>
Fri, 21 Nov 2014 09:55:45 +0000 (10:55 +0100)
Instead of locking the socket exclusively to wait for replies, use watcher
to wait for and read in responses asynchronously. This allows multiple parallel
Netlink queries, which can significantly improve performance if the kernel
Netlink layer has longer latencies and supports parallel queries.

For vanilla Linux, parallel queries don't make much sense, as it usually returns
EBUSY for the relevant dump requests. This requires a retry, and in the end
makes queries more expensive under high load.

Instead of checking the Netlink message sequence number to detect multi-part
messages, this code now relies on the NLM_F_MULTI flag to detect them. This
has previously been avoided (by 1d51abb7). It is unclear if the flag did not
work correctly on very old Linux kernels, or if the flag was not used
appropriately by strongSwan. The flag seems to work just fine back to 2.6.18,
which is a kernel still in use by RedHat/CentOS 5.

src/libhydra/plugins/kernel_netlink/kernel_netlink_shared.c

index b4cece7..6e1dd8c 100644 (file)
@@ -1,4 +1,6 @@
 /*
+ * Copyright (C) 2014 Martin Willi
+ * Copyright (C) 2014 revosec AG
  * Copyright (C) 2008 Tobias Brunner
  * Hochschule fuer Technik Rapperswil
  *
@@ -23,6 +25,9 @@
 
 #include <utils/debug.h>
 #include <threading/mutex.h>
+#include <threading/condvar.h>
+#include <collections/array.h>
+#include <collections/hashtable.h>
 
 typedef struct private_netlink_socket_t private_netlink_socket_t;
 
@@ -30,20 +35,26 @@ typedef struct private_netlink_socket_t private_netlink_socket_t;
  * Private variables and functions of netlink_socket_t class.
  */
 struct private_netlink_socket_t {
+
        /**
         * public part of the netlink_socket_t object.
         */
        netlink_socket_t public;
 
        /**
-        * mutex to lock access to netlink socket
+        * mutex to lock access entries
         */
        mutex_t *mutex;
 
        /**
-        * current sequence number for netlink request
+        * Netlink request entries currently active, uintptr_t seq => entry_t
+        */
+       hashtable_t *entries;
+
+       /**
+        * Current sequence number for Netlink requests
         */
-       int seq;
+       refcount_t seq;
 
        /**
         * netlink socket
@@ -57,110 +68,212 @@ struct private_netlink_socket_t {
 };
 
 /**
- * Imported from kernel_netlink_ipsec.c
+ * Request entry the answer for a waiting thread is collected in
  */
-extern enum_name_t *xfrm_msg_names;
+typedef struct {
+       /** Condition variable thread is waiting */
+       condvar_t *condvar;
+       /** Array of hdrs in a multi-message response, as struct nlmsghdr* */
+       array_t *hdrs;
+       /** All response messages received? */
+       bool complete;
+} entry_t;
 
-METHOD(netlink_socket_t, netlink_send, status_t,
-       private_netlink_socket_t *this, struct nlmsghdr *in, struct nlmsghdr **out,
-       size_t *out_len)
+/**
+ * Clean up a thread waiting entry
+ */
+static void destroy_entry(entry_t *entry)
 {
-       union {
-               struct nlmsghdr hdr;
-               u_char bytes[4096];
-       } response;
-       struct sockaddr_nl addr;
-       chunk_t result = chunk_empty;
-       int len;
-
-       this->mutex->lock(this->mutex);
-
-       in->nlmsg_seq = ++this->seq;
-       in->nlmsg_pid = getpid();
+       entry->condvar->destroy(entry->condvar);
+       array_destroy_function(entry->hdrs, (void*)free, NULL);
+       free(entry);
+}
 
-       memset(&addr, 0, sizeof(addr));
-       addr.nl_family = AF_NETLINK;
-       addr.nl_pid = 0;
-       addr.nl_groups = 0;
+/**
+ * Write a Netlink message to socket
+ */
+static bool write_msg(private_netlink_socket_t *this, struct nlmsghdr *msg)
+{
+       struct sockaddr_nl addr = {
+               .nl_family = AF_NETLINK,
+       };
+       int len;
 
-       if (this->names)
-       {
-               DBG3(DBG_KNL, "sending %N: %b",
-                        this->names, in->nlmsg_type, in, in->nlmsg_len);
-       }
        while (TRUE)
        {
-               len = sendto(this->socket, in, in->nlmsg_len, 0,
+               len = sendto(this->socket, msg, msg->nlmsg_len, 0,
                                         (struct sockaddr*)&addr, sizeof(addr));
-
-               if (len != in->nlmsg_len)
+               if (len != msg->nlmsg_len)
                {
                        if (errno == EINTR)
                        {
-                               /* interrupted, try again */
                                continue;
                        }
-                       this->mutex->unlock(this->mutex);
-                       DBG1(DBG_KNL, "error sending to netlink socket: %s", strerror(errno));
-                       return FAILED;
+                       DBG1(DBG_KNL, "netlink write error: %s", strerror(errno));
+                       return FALSE;
                }
-               break;
+               return TRUE;
        }
+}
 
-       while (TRUE)
+/**
+ * Read a single Netlink message from socket
+ */
+static size_t read_msg(private_netlink_socket_t *this,
+                                          char buf[4096], size_t buflen, bool block)
+{
+       ssize_t len;
+
+       len = recv(this->socket, buf, buflen, block ? 0 : MSG_DONTWAIT);
+       if (len == buflen)
+       {
+               DBG1(DBG_KNL, "netlink response exceeds buffer size");
+               return 0;
+       }
+       if (len < 0)
        {
-               len = recv(this->socket, &response, sizeof(response), 0);
-               if (len < 0)
+               if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR)
                {
-                       if (errno == EINTR)
-                       {
-                               DBG1(DBG_KNL, "got interrupted");
-                               /* interrupted, try again */
-                               continue;
-                       }
-                       DBG1(DBG_KNL, "error reading from netlink socket: %s", strerror(errno));
-                       this->mutex->unlock(this->mutex);
-                       free(result.ptr);
-                       return FAILED;
+                       DBG1(DBG_KNL, "netlink read error: %s", strerror(errno));
                }
-               if (!NLMSG_OK(&response.hdr, len))
+               return 0;
+       }
+       return len;
+}
+
+/**
+ * Queue received response message
+ */
+static bool queue(private_netlink_socket_t *this, struct nlmsghdr *buf)
+{
+       struct nlmsghdr *hdr;
+       entry_t *entry;
+       uintptr_t seq;
+
+       seq = (uintptr_t)buf->nlmsg_seq;
+
+       this->mutex->lock(this->mutex);
+       entry = this->entries->get(this->entries, (void*)seq);
+       if (entry)
+       {
+               hdr = malloc(buf->nlmsg_len);
+               memcpy(hdr, buf, buf->nlmsg_len);
+               array_insert(entry->hdrs, ARRAY_TAIL, hdr);
+               if (hdr->nlmsg_type == NLMSG_DONE || !(hdr->nlmsg_flags & NLM_F_MULTI))
                {
-                       DBG1(DBG_KNL, "received corrupted netlink message");
-                       this->mutex->unlock(this->mutex);
-                       free(result.ptr);
-                       return FAILED;
+                       entry->complete = TRUE;
+                       entry->condvar->signal(entry->condvar);
                }
-               if (response.hdr.nlmsg_seq != this->seq)
+       }
+       else
+       {
+               DBG1(DBG_KNL, "received unknown netlink seq %u, ignored", seq);
+       }
+       this->mutex->unlock(this->mutex);
+
+       return entry != NULL;
+}
+
+/**
+ * Read and queue response message, optionally blocking
+ */
+static void read_and_queue(private_netlink_socket_t *this, bool block)
+{
+       struct nlmsghdr *hdr;
+       union {
+               struct nlmsghdr hdr;
+               char bytes[4096];
+       } buf;
+       size_t len;
+
+       len = read_msg(this, buf.bytes, sizeof(buf.bytes), block);
+       if (len)
+       {
+               hdr = &buf.hdr;
+               while (NLMSG_OK(hdr, len))
                {
-                       DBG1(DBG_KNL, "received invalid netlink sequence number");
-                       if (response.hdr.nlmsg_seq < this->seq)
+                       if (!queue(this, hdr))
                        {
-                               continue;
+                               break;
                        }
-                       this->mutex->unlock(this->mutex);
-                       free(result.ptr);
-                       return FAILED;
+                       hdr = NLMSG_NEXT(hdr, len);
+               }
+       }
+}
+
+CALLBACK(watch, bool,
+       private_netlink_socket_t *this, int fd, watcher_event_t event)
+{
+       if (event == WATCHER_READ)
+       {
+               read_and_queue(this, FALSE);
+       }
+       return TRUE;
+}
+
+METHOD(netlink_socket_t, netlink_send, status_t,
+       private_netlink_socket_t *this, struct nlmsghdr *in, struct nlmsghdr **out,
+       size_t *out_len)
+{
+       struct nlmsghdr *hdr;
+       chunk_t result = {};
+       entry_t *entry;
+       uintptr_t seq;
+
+       seq = ref_get(&this->seq);
+       in->nlmsg_seq = seq;
+       in->nlmsg_pid = getpid();
+
+       if (this->names)
+       {
+               DBG3(DBG_KNL, "sending %N %u: %b", this->names, in->nlmsg_type,
+                        (u_int)seq, in, in->nlmsg_len);
+       }
+
+       this->mutex->lock(this->mutex);
+       if (!write_msg(this, in))
+       {
+               this->mutex->unlock(this->mutex);
+               return FAILED;
+       }
+
+       INIT(entry,
+               .condvar = condvar_create(CONDVAR_TYPE_DEFAULT),
+               .hdrs = array_create(0, 0),
+       );
+       this->entries->put(this->entries, (void*)seq, entry);
+
+       while (!entry->complete)
+       {
+               if (lib->watcher->get_state(lib->watcher) == WATCHER_RUNNING)
+               {
+                       entry->condvar->wait(entry->condvar, this->mutex);
+               }
+               else
+               {       /* During (de-)initialization, no watcher thread is active.
+                        * collect responses ourselves. */
+                       read_and_queue(this, TRUE);
                }
+       }
+       this->entries->remove(this->entries, (void*)seq);
 
-               result = chunk_cat("mc", result, chunk_create(response.bytes, len));
+       this->mutex->unlock(this->mutex);
 
-               /* NLM_F_MULTI flag does not seem to be set correctly, we use sequence
-                * numbers to detect multi header messages */
-               len = recv(this->socket, &response.hdr, sizeof(response.hdr),
-                                  MSG_PEEK | MSG_DONTWAIT);
-               if (len == sizeof(response.hdr) && response.hdr.nlmsg_seq == this->seq)
+       while (array_remove(entry->hdrs, ARRAY_HEAD, &hdr))
+       {
+               if (this->names)
                {
-                       /* seems to be multipart */
-                       continue;
+                       DBG3(DBG_KNL, "received %N %u: %b", this->names, hdr->nlmsg_type,
+                                hdr->nlmsg_seq, hdr, hdr->nlmsg_len);
                }
-               break;
+               result = chunk_cat("mm", result,
+                                                  chunk_create((char*)hdr, hdr->nlmsg_len));
        }
+       destroy_entry(entry);
 
        *out_len = result.len;
        *out = (struct nlmsghdr*)result.ptr;
 
-       this->mutex->unlock(this->mutex);
-
        return SUCCESS;
 }
 
@@ -221,8 +334,10 @@ METHOD(netlink_socket_t, destroy, void,
 {
        if (this->socket != -1)
        {
+               lib->watcher->remove(lib->watcher, this->socket);
                close(this->socket);
        }
+       this->entries->destroy(this->entries);
        this->mutex->destroy(this->mutex);
        free(this);
 }
@@ -244,8 +359,9 @@ netlink_socket_t *netlink_socket_create(int protocol, enum_name_t *names)
                        .destroy = _destroy,
                },
                .seq = 200,
-               .mutex = mutex_create(MUTEX_TYPE_DEFAULT),
+               .mutex = mutex_create(MUTEX_TYPE_RECURSIVE),
                .socket = socket(AF_NETLINK, SOCK_RAW, protocol),
+               .entries = hashtable_create(hashtable_hash_ptr, hashtable_equals_ptr, 4),
                .names = names,
        );
 
@@ -262,6 +378,8 @@ netlink_socket_t *netlink_socket_create(int protocol, enum_name_t *names)
                return NULL;
        }
 
+       lib->watcher->add(lib->watcher, this->socket, WATCHER_READ, watch, this);
+
        return &this->public;
 }