648771e7512b8a6aca3e189091ac34a9f2cca601
[strongswan.git] / src / libtls / tls_socket.c
1 /*
2 * Copyright (C) 2010 Martin Willi
3 * Copyright (C) 2010 revosec AG
4 *
5 * This program is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License as published by the
7 * Free Software Foundation; either version 2 of the License, or (at your
8 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
9 *
10 * This program is distributed in the hope that it will be useful, but
11 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
12 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 * for more details.
14 */
15
16 #include "tls_socket.h"
17
18 #include <unistd.h>
19 #include <errno.h>
20
21 #include <utils/debug.h>
22 #include <threading/thread.h>
23
24 /**
25 * Buffer size for plain side I/O
26 */
27 #define PLAIN_BUF_SIZE TLS_MAX_FRAGMENT_LEN
28
29 /**
30 * Buffer size for encrypted side I/O
31 */
32 #define CRYPTO_BUF_SIZE TLS_MAX_FRAGMENT_LEN + 2048
33
34 typedef struct private_tls_socket_t private_tls_socket_t;
35 typedef struct private_tls_application_t private_tls_application_t;
36
37 struct private_tls_application_t {
38
39 /**
40 * Implements tls_application layer.
41 */
42 tls_application_t application;
43
44 /**
45 * Output buffer to write to
46 */
47 chunk_t out;
48
49 /**
50 * Number of bytes written to out
51 */
52 size_t out_done;
53
54 /**
55 * Input buffer to read to
56 */
57 chunk_t in;
58
59 /**
60 * Number of bytes read to in
61 */
62 size_t in_done;
63
64 /**
65 * Cached input data
66 */
67 chunk_t cache;
68
69 /**
70 * Bytes consumed in cache
71 */
72 size_t cache_done;
73
74 /**
75 * Close TLS connection?
76 */
77 bool close;
78 };
79
80 /**
81 * Private data of an tls_socket_t object.
82 */
83 struct private_tls_socket_t {
84
85 /**
86 * Public tls_socket_t interface.
87 */
88 tls_socket_t public;
89
90 /**
91 * TLS application implementation
92 */
93 private_tls_application_t app;
94
95 /**
96 * TLS stack
97 */
98 tls_t *tls;
99
100 /**
101 * Underlying OS socket
102 */
103 int fd;
104 };
105
106 METHOD(tls_application_t, process, status_t,
107 private_tls_application_t *this, bio_reader_t *reader)
108 {
109 chunk_t data;
110 size_t len;
111
112 if (this->close)
113 {
114 return SUCCESS;
115 }
116 len = min(reader->remaining(reader), this->in.len - this->in_done);
117 if (len)
118 { /* copy to read buffer as much as fits in */
119 if (!reader->read_data(reader, len, &data))
120 {
121 return FAILED;
122 }
123 memcpy(this->in.ptr + this->in_done, data.ptr, data.len);
124 this->in_done += data.len;
125 }
126 else
127 { /* read buffer is full, cache for next read */
128 if (!reader->read_data(reader, reader->remaining(reader), &data))
129 {
130 return FAILED;
131 }
132 this->cache = chunk_cat("mc", this->cache, data);
133 }
134 return NEED_MORE;
135 }
136
137 METHOD(tls_application_t, build, status_t,
138 private_tls_application_t *this, bio_writer_t *writer)
139 {
140 if (this->close)
141 {
142 return SUCCESS;
143 }
144 if (this->out.len > this->out_done)
145 {
146 writer->write_data(writer, this->out);
147 this->out_done = this->out.len;
148 return NEED_MORE;
149 }
150 return INVALID_STATE;
151 }
152
153 /**
154 * TLS data exchange loop
155 */
156 static bool exchange(private_tls_socket_t *this, bool wr, bool block)
157 {
158 char buf[CRYPTO_BUF_SIZE], *pos;
159 ssize_t in, out;
160 size_t len;
161 int round = 0, flags;
162
163 for (round = 0; TRUE; round++)
164 {
165 while (TRUE)
166 {
167 len = sizeof(buf);
168 switch (this->tls->build(this->tls, buf, &len, NULL))
169 {
170 case NEED_MORE:
171 case ALREADY_DONE:
172 pos = buf;
173 while (len)
174 {
175 out = write(this->fd, pos, len);
176 if (out == -1)
177 {
178 DBG1(DBG_TLS, "TLS crypto write error: %s",
179 strerror(errno));
180 return FALSE;
181 }
182 len -= out;
183 pos += out;
184 }
185 continue;
186 case INVALID_STATE:
187 break;
188 case SUCCESS:
189 return TRUE;
190 default:
191 return FALSE;
192 }
193 break;
194 }
195 if (wr)
196 {
197 if (this->app.out_done == this->app.out.len)
198 { /* all data written */
199 return TRUE;
200 }
201 }
202 else
203 {
204 if (this->app.in_done == this->app.in.len)
205 { /* buffer fully received */
206 return TRUE;
207 }
208 }
209
210 flags = 0;
211 if (this->app.out_done == this->app.out.len)
212 {
213 if (!block || this->app.in_done)
214 {
215 flags |= MSG_DONTWAIT;
216 }
217 }
218 in = recv(this->fd, buf, sizeof(buf), flags);
219 if (in < 0)
220 {
221 if (errno == EAGAIN || errno == EWOULDBLOCK)
222 {
223 if (this->app.in_done == 0)
224 {
225 /* reading, nothing got yet, and call would block */
226 errno = EWOULDBLOCK;
227 this->app.in_done = -1;
228 }
229 return TRUE;
230 }
231 return FALSE;
232 }
233 if (in == 0)
234 { /* EOF */
235 return TRUE;
236 }
237 switch (this->tls->process(this->tls, buf, in))
238 {
239 case NEED_MORE:
240 break;
241 case SUCCESS:
242 return TRUE;
243 default:
244 return FALSE;
245 }
246 }
247 }
248
249 METHOD(tls_socket_t, read_, ssize_t,
250 private_tls_socket_t *this, void *buf, size_t len, bool block)
251 {
252 if (this->app.cache.len)
253 {
254 size_t cache;
255
256 cache = min(len, this->app.cache.len - this->app.cache_done);
257 memcpy(buf, this->app.cache.ptr + this->app.cache_done, cache);
258
259 this->app.cache_done += cache;
260 if (this->app.cache_done == this->app.cache.len)
261 {
262 chunk_free(&this->app.cache);
263 this->app.cache_done = 0;
264 }
265 return cache;
266 }
267 this->app.in.ptr = buf;
268 this->app.in.len = len;
269 this->app.in_done = 0;
270 if (exchange(this, FALSE, block))
271 {
272 return this->app.in_done;
273 }
274 return -1;
275 }
276
277 METHOD(tls_socket_t, write_, ssize_t,
278 private_tls_socket_t *this, void *buf, size_t len)
279 {
280 this->app.out.ptr = buf;
281 this->app.out.len = len;
282 this->app.out_done = 0;
283 if (exchange(this, TRUE, FALSE))
284 {
285 return this->app.out_done;
286 }
287 return -1;
288 }
289
290 METHOD(tls_socket_t, splice, bool,
291 private_tls_socket_t *this, int rfd, int wfd)
292 {
293 char buf[PLAIN_BUF_SIZE], *pos;
294 fd_set set;
295 ssize_t in, out;
296 bool old, plain_eof = FALSE, crypto_eof = FALSE;
297
298 while (!plain_eof && !crypto_eof)
299 {
300 FD_ZERO(&set);
301 FD_SET(rfd, &set);
302 FD_SET(this->fd, &set);
303
304 old = thread_cancelability(TRUE);
305 in = select(max(rfd, this->fd) + 1, &set, NULL, NULL, NULL);
306 thread_cancelability(old);
307 if (in == -1)
308 {
309 DBG1(DBG_TLS, "TLS select error: %s", strerror(errno));
310 return FALSE;
311 }
312 while (!plain_eof && FD_ISSET(this->fd, &set))
313 {
314 in = read_(this, buf, sizeof(buf), FALSE);
315 switch (in)
316 {
317 case 0:
318 plain_eof = TRUE;
319 break;
320 case -1:
321 if (errno != EWOULDBLOCK)
322 {
323 DBG1(DBG_TLS, "TLS read error: %s", strerror(errno));
324 return FALSE;
325 }
326 break;
327 default:
328 pos = buf;
329 while (in)
330 {
331 out = write(wfd, pos, in);
332 if (out == -1)
333 {
334 DBG1(DBG_TLS, "TLS plain write error: %s",
335 strerror(errno));
336 return FALSE;
337 }
338 in -= out;
339 pos += out;
340 }
341 continue;
342 }
343 break;
344 }
345 if (!crypto_eof && FD_ISSET(rfd, &set))
346 {
347 in = read(rfd, buf, sizeof(buf));
348 switch (in)
349 {
350 case 0:
351 crypto_eof = TRUE;
352 break;
353 case -1:
354 DBG1(DBG_TLS, "TLS plain read error: %s", strerror(errno));
355 return FALSE;
356 default:
357 pos = buf;
358 while (in)
359 {
360 out = write_(this, pos, in);
361 if (out == -1)
362 {
363 DBG1(DBG_TLS, "TLS write error");
364 return FALSE;
365 }
366 in -= out;
367 pos += out;
368 }
369 break;
370 }
371 }
372 }
373 return TRUE;
374 }
375
376 METHOD(tls_socket_t, get_fd, int,
377 private_tls_socket_t *this)
378 {
379 return this->fd;
380 }
381
382 METHOD(tls_socket_t, get_server_id, identification_t*,
383 private_tls_socket_t *this)
384 {
385 return this->tls->get_server_id(this->tls);
386 }
387
388 METHOD(tls_socket_t, get_peer_id, identification_t*,
389 private_tls_socket_t *this)
390 {
391 return this->tls->get_peer_id(this->tls);
392 }
393
394 METHOD(tls_socket_t, destroy, void,
395 private_tls_socket_t *this)
396 {
397 /* send a TLS close notify if not done yet */
398 this->app.close = TRUE;
399 write_(this, NULL, 0);
400 free(this->app.cache.ptr);
401 this->tls->destroy(this->tls);
402 free(this);
403 }
404
405 /**
406 * See header
407 */
408 tls_socket_t *tls_socket_create(bool is_server, identification_t *server,
409 identification_t *peer, int fd, tls_cache_t *cache,
410 tls_version_t max_version, bool nullok)
411 {
412 private_tls_socket_t *this;
413 tls_purpose_t purpose;
414
415 INIT(this,
416 .public = {
417 .read = _read_,
418 .write = _write_,
419 .splice = _splice,
420 .get_fd = _get_fd,
421 .get_server_id = _get_server_id,
422 .get_peer_id = _get_peer_id,
423 .destroy = _destroy,
424 },
425 .app = {
426 .application = {
427 .build = _build,
428 .process = _process,
429 .destroy = (void*)nop,
430 },
431 },
432 .fd = fd,
433 );
434
435 if (nullok)
436 {
437 purpose = TLS_PURPOSE_GENERIC_NULLOK;
438 }
439 else
440 {
441 purpose = TLS_PURPOSE_GENERIC;
442 }
443
444 this->tls = tls_create(is_server, server, peer, purpose,
445 &this->app.application, cache);
446 if (!this->tls)
447 {
448 free(this);
449 return NULL;
450 }
451 this->tls->set_version(this->tls, max_version);
452
453 return &this->public;
454 }