tls-test: Make plugin list configurable via environment variable
[strongswan.git] / scripts / tls_test.c
1 /*
2 * Copyright (C) 2010 Martin Willi
3 * Copyright (C) 2010 revosec AG
4 *
5 * This program is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License as published by the
7 * Free Software Foundation; either version 2 of the License, or (at your
8 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
9 *
10 * This program is distributed in the hope that it will be useful, but
11 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
12 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 * for more details.
14 */
15
16 #include <unistd.h>
17 #include <stdio.h>
18 #include <sys/types.h>
19 #include <sys/socket.h>
20 #include <getopt.h>
21 #include <errno.h>
22 #include <string.h>
23
24 #include <library.h>
25 #include <utils/debug.h>
26 #include <tls_socket.h>
27 #include <networking/host.h>
28 #include <credentials/sets/mem_cred.h>
29
30 /**
31 * Print usage information
32 */
33 static void usage(FILE *out, char *cmd)
34 {
35 fprintf(out, "usage:\n");
36 fprintf(out, " %s --connect <address> --port <port> [--key <key] [--cert <file>]+ [--times <n>]\n", cmd);
37 fprintf(out, " %s --listen <address> --port <port> --key <key> [--cert <file>]+ [--times <n>]\n", cmd);
38 }
39
40 /**
41 * Check, as client, if we have a client certificate with private key
42 */
43 static identification_t *find_client_id()
44 {
45 identification_t *client = NULL, *keyid;
46 enumerator_t *enumerator;
47 certificate_t *cert;
48 public_key_t *pubkey;
49 private_key_t *privkey;
50 chunk_t chunk;
51
52 enumerator = lib->credmgr->create_cert_enumerator(lib->credmgr,
53 CERT_X509, KEY_ANY, NULL, FALSE);
54 while (enumerator->enumerate(enumerator, &cert))
55 {
56 pubkey = cert->get_public_key(cert);
57 if (pubkey)
58 {
59 if (pubkey->get_fingerprint(pubkey, KEYID_PUBKEY_SHA1, &chunk))
60 {
61 keyid = identification_create_from_encoding(ID_KEY_ID, chunk);
62 privkey = lib->credmgr->get_private(lib->credmgr,
63 pubkey->get_type(pubkey), keyid, NULL);
64 keyid->destroy(keyid);
65 if (privkey)
66 {
67 client = cert->get_subject(cert);
68 client = client->clone(client);
69 privkey->destroy(privkey);
70 }
71 }
72 pubkey->destroy(pubkey);
73 }
74 if (client)
75 {
76 break;
77 }
78 }
79 enumerator->destroy(enumerator);
80
81 return client;
82 }
83
84 /**
85 * Client routine
86 */
87 static int run_client(host_t *host, identification_t *server,
88 identification_t *client, int times, tls_cache_t *cache,
89 tls_version_t min_version, tls_version_t max_version)
90 {
91 tls_socket_t *tls;
92 int fd, res;
93
94 while (times == -1 || times-- > 0)
95 {
96 DBG2(DBG_TLS, "connecting to %#H", host);
97 fd = socket(host->get_family(host), SOCK_STREAM, 0);
98 if (fd == -1)
99 {
100 DBG1(DBG_TLS, "opening socket failed: %s", strerror(errno));
101 return 1;
102 }
103 if (connect(fd, host->get_sockaddr(host),
104 *host->get_sockaddr_len(host)) == -1)
105 {
106 DBG1(DBG_TLS, "connecting to %#H failed: %s", host, strerror(errno));
107 close(fd);
108 return 1;
109 }
110 tls = tls_socket_create(FALSE, server, client, fd, cache, min_version,
111 max_version, TRUE);
112 if (!tls)
113 {
114 close(fd);
115 return 1;
116 }
117 res = tls->splice(tls, 0, 1) ? 0 : 1;
118 tls->destroy(tls);
119 close(fd);
120 if (res)
121 {
122 break;
123 }
124 }
125 return res;
126 }
127
128 /**
129 * Server routine
130 */
131 static int serve(host_t *host, identification_t *server,
132 int times, tls_cache_t *cache, tls_version_t min_version,
133 tls_version_t max_version)
134 {
135 tls_socket_t *tls;
136 int fd, cfd;
137
138 fd = socket(AF_INET, SOCK_STREAM, 0);
139 if (fd == -1)
140 {
141 DBG1(DBG_TLS, "opening socket failed: %s", strerror(errno));
142 return 1;
143 }
144 if (bind(fd, host->get_sockaddr(host),
145 *host->get_sockaddr_len(host)) == -1)
146 {
147 DBG1(DBG_TLS, "binding to %#H failed: %s", host, strerror(errno));
148 close(fd);
149 return 1;
150 }
151 if (listen(fd, 1) == -1)
152 {
153 DBG1(DBG_TLS, "listen to %#H failed: %m", host, strerror(errno));
154 close(fd);
155 return 1;
156 }
157
158 while (times == -1 || times-- > 0)
159 {
160 cfd = accept(fd, host->get_sockaddr(host), host->get_sockaddr_len(host));
161 if (cfd == -1)
162 {
163 DBG1(DBG_TLS, "accept failed: %s", strerror(errno));
164 close(fd);
165 return 1;
166 }
167 DBG1(DBG_TLS, "%#H connected", host);
168
169 tls = tls_socket_create(TRUE, server, NULL, cfd, cache, min_version,
170 max_version, TRUE);
171 if (!tls)
172 {
173 close(fd);
174 return 1;
175 }
176 tls->splice(tls, 0, 1);
177 DBG1(DBG_TLS, "%#H disconnected", host);
178 tls->destroy(tls);
179 }
180 close(fd);
181
182 return 0;
183 }
184
185 /**
186 * In-Memory credential set
187 */
188 static mem_cred_t *creds;
189
190 /**
191 * Load certificate from file
192 */
193 static bool load_certificate(char *filename)
194 {
195 certificate_t *cert;
196
197 cert = lib->creds->create(lib->creds, CRED_CERTIFICATE, CERT_X509,
198 BUILD_FROM_FILE, filename, BUILD_END);
199 if (!cert)
200 {
201 DBG1(DBG_TLS, "loading certificate from '%s' failed", filename);
202 return FALSE;
203 }
204 creds->add_cert(creds, TRUE, cert);
205 return TRUE;
206 }
207
208 /**
209 * Load private key from file
210 */
211 static bool load_key(char *filename)
212 {
213 private_key_t *key;
214
215 key = lib->creds->create(lib->creds, CRED_PRIVATE_KEY, KEY_RSA,
216 BUILD_FROM_FILE, filename, BUILD_END);
217 if (!key)
218 {
219 DBG1(DBG_TLS, "loading key from '%s' failed", filename);
220 return FALSE;
221 }
222 creds->add_key(creds, key);
223 return TRUE;
224 }
225
226 /**
227 * TLS debug level
228 */
229 static level_t tls_level = 1;
230
231 static void dbg_tls(debug_t group, level_t level, char *fmt, ...)
232 {
233 if ((group == DBG_TLS && level <= tls_level) || level <= 1)
234 {
235 va_list args;
236
237 va_start(args, fmt);
238 vfprintf(stderr, fmt, args);
239 fprintf(stderr, "\n");
240 va_end(args);
241 }
242 }
243
244 /**
245 * Cleanup
246 */
247 static void cleanup()
248 {
249 lib->credmgr->remove_set(lib->credmgr, &creds->set);
250 creds->destroy(creds);
251 library_deinit();
252 }
253
254 /**
255 * Initialize library
256 */
257 static void init()
258 {
259 char *plugins;
260
261 library_init(NULL, "tls_test");
262
263 dbg = dbg_tls;
264
265 plugins = getenv("PLUGINS") ?: PLUGINS;
266 lib->plugins->load(lib->plugins, plugins);
267
268 creds = mem_cred_create();
269 lib->credmgr->add_set(lib->credmgr, &creds->set);
270
271 atexit(cleanup);
272 }
273
274 /**
275 * Used to parse TLS versions
276 */
277 ENUM(numeric_version_names, TLS_1_0, TLS_1_3,
278 "1.0",
279 "1.1",
280 "1.2",
281 "1.3");
282
283 int main(int argc, char *argv[])
284 {
285 char *address = NULL;
286 bool listen = FALSE;
287 int port = 0, times = -1, res, family = AF_UNSPEC;
288 identification_t *server, *client;
289 tls_version_t min_version = TLS_1_0, max_version = TLS_1_3;
290 tls_cache_t *cache;
291 host_t *host;
292
293 init();
294
295 while (TRUE)
296 {
297 struct option long_opts[] = {
298 {"help", no_argument, NULL, 'h' },
299 {"connect", required_argument, NULL, 'c' },
300 {"listen", required_argument, NULL, 'l' },
301 {"port", required_argument, NULL, 'p' },
302 {"cert", required_argument, NULL, 'x' },
303 {"key", required_argument, NULL, 'k' },
304 {"times", required_argument, NULL, 't' },
305 {"ipv4", no_argument, NULL, '4' },
306 {"ipv6", no_argument, NULL, '6' },
307 {"min-version", required_argument, NULL, 'm' },
308 {"max-version", required_argument, NULL, 'M' },
309 {"version", required_argument, NULL, 'v' },
310 {"debug", required_argument, NULL, 'd' },
311 {0,0,0,0 }
312 };
313 switch (getopt_long(argc, argv, "", long_opts, NULL))
314 {
315 case EOF:
316 break;
317 case 'h':
318 usage(stdout, argv[0]);
319 return 0;
320 case 'x':
321 if (!load_certificate(optarg))
322 {
323 return 1;
324 }
325 continue;
326 case 'k':
327 if (!load_key(optarg))
328 {
329 return 1;
330 }
331 continue;
332 case 'l':
333 listen = TRUE;
334 /* fall */
335 case 'c':
336 if (address)
337 {
338 usage(stderr, argv[0]);
339 return 1;
340 }
341 address = optarg;
342 continue;
343 case 'p':
344 port = atoi(optarg);
345 continue;
346 case 't':
347 times = atoi(optarg);
348 continue;
349 case 'd':
350 tls_level = atoi(optarg);
351 continue;
352 case '4':
353 family = AF_INET;
354 continue;
355 case '6':
356 family = AF_INET6;
357 continue;
358 case 'm':
359 if (!enum_from_name(numeric_version_names, optarg, &min_version))
360 {
361 fprintf(stderr, "unknown minimum TLS version: %s\n", optarg);
362 return 1;
363 }
364 continue;
365 case 'M':
366 if (!enum_from_name(numeric_version_names, optarg, &max_version))
367 {
368 fprintf(stderr, "unknown maximum TLS version: %s\n", optarg);
369 return 1;
370 }
371 continue;
372 case 'v':
373 if (!enum_from_name(numeric_version_names, optarg, &min_version))
374 {
375 fprintf(stderr, "unknown TLS version: %s\n", optarg);
376 return 1;
377 }
378 max_version = min_version;
379 continue;
380 default:
381 usage(stderr, argv[0]);
382 return 1;
383 }
384 break;
385 }
386 if (!port || !address)
387 {
388 usage(stderr, argv[0]);
389 return 1;
390 }
391 host = host_create_from_dns(address, family, port);
392 if (!host)
393 {
394 DBG1(DBG_TLS, "resolving hostname %s failed", address);
395 return 1;
396 }
397 server = identification_create_from_string(address);
398 cache = tls_cache_create(100, 30);
399 if (listen)
400 {
401 res = serve(host, server, times, cache, min_version, max_version);
402 }
403 else
404 {
405 client = find_client_id();
406 res = run_client(host, server, client, times, cache, min_version,
407 max_version);
408 DESTROY_IF(client);
409 }
410 cache->destroy(cache);
411 host->destroy(host);
412 server->destroy(server);
413 return res;
414 }