8ecb89fc96faef0ff16fefaeaacb3bb148d21af9
[strongswan.git] / src / libstrongswan / networking / streams / stream.c
1 /*
2 * Copyright (C) 2013 Martin Willi
3 * Copyright (C) 2013 revosec AG
4 *
5 * This program is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License as published by the
7 * Free Software Foundation; either version 2 of the License, or (at your
8 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
9 *
10 * This program is distributed in the hope that it will be useful, but
11 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
12 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 * for more details.
14 */
15
16 #include <library.h>
17 #include <errno.h>
18 #include <unistd.h>
19 #include <limits.h>
20
21 typedef struct private_stream_t private_stream_t;
22
23 /**
24 * Private data of an stream_t object.
25 */
26 struct private_stream_t {
27
28 /**
29 * Public stream_t interface.
30 */
31 stream_t public;
32
33 /**
34 * Underlying socket
35 */
36 int fd;
37
38 /**
39 * Callback if data is ready to read
40 */
41 stream_cb_t read_cb;
42
43 /**
44 * Data for read-ready callback
45 */
46 void *read_data;
47
48 /**
49 * Callback if write is non-blocking
50 */
51 stream_cb_t write_cb;
52
53 /**
54 * Data for write-ready callback
55 */
56 void *write_data;
57 };
58
59 METHOD(stream_t, read_, ssize_t,
60 private_stream_t *this, void *buf, size_t len, bool block)
61 {
62 while (TRUE)
63 {
64 ssize_t ret;
65
66 if (block)
67 {
68 ret = read(this->fd, buf, len);
69 }
70 else
71 {
72 ret = recv(this->fd, buf, len, MSG_DONTWAIT);
73 if (ret == -1 && errno == EAGAIN)
74 {
75 /* unify EGAIN and EWOULDBLOCK */
76 errno = EWOULDBLOCK;
77 }
78 }
79 if (ret == -1 && errno == EINTR)
80 { /* interrupted, try again */
81 continue;
82 }
83 return ret;
84 }
85 }
86
87 METHOD(stream_t, read_all, bool,
88 private_stream_t *this, void *buf, size_t len)
89 {
90 ssize_t ret;
91
92 while (len)
93 {
94 ret = read_(this, buf, len, TRUE);
95 if (ret < 0)
96 {
97 return FALSE;
98 }
99 if (ret == 0)
100 {
101 errno = ECONNRESET;
102 return FALSE;
103 }
104 len -= ret;
105 buf += ret;
106 }
107 return TRUE;
108 }
109
110 METHOD(stream_t, write_, ssize_t,
111 private_stream_t *this, void *buf, size_t len, bool block)
112 {
113 ssize_t ret;
114
115 while (TRUE)
116 {
117 if (block)
118 {
119 ret = write(this->fd, buf, len);
120 }
121 else
122 {
123 ret = send(this->fd, buf, len, MSG_DONTWAIT);
124 if (ret == -1 && errno == EAGAIN)
125 {
126 /* unify EGAIN and EWOULDBLOCK */
127 errno = EWOULDBLOCK;
128 }
129 }
130 if (ret == -1 && errno == EINTR)
131 { /* interrupted, try again */
132 continue;
133 }
134 return ret;
135 }
136 }
137
138 METHOD(stream_t, write_all, bool,
139 private_stream_t *this, void *buf, size_t len)
140 {
141 ssize_t ret;
142
143 while (len)
144 {
145 ret = write_(this, buf, len, TRUE);
146 if (ret < 0)
147 {
148 return FALSE;
149 }
150 if (ret == 0)
151 {
152 errno = ECONNRESET;
153 return FALSE;
154 }
155 len -= ret;
156 buf += ret;
157 }
158 return TRUE;
159 }
160
161 /**
162 * Remove a registered watcher
163 */
164 static void remove_watcher(private_stream_t *this)
165 {
166 if (this->read_cb || this->write_cb)
167 {
168 lib->watcher->remove(lib->watcher, this->fd);
169 }
170 }
171
172 /**
173 * Watcher callback
174 */
175 static bool watch(private_stream_t *this, int fd, watcher_event_t event)
176 {
177 bool keep = FALSE;
178 stream_cb_t cb;
179
180 switch (event)
181 {
182 case WATCHER_READ:
183 cb = this->read_cb;
184 this->read_cb = NULL;
185 keep = cb(this->read_data, &this->public);
186 if (keep)
187 {
188 this->read_cb = cb;
189 }
190 break;
191 case WATCHER_WRITE:
192 cb = this->write_cb;
193 this->write_cb = NULL;
194 keep = cb(this->write_data, &this->public);
195 if (keep)
196 {
197 this->write_cb = cb;
198 }
199 break;
200 case WATCHER_EXCEPT:
201 break;
202 }
203 return keep;
204 }
205
206 /**
207 * Register watcher for stream callbacks
208 */
209 static void add_watcher(private_stream_t *this)
210 {
211 watcher_event_t events = 0;
212
213 if (this->read_cb)
214 {
215 events |= WATCHER_READ;
216 }
217 if (this->write_cb)
218 {
219 events |= WATCHER_WRITE;
220 }
221 if (events)
222 {
223 lib->watcher->add(lib->watcher, this->fd, events,
224 (watcher_cb_t)watch, this);
225 }
226 }
227
228 METHOD(stream_t, on_read, void,
229 private_stream_t *this, stream_cb_t cb, void *data)
230 {
231 remove_watcher(this);
232
233 this->read_cb = cb;
234 this->read_data = data;
235
236 add_watcher(this);
237 }
238
239 METHOD(stream_t, on_write, void,
240 private_stream_t *this, stream_cb_t cb, void *data)
241 {
242 remove_watcher(this);
243
244 this->write_cb = cb;
245 this->write_data = data;
246
247 add_watcher(this);
248 }
249
250 METHOD(stream_t, get_file, FILE*,
251 private_stream_t *this)
252 {
253 FILE *file;
254 int fd;
255
256 /* fclose() closes the FD passed to fdopen(), so dup() it */
257 fd = dup(this->fd);
258 if (fd == -1)
259 {
260 return NULL;
261 }
262 file = fdopen(fd, "w+");
263 if (!file)
264 {
265 close(fd);
266 }
267 return file;
268 }
269
270 METHOD(stream_t, destroy, void,
271 private_stream_t *this)
272 {
273 remove_watcher(this);
274 close(this->fd);
275 free(this);
276 }
277
278 /**
279 * See header
280 */
281 stream_t *stream_create_from_fd(int fd)
282 {
283 private_stream_t *this;
284
285 INIT(this,
286 .public = {
287 .read = _read_,
288 .read_all = _read_all,
289 .on_read = _on_read,
290 .write = _write_,
291 .write_all = _write_all,
292 .on_write = _on_write,
293 .get_file = _get_file,
294 .destroy = _destroy,
295 },
296 .fd = fd,
297 );
298
299 return &this->public;
300 }
301
302 /**
303 * See header
304 */
305 int stream_parse_uri_unix(char *uri, struct sockaddr_un *addr)
306 {
307 if (!strpfx(uri, "unix://"))
308 {
309 return -1;
310 }
311 uri += strlen("unix://");
312
313 memset(addr, 0, sizeof(*addr));
314 addr->sun_family = AF_UNIX;
315 strncpy(addr->sun_path, uri, sizeof(addr->sun_path));
316 addr->sun_path[sizeof(addr->sun_path)-1] = '\0';
317
318 return offsetof(struct sockaddr_un, sun_path) + strlen(addr->sun_path);
319 }
320
321 /**
322 * See header
323 */
324 stream_t *stream_create_unix(char *uri)
325 {
326 struct sockaddr_un addr;
327 int len, fd;
328
329 len = stream_parse_uri_unix(uri, &addr);
330 if (len == -1)
331 {
332 DBG1(DBG_NET, "invalid stream URI: '%s'", uri);
333 return NULL;
334 }
335 fd = socket(AF_UNIX, SOCK_STREAM, 0);
336 if (fd < 0)
337 {
338 DBG1(DBG_NET, "opening socket '%s' failed: %s", uri, strerror(errno));
339 return NULL;
340 }
341 if (connect(fd, (struct sockaddr*)&addr, len) < 0)
342 {
343 DBG1(DBG_NET, "connecting to '%s' failed: %s", uri, strerror(errno));
344 close(fd);
345 return NULL;
346 }
347 return stream_create_from_fd(fd);
348 }
349
350 /**
351 * See header.
352 */
353 int stream_parse_uri_tcp(char *uri, struct sockaddr *addr)
354 {
355 char *pos, buf[128];
356 host_t *host;
357 u_long port;
358 int len;
359
360 if (!strpfx(uri, "tcp://"))
361 {
362 return -1;
363 }
364 uri += strlen("tcp://");
365 pos = strrchr(uri, ':');
366 if (!pos)
367 {
368 return -1;
369 }
370 if (*uri == '[' && pos > uri && *(pos - 1) == ']')
371 {
372 /* IPv6 URI */
373 snprintf(buf, sizeof(buf), "%.*s", (int)(pos - uri - 2), uri + 1);
374 }
375 else
376 {
377 snprintf(buf, sizeof(buf), "%.*s", (int)(pos - uri), uri);
378 }
379 port = strtoul(pos + 1, &pos, 10);
380 if (port == ULONG_MAX || *pos || port > 65535)
381 {
382 return -1;
383 }
384 host = host_create_from_dns(buf, AF_UNSPEC, port);
385 if (!host)
386 {
387 return -1;
388 }
389 len = *host->get_sockaddr_len(host);
390 memcpy(addr, host->get_sockaddr(host), len);
391 host->destroy(host);
392 return len;
393 }
394
395 /**
396 * See header
397 */
398 stream_t *stream_create_tcp(char *uri)
399 {
400 union {
401 struct sockaddr_in in;
402 struct sockaddr_in6 in6;
403 struct sockaddr sa;
404 } addr;
405 int fd, len;
406
407 len = stream_parse_uri_tcp(uri, &addr.sa);
408 if (len == -1)
409 {
410 DBG1(DBG_NET, "invalid stream URI: '%s'", uri);
411 return NULL;
412 }
413 fd = socket(addr.sa.sa_family, SOCK_STREAM, 0);
414 if (fd < 0)
415 {
416 DBG1(DBG_NET, "opening socket '%s' failed: %s", uri, strerror(errno));
417 return NULL;
418 }
419 if (connect(fd, &addr.sa, len))
420 {
421 DBG1(DBG_NET, "connecting to '%s' failed: %s", uri, strerror(errno));
422 close(fd);
423 return NULL;
424 }
425 return stream_create_from_fd(fd);
426 }