testing: Bump guest kernel to Linux 5.11
[strongswan.git] / scripts / tls_test.c
1 /*
2 * Copyright (C) 2020 Pascal Knecht
3 * Copyright (C) 2020 Tobias Brunner
4 * HSR Hochschule fuer Technik Rapperswil
5 *
6 * Copyright (C) 2010 Martin Willi
7 * Copyright (C) 2010 revosec AG
8 *
9 * This program is free software; you can redistribute it and/or modify it
10 * under the terms of the GNU General Public License as published by the
11 * Free Software Foundation; either version 2 of the License, or (at your
12 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
13 *
14 * This program is distributed in the hope that it will be useful, but
15 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
16 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
17 * for more details.
18 */
19
20 #include <unistd.h>
21 #include <stdio.h>
22 #include <sys/types.h>
23 #include <sys/socket.h>
24 #include <getopt.h>
25 #include <errno.h>
26 #include <string.h>
27
28 #include <library.h>
29 #include <utils/debug.h>
30 #include <tls_socket.h>
31 #include <networking/host.h>
32 #include <credentials/sets/mem_cred.h>
33
34 /**
35 * Print usage information
36 */
37 static void usage(FILE *out, char *cmd)
38 {
39 fprintf(out, "usage:\n");
40 fprintf(out, " %s --connect <address> --port <port> [--key <key] [--cert <file>] [--cacert <file>]+ [--times <n>]\n", cmd);
41 fprintf(out, " %s --listen <address> --port <port> --key <key> --cert <file> [--cacert <file>]+ [--auth-optional] [--times <n>]\n", cmd);
42 fprintf(out, "\n");
43 fprintf(out, "options:\n");
44 fprintf(out, " --help print help and exit\n");
45 fprintf(out, " --connect <address> connect to a server on dns name or ip address\n");
46 fprintf(out, " --listen <address> listen on dns name or ip address\n");
47 fprintf(out, " --port <port> specify the port to use\n");
48 fprintf(out, " --cert <file> certificate to authenticate itself\n");
49 fprintf(out, " --key <file> private key to authenticate itself\n");
50 fprintf(out, " --cacert <file> certificate to verify other peer\n");
51 fprintf(out, " --auth-optional don't enforce client authentication\n");
52 fprintf(out, " --times <n> specify the amount of repeated connection establishments\n");
53 fprintf(out, " --ipv4 use IPv4\n");
54 fprintf(out, " --ipv6 use IPv6\n");
55 fprintf(out, " --min-version <version> specify the minimum TLS version, supported versions:\n");
56 fprintf(out, " 1.0 (default), 1.1, 1.2 and 1.3\n");
57 fprintf(out, " --max-version <version> specify the maximum TLS version, supported versions:\n");
58 fprintf(out, " 1.0, 1.1, 1.2 and 1.3 (default)\n");
59 fprintf(out, " --version <version> set one specific TLS version to use, supported versions:\n");
60 fprintf(out, " 1.0, 1.1, 1.2 and 1.3\n");
61 fprintf(out, " --debug <debug level> set debug level, default is 1\n");
62 }
63
64 /**
65 * Check, as client, if we have a client certificate with private key
66 */
67 static identification_t *find_client_id()
68 {
69 identification_t *client = NULL, *keyid;
70 enumerator_t *enumerator;
71 certificate_t *cert;
72 public_key_t *pubkey;
73 private_key_t *privkey;
74 chunk_t chunk;
75
76 enumerator = lib->credmgr->create_cert_enumerator(lib->credmgr,
77 CERT_X509, KEY_ANY, NULL, FALSE);
78 while (enumerator->enumerate(enumerator, &cert))
79 {
80 pubkey = cert->get_public_key(cert);
81 if (pubkey)
82 {
83 if (pubkey->get_fingerprint(pubkey, KEYID_PUBKEY_SHA1, &chunk))
84 {
85 keyid = identification_create_from_encoding(ID_KEY_ID, chunk);
86 privkey = lib->credmgr->get_private(lib->credmgr,
87 pubkey->get_type(pubkey), keyid, NULL);
88 keyid->destroy(keyid);
89 if (privkey)
90 {
91 client = cert->get_subject(cert);
92 client = client->clone(client);
93 privkey->destroy(privkey);
94 }
95 }
96 pubkey->destroy(pubkey);
97 }
98 if (client)
99 {
100 break;
101 }
102 }
103 enumerator->destroy(enumerator);
104
105 return client;
106 }
107
108 /**
109 * Client routine
110 */
111 static int run_client(host_t *host, identification_t *server,
112 identification_t *client, int times, tls_cache_t *cache,
113 tls_version_t min_version, tls_version_t max_version,
114 tls_flag_t flags)
115 {
116 tls_socket_t *tls;
117 int fd, res;
118
119 while (times == -1 || times-- > 0)
120 {
121 DBG2(DBG_TLS, "connecting to %#H", host);
122 fd = socket(host->get_family(host), SOCK_STREAM, 0);
123 if (fd == -1)
124 {
125 DBG1(DBG_TLS, "opening socket failed: %s", strerror(errno));
126 return 1;
127 }
128 if (connect(fd, host->get_sockaddr(host),
129 *host->get_sockaddr_len(host)) == -1)
130 {
131 DBG1(DBG_TLS, "connecting to %#H failed: %s", host, strerror(errno));
132 close(fd);
133 return 1;
134 }
135 tls = tls_socket_create(FALSE, server, client, fd, cache, min_version,
136 max_version, flags);
137 if (!tls)
138 {
139 close(fd);
140 return 1;
141 }
142 res = tls->splice(tls, 0, 1) ? 0 : 1;
143 tls->destroy(tls);
144 close(fd);
145 if (res)
146 {
147 break;
148 }
149 }
150 return res;
151 }
152
153 /**
154 * Server routine
155 */
156 static int serve(host_t *host, identification_t *server, identification_t *client,
157 int times, tls_cache_t *cache, tls_version_t min_version,
158 tls_version_t max_version, tls_flag_t flags)
159 {
160 tls_socket_t *tls;
161 int fd, cfd;
162
163 fd = socket(AF_INET, SOCK_STREAM, 0);
164 if (fd == -1)
165 {
166 DBG1(DBG_TLS, "opening socket failed: %s", strerror(errno));
167 return 1;
168 }
169 if (bind(fd, host->get_sockaddr(host),
170 *host->get_sockaddr_len(host)) == -1)
171 {
172 DBG1(DBG_TLS, "binding to %#H failed: %s", host, strerror(errno));
173 close(fd);
174 return 1;
175 }
176 if (listen(fd, 1) == -1)
177 {
178 DBG1(DBG_TLS, "listen to %#H failed: %m", host, strerror(errno));
179 close(fd);
180 return 1;
181 }
182
183 while (times == -1 || times-- > 0)
184 {
185 cfd = accept(fd, host->get_sockaddr(host), host->get_sockaddr_len(host));
186 if (cfd == -1)
187 {
188 DBG1(DBG_TLS, "accept failed: %s", strerror(errno));
189 close(fd);
190 return 1;
191 }
192 DBG1(DBG_TLS, "%#H connected", host);
193
194 tls = tls_socket_create(TRUE, server, client, cfd, cache, min_version,
195 max_version, flags);
196 if (!tls)
197 {
198 close(fd);
199 return 1;
200 }
201 tls->splice(tls, 0, 1);
202 DBG1(DBG_TLS, "%#H disconnected", host);
203 tls->destroy(tls);
204 }
205 close(fd);
206
207 return 0;
208 }
209
210 /**
211 * In-Memory credential set
212 */
213 static mem_cred_t *creds;
214
215 /**
216 * Load certificate from file
217 */
218 static bool load_certificate(char *filename)
219 {
220 certificate_t *cert;
221
222 cert = lib->creds->create(lib->creds, CRED_CERTIFICATE, CERT_X509,
223 BUILD_FROM_FILE, filename, BUILD_END);
224 if (!cert)
225 {
226 DBG1(DBG_TLS, "loading certificate from '%s' failed", filename);
227 return FALSE;
228 }
229 creds->add_cert(creds, TRUE, cert);
230 return TRUE;
231 }
232
233 /**
234 * Load private key from file
235 */
236 static bool load_key(char *filename)
237 {
238 private_key_t *key;
239
240 key = lib->creds->create(lib->creds, CRED_PRIVATE_KEY, KEY_ANY,
241 BUILD_FROM_FILE, filename, BUILD_END);
242 if (!key)
243 {
244 DBG1(DBG_TLS, "loading key from '%s' failed", filename);
245 return FALSE;
246 }
247 creds->add_key(creds, key);
248 return TRUE;
249 }
250
251 /**
252 * TLS debug level
253 */
254 static level_t tls_level = 1;
255
256 static void dbg_tls(debug_t group, level_t level, char *fmt, ...)
257 {
258 if ((group == DBG_TLS && level <= tls_level) || level <= 1)
259 {
260 va_list args;
261
262 va_start(args, fmt);
263 vfprintf(stderr, fmt, args);
264 fprintf(stderr, "\n");
265 va_end(args);
266 }
267 }
268
269 /**
270 * Cleanup
271 */
272 static void cleanup()
273 {
274 lib->credmgr->remove_set(lib->credmgr, &creds->set);
275 creds->destroy(creds);
276 library_deinit();
277 }
278
279 /**
280 * Initialize library
281 */
282 static void init()
283 {
284 char *plugins;
285
286 library_init(NULL, "tls_test");
287
288 dbg = dbg_tls;
289
290 plugins = getenv("PLUGINS") ?: PLUGINS;
291 lib->plugins->load(lib->plugins, plugins);
292
293 creds = mem_cred_create();
294 lib->credmgr->add_set(lib->credmgr, &creds->set);
295
296 atexit(cleanup);
297 }
298
299 int main(int argc, char *argv[])
300 {
301 char *address = NULL;
302 bool listen = FALSE;
303 int port = 0, times = -1, res, family = AF_UNSPEC;
304 identification_t *server, *client = NULL;
305 tls_version_t min_version = TLS_SUPPORTED_MIN, max_version = TLS_SUPPORTED_MAX;
306 tls_flag_t flags = TLS_FLAG_ENCRYPTION_OPTIONAL;
307 tls_cache_t *cache;
308 host_t *host;
309
310 init();
311
312 while (TRUE)
313 {
314 struct option long_opts[] = {
315 {"help", no_argument, NULL, 'h' },
316 {"connect", required_argument, NULL, 'c' },
317 {"listen", required_argument, NULL, 'l' },
318 {"port", required_argument, NULL, 'p' },
319 {"cert", required_argument, NULL, 'x' },
320 {"key", required_argument, NULL, 'k' },
321 {"cacert", required_argument, NULL, 'f' },
322 {"times", required_argument, NULL, 't' },
323 {"ipv4", no_argument, NULL, '4' },
324 {"ipv6", no_argument, NULL, '6' },
325 {"min-version", required_argument, NULL, 'm' },
326 {"max-version", required_argument, NULL, 'M' },
327 {"version", required_argument, NULL, 'v' },
328 {"auth-optional", no_argument, NULL, 'n' },
329 {"debug", required_argument, NULL, 'd' },
330 {0,0,0,0 }
331 };
332 switch (getopt_long(argc, argv, "", long_opts, NULL))
333 {
334 case EOF:
335 break;
336 case 'h':
337 usage(stdout, argv[0]);
338 return 0;
339 case 'x':
340 if (!load_certificate(optarg))
341 {
342 return 1;
343 }
344 continue;
345 case 'k':
346 if (!load_key(optarg))
347 {
348 return 1;
349 }
350 continue;
351 case 'f':
352 if (!load_certificate(optarg))
353 {
354 return 1;
355 }
356 client = identification_create_from_encoding(ID_ANY, chunk_empty);
357 continue;
358 case 'l':
359 listen = TRUE;
360 /* fall */
361 case 'c':
362 if (address)
363 {
364 usage(stderr, argv[0]);
365 return 1;
366 }
367 address = optarg;
368 continue;
369 case 'p':
370 port = atoi(optarg);
371 continue;
372 case 't':
373 times = atoi(optarg);
374 continue;
375 case 'd':
376 tls_level = atoi(optarg);
377 continue;
378 case '4':
379 family = AF_INET;
380 continue;
381 case '6':
382 family = AF_INET6;
383 continue;
384 case 'm':
385 if (!enum_from_name(tls_numeric_version_names, optarg,
386 &min_version))
387 {
388 fprintf(stderr, "unknown minimum TLS version: %s\n", optarg);
389 return 1;
390 }
391 continue;
392 case 'M':
393 if (!enum_from_name(tls_numeric_version_names, optarg,
394 &max_version))
395 {
396 fprintf(stderr, "unknown maximum TLS version: %s\n", optarg);
397 return 1;
398 }
399 continue;
400 case 'v':
401 if (!enum_from_name(tls_numeric_version_names, optarg,
402 &min_version))
403 {
404 fprintf(stderr, "unknown TLS version: %s\n", optarg);
405 return 1;
406 }
407 max_version = min_version;
408 continue;
409 case 'n':
410 flags |= TLS_FLAG_CLIENT_AUTH_OPTIONAL;
411 continue;
412 default:
413 usage(stderr, argv[0]);
414 return 1;
415 }
416 break;
417 }
418 if (!port || !address)
419 {
420 usage(stderr, argv[0]);
421 return 1;
422 }
423 host = host_create_from_dns(address, family, port);
424 if (!host)
425 {
426 DBG1(DBG_TLS, "resolving hostname %s failed", address);
427 return 1;
428 }
429 server = identification_create_from_string(address);
430 cache = tls_cache_create(100, 30);
431 if (listen)
432 {
433 res = serve(host, server, client, times, cache, min_version,
434 max_version, flags);
435 }
436 else
437 {
438 DESTROY_IF(client);
439 client = find_client_id();
440 res = run_client(host, server, client, times, cache, min_version,
441 max_version, flags);
442 DESTROY_IF(client);
443 }
444 cache->destroy(cache);
445 host->destroy(host);
446 server->destroy(server);
447 return res;
448 }