Implemented TLS session resumption both as client and as server
[strongswan.git] / scripts / tls_test.c
1 /*
2 * Copyright (C) 2010 Martin Willi
3 * Copyright (C) 2010 revosec AG
4 *
5 * This program is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License as published by the
7 * Free Software Foundation; either version 2 of the License, or (at your
8 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
9 *
10 * This program is distributed in the hope that it will be useful, but
11 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
12 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 * for more details.
14 */
15
16 #include <unistd.h>
17 #include <stdio.h>
18 #include <sys/types.h>
19 #include <sys/socket.h>
20 #include <getopt.h>
21 #include <errno.h>
22 #include <string.h>
23
24 #include <library.h>
25 #include <debug.h>
26 #include <tls_socket.h>
27 #include <utils/host.h>
28 #include <credentials/sets/mem_cred.h>
29
30 /**
31 * Print usage information
32 */
33 static void usage(FILE *out, char *cmd)
34 {
35 fprintf(out, "usage:\n");
36 fprintf(out, " %s --connect <address> --port <port> [--cert <file>]+ [--times <n>]\n", cmd);
37 fprintf(out, " %s --listen <address> --port <port> --key <key> [--cert <file>]+ [--times <n>]\n", cmd);
38 }
39
40 /**
41 * Stream between stdio and TLS socket
42 */
43 static int stream(int fd, tls_socket_t *tls)
44 {
45 while (TRUE)
46 {
47 fd_set set;
48 chunk_t data;
49
50 FD_ZERO(&set);
51 FD_SET(fd, &set);
52 FD_SET(0, &set);
53
54 if (select(fd + 1, &set, NULL, NULL, NULL) == -1)
55 {
56 return 1;
57 }
58 if (FD_ISSET(fd, &set))
59 {
60 if (!tls->read(tls, &data))
61 {
62 return 0;
63 }
64 if (data.len)
65 {
66 ignore_result(write(1, data.ptr, data.len));
67 free(data.ptr);
68 }
69 }
70 if (FD_ISSET(0, &set))
71 {
72 char buf[1024];
73 ssize_t len;
74
75 len = read(0, buf, sizeof(buf));
76 if (len == 0)
77 {
78 return 0;
79 }
80 if (len > 0)
81 {
82 if (!tls->write(tls, chunk_create(buf, len)))
83 {
84 DBG1(DBG_TLS, "TLS write error");
85 return 1;
86 }
87 }
88 }
89 }
90 }
91
92 /**
93 * Client routine
94 */
95 static int client(host_t *host, identification_t *server,
96 int times, tls_cache_t *cache)
97 {
98 tls_socket_t *tls;
99 int fd, res;
100
101 while (times == -1 || times-- > 0)
102 {
103 fd = socket(AF_INET, SOCK_STREAM, 0);
104 if (fd == -1)
105 {
106 DBG1(DBG_TLS, "opening socket failed: %s", strerror(errno));
107 return 1;
108 }
109 if (connect(fd, host->get_sockaddr(host),
110 *host->get_sockaddr_len(host)) == -1)
111 {
112 DBG1(DBG_TLS, "connecting to %#H failed: %s", host, strerror(errno));
113 close(fd);
114 return 1;
115 }
116 tls = tls_socket_create(FALSE, server, NULL, fd, cache);
117 if (!tls)
118 {
119 close(fd);
120 return 1;
121 }
122 res = stream(fd, tls);
123 tls->destroy(tls);
124 close(fd);
125 if (res)
126 {
127 break;
128 }
129 }
130 return res;
131 }
132
133 /**
134 * Server routine
135 */
136 static int serve(host_t *host, identification_t *server,
137 int times, tls_cache_t *cache)
138 {
139 tls_socket_t *tls;
140 int fd, cfd;
141
142 fd = socket(AF_INET, SOCK_STREAM, 0);
143 if (fd == -1)
144 {
145 DBG1(DBG_TLS, "opening socket failed: %s", strerror(errno));
146 return 1;
147 }
148 if (bind(fd, host->get_sockaddr(host),
149 *host->get_sockaddr_len(host)) == -1)
150 {
151 DBG1(DBG_TLS, "binding to %#H failed: %s", host, strerror(errno));
152 close(fd);
153 return 1;
154 }
155 if (listen(fd, 1) == -1)
156 {
157 DBG1(DBG_TLS, "listen to %#H failed: %m", host, strerror(errno));
158 close(fd);
159 return 1;
160 }
161
162 while (times == -1 || times-- > 0)
163 {
164 cfd = accept(fd, host->get_sockaddr(host), host->get_sockaddr_len(host));
165 if (cfd == -1)
166 {
167 DBG1(DBG_TLS, "accept failed: %s", strerror(errno));
168 close(fd);
169 return 1;
170 }
171 DBG1(DBG_TLS, "%#H connected", host);
172
173 tls = tls_socket_create(TRUE, server, NULL, cfd, cache);
174 if (!tls)
175 {
176 close(fd);
177 return 1;
178 }
179 stream(cfd, tls);
180 DBG1(DBG_TLS, "%#H disconnected", host);
181 tls->destroy(tls);
182 }
183 close(fd);
184
185 return 0;
186 }
187
188 /**
189 * In-Memory credential set
190 */
191 static mem_cred_t *creds;
192
193 /**
194 * Load certificate from file
195 */
196 static bool load_certificate(char *filename)
197 {
198 certificate_t *cert;
199
200 cert = lib->creds->create(lib->creds, CRED_CERTIFICATE, CERT_X509,
201 BUILD_FROM_FILE, filename, BUILD_END);
202 if (!cert)
203 {
204 DBG1(DBG_TLS, "loading certificate from '%s' failed", filename);
205 return FALSE;
206 }
207 creds->add_cert(creds, TRUE, cert);
208 return TRUE;
209 }
210
211 /**
212 * Load private key from file
213 */
214 static bool load_key(char *filename)
215 {
216 private_key_t *key;
217
218 key = lib->creds->create(lib->creds, CRED_PRIVATE_KEY, KEY_RSA,
219 BUILD_FROM_FILE, filename, BUILD_END);
220 if (!key)
221 {
222 DBG1(DBG_TLS, "loading key from '%s' failed", filename);
223 return FALSE;
224 }
225 creds->add_key(creds, key);
226 return TRUE;
227 }
228
229 /**
230 * TLS debug level
231 */
232 static level_t tls_level = 1;
233
234 static void dbg_tls(debug_t group, level_t level, char *fmt, ...)
235 {
236 if ((group == DBG_TLS && level <= tls_level) || level <= 1)
237 {
238 va_list args;
239
240 va_start(args, fmt);
241 vfprintf(stderr, fmt, args);
242 fprintf(stderr, "\n");
243 va_end(args);
244 }
245 }
246
247 /**
248 * Cleanup
249 */
250 static void cleanup()
251 {
252 lib->credmgr->remove_set(lib->credmgr, &creds->set);
253 creds->destroy(creds);
254 library_deinit();
255 }
256
257 /**
258 * Initialize library
259 */
260 static void init()
261 {
262 library_init(NULL);
263
264 dbg = dbg_tls;
265
266 lib->plugins->load(lib->plugins, NULL, PLUGINS);
267
268 creds = mem_cred_create();
269 lib->credmgr->add_set(lib->credmgr, &creds->set);
270
271 atexit(cleanup);
272 }
273
274 int main(int argc, char *argv[])
275 {
276 char *address = NULL;
277 bool listen = FALSE;
278 int port = 0, times = -1, res;
279 identification_t *server;
280 tls_cache_t *cache;
281 host_t *host;
282
283 init();
284
285 while (TRUE)
286 {
287 struct option long_opts[] = {
288 {"help", no_argument, NULL, 'h' },
289 {"connect", required_argument, NULL, 'c' },
290 {"listen", required_argument, NULL, 'l' },
291 {"port", required_argument, NULL, 'p' },
292 {"cert", required_argument, NULL, 'x' },
293 {"key", required_argument, NULL, 'k' },
294 {"times", required_argument, NULL, 't' },
295 {"debug", required_argument, NULL, 'd' },
296 {0,0,0,0 }
297 };
298 switch (getopt_long(argc, argv, "", long_opts, NULL))
299 {
300 case EOF:
301 break;
302 case 'h':
303 usage(stdout, argv[0]);
304 return 0;
305 case 'x':
306 if (!load_certificate(optarg))
307 {
308 return 1;
309 }
310 continue;
311 case 'k':
312 if (!load_key(optarg))
313 {
314 return 1;
315 }
316 continue;
317 case 'l':
318 listen = TRUE;
319 /* fall */
320 case 'c':
321 if (address)
322 {
323 usage(stderr, argv[0]);
324 return 1;
325 }
326 address = optarg;
327 continue;
328 case 'p':
329 port = atoi(optarg);
330 continue;
331 case 't':
332 times = atoi(optarg);
333 continue;
334 case 'd':
335 tls_level = atoi(optarg);
336 continue;
337 default:
338 usage(stderr, argv[0]);
339 return 1;
340 }
341 break;
342 }
343 if (!port || !address)
344 {
345 usage(stderr, argv[0]);
346 return 1;
347 }
348 host = host_create_from_dns(address, 0, port);
349 if (!host)
350 {
351 DBG1(DBG_TLS, "resolving hostname %s failed", address);
352 return 1;
353 }
354 server = identification_create_from_string(address);
355 cache = tls_cache_create(100, 30);
356 if (listen)
357 {
358 res = serve(host, server, times, cache);
359 }
360 else
361 {
362 res = client(host, server, times, cache);
363 }
364 cache->destroy(cache);
365 host->destroy(host);
366 server->destroy(server);
367 return res;
368 }
369