d0d259e60041ff411c69175e1e83567c01371633
[strongswan.git] / scripts / tls_test.c
1 /*
2 * Copyright (C) 2010 Martin Willi
3 * Copyright (C) 2010 revosec AG
4 *
5 * This program is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License as published by the
7 * Free Software Foundation; either version 2 of the License, or (at your
8 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
9 *
10 * This program is distributed in the hope that it will be useful, but
11 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
12 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 * for more details.
14 */
15
16 #include <unistd.h>
17 #include <stdio.h>
18 #include <sys/types.h>
19 #include <sys/socket.h>
20 #include <getopt.h>
21 #include <errno.h>
22 #include <string.h>
23
24 #include <library.h>
25 #include <utils/debug.h>
26 #include <tls_socket.h>
27 #include <networking/host.h>
28 #include <credentials/sets/mem_cred.h>
29
30 /**
31 * Print usage information
32 */
33 static void usage(FILE *out, char *cmd)
34 {
35 fprintf(out, "usage:\n");
36 fprintf(out, " %s --connect <address> --port <port> [--cert <file>]+ [--times <n>]\n", cmd);
37 fprintf(out, " %s --listen <address> --port <port> --key <key> [--cert <file>]+ [--times <n>]\n", cmd);
38 }
39
40 /**
41 * Client routine
42 */
43 static int client(host_t *host, identification_t *server,
44 int times, tls_cache_t *cache)
45 {
46 tls_socket_t *tls;
47 int fd, res;
48
49 while (times == -1 || times-- > 0)
50 {
51 fd = socket(AF_INET, SOCK_STREAM, 0);
52 if (fd == -1)
53 {
54 DBG1(DBG_TLS, "opening socket failed: %s", strerror(errno));
55 return 1;
56 }
57 if (connect(fd, host->get_sockaddr(host),
58 *host->get_sockaddr_len(host)) == -1)
59 {
60 DBG1(DBG_TLS, "connecting to %#H failed: %s", host, strerror(errno));
61 close(fd);
62 return 1;
63 }
64 tls = tls_socket_create(FALSE, server, NULL, fd, cache);
65 if (!tls)
66 {
67 close(fd);
68 return 1;
69 }
70 res = tls->splice(tls, 0, 1) ? 0 : 1;
71 tls->destroy(tls);
72 close(fd);
73 if (res)
74 {
75 break;
76 }
77 }
78 return res;
79 }
80
81 /**
82 * Server routine
83 */
84 static int serve(host_t *host, identification_t *server,
85 int times, tls_cache_t *cache)
86 {
87 tls_socket_t *tls;
88 int fd, cfd;
89
90 fd = socket(AF_INET, SOCK_STREAM, 0);
91 if (fd == -1)
92 {
93 DBG1(DBG_TLS, "opening socket failed: %s", strerror(errno));
94 return 1;
95 }
96 if (bind(fd, host->get_sockaddr(host),
97 *host->get_sockaddr_len(host)) == -1)
98 {
99 DBG1(DBG_TLS, "binding to %#H failed: %s", host, strerror(errno));
100 close(fd);
101 return 1;
102 }
103 if (listen(fd, 1) == -1)
104 {
105 DBG1(DBG_TLS, "listen to %#H failed: %m", host, strerror(errno));
106 close(fd);
107 return 1;
108 }
109
110 while (times == -1 || times-- > 0)
111 {
112 cfd = accept(fd, host->get_sockaddr(host), host->get_sockaddr_len(host));
113 if (cfd == -1)
114 {
115 DBG1(DBG_TLS, "accept failed: %s", strerror(errno));
116 close(fd);
117 return 1;
118 }
119 DBG1(DBG_TLS, "%#H connected", host);
120
121 tls = tls_socket_create(TRUE, server, NULL, cfd, cache);
122 if (!tls)
123 {
124 close(fd);
125 return 1;
126 }
127 tls->splice(tls, 0, 1);
128 DBG1(DBG_TLS, "%#H disconnected", host);
129 tls->destroy(tls);
130 }
131 close(fd);
132
133 return 0;
134 }
135
136 /**
137 * In-Memory credential set
138 */
139 static mem_cred_t *creds;
140
141 /**
142 * Load certificate from file
143 */
144 static bool load_certificate(char *filename)
145 {
146 certificate_t *cert;
147
148 cert = lib->creds->create(lib->creds, CRED_CERTIFICATE, CERT_X509,
149 BUILD_FROM_FILE, filename, BUILD_END);
150 if (!cert)
151 {
152 DBG1(DBG_TLS, "loading certificate from '%s' failed", filename);
153 return FALSE;
154 }
155 creds->add_cert(creds, TRUE, cert);
156 return TRUE;
157 }
158
159 /**
160 * Load private key from file
161 */
162 static bool load_key(char *filename)
163 {
164 private_key_t *key;
165
166 key = lib->creds->create(lib->creds, CRED_PRIVATE_KEY, KEY_RSA,
167 BUILD_FROM_FILE, filename, BUILD_END);
168 if (!key)
169 {
170 DBG1(DBG_TLS, "loading key from '%s' failed", filename);
171 return FALSE;
172 }
173 creds->add_key(creds, key);
174 return TRUE;
175 }
176
177 /**
178 * TLS debug level
179 */
180 static level_t tls_level = 1;
181
182 static void dbg_tls(debug_t group, level_t level, char *fmt, ...)
183 {
184 if ((group == DBG_TLS && level <= tls_level) || level <= 1)
185 {
186 va_list args;
187
188 va_start(args, fmt);
189 vfprintf(stderr, fmt, args);
190 fprintf(stderr, "\n");
191 va_end(args);
192 }
193 }
194
195 /**
196 * Cleanup
197 */
198 static void cleanup()
199 {
200 lib->credmgr->remove_set(lib->credmgr, &creds->set);
201 creds->destroy(creds);
202 library_deinit();
203 }
204
205 /**
206 * Initialize library
207 */
208 static void init()
209 {
210 library_init(NULL);
211
212 dbg = dbg_tls;
213
214 lib->plugins->load(lib->plugins, NULL, PLUGINS);
215
216 creds = mem_cred_create();
217 lib->credmgr->add_set(lib->credmgr, &creds->set);
218
219 atexit(cleanup);
220 }
221
222 int main(int argc, char *argv[])
223 {
224 char *address = NULL;
225 bool listen = FALSE;
226 int port = 0, times = -1, res;
227 identification_t *server;
228 tls_cache_t *cache;
229 host_t *host;
230
231 init();
232
233 while (TRUE)
234 {
235 struct option long_opts[] = {
236 {"help", no_argument, NULL, 'h' },
237 {"connect", required_argument, NULL, 'c' },
238 {"listen", required_argument, NULL, 'l' },
239 {"port", required_argument, NULL, 'p' },
240 {"cert", required_argument, NULL, 'x' },
241 {"key", required_argument, NULL, 'k' },
242 {"times", required_argument, NULL, 't' },
243 {"debug", required_argument, NULL, 'd' },
244 {0,0,0,0 }
245 };
246 switch (getopt_long(argc, argv, "", long_opts, NULL))
247 {
248 case EOF:
249 break;
250 case 'h':
251 usage(stdout, argv[0]);
252 return 0;
253 case 'x':
254 if (!load_certificate(optarg))
255 {
256 return 1;
257 }
258 continue;
259 case 'k':
260 if (!load_key(optarg))
261 {
262 return 1;
263 }
264 continue;
265 case 'l':
266 listen = TRUE;
267 /* fall */
268 case 'c':
269 if (address)
270 {
271 usage(stderr, argv[0]);
272 return 1;
273 }
274 address = optarg;
275 continue;
276 case 'p':
277 port = atoi(optarg);
278 continue;
279 case 't':
280 times = atoi(optarg);
281 continue;
282 case 'd':
283 tls_level = atoi(optarg);
284 continue;
285 default:
286 usage(stderr, argv[0]);
287 return 1;
288 }
289 break;
290 }
291 if (!port || !address)
292 {
293 usage(stderr, argv[0]);
294 return 1;
295 }
296 host = host_create_from_dns(address, 0, port);
297 if (!host)
298 {
299 DBG1(DBG_TLS, "resolving hostname %s failed", address);
300 return 1;
301 }
302 server = identification_create_from_string(address);
303 cache = tls_cache_create(100, 30);
304 if (listen)
305 {
306 res = serve(host, server, times, cache);
307 }
308 else
309 {
310 res = client(host, server, times, cache);
311 }
312 cache->destroy(cache);
313 host->destroy(host);
314 server->destroy(server);
315 return res;
316 }
317