Alternative to mem_cred_t.add_cert added, which returns the certificate.
[strongswan.git] / src / libstrongswan / credentials / sets / mem_cred.c
index a64f431..08a1e71 100644 (file)
@@ -1,4 +1,6 @@
 /*
 /*
+ * Copyright (C) 2010 Tobias Brunner
+ * Hochschule fuer Technik Rapperwsil
  * Copyright (C) 2010 Martin Willi
  * Copyright (C) 2010 revosec AG
  *
  * Copyright (C) 2010 Martin Willi
  * Copyright (C) 2010 revosec AG
  *
@@ -36,12 +38,217 @@ struct private_mem_cred_t {
        rwlock_t *lock;
 
        /**
        rwlock_t *lock;
 
        /**
+        * List of trusted certificates, certificate_t
+        */
+       linked_list_t *trusted;
+
+       /**
+        * List of trusted and untrusted certificates, certificate_t
+        */
+       linked_list_t *untrusted;
+
+       /**
+        * List of private keys, private_key_t
+        */
+       linked_list_t *keys;
+
+       /**
         * List of shared keys, as shared_entry_t
         */
        linked_list_t *shared;
 };
 
 /**
         * List of shared keys, as shared_entry_t
         */
        linked_list_t *shared;
 };
 
 /**
+ * Data for the certificate enumerator
+ */
+typedef struct {
+       rwlock_t *lock;
+       certificate_type_t cert;
+       key_type_t key;
+       identification_t *id;
+} cert_data_t;
+
+/**
+ * destroy cert_data
+ */
+static void cert_data_destroy(cert_data_t *data)
+{
+       data->lock->unlock(data->lock);
+       free(data);
+}
+
+/**
+ * filter function for certs enumerator
+ */
+static bool certs_filter(cert_data_t *data, certificate_t **in, certificate_t **out)
+{
+       public_key_t *public;
+       certificate_t *cert = *in;
+
+       if (data->cert == CERT_ANY || data->cert == cert->get_type(cert))
+       {
+               public = cert->get_public_key(cert);
+               if (public)
+               {
+                       if (data->key == KEY_ANY || data->key == public->get_type(public))
+                       {
+                               if (data->id && public->has_fingerprint(public,
+                                                                                       data->id->get_encoding(data->id)))
+                               {
+                                       public->destroy(public);
+                                       *out = *in;
+                                       return TRUE;
+                               }
+                       }
+                       public->destroy(public);
+               }
+               else if (data->key != KEY_ANY)
+               {
+                       return FALSE;
+               }
+               if (data->id == NULL || cert->has_subject(cert, data->id))
+               {
+                       *out = *in;
+                       return TRUE;
+               }
+       }
+       return FALSE;
+}
+
+METHOD(credential_set_t, create_cert_enumerator, enumerator_t*,
+       private_mem_cred_t *this, certificate_type_t cert, key_type_t key,
+       identification_t *id, bool trusted)
+{
+       cert_data_t *data;
+       enumerator_t *enumerator;
+
+       INIT(data,
+               .lock = this->lock,
+               .cert = cert,
+               .key = key,
+               .id = id,
+       );
+       this->lock->read_lock(this->lock);
+       if (trusted)
+       {
+               enumerator = this->trusted->create_enumerator(this->trusted);
+       }
+       else
+       {
+               enumerator = this->untrusted->create_enumerator(this->untrusted);
+       }
+       return enumerator_create_filter(enumerator, (void*)certs_filter, data,
+                                                                       (void*)cert_data_destroy);
+}
+
+static bool certificate_equals(certificate_t *item, certificate_t *cert)
+{
+       return item->equals(item, cert);
+}
+
+/**
+ * Add a certificate the the cache. Returns a reference to "cert" or a
+ * previously cached certificate that equals "cert".
+ */
+static certificate_t *add_cert_internal(private_mem_cred_t *this, bool trusted,
+                                                                               certificate_t *cert)
+{
+       certificate_t *cached;
+       this->lock->write_lock(this->lock);
+       if (this->untrusted->find_last(this->untrusted,
+                                                                  (linked_list_match_t)certificate_equals,
+                                                                  (void**)&cached, cert) == SUCCESS)
+       {
+               cert->destroy(cert);
+               cert = cached->get_ref(cached);
+       }
+       else
+       {
+               if (trusted)
+               {
+                       this->trusted->insert_last(this->trusted, cert->get_ref(cert));
+               }
+               this->untrusted->insert_last(this->untrusted, cert->get_ref(cert));
+       }
+       this->lock->unlock(this->lock);
+       return cert;
+}
+
+METHOD(mem_cred_t, add_cert, void,
+       private_mem_cred_t *this, bool trusted, certificate_t *cert)
+{
+       certificate_t *cached = add_cert_internal(this, trusted, cert);
+       cached->destroy(cached);
+}
+
+METHOD(mem_cred_t, add_cert_ref, certificate_t*,
+       private_mem_cred_t *this, bool trusted, certificate_t *cert)
+{
+       return add_cert_internal(this, trusted, cert);
+}
+
+/**
+ * Data for key enumerator
+ */
+typedef struct {
+       rwlock_t *lock;
+       key_type_t type;
+       identification_t *id;
+} key_data_t;
+
+/**
+ * Destroy key enumerator data
+ */
+static void key_data_destroy(key_data_t *data)
+{
+       data->lock->unlock(data->lock);
+       free(data);
+}
+
+/**
+ * filter function for private key enumerator
+ */
+static bool key_filter(key_data_t *data, private_key_t **in, private_key_t **out)
+{
+       private_key_t *key;
+
+       key = *in;
+       if (data->type == KEY_ANY || data->type == key->get_type(key))
+       {
+               if (data->id == NULL ||
+                       key->has_fingerprint(key, data->id->get_encoding(data->id)))
+               {
+                       *out = key;
+                       return TRUE;
+               }
+       }
+       return FALSE;
+}
+
+METHOD(credential_set_t, create_private_enumerator, enumerator_t*,
+       private_mem_cred_t *this, key_type_t type, identification_t *id)
+{
+       key_data_t *data;
+
+       INIT(data,
+               .lock = this->lock,
+               .type = type,
+               .id = id,
+       );
+       this->lock->read_lock(this->lock);
+       return enumerator_create_filter(this->keys->create_enumerator(this->keys),
+                                                       (void*)key_filter, data, (void*)key_data_destroy);
+}
+
+METHOD(mem_cred_t, add_key, void,
+       private_mem_cred_t *this, private_key_t *key)
+{
+       this->lock->write_lock(this->lock);
+       this->keys->insert_last(this->keys, key);
+       this->lock->unlock(this->lock);
+}
+
+/**
  * Shared key entry
  */
 typedef struct {
  * Shared key entry
  */
 typedef struct {
@@ -161,40 +368,68 @@ METHOD(credential_set_t, create_shared_enumerator, enumerator_t*,
                                                (void*)shared_filter, data, (void*)shared_data_destroy);
 }
 
                                                (void*)shared_filter, data, (void*)shared_data_destroy);
 }
 
-METHOD(mem_cred_t, add_shared, void,
-       private_mem_cred_t *this, shared_key_t *shared, ...)
+METHOD(mem_cred_t, add_shared_list, void,
+       private_mem_cred_t *this, shared_key_t *shared, linked_list_t* owners)
 {
        shared_entry_t *entry;
 {
        shared_entry_t *entry;
-       identification_t *id;
-       va_list args;
 
        INIT(entry,
                .shared = shared,
 
        INIT(entry,
                .shared = shared,
-               .owners = linked_list_create(),
+               .owners = owners,
        );
 
        );
 
+       this->lock->write_lock(this->lock);
+       this->shared->insert_last(this->shared, entry);
+       this->lock->unlock(this->lock);
+}
+
+METHOD(mem_cred_t, add_shared, void,
+       private_mem_cred_t *this, shared_key_t *shared, ...)
+{
+       identification_t *id;
+       linked_list_t *owners = linked_list_create();
+       va_list args;
+
        va_start(args, shared);
        do
        {
                id = va_arg(args, identification_t*);
                if (id)
                {
        va_start(args, shared);
        do
        {
                id = va_arg(args, identification_t*);
                if (id)
                {
-                       entry->owners->insert_last(entry->owners, id);
+                       owners->insert_last(owners, id);
                }
        }
        while (id);
        va_end(args);
 
                }
        }
        while (id);
        va_end(args);
 
+       add_shared_list(this, shared, owners);
+}
+
+METHOD(mem_cred_t, clear_, void,
+       private_mem_cred_t *this)
+{
        this->lock->write_lock(this->lock);
        this->lock->write_lock(this->lock);
-       this->shared->insert_last(this->shared, entry);
+       this->trusted->destroy_offset(this->trusted,
+                                                                 offsetof(certificate_t, destroy));
+       this->untrusted->destroy_offset(this->untrusted,
+                                                                       offsetof(certificate_t, destroy));
+       this->keys->destroy_offset(this->keys, offsetof(private_key_t, destroy));
+       this->shared->destroy_function(this->shared, (void*)shared_entry_destroy);
+       this->trusted = linked_list_create();
+       this->untrusted = linked_list_create();
+       this->keys = linked_list_create();
+       this->shared = linked_list_create();
        this->lock->unlock(this->lock);
 }
 
        this->lock->unlock(this->lock);
 }
 
-
 METHOD(mem_cred_t, destroy, void,
        private_mem_cred_t *this)
 {
 METHOD(mem_cred_t, destroy, void,
        private_mem_cred_t *this)
 {
-       this->shared->destroy_function(this->shared, (void*)shared_entry_destroy);
+       clear_(this);
+       this->trusted->destroy(this->trusted);
+       this->untrusted->destroy(this->untrusted);
+       this->keys->destroy(this->keys);
+       this->shared->destroy(this->shared);
        this->lock->destroy(this->lock);
        free(this);
 }
        this->lock->destroy(this->lock);
        free(this);
 }
@@ -210,14 +445,22 @@ mem_cred_t *mem_cred_create()
                .public = {
                        .set = {
                                .create_shared_enumerator = _create_shared_enumerator,
                .public = {
                        .set = {
                                .create_shared_enumerator = _create_shared_enumerator,
-                               .create_private_enumerator = (void*)return_null,
-                               .create_cert_enumerator = (void*)return_null,
+                               .create_private_enumerator = _create_private_enumerator,
+                               .create_cert_enumerator = _create_cert_enumerator,
                                .create_cdp_enumerator  = (void*)return_null,
                                .cache_cert = (void*)nop,
                        },
                                .create_cdp_enumerator  = (void*)return_null,
                                .cache_cert = (void*)nop,
                        },
+                       .add_cert = _add_cert,
+                       .add_cert_ref = _add_cert_ref,
+                       .add_key = _add_key,
                        .add_shared = _add_shared,
                        .add_shared = _add_shared,
+                       .add_shared_list = _add_shared_list,
+                       .clear = _clear_,
                        .destroy = _destroy,
                },
                        .destroy = _destroy,
                },
+               .trusted = linked_list_create(),
+               .untrusted = linked_list_create(),
+               .keys = linked_list_create(),
                .shared = linked_list_create(),
                .lock = rwlock_create(RWLOCK_TYPE_DEFAULT),
        );
                .shared = linked_list_create(),
                .lock = rwlock_create(RWLOCK_TYPE_DEFAULT),
        );