pt-tls-client: Support for TPM keyids
[strongswan.git] / src / pt-tls-client / pt-tls-client.c
index 4e108ad..a29d37a 100644 (file)
@@ -1,6 +1,7 @@
 /*
  * Copyright (C) 2010-2013 Martin Willi, revosec AG
- * Copyright (C) 2013 Andreas Steffen, HSR Hochschule für Technik Rapperswil
+ * Copyright (C) 2013-2015 Andreas Steffen
+ * HSR Hochschule für Technik Rapperswil
  *
  * This program is free software; you can redistribute it and/or modify it
  * under the terms of the GNU General Public License as published by the
 #include <unistd.h>
 #include <stdio.h>
 #include <sys/types.h>
-#include <sys/socket.h>
 #include <getopt.h>
 #include <errno.h>
 #include <string.h>
 #include <stdlib.h>
+#ifdef HAVE_SYSLOG
+#include <syslog.h>
+#endif
 
 #include <pt_tls.h>
 #include <pt_tls_client.h>
 /**
  * Print usage information
  */
-static void usage(FILE *out, char *cmd)
+static void usage(FILE *out)
 {
-       fprintf(out, "usage:\n");
-       fprintf(out, "  %s --connect <address> [--port <port>] [--cert <file>]+\n", cmd);
-       fprintf(out, "               [--client <client-id>] [--secret <password>]\n");
-       fprintf(out, "               [--optionsfrom <filename>]\n");
+       fprintf(out,
+               "Usage: pt-tls  --connect <hostname|address> [--port <port>]\n"
+               "              [--cert <file>]+ [--keyid <hex>|--key <file>]\n"
+               "              [--key-type rsa|ecdsa] [--client <client-id>]\n"
+               "              [--secret <password>] [--optionsfrom <filename>]\n"
+               "              [--quiet] [--debug <level>]\n");
 }
 
 /**
  * Client routine
  */
-static int client(char *address, u_int16_t port, char *identity)
+static int client(char *address, uint16_t port, char *identity)
 {
        pt_tls_client_t *assessment;
        tls_t *tnccs;
-       identification_t *server, *client;
-       host_t *host;
+       identification_t *server_id, *client_id;
+       host_t *server_ip, *client_ip;
        status_t status;
 
-       host = host_create_from_dns(address, AF_UNSPEC, port);
-       if (!host)
+       server_ip = host_create_from_dns(address, AF_UNSPEC, port);
+       if (!server_ip)
        {
                return 1;
        }
-       server = identification_create_from_string(address);
-       client = identification_create_from_string(identity);
+
+       client_ip = host_create_any(server_ip->get_family(server_ip));
+       if (!client_ip)
+       {
+               server_ip->destroy(server_ip);
+               return 1;
+       }
+       server_id = identification_create_from_string(address);
+       client_id = identification_create_from_string(identity);
+
        tnccs = (tls_t*)tnc->tnccs->create_instance(tnc->tnccs, TNCCS_2_0, FALSE,
-                                                               server, client, TNC_IFT_TLS_2_0, NULL);
+                                                               server_id, client_id, server_ip, client_ip,
+                                                               TNC_IFT_TLS_2_0, NULL);
+       client_ip->destroy(client_ip);
+
        if (!tnccs)
        {
                fprintf(stderr, "loading TNCCS failed: %s\n", PLUGINS);
-               host->destroy(host);
-               server->destroy(server);
-               client->destroy(client);
+               server_ip->destroy(server_ip);
+               server_id->destroy(server_id);
+               client_id->destroy(client_id);
                return 1;
        }
-       assessment = pt_tls_client_create(host, server, client);
+       assessment = pt_tls_client_create(server_ip, server_id, client_id);
        status = assessment->run_assessment(assessment, (tnccs_t*)tnccs);
        assessment->destroy(assessment);
        tnccs->destroy(tnccs);
-       return status;
+
+       return (status != SUCCESS);
 }
 
 
@@ -105,15 +122,26 @@ static bool load_certificate(char *filename)
 /**
  * Load private key from file
  */
-static bool load_key(char *filename)
+static bool load_key(char *keyid, char *filename, key_type_t type)
 {
        private_key_t *key;
+       chunk_t chunk;
 
-       key = lib->creds->create(lib->creds, CRED_PRIVATE_KEY, KEY_RSA,
-                                                        BUILD_FROM_FILE, filename, BUILD_END);
+       if (keyid)
+       {
+               chunk = chunk_from_hex(chunk_create(keyid, strlen(keyid)), NULL);
+               key = lib->creds->create(lib->creds, CRED_PRIVATE_KEY, KEY_ANY,
+                                                                BUILD_PKCS11_KEYID, chunk, BUILD_END);
+               chunk_free(&chunk);
+       }
+       else
+       {
+               key = lib->creds->create(lib->creds, CRED_PRIVATE_KEY, type,
+                                                                BUILD_FROM_FILE, filename, BUILD_END);
+       }
        if (!key)
        {
-               DBG1(DBG_TLS, "loading key from '%s' failed", filename);
+               DBG1(DBG_TLS, "loading key from '%s' failed", keyid ? keyid : filename);
                return FALSE;
        }
        creds->add_key(creds, key);
@@ -121,21 +149,71 @@ static bool load_key(char *filename)
 }
 
 /**
- * Debug level
+ * Logging and debug level
  */
-static level_t pt_tls_level = 1;
+static bool log_to_stderr = TRUE;
+#ifdef HAVE_SYSLOG
+static bool log_to_syslog = TRUE;
+#endif /* HAVE_SYSLOG */
+static level_t default_loglevel = 1;
 
 static void dbg_pt_tls(debug_t group, level_t level, char *fmt, ...)
 {
-       if (level <= pt_tls_level)
+       va_list args;
+
+       if (level <= default_loglevel)
        {
-               va_list args;
+               if (log_to_stderr)
+               {
+                       va_start(args, fmt);
+                       vfprintf(stderr, fmt, args);
+                       va_end(args);
+                       fprintf(stderr, "\n");
+               }
+#ifdef HAVE_SYSLOG
+               if (log_to_syslog)
+               {
+                       char buffer[8192];
+                       char *current = buffer, *next;
+
+                       /* write in memory buffer first */
+                       va_start(args, fmt);
+                       vsnprintf(buffer, sizeof(buffer), fmt, args);
+                       va_end(args);
+
+                       /* do a syslog with every line */
+                       while (current)
+                       {
+                               next = strchr(current, '\n');
+                               if (next)
+                               {
+                                       *(next++) = '\0';
+                               }
+                               syslog(LOG_INFO, "%s\n", current);
+                               current = next;
+                       }
+               }
+#endif /* HAVE_SYSLOG */
+       }
+}
 
-               va_start(args, fmt);
-               vfprintf(stderr, fmt, args);
-               fprintf(stderr, "\n");
-               va_end(args);
+/**
+ * Initialize logging to stderr/syslog
+ */
+static void init_log(const char *program)
+{
+       dbg = dbg_pt_tls;
+
+       if (log_to_stderr)
+       {
+               setbuf(stderr, NULL);
        }
+#ifdef HAVE_SYSLOG
+       if (log_to_syslog)
+       {
+               openlog(program, LOG_CONS | LOG_NDELAY | LOG_PID, LOG_AUTHPRIV);
+       }
+#endif /* HAVE_SYSLOG */
 }
 
 /**
@@ -166,19 +244,20 @@ static void init()
                        PLUGIN_PROVIDE(CUSTOM, "pt-tls-client"),
                                PLUGIN_DEPENDS(CUSTOM, "tnccs-manager"),
        };
-       library_init(NULL);
+       library_init(NULL, "pt-tls-client");
        libtnccs_init();
 
-       dbg = dbg_pt_tls;
+       init_log("pt-tls-client");
        options = options_create();
 
        lib->plugins->add_static_features(lib->plugins, "pt-tls-client", features,
-                                                                         countof(features), TRUE);
+                                                                         countof(features), TRUE, NULL, NULL);
        if (!lib->plugins->load(lib->plugins,
                        lib->settings->get_str(lib->settings, "pt-tls-client.load", PLUGINS)))
        {
                exit(SS_RC_INITIALIZATION_FAILED);
        }
+       lib->plugins->status(lib->plugins, LEVEL_CTRL);
 
        creds = mem_cred_create();
        lib->credmgr->add_set(lib->credmgr, &creds->set);
@@ -189,6 +268,8 @@ static void init()
 int main(int argc, char *argv[])
 {
        char *address = NULL, *identity = "%any", *secret = NULL;
+       char *keyid = NULL, *key_file = NULL;
+       key_type_t key_type = KEY_RSA;
        int port = PT_TLS_PORT;
 
        init();
@@ -202,7 +283,11 @@ int main(int argc, char *argv[])
                        {"secret",              required_argument,              NULL,           's' },
                        {"port",                required_argument,              NULL,           'p' },
                        {"cert",                required_argument,              NULL,           'x' },
+                       {"keyid",               required_argument,              NULL,           'K' },
                        {"key",                 required_argument,              NULL,           'k' },
+                       {"key-type",    required_argument,              NULL,           't' },
+                       {"mutual",              no_argument,                    NULL,           'm' },
+                       {"quiet",               no_argument,                    NULL,           'q' },
                        {"debug",               required_argument,              NULL,           'd' },
                        {"optionsfrom", required_argument,              NULL,           '+' },
                        {0,0,0,0 }
@@ -211,56 +296,81 @@ int main(int argc, char *argv[])
                {
                        case EOF:
                                break;
-                       case 'h':
-                               usage(stdout, argv[0]);
+                       case 'h':                       /* --help */
+                               usage(stdout);
                                return 0;
-                       case 'x':
+                       case 'x':                       /* --cert <file> */
                                if (!load_certificate(optarg))
                                {
                                        return 1;
                                }
                                continue;
-                       case 'k':
-                               if (!load_key(optarg))
+                       case 'K':                       /* --keyid <hex> */
+                               keyid = optarg;
+                               continue;
+                       case 'k':                       /* --key <file> */
+                               key_file = optarg;
+                               continue;
+                       case 't':                       /* --key-type <type> */
+                               if (strcaseeq(optarg, "ecdsa"))
                                {
-                                       return 1;
+                                       key_type = KEY_ECDSA;
+                               }
+                               else if (strcaseeq(optarg, "rsa"))
+                               {
+                                       key_type = KEY_RSA;
+                               }
+                               else
+                               {
+                                       key_type = KEY_ANY;
                                }
                                continue;
-                       case 'c':
+                       case 'c':                       /* --connect <hostname|address> */
                                if (address)
                                {
-                                       usage(stderr, argv[0]);
+                                       usage(stderr);
                                        return 1;
                                }
                                address = optarg;
                                continue;
-                       case 'i':
+                       case 'i':                       /* --client <client-id> */
                                identity = optarg;
                                continue;
-                       case 's':
+                       case 's':                       /* --secret <password> */
                                secret = optarg;
                                continue;
-                       case 'p':
+                       case 'p':                       /* --port <port> */
                                port = atoi(optarg);
                                continue;
-                       case 'd':
-                               pt_tls_level = atoi(optarg);
+                       case 'm':                       /* --mutual */
+                               lib->settings->set_bool(lib->settings,
+                                                               "%s.plugins.tnccs-20.mutual", TRUE, lib->ns);
+                               continue;
+                       case 'q':               /* --quiet */
+                               log_to_stderr = FALSE;
                                continue;
-                       case '+':       /* --optionsfrom <filename> */
+                       case 'd':                       /* --debug <level> */
+                               default_loglevel = atoi(optarg);
+                               continue;
+                       case '+':                       /* --optionsfrom <filename> */
                                if (!options->from(options, optarg, &argc, &argv, optind))
                                {
                                        return 1;
                                }
                                continue;
                        default:
-                               usage(stderr, argv[0]);
+                               usage(stderr);
                                return 1;
                }
                break;
        }
        if (!address)
        {
-               usage(stderr, argv[0]);
+               usage(stderr);
+               return 1;
+       }
+       if ((keyid || key_file) && !load_key(keyid, key_file, key_type))
+       {
                return 1;
        }
        if (secret)
@@ -269,6 +379,5 @@ int main(int argc, char *argv[])
                                                                                chunk_clone(chunk_from_str(secret))),
                                                        identification_create_from_string(identity), NULL);
        }
-
        return client(address, port, identity);
 }