kernel-netlink: Support parallel Netlink queries
[strongswan.git] / src / libhydra / plugins / kernel_netlink / kernel_netlink_shared.c
1 /*
2 * Copyright (C) 2014 Martin Willi
3 * Copyright (C) 2014 revosec AG
4 * Copyright (C) 2008 Tobias Brunner
5 * Hochschule fuer Technik Rapperswil
6 *
7 * This program is free software; you can redistribute it and/or modify it
8 * under the terms of the GNU General Public License as published by the
9 * Free Software Foundation; either version 2 of the License, or (at your
10 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
11 *
12 * This program is distributed in the hope that it will be useful, but
13 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
14 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
15 * for more details.
16 */
17
18 #include <sys/socket.h>
19 #include <linux/netlink.h>
20 #include <linux/rtnetlink.h>
21 #include <errno.h>
22 #include <unistd.h>
23
24 #include "kernel_netlink_shared.h"
25
26 #include <utils/debug.h>
27 #include <threading/mutex.h>
28 #include <threading/condvar.h>
29 #include <collections/array.h>
30 #include <collections/hashtable.h>
31
32 typedef struct private_netlink_socket_t private_netlink_socket_t;
33
34 /**
35 * Private variables and functions of netlink_socket_t class.
36 */
37 struct private_netlink_socket_t {
38
39 /**
40 * public part of the netlink_socket_t object.
41 */
42 netlink_socket_t public;
43
44 /**
45 * mutex to lock access entries
46 */
47 mutex_t *mutex;
48
49 /**
50 * Netlink request entries currently active, uintptr_t seq => entry_t
51 */
52 hashtable_t *entries;
53
54 /**
55 * Current sequence number for Netlink requests
56 */
57 refcount_t seq;
58
59 /**
60 * netlink socket
61 */
62 int socket;
63
64 /**
65 * Enum names for Netlink messages
66 */
67 enum_name_t *names;
68 };
69
70 /**
71 * Request entry the answer for a waiting thread is collected in
72 */
73 typedef struct {
74 /** Condition variable thread is waiting */
75 condvar_t *condvar;
76 /** Array of hdrs in a multi-message response, as struct nlmsghdr* */
77 array_t *hdrs;
78 /** All response messages received? */
79 bool complete;
80 } entry_t;
81
82 /**
83 * Clean up a thread waiting entry
84 */
85 static void destroy_entry(entry_t *entry)
86 {
87 entry->condvar->destroy(entry->condvar);
88 array_destroy_function(entry->hdrs, (void*)free, NULL);
89 free(entry);
90 }
91
92 /**
93 * Write a Netlink message to socket
94 */
95 static bool write_msg(private_netlink_socket_t *this, struct nlmsghdr *msg)
96 {
97 struct sockaddr_nl addr = {
98 .nl_family = AF_NETLINK,
99 };
100 int len;
101
102 while (TRUE)
103 {
104 len = sendto(this->socket, msg, msg->nlmsg_len, 0,
105 (struct sockaddr*)&addr, sizeof(addr));
106 if (len != msg->nlmsg_len)
107 {
108 if (errno == EINTR)
109 {
110 continue;
111 }
112 DBG1(DBG_KNL, "netlink write error: %s", strerror(errno));
113 return FALSE;
114 }
115 return TRUE;
116 }
117 }
118
119 /**
120 * Read a single Netlink message from socket
121 */
122 static size_t read_msg(private_netlink_socket_t *this,
123 char buf[4096], size_t buflen, bool block)
124 {
125 ssize_t len;
126
127 len = recv(this->socket, buf, buflen, block ? 0 : MSG_DONTWAIT);
128 if (len == buflen)
129 {
130 DBG1(DBG_KNL, "netlink response exceeds buffer size");
131 return 0;
132 }
133 if (len < 0)
134 {
135 if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR)
136 {
137 DBG1(DBG_KNL, "netlink read error: %s", strerror(errno));
138 }
139 return 0;
140 }
141 return len;
142 }
143
144 /**
145 * Queue received response message
146 */
147 static bool queue(private_netlink_socket_t *this, struct nlmsghdr *buf)
148 {
149 struct nlmsghdr *hdr;
150 entry_t *entry;
151 uintptr_t seq;
152
153 seq = (uintptr_t)buf->nlmsg_seq;
154
155 this->mutex->lock(this->mutex);
156 entry = this->entries->get(this->entries, (void*)seq);
157 if (entry)
158 {
159 hdr = malloc(buf->nlmsg_len);
160 memcpy(hdr, buf, buf->nlmsg_len);
161 array_insert(entry->hdrs, ARRAY_TAIL, hdr);
162 if (hdr->nlmsg_type == NLMSG_DONE || !(hdr->nlmsg_flags & NLM_F_MULTI))
163 {
164 entry->complete = TRUE;
165 entry->condvar->signal(entry->condvar);
166 }
167 }
168 else
169 {
170 DBG1(DBG_KNL, "received unknown netlink seq %u, ignored", seq);
171 }
172 this->mutex->unlock(this->mutex);
173
174 return entry != NULL;
175 }
176
177 /**
178 * Read and queue response message, optionally blocking
179 */
180 static void read_and_queue(private_netlink_socket_t *this, bool block)
181 {
182 struct nlmsghdr *hdr;
183 union {
184 struct nlmsghdr hdr;
185 char bytes[4096];
186 } buf;
187 size_t len;
188
189 len = read_msg(this, buf.bytes, sizeof(buf.bytes), block);
190 if (len)
191 {
192 hdr = &buf.hdr;
193 while (NLMSG_OK(hdr, len))
194 {
195 if (!queue(this, hdr))
196 {
197 break;
198 }
199 hdr = NLMSG_NEXT(hdr, len);
200 }
201 }
202 }
203
204 CALLBACK(watch, bool,
205 private_netlink_socket_t *this, int fd, watcher_event_t event)
206 {
207 if (event == WATCHER_READ)
208 {
209 read_and_queue(this, FALSE);
210 }
211 return TRUE;
212 }
213
214 METHOD(netlink_socket_t, netlink_send, status_t,
215 private_netlink_socket_t *this, struct nlmsghdr *in, struct nlmsghdr **out,
216 size_t *out_len)
217 {
218 struct nlmsghdr *hdr;
219 chunk_t result = {};
220 entry_t *entry;
221 uintptr_t seq;
222
223 seq = ref_get(&this->seq);
224 in->nlmsg_seq = seq;
225 in->nlmsg_pid = getpid();
226
227 if (this->names)
228 {
229 DBG3(DBG_KNL, "sending %N %u: %b", this->names, in->nlmsg_type,
230 (u_int)seq, in, in->nlmsg_len);
231 }
232
233 this->mutex->lock(this->mutex);
234 if (!write_msg(this, in))
235 {
236 this->mutex->unlock(this->mutex);
237 return FAILED;
238 }
239
240 INIT(entry,
241 .condvar = condvar_create(CONDVAR_TYPE_DEFAULT),
242 .hdrs = array_create(0, 0),
243 );
244 this->entries->put(this->entries, (void*)seq, entry);
245
246 while (!entry->complete)
247 {
248 if (lib->watcher->get_state(lib->watcher) == WATCHER_RUNNING)
249 {
250 entry->condvar->wait(entry->condvar, this->mutex);
251 }
252 else
253 { /* During (de-)initialization, no watcher thread is active.
254 * collect responses ourselves. */
255 read_and_queue(this, TRUE);
256 }
257 }
258 this->entries->remove(this->entries, (void*)seq);
259
260 this->mutex->unlock(this->mutex);
261
262 while (array_remove(entry->hdrs, ARRAY_HEAD, &hdr))
263 {
264 if (this->names)
265 {
266 DBG3(DBG_KNL, "received %N %u: %b", this->names, hdr->nlmsg_type,
267 hdr->nlmsg_seq, hdr, hdr->nlmsg_len);
268 }
269 result = chunk_cat("mm", result,
270 chunk_create((char*)hdr, hdr->nlmsg_len));
271 }
272 destroy_entry(entry);
273
274 *out_len = result.len;
275 *out = (struct nlmsghdr*)result.ptr;
276
277 return SUCCESS;
278 }
279
280 METHOD(netlink_socket_t, netlink_send_ack, status_t,
281 private_netlink_socket_t *this, struct nlmsghdr *in)
282 {
283 struct nlmsghdr *out, *hdr;
284 size_t len;
285
286 if (netlink_send(this, in, &out, &len) != SUCCESS)
287 {
288 return FAILED;
289 }
290 hdr = out;
291 while (NLMSG_OK(hdr, len))
292 {
293 switch (hdr->nlmsg_type)
294 {
295 case NLMSG_ERROR:
296 {
297 struct nlmsgerr* err = NLMSG_DATA(hdr);
298
299 if (err->error)
300 {
301 if (-err->error == EEXIST)
302 { /* do not report existing routes */
303 free(out);
304 return ALREADY_DONE;
305 }
306 if (-err->error == ESRCH)
307 { /* do not report missing entries */
308 free(out);
309 return NOT_FOUND;
310 }
311 DBG1(DBG_KNL, "received netlink error: %s (%d)",
312 strerror(-err->error), -err->error);
313 free(out);
314 return FAILED;
315 }
316 free(out);
317 return SUCCESS;
318 }
319 default:
320 hdr = NLMSG_NEXT(hdr, len);
321 continue;
322 case NLMSG_DONE:
323 break;
324 }
325 break;
326 }
327 DBG1(DBG_KNL, "netlink request not acknowledged");
328 free(out);
329 return FAILED;
330 }
331
332 METHOD(netlink_socket_t, destroy, void,
333 private_netlink_socket_t *this)
334 {
335 if (this->socket != -1)
336 {
337 lib->watcher->remove(lib->watcher, this->socket);
338 close(this->socket);
339 }
340 this->entries->destroy(this->entries);
341 this->mutex->destroy(this->mutex);
342 free(this);
343 }
344
345 /**
346 * Described in header.
347 */
348 netlink_socket_t *netlink_socket_create(int protocol, enum_name_t *names)
349 {
350 private_netlink_socket_t *this;
351 struct sockaddr_nl addr = {
352 .nl_family = AF_NETLINK,
353 };
354
355 INIT(this,
356 .public = {
357 .send = _netlink_send,
358 .send_ack = _netlink_send_ack,
359 .destroy = _destroy,
360 },
361 .seq = 200,
362 .mutex = mutex_create(MUTEX_TYPE_RECURSIVE),
363 .socket = socket(AF_NETLINK, SOCK_RAW, protocol),
364 .entries = hashtable_create(hashtable_hash_ptr, hashtable_equals_ptr, 4),
365 .names = names,
366 );
367
368 if (this->socket == -1)
369 {
370 DBG1(DBG_KNL, "unable to create netlink socket");
371 destroy(this);
372 return NULL;
373 }
374 if (bind(this->socket, (struct sockaddr*)&addr, sizeof(addr)))
375 {
376 DBG1(DBG_KNL, "unable to bind netlink socket");
377 destroy(this);
378 return NULL;
379 }
380
381 lib->watcher->add(lib->watcher, this->socket, WATCHER_READ, watch, this);
382
383 return &this->public;
384 }
385
386 /**
387 * Described in header.
388 */
389 void netlink_add_attribute(struct nlmsghdr *hdr, int rta_type, chunk_t data,
390 size_t buflen)
391 {
392 struct rtattr *rta;
393
394 if (NLMSG_ALIGN(hdr->nlmsg_len) + RTA_LENGTH(data.len) > buflen)
395 {
396 DBG1(DBG_KNL, "unable to add attribute, buffer too small");
397 return;
398 }
399
400 rta = (struct rtattr*)(((char*)hdr) + NLMSG_ALIGN(hdr->nlmsg_len));
401 rta->rta_type = rta_type;
402 rta->rta_len = RTA_LENGTH(data.len);
403 memcpy(RTA_DATA(rta), data.ptr, data.len);
404 hdr->nlmsg_len = NLMSG_ALIGN(hdr->nlmsg_len) + rta->rta_len;
405 }
406
407 /**
408 * Described in header.
409 */
410 void* netlink_reserve(struct nlmsghdr *hdr, int buflen, int type, int len)
411 {
412 struct rtattr *rta;
413
414 if (NLMSG_ALIGN(hdr->nlmsg_len) + RTA_LENGTH(len) > buflen)
415 {
416 DBG1(DBG_KNL, "unable to add attribute, buffer too small");
417 return NULL;
418 }
419
420 rta = ((void*)hdr) + NLMSG_ALIGN(hdr->nlmsg_len);
421 rta->rta_type = type;
422 rta->rta_len = RTA_LENGTH(len);
423 hdr->nlmsg_len = NLMSG_ALIGN(hdr->nlmsg_len) + rta->rta_len;
424
425 return RTA_DATA(rta);
426 }