Merge branch 'stroke-ca-sections'
authorTobias Brunner <tobias@strongswan.org>
Thu, 20 Aug 2015 17:37:09 +0000 (19:37 +0200)
committerTobias Brunner <tobias@strongswan.org>
Thu, 20 Aug 2015 17:38:53 +0000 (19:38 +0200)
This resolves the duplicate CERTREQ issue when certificates in
ipsec.d/cacerts were referenced in ca sections.  It also ensures CA
certificates are reloaded atomically, so there is never a time when
an unchanged CA certificate is not available.

References #842.

src/libcharon/plugins/stroke/stroke_ca.c
src/libcharon/plugins/stroke/stroke_ca.h
src/libcharon/plugins/stroke/stroke_cred.c
src/libcharon/plugins/stroke/stroke_cred.h
src/libcharon/plugins/stroke/stroke_socket.c
src/libstrongswan/credentials/sets/mem_cred.c
src/libstrongswan/credentials/sets/mem_cred.h

index b470b81..13ed41e 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2008 Tobias Brunner
+ * Copyright (C) 2008-2015 Tobias Brunner
  * Copyright (C) 2008 Martin Willi
  * Hochschule fuer Technik Rapperswil
  *
 #include <daemon.h>
 
 typedef struct private_stroke_ca_t private_stroke_ca_t;
+typedef struct ca_section_t ca_section_t;
+typedef struct ca_cert_t ca_cert_t;
+
+/**
+ * Provided by stroke_cred.c
+ */
+certificate_t *stroke_load_ca_cert(char *filename);
 
 /**
  * private data of stroke_ca
@@ -41,17 +48,16 @@ struct private_stroke_ca_t {
        rwlock_t *lock;
 
        /**
-        * list of starters CA sections and its certificates (ca_section_t)
+        * list of CA sections and their certificates (ca_section_t)
         */
        linked_list_t *sections;
 
        /**
-        * stroke credentials, stores our CA certificates
+        * list of all loaded CA certificates (ca_cert_t)
         */
-       stroke_cred_t *cred;
+       linked_list_t *certs;
 };
 
-typedef struct ca_section_t ca_section_t;
 
 /**
  * loaded ipsec.conf CA sections
@@ -64,7 +70,12 @@ struct ca_section_t {
        char *name;
 
        /**
-        * reference to cert in trusted_credential_t
+        * path/name of the certificate
+        */
+       char *path;
+
+       /**
+        * reference to cert
         */
        certificate_t *cert;
 
@@ -90,16 +101,37 @@ struct ca_section_t {
 };
 
 /**
+ * loaded CA certificate
+ */
+struct ca_cert_t {
+
+       /**
+        * reference to cert
+        */
+       certificate_t *cert;
+
+       /**
+        * The number of CA sections referring to this certificate
+        */
+       u_int count;
+
+       /**
+        * TRUE if this certificate was automatically loaded
+        */
+       bool automatic;
+};
+
+/**
  * create a new CA section
  */
-static ca_section_t *ca_section_create(char *name, certificate_t *cert)
+static ca_section_t *ca_section_create(char *name, char *path)
 {
        ca_section_t *ca = malloc_thing(ca_section_t);
 
        ca->name = strdup(name);
+       ca->path = strdup(path);
        ca->crl = linked_list_create();
        ca->ocsp = linked_list_create();
-       ca->cert = cert;
        ca->hashes = linked_list_create();
        ca->certuribase = NULL;
        return ca;
@@ -115,11 +147,21 @@ static void ca_section_destroy(ca_section_t *this)
        this->hashes->destroy_offset(this->hashes, offsetof(identification_t, destroy));
        this->cert->destroy(this->cert);
        free(this->certuribase);
+       free(this->path);
        free(this->name);
        free(this);
 }
 
 /**
+ * Destroy a ca cert entry
+ */
+static void ca_cert_destroy(ca_cert_t *this)
+{
+       this->cert->destroy(this->cert);
+       free(this);
+}
+
+/**
  * Data for the certificate enumerator
  */
 typedef struct {
@@ -141,7 +183,7 @@ static void cert_data_destroy(cert_data_t *data)
 /**
  * filter function for certs enumerator
  */
-static bool certs_filter(cert_data_t *data, ca_section_t **in,
+static bool certs_filter(cert_data_t *data, ca_cert_t **in,
                                                 certificate_t **out)
 {
        public_key_t *public;
@@ -192,7 +234,7 @@ METHOD(credential_set_t, create_cert_enumerator, enumerator_t*,
        );
 
        this->lock->read_lock(this->lock);
-       enumerator = this->sections->create_enumerator(this->sections);
+       enumerator = this->certs->create_enumerator(this->certs);
        return enumerator_create_filter(enumerator, (void*)certs_filter, data,
                                                                        (void*)cert_data_destroy);
 }
@@ -312,6 +354,81 @@ METHOD(credential_set_t, create_cdp_enumerator, enumerator_t*,
                        data, (void*)cdp_data_destroy);
 }
 
+/**
+ * Compare the given certificate to the ca_cert_t items in the list
+ */
+static bool match_cert(ca_cert_t *item, certificate_t *cert)
+{
+       return cert->equals(cert, item->cert);
+}
+
+/**
+ * Match automatically added certificates and remove/destroy them if they are
+ * not referenced by CA sections.
+ */
+static bool remove_auto_certs(ca_cert_t *item, void *not_used)
+{
+       if (item->automatic)
+       {
+               item->automatic = FALSE;
+               if (!item->count)
+               {
+                       ca_cert_destroy(item);
+                       return TRUE;
+               }
+       }
+       return FALSE;
+}
+
+/**
+ * Find the given certificate that was referenced by a section and remove it
+ * unless it was also loaded automatically or is used by other CA sections.
+ */
+static bool remove_cert(ca_cert_t *item, certificate_t *cert)
+{
+       if (item->count && cert->equals(cert, item->cert))
+       {
+               if (--item->count == 0 && !item->automatic)
+               {
+                       ca_cert_destroy(item);
+                       return TRUE;
+               }
+       }
+       return FALSE;
+}
+
+/**
+ * Adds a certificate to the certificate store
+ */
+static certificate_t *add_cert_internal(private_stroke_ca_t *this,
+                                                                               certificate_t *cert, bool automatic)
+{
+       ca_cert_t *found;
+
+       if (this->certs->find_first(this->certs, (linked_list_match_t)match_cert,
+                                                               (void**)&found, cert) == SUCCESS)
+       {
+               cert->destroy(cert);
+               cert = found->cert->get_ref(found->cert);
+       }
+       else
+       {
+               INIT(found,
+                       .cert = cert->get_ref(cert)
+               );
+               this->certs->insert_first(this->certs, found);
+       }
+       if (automatic)
+       {
+               found->automatic = TRUE;
+       }
+       else
+       {
+               found->count++;
+       }
+       return cert;
+}
+
 METHOD(stroke_ca_t, add, void,
        private_stroke_ca_t *this, stroke_msg_t *msg)
 {
@@ -323,10 +440,10 @@ METHOD(stroke_ca_t, add, void,
                DBG1(DBG_CFG, "missing cacert parameter");
                return;
        }
-       cert = this->cred->load_ca(this->cred, msg->add_ca.cacert);
+       cert = stroke_load_ca_cert(msg->add_ca.cacert);
        if (cert)
        {
-               ca = ca_section_create(msg->add_ca.name, cert);
+               ca = ca_section_create(msg->add_ca.name, msg->add_ca.cacert);
                if (msg->add_ca.crluri)
                {
                        ca->crl->insert_last(ca->crl, strdup(msg->add_ca.crluri));
@@ -348,6 +465,7 @@ METHOD(stroke_ca_t, add, void,
                        ca->certuribase = strdup(msg->add_ca.certuribase);
                }
                this->lock->write_lock(this->lock);
+               ca->cert = add_cert_internal(this, cert, FALSE);
                this->sections->insert_last(this->sections, ca);
                this->lock->unlock(this->lock);
                DBG1(DBG_CFG, "added ca '%s'", msg->add_ca.name);
@@ -372,8 +490,12 @@ METHOD(stroke_ca_t, del, void,
                ca = NULL;
        }
        enumerator->destroy(enumerator);
+       if (ca)
+       {
+               this->certs->remove(this->certs, ca->cert, (void*)remove_cert);
+       }
        this->lock->unlock(this->lock);
-       if (ca == NULL)
+       if (!ca)
        {
                DBG1(DBG_CFG, "no ca named '%s' found\n", msg->del_ca.name);
                return;
@@ -383,6 +505,88 @@ METHOD(stroke_ca_t, del, void,
        lib->credmgr->flush_cache(lib->credmgr, CERT_ANY);
 }
 
+METHOD(stroke_ca_t, get_cert_ref, certificate_t*,
+       private_stroke_ca_t *this, certificate_t *cert)
+{
+       ca_cert_t *found;
+
+       this->lock->read_lock(this->lock);
+       if (this->certs->find_first(this->certs, (linked_list_match_t)match_cert,
+                                                               (void**)&found, cert) == SUCCESS)
+       {
+               cert->destroy(cert);
+               cert = found->cert->get_ref(found->cert);
+       }
+       this->lock->unlock(this->lock);
+       return cert;
+}
+
+METHOD(stroke_ca_t, reload_certs, void,
+       private_stroke_ca_t *this)
+{
+       enumerator_t *enumerator;
+       certificate_t *cert;
+       ca_section_t *ca;
+       certificate_type_t type = CERT_X509;
+
+       /* holding the write lock while loading/parsing certificates is not optimal,
+        * however, there usually are not that many ca sections configured */
+       this->lock->write_lock(this->lock);
+       if (this->sections->get_count(this->sections))
+       {
+               DBG1(DBG_CFG, "rereading ca certificates in ca sections");
+       }
+       enumerator = this->sections->create_enumerator(this->sections);
+       while (enumerator->enumerate(enumerator, &ca))
+       {
+               cert = stroke_load_ca_cert(ca->path);
+               if (cert)
+               {
+                       if (cert->equals(cert, ca->cert))
+                       {
+                               cert->destroy(cert);
+                       }
+                       else
+                       {
+                               this->certs->remove(this->certs, ca->cert, (void*)remove_cert);
+                               ca->cert->destroy(ca->cert);
+                               ca->cert = add_cert_internal(this, cert, FALSE);
+                       }
+               }
+               else
+               {
+                       DBG1(DBG_CFG, "failed to reload certificate '%s', removing ca '%s'",
+                                ca->path, ca->name);
+                       this->sections->remove_at(this->sections, enumerator);
+                       this->certs->remove(this->certs, ca->cert, (void*)remove_cert);
+                       ca_section_destroy(ca);
+                       type = CERT_ANY;
+               }
+       }
+       enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
+       lib->credmgr->flush_cache(lib->credmgr, type);
+}
+
+METHOD(stroke_ca_t, replace_certs, void,
+       private_stroke_ca_t *this, mem_cred_t *certs)
+{
+       enumerator_t *enumerator;
+       certificate_t *cert;
+
+       enumerator = certs->set.create_cert_enumerator(&certs->set, CERT_X509,
+                                                                                                  KEY_ANY, NULL, TRUE);
+       this->lock->write_lock(this->lock);
+       this->certs->remove(this->certs, NULL, (void*)remove_auto_certs);
+       while (enumerator->enumerate(enumerator, &cert))
+       {
+               cert = add_cert_internal(this, cert->get_ref(cert), TRUE);
+               cert->destroy(cert);
+       }
+       this->lock->unlock(this->lock);
+       enumerator->destroy(enumerator);
+       lib->credmgr->flush_cache(lib->credmgr, CERT_X509);
+}
 /**
  * list crl or ocsp URIs
  */
@@ -501,6 +705,7 @@ METHOD(stroke_ca_t, destroy, void,
        private_stroke_ca_t *this)
 {
        this->sections->destroy_function(this->sections, (void*)ca_section_destroy);
+       this->certs->destroy_function(this->certs, (void*)ca_cert_destroy);
        this->lock->destroy(this->lock);
        free(this);
 }
@@ -508,7 +713,7 @@ METHOD(stroke_ca_t, destroy, void,
 /*
  * see header file
  */
-stroke_ca_t *stroke_ca_create(stroke_cred_t *cred)
+stroke_ca_t *stroke_ca_create()
 {
        private_stroke_ca_t *this;
 
@@ -524,12 +729,15 @@ stroke_ca_t *stroke_ca_create(stroke_cred_t *cred)
                        .add = _add,
                        .del = _del,
                        .list = _list,
+                       .get_cert_ref = _get_cert_ref,
+                       .reload_certs = _reload_certs,
+                       .replace_certs = _replace_certs,
                        .check_for_hash_and_url = _check_for_hash_and_url,
                        .destroy = _destroy,
                },
                .sections = linked_list_create(),
+               .certs = linked_list_create(),
                .lock = rwlock_create(RWLOCK_TYPE_DEFAULT),
-               .cred = cred,
        );
 
        return &this->public;
index 21af912..2740006 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2008 Tobias Brunner
+ * Copyright (C) 2008-2015 Tobias Brunner
  * Copyright (C) 2008 Martin Willi
  * Hochschule fuer Technik Rapperswil
  *
@@ -23,8 +23,7 @@
 #define STROKE_CA_H_
 
 #include <stroke_msg.h>
-
-#include "stroke_cred.h"
+#include <credentials/sets/mem_cred.h>
 
 typedef struct stroke_ca_t stroke_ca_t;
 
@@ -67,6 +66,29 @@ struct stroke_ca_t {
        void (*check_for_hash_and_url)(stroke_ca_t *this, certificate_t* cert);
 
        /**
+        * Get a reference to a CA certificate if it is already stored,
+        * otherwise returns the same certificate.
+        *
+        * @param cert          certificate to check
+        * @return                      reference to stored CA certifiate, or original
+        */
+       certificate_t *(*get_cert_ref)(stroke_ca_t *this, certificate_t *cert);
+
+       /**
+        * Reload CA certificates referenced in CA sections. Flushes the certificate
+        * cache.
+        */
+       void (*reload_certs)(stroke_ca_t *this);
+
+       /**
+        * Replace automatically loaded CA certificates.  Flushes the certificate
+        * cache.
+        *
+        * @param certs         credential set to take certificates from (not modified)
+        */
+       void (*replace_certs)(stroke_ca_t *this, mem_cred_t *certs);
+
+       /**
         * Destroy a stroke_ca instance.
         */
        void (*destroy)(stroke_ca_t *this);
@@ -75,6 +97,6 @@ struct stroke_ca_t {
 /**
  * Create a stroke_ca instance.
  */
-stroke_ca_t *stroke_ca_create(stroke_cred_t *cred);
+stroke_ca_t *stroke_ca_create();
 
 #endif /** STROKE_CA_H_ @}*/
index 5e423f1..4292888 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2008-2013 Tobias Brunner
+ * Copyright (C) 2008-2015 Tobias Brunner
  * Copyright (C) 2008 Martin Willi
  * Hochschule fuer Technik Rapperswil
  *
@@ -75,11 +75,6 @@ struct private_stroke_cred_t {
        mem_cred_t *creds;
 
        /**
-        * CA certificates
-        */
-       mem_cred_t *cacerts;
-
-       /**
         * Attribute Authority certificates
         */
        mem_cred_t *aacerts;
@@ -94,6 +89,11 @@ struct private_stroke_cred_t {
         * cache CRLs to disk?
         */
        bool cachecrl;
+
+       /**
+        * CA certificate store
+        */
+       stroke_ca_t *ca;
 };
 
 /** Length of smartcard specifier parts (module, keyid) */
@@ -182,70 +182,6 @@ static certificate_t *load_from_smartcard(smartcard_format_t format,
        return cred;
 }
 
-METHOD(stroke_cred_t, load_ca, certificate_t*,
-       private_stroke_cred_t *this, char *filename)
-{
-       certificate_t *cert = NULL;
-       char path[PATH_MAX];
-
-       if (strpfx(filename, "%smartcard"))
-       {
-               smartcard_format_t format;
-               char module[SC_PART_LEN], keyid[SC_PART_LEN];
-               u_int slot;
-
-               format = parse_smartcard(filename, &slot, module, keyid);
-               if (format != SC_FORMAT_INVALID)
-               {
-                       cert = (certificate_t*)load_from_smartcard(format,
-                                                       slot, module, keyid, CRED_CERTIFICATE, CERT_X509);
-               }
-       }
-       else
-       {
-               if (*filename == '/')
-               {
-                       snprintf(path, sizeof(path), "%s", filename);
-               }
-               else
-               {
-                       snprintf(path, sizeof(path), "%s/%s", CA_CERTIFICATE_DIR, filename);
-               }
-
-               if (this->force_ca_cert)
-               {       /* we treat this certificate as a CA certificate even if it has no
-                        * CA basic constraint */
-                       cert = lib->creds->create(lib->creds,
-                                                                 CRED_CERTIFICATE, CERT_X509,
-                                                                 BUILD_FROM_FILE, path, BUILD_X509_FLAG, X509_CA,
-                                                                 BUILD_END);
-               }
-               else
-               {
-                       cert = lib->creds->create(lib->creds,
-                                                                 CRED_CERTIFICATE, CERT_X509,
-                                                                 BUILD_FROM_FILE, path,
-                                                                 BUILD_END);
-               }
-       }
-       if (cert)
-       {
-               x509_t *x509 = (x509_t*)cert;
-
-               if (!(x509->get_flags(x509) & X509_CA))
-               {
-                       DBG1(DBG_CFG, "  ca certificate \"%Y\" misses ca basic constraint, "
-                                "discarded", cert->get_subject(cert));
-                       cert->destroy(cert);
-                       return NULL;
-               }
-               DBG1(DBG_CFG, "  loaded ca certificate \"%Y\" from '%s'",
-                        cert->get_subject(cert), filename);
-               return this->creds->get_cert_ref(this->creds, cert);
-       }
-       return NULL;
-}
-
 METHOD(stroke_cred_t, load_peer, certificate_t*,
        private_stroke_cred_t *this, char *filename)
 {
@@ -384,22 +320,52 @@ METHOD(stroke_cred_t, load_pubkey, certificate_t*,
 }
 
 /**
- * Load a CA certificate  from disk
+ * Load a CA certificate, optionally force it to be one
  */
-static void load_x509_ca(private_stroke_cred_t *this, char *file)
+static certificate_t *load_ca_cert(char *filename, bool force_ca_cert)
 {
-       certificate_t *cert;
+       certificate_t *cert = NULL;
+       char path[PATH_MAX];
+
+       if (strpfx(filename, "%smartcard"))
+       {
+               smartcard_format_t format;
+               char module[SC_PART_LEN], keyid[SC_PART_LEN];
+               u_int slot;
 
-       if (this->force_ca_cert)
-       {       /* treat certificate as CA cert even it has no CA basic constraint */
-               cert = lib->creds->create(lib->creds, CRED_CERTIFICATE, CERT_X509,
-                                                                 BUILD_FROM_FILE, file,
-                                                                 BUILD_X509_FLAG, X509_CA, BUILD_END);
+               format = parse_smartcard(filename, &slot, module, keyid);
+               if (format != SC_FORMAT_INVALID)
+               {
+                       cert = (certificate_t*)load_from_smartcard(format,
+                                                       slot, module, keyid, CRED_CERTIFICATE, CERT_X509);
+               }
        }
        else
        {
-               cert = lib->creds->create(lib->creds, CRED_CERTIFICATE, CERT_X509,
-                                                                 BUILD_FROM_FILE, file, BUILD_END);
+               if (*filename == '/')
+               {
+                       snprintf(path, sizeof(path), "%s", filename);
+               }
+               else
+               {
+                       snprintf(path, sizeof(path), "%s/%s", CA_CERTIFICATE_DIR, filename);
+               }
+
+               if (force_ca_cert)
+               {       /* we treat this certificate as a CA certificate even if it has no
+                        * CA basic constraint */
+                       cert = lib->creds->create(lib->creds,
+                                                                 CRED_CERTIFICATE, CERT_X509,
+                                                                 BUILD_FROM_FILE, path, BUILD_X509_FLAG, X509_CA,
+                                                                 BUILD_END);
+               }
+               else
+               {
+                       cert = lib->creds->create(lib->creds,
+                                                                 CRED_CERTIFICATE, CERT_X509,
+                                                                 BUILD_FROM_FILE, path,
+                                                                 BUILD_END);
+               }
        }
        if (cert)
        {
@@ -410,13 +376,41 @@ static void load_x509_ca(private_stroke_cred_t *this, char *file)
                        DBG1(DBG_CFG, "  ca certificate \"%Y\" lacks ca basic constraint, "
                                 "discarded", cert->get_subject(cert));
                        cert->destroy(cert);
+                       return NULL;
                }
-               else
-               {
-                       DBG1(DBG_CFG, "  loaded ca certificate \"%Y\" from '%s'",
-                                cert->get_subject(cert), file);
-                       this->cacerts->add_cert(this->cacerts, TRUE, cert);
-               }
+               DBG1(DBG_CFG, "  loaded ca certificate \"%Y\" from '%s'",
+                        cert->get_subject(cert), filename);
+               return cert;
+       }
+       return NULL;
+}
+
+/**
+ * Used by stroke_ca.c
+ */
+certificate_t *stroke_load_ca_cert(char *filename)
+{
+       bool force_ca_cert;
+
+       force_ca_cert = lib->settings->get_bool(lib->settings,
+                                               "%s.plugins.stroke.ignore_missing_ca_basic_constraint",
+                                               FALSE, lib->ns);
+       return load_ca_cert(filename, force_ca_cert);
+}
+
+/**
+ * Load a CA certificate from disk
+ */
+static void load_x509_ca(private_stroke_cred_t *this, char *file,
+                                                mem_cred_t *creds)
+{
+       certificate_t *cert;
+
+       cert = load_ca_cert(file, this->force_ca_cert);
+       if (cert)
+       {
+               cert = this->ca->get_cert_ref(this->ca, cert);
+               creds->add_cert(creds, TRUE, cert);
        }
        else
        {
@@ -427,7 +421,8 @@ static void load_x509_ca(private_stroke_cred_t *this, char *file)
 /**
  * Load AA certificate with flags from disk
  */
-static void load_x509_aa(private_stroke_cred_t *this, char *file)
+static void load_x509_aa(private_stroke_cred_t *this,char *file,
+                                                mem_cred_t *creds)
 {
        certificate_t *cert;
 
@@ -438,7 +433,7 @@ static void load_x509_aa(private_stroke_cred_t *this, char *file)
        {
                DBG1(DBG_CFG, "  loaded AA certificate \"%Y\" from '%s'",
                         cert->get_subject(cert), file);
-               this->aacerts->add_cert(this->aacerts, TRUE, cert);
+               creds->add_cert(creds, TRUE, cert);
        }
        else
        {
@@ -449,7 +444,8 @@ static void load_x509_aa(private_stroke_cred_t *this, char *file)
 /**
  * Load a certificate with flags from disk
  */
-static void load_x509(private_stroke_cred_t *this, char *file, x509_flag_t flag)
+static void load_x509(private_stroke_cred_t *this, char *file, x509_flag_t flag,
+                                         mem_cred_t *creds)
 {
        certificate_t *cert;
 
@@ -461,7 +457,7 @@ static void load_x509(private_stroke_cred_t *this, char *file, x509_flag_t flag)
        {
                DBG1(DBG_CFG, "  loaded certificate \"%Y\" from '%s'",
                         cert->get_subject(cert), file);
-               this->creds->add_cert(this->creds, TRUE, cert);
+               creds->add_cert(creds, TRUE, cert);
        }
        else
        {
@@ -472,7 +468,8 @@ static void load_x509(private_stroke_cred_t *this, char *file, x509_flag_t flag)
 /**
  * Load a CRL from a file
  */
-static void load_x509_crl(private_stroke_cred_t *this, char *file)
+static void load_x509_crl(private_stroke_cred_t *this, char *file,
+                                                 mem_cred_t *creds)
 {
        certificate_t *cert;
 
@@ -480,8 +477,8 @@ static void load_x509_crl(private_stroke_cred_t *this, char *file)
                                                          BUILD_FROM_FILE, file, BUILD_END);
        if (cert)
        {
-               this->creds->add_crl(this->creds, (crl_t*)cert);
                DBG1(DBG_CFG, "  loaded crl from '%s'",  file);
+               creds->add_crl(creds, (crl_t*)cert);
        }
        else
        {
@@ -492,7 +489,8 @@ static void load_x509_crl(private_stroke_cred_t *this, char *file)
 /**
  * Load an attribute certificate from a file
  */
-static void load_x509_ac(private_stroke_cred_t *this, char *file)
+static void load_x509_ac(private_stroke_cred_t *this, char *file,
+                                                mem_cred_t *creds)
 {
        certificate_t *cert;
 
@@ -501,7 +499,7 @@ static void load_x509_ac(private_stroke_cred_t *this, char *file)
        if (cert)
        {
                DBG1(DBG_CFG, "  loaded attribute certificate from '%s'", file);
-               this->creds->add_cert(this->creds, FALSE, cert);
+               creds->add_cert(creds, FALSE, cert);
        }
        else
        {
@@ -513,7 +511,8 @@ static void load_x509_ac(private_stroke_cred_t *this, char *file)
  * load trusted certificates from a directory
  */
 static void load_certdir(private_stroke_cred_t *this, char *path,
-                                                certificate_type_t type, x509_flag_t flag)
+                                                certificate_type_t type, x509_flag_t flag,
+                                                mem_cred_t *creds)
 {
        enumerator_t *enumerator;
        struct stat st;
@@ -534,22 +533,22 @@ static void load_certdir(private_stroke_cred_t *this, char *path,
                                case CERT_X509:
                                        if (flag & X509_CA)
                                        {
-                                               load_x509_ca(this, file);
+                                               load_x509_ca(this, file, creds);
                                        }
                                        else if (flag & X509_AA)
                                        {
-                                               load_x509_aa(this, file);
+                                               load_x509_aa(this, file, creds);
                                        }
                                        else
                                        {
-                                               load_x509(this, file, flag);
+                                               load_x509(this, file, flag, creds);
                                        }
                                        break;
                                case CERT_X509_CRL:
-                                       load_x509_crl(this, file);
+                                       load_x509_crl(this, file, creds);
                                        break;
                                case CERT_X509_AC:
-                                       load_x509_ac(this, file);
+                                       load_x509_ac(this, file, creds);
                                        break;
                                default:
                                        break;
@@ -1348,30 +1347,38 @@ static void load_secrets(private_stroke_cred_t *this, mem_cred_t *secrets,
  */
 static void load_certs(private_stroke_cred_t *this)
 {
+       mem_cred_t *creds;
+
        DBG1(DBG_CFG, "loading ca certificates from '%s'",
                 CA_CERTIFICATE_DIR);
-       load_certdir(this, CA_CERTIFICATE_DIR, CERT_X509, X509_CA);
+       creds = mem_cred_create();
+       load_certdir(this, CA_CERTIFICATE_DIR, CERT_X509, X509_CA, creds);
+       this->ca->replace_certs(this->ca, creds);
+       creds->destroy(creds);
 
        DBG1(DBG_CFG, "loading aa certificates from '%s'",
                 AA_CERTIFICATE_DIR);
-       load_certdir(this, AA_CERTIFICATE_DIR, CERT_X509, X509_AA);
+       load_certdir(this, AA_CERTIFICATE_DIR, CERT_X509, X509_AA, this->aacerts);
 
        DBG1(DBG_CFG, "loading ocsp signer certificates from '%s'",
                 OCSP_CERTIFICATE_DIR);
-       load_certdir(this, OCSP_CERTIFICATE_DIR, CERT_X509, X509_OCSP_SIGNER);
+       load_certdir(this, OCSP_CERTIFICATE_DIR, CERT_X509, X509_OCSP_SIGNER,
+                                this->creds);
 
        DBG1(DBG_CFG, "loading attribute certificates from '%s'",
                 ATTR_CERTIFICATE_DIR);
-       load_certdir(this, ATTR_CERTIFICATE_DIR, CERT_X509_AC, 0);
+       load_certdir(this, ATTR_CERTIFICATE_DIR, CERT_X509_AC, 0, this->creds);
 
        DBG1(DBG_CFG, "loading crls from '%s'",
                 CRL_DIR);
-       load_certdir(this, CRL_DIR, CERT_X509_CRL, 0);
+       load_certdir(this, CRL_DIR, CERT_X509_CRL, 0, this->creds);
 }
 
 METHOD(stroke_cred_t, reread, void,
        private_stroke_cred_t *this, stroke_msg_t *msg, FILE *prompt)
 {
+       mem_cred_t *creds;
+
        if (msg->reread.flags & REREAD_SECRETS)
        {
                DBG1(DBG_CFG, "rereading secrets");
@@ -1379,38 +1386,44 @@ METHOD(stroke_cred_t, reread, void,
        }
        if (msg->reread.flags & REREAD_CACERTS)
        {
+               /* first reload certificates in ca sections, so we can refer to them */
+               this->ca->reload_certs(this->ca);
+
                DBG1(DBG_CFG, "rereading ca certificates from '%s'",
                         CA_CERTIFICATE_DIR);
-               this->cacerts->clear(this->cacerts);
+               creds = mem_cred_create();
+               load_certdir(this, CA_CERTIFICATE_DIR, CERT_X509, X509_CA, creds);
+               this->ca->replace_certs(this->ca, creds);
+               creds->destroy(creds);
+       }
+       if (msg->reread.flags & REREAD_AACERTS)
+       {
+               DBG1(DBG_CFG, "rereading aa certificates from '%s'",
+                        AA_CERTIFICATE_DIR);
+               creds = mem_cred_create();
+               load_certdir(this, AA_CERTIFICATE_DIR, CERT_X509, X509_AA, creds);
+               this->aacerts->replace_certs(this->aacerts, creds, FALSE);
+               creds->destroy(creds);
                lib->credmgr->flush_cache(lib->credmgr, CERT_X509);
-               load_certdir(this, CA_CERTIFICATE_DIR, CERT_X509, X509_CA);
        }
        if (msg->reread.flags & REREAD_OCSPCERTS)
        {
                DBG1(DBG_CFG, "rereading ocsp signer certificates from '%s'",
                         OCSP_CERTIFICATE_DIR);
                load_certdir(this, OCSP_CERTIFICATE_DIR, CERT_X509,
-                        X509_OCSP_SIGNER);
-       }
-       if (msg->reread.flags & REREAD_AACERTS)
-       {
-               DBG1(DBG_CFG, "rereading aa certificates from '%s'",
-                        AA_CERTIFICATE_DIR);
-               this->aacerts->clear(this->aacerts);
-               lib->credmgr->flush_cache(lib->credmgr, CERT_X509);
-               load_certdir(this, AA_CERTIFICATE_DIR, CERT_X509, X509_AA);
+                        X509_OCSP_SIGNER, this->creds);
        }
        if (msg->reread.flags & REREAD_ACERTS)
        {
                DBG1(DBG_CFG, "rereading attribute certificates from '%s'",
                         ATTR_CERTIFICATE_DIR);
-               load_certdir(this, ATTR_CERTIFICATE_DIR, CERT_X509_AC, 0);
+               load_certdir(this, ATTR_CERTIFICATE_DIR, CERT_X509_AC, 0, this->creds);
        }
        if (msg->reread.flags & REREAD_CRLS)
        {
                DBG1(DBG_CFG, "rereading crls from '%s'",
                         CRL_DIR);
-               load_certdir(this, CRL_DIR, CERT_X509_CRL, 0);
+               load_certdir(this, CRL_DIR, CERT_X509_CRL, 0, this->creds);
        }
 }
 
@@ -1424,10 +1437,8 @@ METHOD(stroke_cred_t, destroy, void,
        private_stroke_cred_t *this)
 {
        lib->credmgr->remove_set(lib->credmgr, &this->aacerts->set);
-       lib->credmgr->remove_set(lib->credmgr, &this->cacerts->set);
        lib->credmgr->remove_set(lib->credmgr, &this->creds->set);
        this->aacerts->destroy(this->aacerts);
-       this->cacerts->destroy(this->cacerts);
        this->creds->destroy(this->creds);
        free(this);
 }
@@ -1435,7 +1446,7 @@ METHOD(stroke_cred_t, destroy, void,
 /*
  * see header file
  */
-stroke_cred_t *stroke_cred_create()
+stroke_cred_t *stroke_cred_create(stroke_ca_t *ca)
 {
        private_stroke_cred_t *this;
 
@@ -1449,7 +1460,6 @@ stroke_cred_t *stroke_cred_create()
                                .cache_cert = (void*)_cache_cert,
                        },
                        .reread = _reread,
-                       .load_ca = _load_ca,
                        .load_peer = _load_peer,
                        .load_pubkey = _load_pubkey,
                        .add_shared = _add_shared,
@@ -1460,12 +1470,11 @@ stroke_cred_t *stroke_cred_create()
                                                                "%s.plugins.stroke.secrets_file", SECRETS_FILE,
                                                                lib->ns),
                .creds = mem_cred_create(),
-               .cacerts = mem_cred_create(),
                .aacerts = mem_cred_create(),
+               .ca = ca,
        );
 
        lib->credmgr->add_set(lib->credmgr, &this->creds->set);
-       lib->credmgr->add_set(lib->credmgr, &this->cacerts->set);
        lib->credmgr->add_set(lib->credmgr, &this->aacerts->set);
 
        this->force_ca_cert = lib->settings->get_bool(lib->settings,
index 9434629..33a0e35 100644 (file)
@@ -29,6 +29,8 @@
 #include <credentials/certificates/certificate.h>
 #include <collections/linked_list.h>
 
+#include "stroke_ca.h"
+
 typedef struct stroke_cred_t stroke_cred_t;
 
 /**
@@ -50,17 +52,6 @@ struct stroke_cred_t {
        void (*reread)(stroke_cred_t *this, stroke_msg_t *msg, FILE *prompt);
 
        /**
-        * Load a CA certificate.
-        *
-        * This method does not add the loaded CA certificate to the internal
-        * credentail set, but returns it only.
-        *
-        * @param filename              file to load CA cert from
-        * @return                              loaded certificate, or NULL
-        */
-       certificate_t* (*load_ca)(stroke_cred_t *this, char *filename);
-
-       /**
         * Load a peer certificate and serve it through the credential_set.
         *
         * @param filename              file to load peer cert from
@@ -103,6 +94,6 @@ struct stroke_cred_t {
 /**
  * Create a stroke_cred instance.
  */
-stroke_cred_t *stroke_cred_create();
+stroke_cred_t *stroke_cred_create(stroke_ca_t *ca);
 
 #endif /** STROKE_CRED_H_ @}*/
index db7e66f..29563e3 100644 (file)
@@ -779,10 +779,10 @@ stroke_socket_t *stroke_socket_create()
                                "%s.plugins.stroke.prevent_loglevel_changes", FALSE, lib->ns),
        );
 
-       this->cred = stroke_cred_create();
+       this->ca = stroke_ca_create();
+       this->cred = stroke_cred_create(this->ca);
        this->attribute = stroke_attribute_create();
        this->handler = stroke_handler_create();
-       this->ca = stroke_ca_create(this->cred);
        this->config = stroke_config_create(this->ca, this->cred, this->attribute);
        this->control = stroke_control_create();
        this->list = stroke_list_create(this->attribute);
index 7ad011b..4884c4b 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2010-2013 Tobias Brunner
+ * Copyright (C) 2010-2015 Tobias Brunner
  * Hochschule fuer Technik Rapperwsil
  * Copyright (C) 2010 Martin Willi
  * Copyright (C) 2010 revosec AG
@@ -197,7 +197,7 @@ METHOD(mem_cred_t, get_cert_ref, certificate_t*,
 {
        certificate_t *cached;
 
-       this->lock->write_lock(this->lock);
+       this->lock->read_lock(this->lock);
        if (this->untrusted->find_first(this->untrusted,
                                                                        (linked_list_match_t)certificate_equals,
                                                                        (void**)&cached, cert) == SUCCESS)
@@ -643,6 +643,49 @@ METHOD(credential_set_t, create_cdp_enumerator, enumerator_t*,
 
 }
 
+static void reset_certs(private_mem_cred_t *this)
+{
+       this->trusted->destroy_offset(this->trusted,
+                                                                 offsetof(certificate_t, destroy));
+       this->untrusted->destroy_offset(this->untrusted,
+                                                                       offsetof(certificate_t, destroy));
+       this->trusted = linked_list_create();
+       this->untrusted = linked_list_create();
+}
+
+static void copy_certs(linked_list_t *dst, linked_list_t *src, bool clone)
+{
+       enumerator_t *enumerator;
+       certificate_t *cert;
+
+       enumerator = src->create_enumerator(src);
+       while (enumerator->enumerate(enumerator, &cert))
+       {
+               if (clone)
+               {
+                       cert = cert->get_ref(cert);
+               }
+               else
+               {
+                       src->remove_at(src, enumerator);
+               }
+               dst->insert_last(dst, cert);
+       }
+       enumerator->destroy(enumerator);
+}
+
+METHOD(mem_cred_t, replace_certs, void,
+       private_mem_cred_t *this, mem_cred_t *other_set, bool clone)
+{
+       private_mem_cred_t *other = (private_mem_cred_t*)other_set;
+
+       this->lock->write_lock(this->lock);
+       reset_certs(this);
+       copy_certs(this->untrusted, other->untrusted, clone);
+       copy_certs(this->trusted, other->trusted, clone);
+       this->lock->unlock(this->lock);
+}
+
 static void reset_secrets(private_mem_cred_t *this)
 {
        this->keys->destroy_offset(this->keys, offsetof(private_key_t, destroy));
@@ -710,17 +753,11 @@ METHOD(mem_cred_t, clear_, void,
        private_mem_cred_t *this)
 {
        this->lock->write_lock(this->lock);
-       this->trusted->destroy_offset(this->trusted,
-                                                                 offsetof(certificate_t, destroy));
-       this->untrusted->destroy_offset(this->untrusted,
-                                                                       offsetof(certificate_t, destroy));
        this->cdps->destroy_function(this->cdps, (void*)cdp_destroy);
-       this->trusted = linked_list_create();
-       this->untrusted = linked_list_create();
        this->cdps = linked_list_create();
+       reset_certs(this);
+       reset_secrets(this);
        this->lock->unlock(this->lock);
-
-       clear_secrets(this);
 }
 
 METHOD(mem_cred_t, destroy, void,
@@ -760,6 +797,7 @@ mem_cred_t *mem_cred_create()
                        .add_shared = _add_shared,
                        .add_shared_list = _add_shared_list,
                        .add_cdp = _add_cdp,
+                       .replace_certs = _replace_certs,
                        .replace_secrets = _replace_secrets,
                        .clear = _clear_,
                        .clear_secrets = _clear_secrets,
index 3ce815a..51f0b8c 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2010-2013 Tobias Brunner
+ * Copyright (C) 2010-2015 Tobias Brunner
  * Hochschule fuer Technik Rapperswil
  * Copyright (C) 2010 Martin Willi
  * Copyright (C) 2010 revosec AG
@@ -102,6 +102,7 @@ struct mem_cred_t {
         */
        void (*add_shared_list)(mem_cred_t *this, shared_key_t *shared,
                                                        linked_list_t *owners);
+
        /**
         * Add a certificate distribution point to the set.
         *
@@ -113,6 +114,15 @@ struct mem_cred_t {
                                        identification_t *id, char *uri);
 
        /**
+        * Replace all certificates in this credential set with those of another.
+        *
+        * @param other                 credential set to get certificates from
+        * @param clone                 TRUE to clone certs, FALSE to adopt them (they
+        *                                              get removed from the other set)
+        */
+       void (*replace_certs)(mem_cred_t *this, mem_cred_t *other, bool clone);
+
+       /**
         * Replace all secrets (private and shared keys) in this credential set
         * with those of another.
         *