swanctl: Automatically unload removed private keys
authorTobias Brunner <tobias@strongswan.org>
Wed, 9 Nov 2016 11:25:00 +0000 (12:25 +0100)
committerTobias Brunner <tobias@strongswan.org>
Thu, 16 Feb 2017 18:21:12 +0000 (19:21 +0100)
src/swanctl/commands/load_creds.c

index 6278f66..6c084e5 100644 (file)
@@ -1,11 +1,11 @@
 /*
- * Copyright (C) 2014 Martin Willi
- * Copyright (C) 2014 revosec AG
- *
  * Copyright (C) 2016 Tobias Brunner
  * Copyright (C) 2015 Andreas Steffen
  * HSR Hochschule fuer Technik Rapperswil
  *
+ * Copyright (C) 2014 Martin Willi
+ * Copyright (C) 2014 revosec AG
+ *
  * This program is free software; you can redistribute it and/or modify it
  * under the terms of the GNU General Public License as published by the
  * Free Software Foundation; either version 2 of the License, or (at your
 #include <credentials/sets/mem_cred.h>
 #include <credentials/sets/callback_cred.h>
 #include <credentials/containers/pkcs12.h>
+#include <collections/hashtable.h>
 
 #include <vici_cert_info.h>
 
+#define HASH_SIZE_SHA1_HEX (2 * HASH_SIZE_SHA1)
+
+/**
+ * Context used to track loaded secrets
+ */
+typedef struct {
+       /** vici connection */
+       vici_conn_t *conn;
+       /** format options */
+       command_format_options_t format;
+       /** read setting */
+       settings_t *cfg;
+       /** don't prompt user for password */
+       bool noprompt;
+       /** list of key ids of loaded private keys */
+       hashtable_t *keys;
+} load_ctx_t;
+
 /**
  * Load a single certificate over vici
  */
-static bool load_cert(vici_conn_t *conn, command_format_options_t format,
-                                         char *dir, certificate_type_t type, x509_flag_t flag,
-                                         chunk_t data)
+static bool load_cert(load_ctx_t *ctx, char *dir, certificate_type_t type,
+                                         x509_flag_t flag, chunk_t data)
 {
        vici_req_t *req;
        vici_res_t *res;
@@ -53,15 +71,15 @@ static bool load_cert(vici_conn_t *conn, command_format_options_t format,
        }
        vici_add_key_value(req, "data", data.ptr, data.len);
 
-       res = vici_submit(req, conn);
+       res = vici_submit(req, ctx->conn);
        if (!res)
        {
                fprintf(stderr, "load-cert request failed: %s\n", strerror(errno));
                return FALSE;
        }
-       if (format & COMMAND_FORMAT_RAW)
+       if (ctx->format & COMMAND_FORMAT_RAW)
        {
-               vici_dump(res, "load-cert reply", format & COMMAND_FORMAT_PRETTY,
+               vici_dump(res, "load-cert reply", ctx->format & COMMAND_FORMAT_PRETTY,
                                  stdout);
        }
        else if (!streq(vici_find_str(res, "no", "success"), "yes"))
@@ -81,8 +99,7 @@ static bool load_cert(vici_conn_t *conn, command_format_options_t format,
 /**
  * Load certficiates from a directory
  */
-static void load_certs(vici_conn_t *conn, command_format_options_t format,
-                                          char *type_str, char *dir)
+static void load_certs(load_ctx_t *ctx, char *type_str, char *dir)
 {
        enumerator_t *enumerator;
        certificate_type_t type;
@@ -103,7 +120,7 @@ static void load_certs(vici_conn_t *conn, command_format_options_t format,
                                map = chunk_map(path, FALSE);
                                if (map)
                                {
-                                       load_cert(conn, format, path, type, flag, *map);
+                                       load_cert(ctx, path, type, flag, *map);
                                        chunk_unmap(map);
                                }
                                else
@@ -120,8 +137,7 @@ static void load_certs(vici_conn_t *conn, command_format_options_t format,
 /**
  * Load a single private key over vici
  */
-static bool load_key(vici_conn_t *conn, command_format_options_t format,
-                                        char *dir, char *type, chunk_t data)
+static bool load_key(load_ctx_t *ctx, char *dir, char *type, chunk_t data)
 {
        vici_req_t *req;
        vici_res_t *res;
@@ -140,15 +156,15 @@ static bool load_key(vici_conn_t *conn, command_format_options_t format,
        }
        vici_add_key_value(req, "data", data.ptr, data.len);
 
-       res = vici_submit(req, conn);
+       res = vici_submit(req, ctx->conn);
        if (!res)
        {
                fprintf(stderr, "load-key request failed: %s\n", strerror(errno));
                return FALSE;
        }
-       if (format & COMMAND_FORMAT_RAW)
+       if (ctx->format & COMMAND_FORMAT_RAW)
        {
-               vici_dump(res, "load-key reply", format & COMMAND_FORMAT_PRETTY,
+               vici_dump(res, "load-key reply", ctx->format & COMMAND_FORMAT_PRETTY,
                                  stdout);
        }
        else if (!streq(vici_find_str(res, "no", "success"), "yes"))
@@ -168,11 +184,12 @@ static bool load_key(vici_conn_t *conn, command_format_options_t format,
 /**
  * Load a private key of any type to vici
  */
-static bool load_key_anytype(vici_conn_t *conn, command_format_options_t format,
-                                                        char *path, private_key_t *private)
+static bool load_key_anytype(load_ctx_t *ctx, char *path,
+                                                        private_key_t *private)
 {
        bool loaded = FALSE;
-       chunk_t encoding;
+       chunk_t encoding, keyid;
+       char hex[HASH_SIZE_SHA1_HEX + 1];
 
        if (!private->get_encoding(private, PRIVKEY_ASN1_DER, &encoding))
        {
@@ -182,18 +199,25 @@ static bool load_key_anytype(vici_conn_t *conn, command_format_options_t format,
        switch (private->get_type(private))
        {
                case KEY_RSA:
-                       loaded = load_key(conn, format, path, "rsa", encoding);
+                       loaded = load_key(ctx, path, "rsa", encoding);
                        break;
                case KEY_ECDSA:
-                       loaded = load_key(conn, format, path, "ecdsa", encoding);
+                       loaded = load_key(ctx, path, "ecdsa", encoding);
                        break;
                case KEY_BLISS:
-                       loaded = load_key(conn, format, path, "bliss", encoding);
+                       loaded = load_key(ctx, path, "bliss", encoding);
                        break;
                default:
                        fprintf(stderr, "unsupported key type in '%s'\n", path);
                        break;
        }
+
+       if (loaded &&
+               private->get_fingerprint(private, KEYID_PUBKEY_SHA1, &keyid) &&
+               snprintf(hex, sizeof(hex), "%+B", &keyid) == HASH_SIZE_SHA1_HEX)
+       {
+               free(ctx->keys->remove(ctx->keys, hex));
+       }
        chunk_clear(&encoding);
        return loaded;
 }
@@ -312,7 +336,7 @@ static void* decrypt(char *name, char *type, chunk_t encoding)
 /**
  * Try to parse a potentially encrypted credential using configured secret
  */
-static void* decrypt_with_config(settings_t *cfg, char *name, char *type,
+static void* decrypt_with_config(load_ctx_t *ctx, char *name, char *type,
                                                                 chunk_t encoding)
 {
        credential_type_t credtype;
@@ -329,16 +353,16 @@ static void* decrypt_with_config(settings_t *cfg, char *name, char *type,
        }
 
        /* load all secrets for this key type */
-       enumerator = cfg->create_section_enumerator(cfg, "secrets");
+       enumerator = ctx->cfg->create_section_enumerator(ctx->cfg, "secrets");
        while (enumerator->enumerate(enumerator, &section))
        {
                if (strpfx(section, type))
                {
-                       file = cfg->get_str(cfg, "secrets.%s.file", NULL, section);
+                       file = ctx->cfg->get_str(ctx->cfg, "secrets.%s.file", NULL, section);
                        if (file && strcaseeq(file, name))
                        {
                                snprintf(buf, sizeof(buf), "secrets.%s", section);
-                               secrets = cfg->create_key_value_enumerator(cfg, buf);
+                               secrets = ctx->cfg->create_key_value_enumerator(ctx->cfg, buf);
                                while (secrets->enumerate(secrets, &key, &value))
                                {
                                        if (strpfx(key, "secret"))
@@ -382,22 +406,20 @@ static void* decrypt_with_config(settings_t *cfg, char *name, char *type,
 /**
  * Try to decrypt and load a private key
  */
-static bool load_encrypted_key(vici_conn_t *conn,
-                                                          command_format_options_t format, settings_t *cfg,
-                                                          char *rel, char *path, char *type, bool noprompt,
-                                                          chunk_t data)
+static bool load_encrypted_key(load_ctx_t *ctx,  char *rel, char *path,
+                                                          char *type, chunk_t data)
 {
        private_key_t *private;
        bool loaded = FALSE;
 
-       private = decrypt_with_config(cfg, rel, type, data);
-       if (!private && !noprompt)
+       private = decrypt_with_config(ctx, rel, type, data);
+       if (!private && !ctx->noprompt)
        {
                private = decrypt(rel, type, data);
        }
        if (private)
        {
-               loaded = load_key_anytype(conn, format, path, private);
+               loaded = load_key_anytype(ctx, path, private);
                private->destroy(private);
        }
        return loaded;
@@ -406,8 +428,7 @@ static bool load_encrypted_key(vici_conn_t *conn,
 /**
  * Load private keys from a directory
  */
-static void load_keys(vici_conn_t *conn, command_format_options_t format,
-                                         bool noprompt, settings_t *cfg, char *type, char *dir)
+static void load_keys(load_ctx_t *ctx, char *type, char *dir)
 {
        enumerator_t *enumerator;
        struct stat st;
@@ -424,10 +445,9 @@ static void load_keys(vici_conn_t *conn, command_format_options_t format,
                                map = chunk_map(path, FALSE);
                                if (map)
                                {
-                                       if (!load_encrypted_key(conn, format, cfg, rel, path, type,
-                                                                                       noprompt, *map))
+                                       if (!load_encrypted_key(ctx, rel, path, type, *map))
                                        {
-                                               load_key(conn, format, path, type, *map);
+                                               load_key(ctx, path, type, *map);
                                        }
                                        chunk_unmap(map);
                                }
@@ -445,8 +465,7 @@ static void load_keys(vici_conn_t *conn, command_format_options_t format,
 /**
  * Load credentials from a PKCS#12 container over vici
  */
-static bool load_pkcs12(vici_conn_t *conn, command_format_options_t format,
-                                               char *path, pkcs12_t *p12)
+static bool load_pkcs12(load_ctx_t *ctx, char *path, pkcs12_t *p12)
 {
        enumerator_t *enumerator;
        certificate_t *cert;
@@ -460,8 +479,7 @@ static bool load_pkcs12(vici_conn_t *conn, command_format_options_t format,
                loaded = FALSE;
                if (cert->get_encoding(cert, CERT_ASN1_DER, &encoding))
                {
-                       loaded = load_cert(conn, format, path, CERT_X509, X509_NONE,
-                                                          encoding);
+                       loaded = load_cert(ctx, path, CERT_X509, X509_NONE, encoding);
                        if (loaded)
                        {
                                fprintf(stderr, "  %Y\n", cert->get_subject(cert));
@@ -478,7 +496,7 @@ static bool load_pkcs12(vici_conn_t *conn, command_format_options_t format,
        enumerator = p12->create_key_enumerator(p12);
        while (loaded && enumerator->enumerate(enumerator, &private))
        {
-               loaded = load_key_anytype(conn, format, path, private);
+               loaded = load_key_anytype(ctx, path, private);
        }
        enumerator->destroy(enumerator);
 
@@ -488,15 +506,14 @@ static bool load_pkcs12(vici_conn_t *conn, command_format_options_t format,
 /**
  * Try to decrypt and load credentials from a container
  */
-static bool load_encrypted_container(vici_conn_t *conn,
-                                       command_format_options_t format, settings_t *cfg, char *rel,
-                                       char *path, char *type, bool noprompt, chunk_t data)
+static bool load_encrypted_container(load_ctx_t *ctx, char *rel, char *path,
+                                                                        char *type, chunk_t data)
 {
        container_t *container;
        bool loaded = FALSE;
 
-       container = decrypt_with_config(cfg, rel, type, data);
-       if (!container && !noprompt)
+       container = decrypt_with_config(ctx, rel, type, data);
+       if (!container && !ctx->noprompt)
        {
                container = decrypt(rel, type, data);
        }
@@ -505,7 +522,7 @@ static bool load_encrypted_container(vici_conn_t *conn,
                switch (container->get_type(container))
                {
                        case CONTAINER_PKCS12:
-                               loaded = load_pkcs12(conn, format, path, (pkcs12_t*)container);
+                               loaded = load_pkcs12(ctx, path, (pkcs12_t*)container);
                                break;
                        default:
                                break;
@@ -518,8 +535,7 @@ static bool load_encrypted_container(vici_conn_t *conn,
 /**
  * Load credential containers from a directory
  */
-static void load_containers(vici_conn_t *conn, command_format_options_t format,
-                                               bool noprompt, settings_t *cfg, char *type, char *dir)
+static void load_containers(load_ctx_t *ctx, char *type, char *dir)
 {
        enumerator_t *enumerator;
        struct stat st;
@@ -536,8 +552,7 @@ static void load_containers(vici_conn_t *conn, command_format_options_t format,
                                map = chunk_map(path, FALSE);
                                if (map)
                                {
-                                       load_encrypted_container(conn, format, cfg, rel, path,
-                                                                                        type, noprompt, *map);
+                                       load_encrypted_container(ctx, rel, path, type, *map);
                                        chunk_unmap(map);
                                }
                                else
@@ -554,8 +569,7 @@ static void load_containers(vici_conn_t *conn, command_format_options_t format,
 /**
  * Load a single secret over VICI
  */
-static bool load_secret(vici_conn_t *conn, settings_t *cfg,
-                                               char *section, command_format_options_t format)
+static bool load_secret(load_ctx_t *ctx, char *section)
 {
        enumerator_t *enumerator;
        vici_req_t *req;
@@ -594,7 +608,7 @@ static bool load_secret(vici_conn_t *conn, settings_t *cfg,
                return TRUE;
        }
 
-       value = cfg->get_str(cfg, "secrets.%s.secret", NULL, section);
+       value = ctx->cfg->get_str(ctx->cfg, "secrets.%s.secret", NULL, section);
        if (!value)
        {
                fprintf(stderr, "missing secret in '%s', ignored\n", section);
@@ -621,7 +635,7 @@ static bool load_secret(vici_conn_t *conn, settings_t *cfg,
 
        vici_begin_list(req, "owners");
        snprintf(buf, sizeof(buf), "secrets.%s", section);
-       enumerator = cfg->create_key_value_enumerator(cfg, buf);
+       enumerator = ctx->cfg->create_key_value_enumerator(ctx->cfg, buf);
        while (enumerator->enumerate(enumerator, &key, &value))
        {
                if (strpfx(key, "id"))
@@ -632,15 +646,15 @@ static bool load_secret(vici_conn_t *conn, settings_t *cfg,
        enumerator->destroy(enumerator);
        vici_end_list(req);
 
-       res = vici_submit(req, conn);
+       res = vici_submit(req, ctx->conn);
        if (!res)
        {
                fprintf(stderr, "load-shared request failed: %s\n", strerror(errno));
                return FALSE;
        }
-       if (format & COMMAND_FORMAT_RAW)
+       if (ctx->format & COMMAND_FORMAT_RAW)
        {
-               vici_dump(res, "load-shared reply", format & COMMAND_FORMAT_PRETTY,
+               vici_dump(res, "load-shared reply", ctx->format & COMMAND_FORMAT_PRETTY,
                                  stdout);
        }
        else if (!streq(vici_find_str(res, "no", "success"), "yes"))
@@ -657,6 +671,75 @@ static bool load_secret(vici_conn_t *conn, settings_t *cfg,
        return ret;
 }
 
+CALLBACK(get_id, int,
+       hashtable_t *ht, vici_res_t *res, char *name, void *value, int len)
+{
+       if (streq(name, "keys"))
+       {
+               char *str;
+
+               if (asprintf(&str, "%.*s", len, value) != -1)
+               {
+                       free(ht->put(ht, str, str));
+               }
+       }
+       return 0;
+}
+
+/**
+ * Get a list of currently loaded private keys
+ */
+static void get_keys(load_ctx_t *ctx)
+{
+       vici_res_t *res;
+
+       res = vici_submit(vici_begin("get-keys"), ctx->conn);
+       if (res)
+       {
+               if (ctx->format & COMMAND_FORMAT_RAW)
+               {
+                       vici_dump(res, "get-keys reply", ctx->format & COMMAND_FORMAT_PRETTY,
+                                         stdout);
+               }
+               vici_parse_cb(res, NULL, NULL, get_id, ctx->keys);
+               vici_free_res(res);
+       }
+}
+
+/**
+ * Remove a given key
+ */
+static bool unload_key(load_ctx_t *ctx, char *id)
+{
+       vici_req_t *req;
+       vici_res_t *res;
+       bool ret = TRUE;
+
+       req = vici_begin("unload-key");
+
+       vici_add_key_valuef(req, "id", "%s", id);
+
+       res = vici_submit(req, ctx->conn);
+       if (!res)
+       {
+               fprintf(stderr, "unload-key request failed: %s\n", strerror(errno));
+               return FALSE;
+       }
+       if (ctx->format & COMMAND_FORMAT_RAW)
+       {
+               vici_dump(res, "unload-key reply", ctx->format & COMMAND_FORMAT_PRETTY,
+                                 stdout);
+       }
+       else if (!streq(vici_find_str(res, "no", "success"), "yes"))
+       {
+               fprintf(stderr, "unloading key '%s' failed: %s\n",
+                               id, vici_find_str(res, "", "errmsg"));
+               ret = FALSE;
+       }
+       vici_free_res(res);
+       return ret;
+}
+
 /**
  * Clear all currently loaded credentials
  */
@@ -686,7 +769,14 @@ int load_creds_cfg(vici_conn_t *conn, command_format_options_t format,
                                   settings_t *cfg, bool clear, bool noprompt)
 {
        enumerator_t *enumerator;
-       char *section;
+       char *section, *id;
+       load_ctx_t ctx = {
+               .conn = conn,
+               .format = format,
+               .noprompt = noprompt,
+               .cfg = cfg,
+               .keys = hashtable_create(hashtable_hash_str, hashtable_equals_str, 8),
+       };
 
        if (clear)
        {
@@ -696,29 +786,38 @@ int load_creds_cfg(vici_conn_t *conn, command_format_options_t format,
                }
        }
 
-       load_certs(conn, format, "x509",     SWANCTL_X509DIR);
-       load_certs(conn, format, "x509ca",   SWANCTL_X509CADIR);
-       load_certs(conn, format, "x509ocsp", SWANCTL_X509OCSPDIR);
-       load_certs(conn, format, "x509aa",   SWANCTL_X509AADIR);
-       load_certs(conn, format, "x509ac",   SWANCTL_X509ACDIR);
-       load_certs(conn, format, "x509crl",  SWANCTL_X509CRLDIR);
-       load_certs(conn, format, "pubkey",   SWANCTL_PUBKEYDIR);
+       get_keys(&ctx);
 
-       load_keys(conn, format, noprompt, cfg, "private", SWANCTL_PRIVATEDIR);
-       load_keys(conn, format, noprompt, cfg, "rsa",     SWANCTL_RSADIR);
-       load_keys(conn, format, noprompt, cfg, "ecdsa",   SWANCTL_ECDSADIR);
-       load_keys(conn, format, noprompt, cfg, "bliss",   SWANCTL_BLISSDIR);
-       load_keys(conn, format, noprompt, cfg, "pkcs8",   SWANCTL_PKCS8DIR);
+       load_certs(&ctx, "x509",     SWANCTL_X509DIR);
+       load_certs(&ctx, "x509ca",   SWANCTL_X509CADIR);
+       load_certs(&ctx, "x509ocsp", SWANCTL_X509OCSPDIR);
+       load_certs(&ctx, "x509aa",   SWANCTL_X509AADIR);
+       load_certs(&ctx, "x509ac",   SWANCTL_X509ACDIR);
+       load_certs(&ctx, "x509crl",  SWANCTL_X509CRLDIR);
+       load_certs(&ctx, "pubkey",   SWANCTL_PUBKEYDIR);
 
-       load_containers(conn, format, noprompt, cfg, "pkcs12", SWANCTL_PKCS12DIR);
+       load_keys(&ctx, "private", SWANCTL_PRIVATEDIR);
+       load_keys(&ctx, "rsa",     SWANCTL_RSADIR);
+       load_keys(&ctx, "ecdsa",   SWANCTL_ECDSADIR);
+       load_keys(&ctx, "bliss",   SWANCTL_BLISSDIR);
+       load_keys(&ctx, "pkcs8",   SWANCTL_PKCS8DIR);
+
+       load_containers(&ctx, "pkcs12", SWANCTL_PKCS12DIR);
 
        enumerator = cfg->create_section_enumerator(cfg, "secrets");
        while (enumerator->enumerate(enumerator, &section))
        {
-               load_secret(conn, cfg, section, format);
+               load_secret(&ctx, section);
        }
        enumerator->destroy(enumerator);
 
+       enumerator = ctx.keys->create_enumerator(ctx.keys);
+       while (enumerator->enumerate(enumerator, &id, NULL))
+       {
+               unload_key(&ctx, id);
+       }
+       enumerator->destroy(enumerator);
+       ctx.keys->destroy_function(ctx.keys, (void*)free);
        return 0;
 }