child-rekey: Suppress updown event when deleting redundant CHILD_SAs
[strongswan.git] / scripts / tls_test.c
index b4d11e6..84a32f9 100644 (file)
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <getopt.h>
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <getopt.h>
+#include <errno.h>
+#include <string.h>
 
 #include <library.h>
 
 #include <library.h>
-#include <debug.h>
+#include <utils/debug.h>
 #include <tls_socket.h>
 #include <tls_socket.h>
-#include <utils/host.h>
+#include <networking/host.h>
 #include <credentials/sets/mem_cred.h>
 
 /**
 #include <credentials/sets/mem_cred.h>
 
 /**
 static void usage(FILE *out, char *cmd)
 {
        fprintf(out, "usage:\n");
 static void usage(FILE *out, char *cmd)
 {
        fprintf(out, "usage:\n");
-       fprintf(out, "  %s --connect <address> --port <port> [--cert <file>]+\n", cmd);
-       fprintf(out, "  %s --listen <address> --port <port> --key <key> [--cert <file>]+ --oneshot\n", cmd);
+       fprintf(out, "  %s --connect <address> --port <port> [--key <key] [--cert <file>]+ [--times <n>]\n", cmd);
+       fprintf(out, "  %s --listen <address> --port <port> --key <key> [--cert <file>]+ [--times <n>]\n", cmd);
 }
 
 /**
 }
 
 /**
- * Stream between stdio and TLS socket
+ * Check, as client, if we have a client certificate with private key
  */
  */
-static int stream(int fd, tls_socket_t *tls)
+static identification_t *find_client_id()
 {
 {
-       while (TRUE)
-       {
-               fd_set set;
-               chunk_t data;
-
-               FD_ZERO(&set);
-               FD_SET(fd, &set);
-               FD_SET(0, &set);
+       identification_t *client = NULL, *keyid;
+       enumerator_t *enumerator;
+       certificate_t *cert;
+       public_key_t *pubkey;
+       private_key_t *privkey;
+       chunk_t chunk;
 
 
-               if (select(fd + 1, &set, NULL, NULL, NULL) == -1)
-               {
-                       return 1;
-               }
-               if (FD_ISSET(fd, &set))
-               {
-                       if (!tls->read(tls, &data))
-                       {
-                               DBG1(DBG_TLS, "TLS read error/end\n");
-                               return 1;
-                       }
-                       if (data.len)
-                       {
-                               ignore_result(write(1, data.ptr, data.len));
-                               free(data.ptr);
-                       }
-               }
-               if (FD_ISSET(0, &set))
+       enumerator = lib->credmgr->create_cert_enumerator(lib->credmgr,
+                                                                                       CERT_X509, KEY_ANY, NULL, FALSE);
+       while (enumerator->enumerate(enumerator, &cert))
+       {
+               pubkey = cert->get_public_key(cert);
+               if (pubkey)
                {
                {
-                       char buf[1024];
-                       ssize_t len;
-
-                       len = read(0, buf, sizeof(buf));
-                       if (len == 0)
+                       if (pubkey->get_fingerprint(pubkey, KEYID_PUBKEY_SHA1, &chunk))
                        {
                        {
-                               return 0;
-                       }
-                       if (len > 0)
-                       {
-                               if (!tls->write(tls, chunk_create(buf, len)))
+                               keyid = identification_create_from_encoding(ID_KEY_ID, chunk);
+                               privkey = lib->credmgr->get_private(lib->credmgr,
+                                                                       pubkey->get_type(pubkey), keyid, NULL);
+                               keyid->destroy(keyid);
+                               if (privkey)
                                {
                                {
-                                       DBG1(DBG_TLS, "TLS write error\n");
-                                       return 1;
+                                       client = cert->get_subject(cert);
+                                       client = client->clone(client);
+                                       privkey->destroy(privkey);
                                }
                        }
                                }
                        }
+                       pubkey->destroy(pubkey);
+               }
+               if (client)
+               {
+                       break;
                }
        }
                }
        }
+       enumerator->destroy(enumerator);
+
+       return client;
 }
 
 /**
  * Client routine
  */
 }
 
 /**
  * Client routine
  */
-static int client(int fd, host_t *host, identification_t *server)
+static int run_client(host_t *host, identification_t *server,
+                                         identification_t *client, int times, tls_cache_t *cache)
 {
        tls_socket_t *tls;
 {
        tls_socket_t *tls;
-       int res;
+       int fd, res;
 
 
-       if (connect(fd, host->get_sockaddr(host),
-                               *host->get_sockaddr_len(host)) == -1)
+       while (times == -1 || times-- > 0)
        {
        {
-               DBG1(DBG_TLS, "connecting to %#H failed: %m\n", host);
-               return 1;
-       }
-       tls = tls_socket_create(FALSE, server, NULL, fd);
-       if (!tls)
-       {
-               return 1;
+               fd = socket(AF_INET, SOCK_STREAM, 0);
+               if (fd == -1)
+               {
+                       DBG1(DBG_TLS, "opening socket failed: %s", strerror(errno));
+                       return 1;
+               }
+               if (connect(fd, host->get_sockaddr(host),
+                                       *host->get_sockaddr_len(host)) == -1)
+               {
+                       DBG1(DBG_TLS, "connecting to %#H failed: %s", host, strerror(errno));
+                       close(fd);
+                       return 1;
+               }
+               tls = tls_socket_create(FALSE, server, client, fd, cache, TLS_1_2, TRUE);
+               if (!tls)
+               {
+                       close(fd);
+                       return 1;
+               }
+               res = tls->splice(tls, 0, 1) ? 0 : 1;
+               tls->destroy(tls);
+               close(fd);
+               if (res)
+               {
+                       break;
+               }
        }
        }
-       res = stream(fd, tls);
-       tls->destroy(tls);
        return res;
 }
 
 /**
  * Server routine
  */
        return res;
 }
 
 /**
  * Server routine
  */
-static int serve(int fd, host_t *host, identification_t *server, bool oneshot)
+static int serve(host_t *host, identification_t *server,
+                                int times, tls_cache_t *cache)
 {
        tls_socket_t *tls;
 {
        tls_socket_t *tls;
-       int cfd;
+       int fd, cfd;
 
 
+       fd = socket(AF_INET, SOCK_STREAM, 0);
+       if (fd == -1)
+       {
+               DBG1(DBG_TLS, "opening socket failed: %s", strerror(errno));
+               return 1;
+       }
        if (bind(fd, host->get_sockaddr(host),
                         *host->get_sockaddr_len(host)) == -1)
        {
        if (bind(fd, host->get_sockaddr(host),
                         *host->get_sockaddr_len(host)) == -1)
        {
-               DBG1(DBG_TLS, "binding to %#H failed: %m\n", host);
+               DBG1(DBG_TLS, "binding to %#H failed: %s", host, strerror(errno));
+               close(fd);
                return 1;
        }
        if (listen(fd, 1) == -1)
        {
                return 1;
        }
        if (listen(fd, 1) == -1)
        {
-               DBG1(DBG_TLS, "listen to %#H failed: %m\n", host);
+               DBG1(DBG_TLS, "listen to %#H failed: %m", host, strerror(errno));
+               close(fd);
                return 1;
        }
 
                return 1;
        }
 
-       do
+       while (times == -1 || times-- > 0)
        {
                cfd = accept(fd, host->get_sockaddr(host), host->get_sockaddr_len(host));
                if (cfd == -1)
                {
        {
                cfd = accept(fd, host->get_sockaddr(host), host->get_sockaddr_len(host));
                if (cfd == -1)
                {
-                       DBG1(DBG_TLS, "accept failed: %m\n");
+                       DBG1(DBG_TLS, "accept failed: %s", strerror(errno));
+                       close(fd);
                        return 1;
                }
                        return 1;
                }
-               DBG1(DBG_TLS, "%#H connected\n", host);
+               DBG1(DBG_TLS, "%#H connected", host);
 
 
-               tls = tls_socket_create(TRUE, server, NULL, cfd);
+               tls = tls_socket_create(TRUE, server, NULL, cfd, cache, TLS_1_2, TRUE);
                if (!tls)
                {
                if (!tls)
                {
+                       close(fd);
                        return 1;
                }
                        return 1;
                }
-               stream(cfd, tls);
-               DBG1(DBG_TLS, "%#H disconnected\n", host);
+               tls->splice(tls, 0, 1);
+               DBG1(DBG_TLS, "%#H disconnected", host);
                tls->destroy(tls);
        }
                tls->destroy(tls);
        }
-       while (!oneshot);
+       close(fd);
 
        return 0;
 }
 
        return 0;
 }
@@ -172,7 +193,7 @@ static bool load_certificate(char *filename)
                                                          BUILD_FROM_FILE, filename, BUILD_END);
        if (!cert)
        {
                                                          BUILD_FROM_FILE, filename, BUILD_END);
        if (!cert)
        {
-               DBG1(DBG_TLS, "loading certificate from '%s' failed\n", filename);
+               DBG1(DBG_TLS, "loading certificate from '%s' failed", filename);
                return FALSE;
        }
        creds->add_cert(creds, TRUE, cert);
                return FALSE;
        }
        creds->add_cert(creds, TRUE, cert);
@@ -190,7 +211,7 @@ static bool load_key(char *filename)
                                                          BUILD_FROM_FILE, filename, BUILD_END);
        if (!key)
        {
                                                          BUILD_FROM_FILE, filename, BUILD_END);
        if (!key)
        {
-               DBG1(DBG_TLS, "loading key from '%s' failed\n", filename);
+               DBG1(DBG_TLS, "loading key from '%s' failed", filename);
                return FALSE;
        }
        creds->add_key(creds, key);
                return FALSE;
        }
        creds->add_key(creds, key);
@@ -230,11 +251,11 @@ static void cleanup()
  */
 static void init()
 {
  */
 static void init()
 {
-       library_init(NULL);
+       library_init(NULL, "tls_test");
 
        dbg = dbg_tls;
 
 
        dbg = dbg_tls;
 
-       lib->plugins->load(lib->plugins, NULL, PLUGINS);
+       lib->plugins->load(lib->plugins, PLUGINS);
 
        creds = mem_cred_create();
        lib->credmgr->add_set(lib->credmgr, &creds->set);
 
        creds = mem_cred_create();
        lib->credmgr->add_set(lib->credmgr, &creds->set);
@@ -245,9 +266,10 @@ static void init()
 int main(int argc, char *argv[])
 {
        char *address = NULL;
 int main(int argc, char *argv[])
 {
        char *address = NULL;
-       bool listen = FALSE, oneshot = FALSE;
-       int port = 0, fd, res;
-       identification_t *server;
+       bool listen = FALSE;
+       int port = 0, times = -1, res;
+       identification_t *server, *client;
+       tls_cache_t *cache;
        host_t *host;
 
        init();
        host_t *host;
 
        init();
@@ -261,7 +283,7 @@ int main(int argc, char *argv[])
                        {"port",                required_argument,              NULL,           'p' },
                        {"cert",                required_argument,              NULL,           'x' },
                        {"key",                 required_argument,              NULL,           'k' },
                        {"port",                required_argument,              NULL,           'p' },
                        {"cert",                required_argument,              NULL,           'x' },
                        {"key",                 required_argument,              NULL,           'k' },
-                       {"oneshot",             no_argument,                    NULL,           'o' },
+                       {"times",               required_argument,              NULL,           't' },
                        {"debug",               required_argument,              NULL,           'd' },
                        {0,0,0,0 }
                };
                        {"debug",               required_argument,              NULL,           'd' },
                        {0,0,0,0 }
                };
@@ -298,8 +320,8 @@ int main(int argc, char *argv[])
                        case 'p':
                                port = atoi(optarg);
                                continue;
                        case 'p':
                                port = atoi(optarg);
                                continue;
-                       case 'o':
-                               oneshot = TRUE;
+                       case 't':
+                               times = atoi(optarg);
                                continue;
                        case 'd':
                                tls_level = atoi(optarg);
                                continue;
                        case 'd':
                                tls_level = atoi(optarg);
@@ -315,37 +337,26 @@ int main(int argc, char *argv[])
                usage(stderr, argv[0]);
                return 1;
        }
                usage(stderr, argv[0]);
                return 1;
        }
-       if (oneshot && !listen)
-       {
-               usage(stderr, argv[0]);
-               return 1;
-       }
-
-       fd = socket(AF_INET, SOCK_STREAM, 0);
-       if (fd == -1)
-       {
-               DBG1(DBG_TLS, "opening socket failed: %m\n");
-               return 1;
-       }
        host = host_create_from_dns(address, 0, port);
        if (!host)
        {
        host = host_create_from_dns(address, 0, port);
        if (!host)
        {
-               DBG1(DBG_TLS, "resolving hostname %s failed\n", address);
-               close(fd);
+               DBG1(DBG_TLS, "resolving hostname %s failed", address);
                return 1;
        }
        server = identification_create_from_string(address);
                return 1;
        }
        server = identification_create_from_string(address);
+       cache = tls_cache_create(100, 30);
        if (listen)
        {
        if (listen)
        {
-               res = serve(fd, host, server, oneshot);
+               res = serve(host, server, times, cache);
        }
        else
        {
        }
        else
        {
-               res = client(fd, host, server);
+               client = find_client_id();
+               res = run_client(host, server, client, times, cache);
+               DESTROY_IF(client);
        }
        }
-       close(fd);
+       cache->destroy(cache);
        host->destroy(host);
        server->destroy(server);
        return res;
 }
        host->destroy(host);
        server->destroy(server);
        return res;
 }
-