Wrap tls_t.get_{server,peer}_id methods in tls_socket_t
[strongswan.git] / src / libtls / tls_socket.c
1 /*
2 * Copyright (C) 2010 Martin Willi
3 * Copyright (C) 2010 revosec AG
4 *
5 * This program is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License as published by the
7 * Free Software Foundation; either version 2 of the License, or (at your
8 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
9 *
10 * This program is distributed in the hope that it will be useful, but
11 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
12 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 * for more details.
14 */
15
16 #include "tls_socket.h"
17
18 #include <unistd.h>
19 #include <errno.h>
20
21 #include <utils/debug.h>
22 #include <threading/thread.h>
23
24 /**
25 * Buffer size for plain side I/O
26 */
27 #define PLAIN_BUF_SIZE 4096
28
29 /**
30 * Buffer size for encrypted side I/O
31 */
32 #define CRYPTO_BUF_SIZE 4096
33
34 typedef struct private_tls_socket_t private_tls_socket_t;
35 typedef struct private_tls_application_t private_tls_application_t;
36
37 struct private_tls_application_t {
38
39 /**
40 * Implements tls_application layer.
41 */
42 tls_application_t application;
43
44 /**
45 * Output buffer to write to
46 */
47 chunk_t out;
48
49 /**
50 * Number of bytes written to out
51 */
52 size_t out_done;
53
54 /**
55 * Input buffer to read to
56 */
57 chunk_t in;
58
59 /**
60 * Number of bytes read to in
61 */
62 size_t in_done;
63
64 /**
65 * Cached input data
66 */
67 chunk_t cache;
68
69 /**
70 * Bytes cosnumed in cache
71 */
72 size_t cache_done;
73
74 /**
75 * Close TLS connection?
76 */
77 bool close;
78 };
79
80 /**
81 * Private data of an tls_socket_t object.
82 */
83 struct private_tls_socket_t {
84
85 /**
86 * Public tls_socket_t interface.
87 */
88 tls_socket_t public;
89
90 /**
91 * TLS application implementation
92 */
93 private_tls_application_t app;
94
95 /**
96 * TLS stack
97 */
98 tls_t *tls;
99
100 /**
101 * Underlying OS socket
102 */
103 int fd;
104 };
105
106 METHOD(tls_application_t, process, status_t,
107 private_tls_application_t *this, bio_reader_t *reader)
108 {
109 chunk_t data;
110 size_t len;
111
112 if (this->close)
113 {
114 return SUCCESS;
115 }
116 len = min(reader->remaining(reader), this->in.len - this->in_done);
117 if (len)
118 { /* copy to read buffer as much as fits in */
119 if (!reader->read_data(reader, len, &data))
120 {
121 return FAILED;
122 }
123 memcpy(this->in.ptr + this->in_done, data.ptr, data.len);
124 this->in_done += data.len;
125 }
126 else
127 { /* read buffer is full, cache for next read */
128 if (!reader->read_data(reader, reader->remaining(reader), &data))
129 {
130 return FAILED;
131 }
132 this->cache = chunk_cat("mc", this->cache, data);
133 }
134 return NEED_MORE;
135 }
136
137 METHOD(tls_application_t, build, status_t,
138 private_tls_application_t *this, bio_writer_t *writer)
139 {
140 if (this->close)
141 {
142 return SUCCESS;
143 }
144 if (this->out.len > this->out_done)
145 {
146 writer->write_data(writer, this->out);
147 this->out_done = this->out.len;
148 return NEED_MORE;
149 }
150 return INVALID_STATE;
151 }
152
153 /**
154 * TLS data exchange loop
155 */
156 static bool exchange(private_tls_socket_t *this, bool wr, bool block)
157 {
158 char buf[CRYPTO_BUF_SIZE], *pos;
159 ssize_t len, out;
160 int round = 0, flags;
161
162 for (round = 0; TRUE; round++)
163 {
164 while (TRUE)
165 {
166 len = sizeof(buf);
167 switch (this->tls->build(this->tls, buf, &len, NULL))
168 {
169 case NEED_MORE:
170 case ALREADY_DONE:
171 pos = buf;
172 while (len)
173 {
174 out = write(this->fd, pos, len);
175 if (out == -1)
176 {
177 DBG1(DBG_TLS, "TLS crypto write error: %s",
178 strerror(errno));
179 return FALSE;
180 }
181 len -= out;
182 pos += out;
183 }
184 continue;
185 case INVALID_STATE:
186 break;
187 case SUCCESS:
188 return TRUE;
189 default:
190 return FALSE;
191 }
192 break;
193 }
194 if (wr)
195 {
196 if (this->app.out_done == this->app.out.len)
197 { /* all data written */
198 return TRUE;
199 }
200 }
201 else
202 {
203 if (this->app.in_done == this->app.in.len)
204 { /* buffer fully received */
205 return TRUE;
206 }
207 }
208
209 flags = 0;
210 if (this->app.out_done == this->app.out.len)
211 {
212 if (!block || this->app.in_done)
213 {
214 flags |= MSG_DONTWAIT;
215 }
216 }
217 len = recv(this->fd, buf, sizeof(buf), flags);
218 if (len < 0)
219 {
220 if (errno == EAGAIN || errno == EWOULDBLOCK)
221 {
222 if (this->app.in_done == 0)
223 {
224 /* reading, nothing got yet, and call would block */
225 errno = EWOULDBLOCK;
226 this->app.in_done = -1;
227 }
228 return TRUE;
229 }
230 return FALSE;
231 }
232 if (len == 0)
233 { /* EOF */
234 return TRUE;
235 }
236 switch (this->tls->process(this->tls, buf, len))
237 {
238 case NEED_MORE:
239 break;
240 case SUCCESS:
241 return TRUE;
242 default:
243 return FALSE;
244 }
245 }
246 }
247
248 METHOD(tls_socket_t, read_, ssize_t,
249 private_tls_socket_t *this, void *buf, size_t len, bool block)
250 {
251 if (this->app.cache.len)
252 {
253 size_t cache;
254
255 cache = min(len, this->app.cache.len - this->app.cache_done);
256 memcpy(buf, this->app.cache.ptr + this->app.cache_done, cache);
257
258 this->app.cache_done += cache;
259 if (this->app.cache_done == this->app.cache.len)
260 {
261 chunk_free(&this->app.cache);
262 this->app.cache_done = 0;
263 }
264 return cache;
265 }
266 this->app.in.ptr = buf;
267 this->app.in.len = len;
268 this->app.in_done = 0;
269 if (exchange(this, FALSE, block))
270 {
271 return this->app.in_done;
272 }
273 return -1;
274 }
275
276 METHOD(tls_socket_t, write_, ssize_t,
277 private_tls_socket_t *this, void *buf, size_t len)
278 {
279 this->app.out.ptr = buf;
280 this->app.out.len = len;
281 this->app.out_done = 0;
282 if (exchange(this, TRUE, FALSE))
283 {
284 return this->app.out_done;
285 }
286 return -1;
287 }
288
289 METHOD(tls_socket_t, splice, bool,
290 private_tls_socket_t *this, int rfd, int wfd)
291 {
292 char buf[PLAIN_BUF_SIZE], *pos;
293 fd_set set;
294 ssize_t in, out;
295 bool old, plain_eof = FALSE, crypto_eof = FALSE;
296
297 while (!plain_eof && !crypto_eof)
298 {
299 FD_ZERO(&set);
300 FD_SET(rfd, &set);
301 FD_SET(this->fd, &set);
302
303 old = thread_cancelability(TRUE);
304 in = select(max(rfd, this->fd) + 1, &set, NULL, NULL, NULL);
305 thread_cancelability(old);
306 if (in == -1)
307 {
308 DBG1(DBG_TLS, "TLS select error: %s", strerror(errno));
309 return FALSE;
310 }
311 while (!plain_eof && FD_ISSET(this->fd, &set))
312 {
313 in = read_(this, buf, sizeof(buf), FALSE);
314 switch (in)
315 {
316 case 0:
317 plain_eof = TRUE;
318 break;
319 case -1:
320 if (errno != EWOULDBLOCK)
321 {
322 DBG1(DBG_TLS, "TLS read error: %s", strerror(errno));
323 return FALSE;
324 }
325 break;
326 default:
327 pos = buf;
328 while (in)
329 {
330 out = write(wfd, pos, in);
331 if (out == -1)
332 {
333 DBG1(DBG_TLS, "TLS plain write error: %s",
334 strerror(errno));
335 return FALSE;
336 }
337 in -= out;
338 pos += out;
339 }
340 continue;
341 }
342 break;
343 }
344 if (!crypto_eof && FD_ISSET(rfd, &set))
345 {
346 in = read(rfd, buf, sizeof(buf));
347 switch (in)
348 {
349 case 0:
350 crypto_eof = TRUE;
351 break;
352 case -1:
353 DBG1(DBG_TLS, "TLS plain read error: %s", strerror(errno));
354 return FALSE;
355 default:
356 pos = buf;
357 while (in)
358 {
359 out = write_(this, pos, in);
360 if (out == -1)
361 {
362 DBG1(DBG_TLS, "TLS write error");
363 return FALSE;
364 }
365 in -= out;
366 pos += out;
367 }
368 break;
369 }
370 }
371 }
372 return TRUE;
373 }
374
375 METHOD(tls_socket_t, get_fd, int,
376 private_tls_socket_t *this)
377 {
378 return this->fd;
379 }
380
381 METHOD(tls_socket_t, get_server_id, identification_t*,
382 private_tls_socket_t *this)
383 {
384 return this->tls->get_server_id(this->tls);
385 }
386
387 METHOD(tls_socket_t, get_peer_id, identification_t*,
388 private_tls_socket_t *this)
389 {
390 return this->tls->get_peer_id(this->tls);
391 }
392
393 METHOD(tls_socket_t, destroy, void,
394 private_tls_socket_t *this)
395 {
396 /* send a TLS close notify if not done yet */
397 this->app.close = TRUE;
398 write_(this, NULL, 0);
399 free(this->app.cache.ptr);
400 this->tls->destroy(this->tls);
401 free(this);
402 }
403
404 /**
405 * See header
406 */
407 tls_socket_t *tls_socket_create(bool is_server, identification_t *server,
408 identification_t *peer, int fd, tls_cache_t *cache)
409 {
410 private_tls_socket_t *this;
411
412 INIT(this,
413 .public = {
414 .read = _read_,
415 .write = _write_,
416 .splice = _splice,
417 .get_fd = _get_fd,
418 .get_server_id = _get_server_id,
419 .get_peer_id = _get_peer_id,
420 .destroy = _destroy,
421 },
422 .app = {
423 .application = {
424 .build = _build,
425 .process = _process,
426 .destroy = (void*)nop,
427 },
428 },
429 .fd = fd,
430 );
431
432 this->tls = tls_create(is_server, server, peer, TLS_PURPOSE_GENERIC,
433 &this->app.application, cache);
434 if (!this->tls)
435 {
436 free(this);
437 return NULL;
438 }
439
440 return &this->public;
441 }