kernel-netlink: Check return value of both halfs when installing default route in...
[strongswan.git] / src / libtls / tls_socket.c
1 /*
2 * Copyright (C) 2010 Martin Willi
3 * Copyright (C) 2010 revosec AG
4 *
5 * This program is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License as published by the
7 * Free Software Foundation; either version 2 of the License, or (at your
8 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
9 *
10 * This program is distributed in the hope that it will be useful, but
11 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
12 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 * for more details.
14 */
15
16 #include "tls_socket.h"
17
18 #include <unistd.h>
19 #include <errno.h>
20
21 #include <utils/debug.h>
22 #include <threading/thread.h>
23
24 /**
25 * Buffer size for plain side I/O
26 */
27 #define PLAIN_BUF_SIZE TLS_MAX_FRAGMENT_LEN
28
29 /**
30 * Buffer size for encrypted side I/O
31 */
32 #define CRYPTO_BUF_SIZE TLS_MAX_FRAGMENT_LEN + 2048
33
34 typedef struct private_tls_socket_t private_tls_socket_t;
35 typedef struct private_tls_application_t private_tls_application_t;
36
37 struct private_tls_application_t {
38
39 /**
40 * Implements tls_application layer.
41 */
42 tls_application_t application;
43
44 /**
45 * Output buffer to write to
46 */
47 chunk_t out;
48
49 /**
50 * Number of bytes written to out
51 */
52 size_t out_done;
53
54 /**
55 * Input buffer to read to
56 */
57 chunk_t in;
58
59 /**
60 * Number of bytes read to in
61 */
62 size_t in_done;
63
64 /**
65 * Cached input data
66 */
67 chunk_t cache;
68
69 /**
70 * Bytes consumed in cache
71 */
72 size_t cache_done;
73
74 /**
75 * Close TLS connection?
76 */
77 bool close;
78 };
79
80 /**
81 * Private data of an tls_socket_t object.
82 */
83 struct private_tls_socket_t {
84
85 /**
86 * Public tls_socket_t interface.
87 */
88 tls_socket_t public;
89
90 /**
91 * TLS application implementation
92 */
93 private_tls_application_t app;
94
95 /**
96 * TLS stack
97 */
98 tls_t *tls;
99
100 /**
101 * Underlying OS socket
102 */
103 int fd;
104 };
105
106 METHOD(tls_application_t, process, status_t,
107 private_tls_application_t *this, bio_reader_t *reader)
108 {
109 chunk_t data;
110 size_t len;
111
112 if (this->close)
113 {
114 return SUCCESS;
115 }
116 len = min(reader->remaining(reader), this->in.len - this->in_done);
117 if (len)
118 { /* copy to read buffer as much as fits in */
119 if (!reader->read_data(reader, len, &data))
120 {
121 return FAILED;
122 }
123 memcpy(this->in.ptr + this->in_done, data.ptr, data.len);
124 this->in_done += data.len;
125 }
126 else
127 { /* read buffer is full, cache for next read */
128 if (!reader->read_data(reader, reader->remaining(reader), &data))
129 {
130 return FAILED;
131 }
132 this->cache = chunk_cat("mc", this->cache, data);
133 }
134 return NEED_MORE;
135 }
136
137 METHOD(tls_application_t, build, status_t,
138 private_tls_application_t *this, bio_writer_t *writer)
139 {
140 if (this->close)
141 {
142 return SUCCESS;
143 }
144 if (this->out.len > this->out_done)
145 {
146 writer->write_data(writer, this->out);
147 this->out_done = this->out.len;
148 return NEED_MORE;
149 }
150 return INVALID_STATE;
151 }
152
153 /**
154 * TLS data exchange loop
155 */
156 static bool exchange(private_tls_socket_t *this, bool wr, bool block)
157 {
158 char buf[CRYPTO_BUF_SIZE], *pos;
159 ssize_t in, out;
160 size_t len;
161 int round = 0, flags;
162
163 for (round = 0; TRUE; round++)
164 {
165 while (TRUE)
166 {
167 len = sizeof(buf);
168 switch (this->tls->build(this->tls, buf, &len, NULL))
169 {
170 case NEED_MORE:
171 case ALREADY_DONE:
172 pos = buf;
173 while (len)
174 {
175 out = write(this->fd, pos, len);
176 if (out == -1)
177 {
178 DBG1(DBG_TLS, "TLS crypto write error: %s",
179 strerror(errno));
180 return FALSE;
181 }
182 len -= out;
183 pos += out;
184 }
185 continue;
186 case INVALID_STATE:
187 break;
188 case SUCCESS:
189 return TRUE;
190 default:
191 return FALSE;
192 }
193 break;
194 }
195 if (wr)
196 {
197 if (this->app.out_done == this->app.out.len)
198 { /* all data written */
199 return TRUE;
200 }
201 }
202 else
203 {
204 if (this->app.in_done == this->app.in.len)
205 { /* buffer fully received */
206 return TRUE;
207 }
208 }
209
210 flags = 0;
211 if (this->app.out_done == this->app.out.len)
212 {
213 if (!block || this->app.in_done)
214 {
215 flags |= MSG_DONTWAIT;
216 }
217 }
218 in = recv(this->fd, buf, sizeof(buf), flags);
219 if (in < 0)
220 {
221 if (errno == EAGAIN || errno == EWOULDBLOCK)
222 {
223 if (this->app.in_done == 0)
224 {
225 /* reading, nothing got yet, and call would block */
226 errno = EWOULDBLOCK;
227 this->app.in_done = -1;
228 }
229 return TRUE;
230 }
231 return FALSE;
232 }
233 if (in == 0)
234 { /* EOF */
235 return TRUE;
236 }
237 switch (this->tls->process(this->tls, buf, in))
238 {
239 case NEED_MORE:
240 break;
241 case SUCCESS:
242 return TRUE;
243 default:
244 return FALSE;
245 }
246 }
247 }
248
249 METHOD(tls_socket_t, read_, ssize_t,
250 private_tls_socket_t *this, void *buf, size_t len, bool block)
251 {
252 if (this->app.cache.len)
253 {
254 size_t cache;
255
256 cache = min(len, this->app.cache.len - this->app.cache_done);
257 memcpy(buf, this->app.cache.ptr + this->app.cache_done, cache);
258
259 this->app.cache_done += cache;
260 if (this->app.cache_done == this->app.cache.len)
261 {
262 chunk_free(&this->app.cache);
263 this->app.cache_done = 0;
264 }
265 return cache;
266 }
267 this->app.in.ptr = buf;
268 this->app.in.len = len;
269 this->app.in_done = 0;
270 if (exchange(this, FALSE, block))
271 {
272 return this->app.in_done;
273 }
274 return -1;
275 }
276
277 METHOD(tls_socket_t, write_, ssize_t,
278 private_tls_socket_t *this, void *buf, size_t len)
279 {
280 this->app.out.ptr = buf;
281 this->app.out.len = len;
282 this->app.out_done = 0;
283 if (exchange(this, TRUE, FALSE))
284 {
285 return this->app.out_done;
286 }
287 return -1;
288 }
289
290 METHOD(tls_socket_t, splice, bool,
291 private_tls_socket_t *this, int rfd, int wfd)
292 {
293 char buf[PLAIN_BUF_SIZE], *pos;
294 ssize_t in, out;
295 bool old, plain_eof = FALSE, crypto_eof = FALSE;
296 struct pollfd pfd[] = {
297 { .fd = this->fd, .events = POLLIN, },
298 { .fd = rfd, .events = POLLIN, },
299 };
300
301 while (!plain_eof && !crypto_eof)
302 {
303 old = thread_cancelability(TRUE);
304 in = poll(pfd, countof(pfd), -1);
305 thread_cancelability(old);
306 if (in == -1)
307 {
308 DBG1(DBG_TLS, "TLS select error: %s", strerror(errno));
309 return FALSE;
310 }
311 while (!plain_eof && pfd[0].revents & (POLLIN | POLLHUP | POLLNVAL))
312 {
313 in = read_(this, buf, sizeof(buf), FALSE);
314 switch (in)
315 {
316 case 0:
317 plain_eof = TRUE;
318 break;
319 case -1:
320 if (errno != EWOULDBLOCK)
321 {
322 DBG1(DBG_TLS, "TLS read error: %s", strerror(errno));
323 return FALSE;
324 }
325 break;
326 default:
327 pos = buf;
328 while (in)
329 {
330 out = write(wfd, pos, in);
331 if (out == -1)
332 {
333 DBG1(DBG_TLS, "TLS plain write error: %s",
334 strerror(errno));
335 return FALSE;
336 }
337 in -= out;
338 pos += out;
339 }
340 continue;
341 }
342 break;
343 }
344 if (!crypto_eof && pfd[1].revents & (POLLIN | POLLHUP | POLLNVAL))
345 {
346 in = read(rfd, buf, sizeof(buf));
347 switch (in)
348 {
349 case 0:
350 crypto_eof = TRUE;
351 break;
352 case -1:
353 DBG1(DBG_TLS, "TLS plain read error: %s", strerror(errno));
354 return FALSE;
355 default:
356 pos = buf;
357 while (in)
358 {
359 out = write_(this, pos, in);
360 if (out == -1)
361 {
362 DBG1(DBG_TLS, "TLS write error");
363 return FALSE;
364 }
365 in -= out;
366 pos += out;
367 }
368 break;
369 }
370 }
371 }
372 return TRUE;
373 }
374
375 METHOD(tls_socket_t, get_fd, int,
376 private_tls_socket_t *this)
377 {
378 return this->fd;
379 }
380
381 METHOD(tls_socket_t, get_server_id, identification_t*,
382 private_tls_socket_t *this)
383 {
384 return this->tls->get_server_id(this->tls);
385 }
386
387 METHOD(tls_socket_t, get_peer_id, identification_t*,
388 private_tls_socket_t *this)
389 {
390 return this->tls->get_peer_id(this->tls);
391 }
392
393 METHOD(tls_socket_t, destroy, void,
394 private_tls_socket_t *this)
395 {
396 /* send a TLS close notify if not done yet */
397 this->app.close = TRUE;
398 write_(this, NULL, 0);
399 free(this->app.cache.ptr);
400 this->tls->destroy(this->tls);
401 free(this);
402 }
403
404 /**
405 * See header
406 */
407 tls_socket_t *tls_socket_create(bool is_server, identification_t *server,
408 identification_t *peer, int fd, tls_cache_t *cache,
409 tls_version_t max_version, bool nullok)
410 {
411 private_tls_socket_t *this;
412 tls_purpose_t purpose;
413
414 INIT(this,
415 .public = {
416 .read = _read_,
417 .write = _write_,
418 .splice = _splice,
419 .get_fd = _get_fd,
420 .get_server_id = _get_server_id,
421 .get_peer_id = _get_peer_id,
422 .destroy = _destroy,
423 },
424 .app = {
425 .application = {
426 .build = _build,
427 .process = _process,
428 .destroy = (void*)nop,
429 },
430 },
431 .fd = fd,
432 );
433
434 if (nullok)
435 {
436 purpose = TLS_PURPOSE_GENERIC_NULLOK;
437 }
438 else
439 {
440 purpose = TLS_PURPOSE_GENERIC;
441 }
442
443 this->tls = tls_create(is_server, server, peer, purpose,
444 &this->app.application, cache);
445 if (!this->tls)
446 {
447 free(this);
448 return NULL;
449 }
450 this->tls->set_version(this->tls, max_version);
451
452 return &this->public;
453 }