Added a tls_socket_t.splice method to wrap a file descriptor into TLS
[strongswan.git] / src / libtls / tls_socket.c
1 /*
2 * Copyright (C) 2010 Martin Willi
3 * Copyright (C) 2010 revosec AG
4 *
5 * This program is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License as published by the
7 * Free Software Foundation; either version 2 of the License, or (at your
8 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
9 *
10 * This program is distributed in the hope that it will be useful, but
11 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
12 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 * for more details.
14 */
15
16 #include "tls_socket.h"
17
18 #include <unistd.h>
19 #include <errno.h>
20
21 #include <debug.h>
22 #include <threading/thread.h>
23
24 /**
25 * Buffer size for plain side I/O
26 */
27 #define PLAIN_BUF_SIZE 4096
28
29 /**
30 * Buffer size for encrypted side I/O
31 */
32 #define CRYPTO_BUF_SIZE 4096
33
34 typedef struct private_tls_socket_t private_tls_socket_t;
35 typedef struct private_tls_application_t private_tls_application_t;
36
37 struct private_tls_application_t {
38
39 /**
40 * Implements tls_application layer.
41 */
42 tls_application_t application;
43
44 /**
45 * Chunk of data to send
46 */
47 chunk_t out;
48
49 /**
50 * Chunk of data received
51 */
52 chunk_t in;
53 };
54
55 /**
56 * Private data of an tls_socket_t object.
57 */
58 struct private_tls_socket_t {
59
60 /**
61 * Public tls_socket_t interface.
62 */
63 tls_socket_t public;
64
65 /**
66 * TLS application implementation
67 */
68 private_tls_application_t app;
69
70 /**
71 * TLS stack
72 */
73 tls_t *tls;
74
75 /**
76 * Underlying OS socket
77 */
78 int fd;
79 };
80
81 METHOD(tls_application_t, process, status_t,
82 private_tls_application_t *this, bio_reader_t *reader)
83 {
84 chunk_t data;
85
86 if (!reader->read_data(reader, reader->remaining(reader), &data))
87 {
88 return FAILED;
89 }
90 this->in = chunk_cat("mc", this->in, data);
91 return NEED_MORE;
92 }
93
94 METHOD(tls_application_t, build, status_t,
95 private_tls_application_t *this, bio_writer_t *writer)
96 {
97 if (this->out.len)
98 {
99 writer->write_data(writer, this->out);
100 this->out = chunk_empty;
101 return NEED_MORE;
102 }
103 return INVALID_STATE;
104 }
105
106 /**
107 * TLS data exchange loop
108 */
109 static bool exchange(private_tls_socket_t *this, bool wr)
110 {
111 char buf[CRYPTO_BUF_SIZE], *pos;
112 ssize_t len, out;
113 int round = 0;
114
115 for (round = 0; TRUE; round++)
116 {
117 while (TRUE)
118 {
119 len = sizeof(buf);
120 switch (this->tls->build(this->tls, buf, &len, NULL))
121 {
122 case NEED_MORE:
123 case ALREADY_DONE:
124 pos = buf;
125 while (len)
126 {
127 out = write(this->fd, pos, len);
128 if (out == -1)
129 {
130 DBG1(DBG_TLS, "TLS crypto write error: %s",
131 strerror(errno));
132 return FALSE;
133 }
134 len -= out;
135 pos += out;
136 }
137 continue;
138 case INVALID_STATE:
139 break;
140 default:
141 return FALSE;
142 }
143 break;
144 }
145 if (wr)
146 {
147 if (this->app.out.len == 0)
148 { /* all data written */
149 return TRUE;
150 }
151 }
152 else
153 {
154 if (this->app.in.len)
155 { /* some data received */
156 return TRUE;
157 }
158 if (round > 0)
159 { /* did some handshaking, return empty chunk to not block */
160 return TRUE;
161 }
162 }
163 len = read(this->fd, buf, sizeof(buf));
164 if (len <= 0)
165 {
166 return FALSE;
167 }
168 if (this->tls->process(this->tls, buf, len) != NEED_MORE)
169 {
170 return FALSE;
171 }
172 }
173 }
174
175 METHOD(tls_socket_t, read_, bool,
176 private_tls_socket_t *this, chunk_t *buf)
177 {
178 if (exchange(this, FALSE))
179 {
180 *buf = this->app.in;
181 this->app.in = chunk_empty;
182 return TRUE;
183 }
184 return FALSE;
185 }
186
187 METHOD(tls_socket_t, write_, bool,
188 private_tls_socket_t *this, chunk_t buf)
189 {
190 this->app.out = buf;
191 if (exchange(this, TRUE))
192 {
193 return TRUE;
194 }
195 return FALSE;
196 }
197
198 METHOD(tls_socket_t, splice, bool,
199 private_tls_socket_t *this, int rfd, int wfd)
200 {
201 char buf[PLAIN_BUF_SIZE], *pos;
202 fd_set set;
203 chunk_t data;
204 ssize_t len;
205 bool old;
206
207 while (TRUE)
208 {
209 FD_ZERO(&set);
210 FD_SET(rfd, &set);
211 FD_SET(this->fd, &set);
212
213 old = thread_cancelability(TRUE);
214 len = select(max(rfd, this->fd) + 1, &set, NULL, NULL, NULL);
215 thread_cancelability(old);
216 if (len == -1)
217 {
218 DBG1(DBG_TLS, "TLS select error: %s", strerror(errno));
219 return FALSE;
220 }
221 if (FD_ISSET(this->fd, &set))
222 {
223 if (!read_(this, &data))
224 {
225 DBG2(DBG_TLS, "TLS read error/disconnect");
226 return TRUE;
227 }
228 pos = data.ptr;
229 while (data.len)
230 {
231 len = write(wfd, pos, data.len);
232 if (len == -1)
233 {
234 free(data.ptr);
235 DBG1(DBG_TLS, "TLS plain write error: %s", strerror(errno));
236 return FALSE;
237 }
238 data.len -= len;
239 pos += len;
240 }
241 free(data.ptr);
242 }
243 if (FD_ISSET(rfd, &set))
244 {
245 len = read(rfd, buf, sizeof(buf));
246 if (len > 0)
247 {
248 if (!write_(this, chunk_create(buf, len)))
249 {
250 DBG1(DBG_TLS, "TLS write error");
251 return FALSE;
252 }
253 }
254 else
255 {
256 if (len < 0)
257 {
258 DBG1(DBG_TLS, "TLS plain read error: %s", strerror(errno));
259 return FALSE;
260 }
261 return TRUE;
262 }
263 }
264 }
265 }
266
267 METHOD(tls_socket_t, get_fd, int,
268 private_tls_socket_t *this)
269 {
270 return this->fd;
271 }
272
273 METHOD(tls_socket_t, destroy, void,
274 private_tls_socket_t *this)
275 {
276 this->tls->destroy(this->tls);
277 free(this->app.in.ptr);
278 free(this);
279 }
280
281 /**
282 * See header
283 */
284 tls_socket_t *tls_socket_create(bool is_server, identification_t *server,
285 identification_t *peer, int fd, tls_cache_t *cache)
286 {
287 private_tls_socket_t *this;
288
289 INIT(this,
290 .public = {
291 .read = _read_,
292 .write = _write_,
293 .splice = _splice,
294 .get_fd = _get_fd,
295 .destroy = _destroy,
296 },
297 .app = {
298 .application = {
299 .build = _build,
300 .process = _process,
301 .destroy = (void*)nop,
302 },
303 },
304 .fd = fd,
305 );
306
307 this->tls = tls_create(is_server, server, peer, TLS_PURPOSE_GENERIC,
308 &this->app.application, cache);
309 if (!this->tls)
310 {
311 free(this);
312 return NULL;
313 }
314
315 return &this->public;
316 }