Implement a SASL PLAIN mechanism using shared secrets
[strongswan.git] / scripts / tls_test.c
index b4d11e6..d0d259e 100644 (file)
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <getopt.h>
+#include <errno.h>
+#include <string.h>
 
 #include <library.h>
-#include <debug.h>
+#include <utils/debug.h>
 #include <tls_socket.h>
-#include <utils/host.h>
+#include <networking/host.h>
 #include <credentials/sets/mem_cred.h>
 
 /**
 static void usage(FILE *out, char *cmd)
 {
        fprintf(out, "usage:\n");
-       fprintf(out, "  %s --connect <address> --port <port> [--cert <file>]+\n", cmd);
-       fprintf(out, "  %s --listen <address> --port <port> --key <key> [--cert <file>]+ --oneshot\n", cmd);
+       fprintf(out, "  %s --connect <address> --port <port> [--cert <file>]+ [--times <n>]\n", cmd);
+       fprintf(out, "  %s --listen <address> --port <port> --key <key> [--cert <file>]+ [--times <n>]\n", cmd);
 }
 
 /**
- * Stream between stdio and TLS socket
+ * Client routine
  */
-static int stream(int fd, tls_socket_t *tls)
+static int client(host_t *host, identification_t *server,
+                                 int times, tls_cache_t *cache)
 {
-       while (TRUE)
-       {
-               fd_set set;
-               chunk_t data;
-
-               FD_ZERO(&set);
-               FD_SET(fd, &set);
-               FD_SET(0, &set);
+       tls_socket_t *tls;
+       int fd, res;
 
-               if (select(fd + 1, &set, NULL, NULL, NULL) == -1)
+       while (times == -1 || times-- > 0)
+       {
+               fd = socket(AF_INET, SOCK_STREAM, 0);
+               if (fd == -1)
                {
+                       DBG1(DBG_TLS, "opening socket failed: %s", strerror(errno));
                        return 1;
                }
-               if (FD_ISSET(fd, &set))
+               if (connect(fd, host->get_sockaddr(host),
+                                       *host->get_sockaddr_len(host)) == -1)
                {
-                       if (!tls->read(tls, &data))
-                       {
-                               DBG1(DBG_TLS, "TLS read error/end\n");
-                               return 1;
-                       }
-                       if (data.len)
-                       {
-                               ignore_result(write(1, data.ptr, data.len));
-                               free(data.ptr);
-                       }
+                       DBG1(DBG_TLS, "connecting to %#H failed: %s", host, strerror(errno));
+                       close(fd);
+                       return 1;
                }
-               if (FD_ISSET(0, &set))
+               tls = tls_socket_create(FALSE, server, NULL, fd, cache);
+               if (!tls)
                {
-                       char buf[1024];
-                       ssize_t len;
-
-                       len = read(0, buf, sizeof(buf));
-                       if (len == 0)
-                       {
-                               return 0;
-                       }
-                       if (len > 0)
-                       {
-                               if (!tls->write(tls, chunk_create(buf, len)))
-                               {
-                                       DBG1(DBG_TLS, "TLS write error\n");
-                                       return 1;
-                               }
-                       }
+                       close(fd);
+                       return 1;
+               }
+               res = tls->splice(tls, 0, 1) ? 0 : 1;
+               tls->destroy(tls);
+               close(fd);
+               if (res)
+               {
+                       break;
                }
        }
+       return res;
 }
 
 /**
- * Client routine
+ * Server routine
  */
-static int client(int fd, host_t *host, identification_t *server)
+static int serve(host_t *host, identification_t *server,
+                                int times, tls_cache_t *cache)
 {
        tls_socket_t *tls;
-       int res;
+       int fd, cfd;
 
-       if (connect(fd, host->get_sockaddr(host),
-                               *host->get_sockaddr_len(host)) == -1)
-       {
-               DBG1(DBG_TLS, "connecting to %#H failed: %m\n", host);
-               return 1;
-       }
-       tls = tls_socket_create(FALSE, server, NULL, fd);
-       if (!tls)
+       fd = socket(AF_INET, SOCK_STREAM, 0);
+       if (fd == -1)
        {
+               DBG1(DBG_TLS, "opening socket failed: %s", strerror(errno));
                return 1;
        }
-       res = stream(fd, tls);
-       tls->destroy(tls);
-       return res;
-}
-
-/**
- * Server routine
- */
-static int serve(int fd, host_t *host, identification_t *server, bool oneshot)
-{
-       tls_socket_t *tls;
-       int cfd;
-
        if (bind(fd, host->get_sockaddr(host),
                         *host->get_sockaddr_len(host)) == -1)
        {
-               DBG1(DBG_TLS, "binding to %#H failed: %m\n", host);
+               DBG1(DBG_TLS, "binding to %#H failed: %s", host, strerror(errno));
+               close(fd);
                return 1;
        }
        if (listen(fd, 1) == -1)
        {
-               DBG1(DBG_TLS, "listen to %#H failed: %m\n", host);
+               DBG1(DBG_TLS, "listen to %#H failed: %m", host, strerror(errno));
+               close(fd);
                return 1;
        }
 
-       do
+       while (times == -1 || times-- > 0)
        {
                cfd = accept(fd, host->get_sockaddr(host), host->get_sockaddr_len(host));
                if (cfd == -1)
                {
-                       DBG1(DBG_TLS, "accept failed: %m\n");
+                       DBG1(DBG_TLS, "accept failed: %s", strerror(errno));
+                       close(fd);
                        return 1;
                }
-               DBG1(DBG_TLS, "%#H connected\n", host);
+               DBG1(DBG_TLS, "%#H connected", host);
 
-               tls = tls_socket_create(TRUE, server, NULL, cfd);
+               tls = tls_socket_create(TRUE, server, NULL, cfd, cache);
                if (!tls)
                {
+                       close(fd);
                        return 1;
                }
-               stream(cfd, tls);
-               DBG1(DBG_TLS, "%#H disconnected\n", host);
+               tls->splice(tls, 0, 1);
+               DBG1(DBG_TLS, "%#H disconnected", host);
                tls->destroy(tls);
        }
-       while (!oneshot);
+       close(fd);
 
        return 0;
 }
@@ -172,7 +149,7 @@ static bool load_certificate(char *filename)
                                                          BUILD_FROM_FILE, filename, BUILD_END);
        if (!cert)
        {
-               DBG1(DBG_TLS, "loading certificate from '%s' failed\n", filename);
+               DBG1(DBG_TLS, "loading certificate from '%s' failed", filename);
                return FALSE;
        }
        creds->add_cert(creds, TRUE, cert);
@@ -190,7 +167,7 @@ static bool load_key(char *filename)
                                                          BUILD_FROM_FILE, filename, BUILD_END);
        if (!key)
        {
-               DBG1(DBG_TLS, "loading key from '%s' failed\n", filename);
+               DBG1(DBG_TLS, "loading key from '%s' failed", filename);
                return FALSE;
        }
        creds->add_key(creds, key);
@@ -245,9 +222,10 @@ static void init()
 int main(int argc, char *argv[])
 {
        char *address = NULL;
-       bool listen = FALSE, oneshot = FALSE;
-       int port = 0, fd, res;
+       bool listen = FALSE;
+       int port = 0, times = -1, res;
        identification_t *server;
+       tls_cache_t *cache;
        host_t *host;
 
        init();
@@ -261,7 +239,7 @@ int main(int argc, char *argv[])
                        {"port",                required_argument,              NULL,           'p' },
                        {"cert",                required_argument,              NULL,           'x' },
                        {"key",                 required_argument,              NULL,           'k' },
-                       {"oneshot",             no_argument,                    NULL,           'o' },
+                       {"times",               required_argument,              NULL,           't' },
                        {"debug",               required_argument,              NULL,           'd' },
                        {0,0,0,0 }
                };
@@ -298,8 +276,8 @@ int main(int argc, char *argv[])
                        case 'p':
                                port = atoi(optarg);
                                continue;
-                       case 'o':
-                               oneshot = TRUE;
+                       case 't':
+                               times = atoi(optarg);
                                continue;
                        case 'd':
                                tls_level = atoi(optarg);
@@ -315,35 +293,23 @@ int main(int argc, char *argv[])
                usage(stderr, argv[0]);
                return 1;
        }
-       if (oneshot && !listen)
-       {
-               usage(stderr, argv[0]);
-               return 1;
-       }
-
-       fd = socket(AF_INET, SOCK_STREAM, 0);
-       if (fd == -1)
-       {
-               DBG1(DBG_TLS, "opening socket failed: %m\n");
-               return 1;
-       }
        host = host_create_from_dns(address, 0, port);
        if (!host)
        {
-               DBG1(DBG_TLS, "resolving hostname %s failed\n", address);
-               close(fd);
+               DBG1(DBG_TLS, "resolving hostname %s failed", address);
                return 1;
        }
        server = identification_create_from_string(address);
+       cache = tls_cache_create(100, 30);
        if (listen)
        {
-               res = serve(fd, host, server, oneshot);
+               res = serve(host, server, times, cache);
        }
        else
        {
-               res = client(fd, host, server);
+               res = client(host, server, times, cache);
        }
-       close(fd);
+       cache->destroy(cache);
        host->destroy(host);
        server->destroy(server);
        return res;