proposal: use a single list to store all transforms
authorMartin Willi <martin@revosec.ch>
Wed, 10 Jul 2013 12:16:46 +0000 (14:16 +0200)
committerMartin Willi <martin@revosec.ch>
Wed, 17 Jul 2013 15:20:17 +0000 (17:20 +0200)
Beside that it makes the code actually simpler, it reduces the number of lists
stored by each IKE_SA and each CHILD_SA by 4, which can be up to 1KB per SA.

src/libcharon/config/proposal.c

index 4803c7b..ce3bded 100644 (file)
@@ -36,7 +36,6 @@ ENUM(protocol_id_names, PROTO_NONE, PROTO_IPCOMP,
 );
 
 typedef struct private_proposal_t private_proposal_t;
-typedef struct algorithm_t algorithm_t;
 
 /**
  * Private data of an proposal_t object
@@ -54,29 +53,9 @@ struct private_proposal_t {
        protocol_id_t protocol;
 
        /**
-        * priority ordered list of encryption algorithms
+        * Priority ordered list of transforms, as entry_t
         */
-       linked_list_t *encryption_algos;
-
-       /**
-        * priority ordered list of integrity algorithms
-        */
-       linked_list_t *integrity_algos;
-
-       /**
-        * priority ordered list of pseudo random functions
-        */
-       linked_list_t *prf_algos;
-
-       /**
-        * priority ordered list of dh groups
-        */
-       linked_list_t *dh_groups;
-
-       /**
-        * priority ordered list of extended sequence number flags
-        */
-       linked_list_t *esns;
+       linked_list_t *transforms;
 
        /**
         * senders SPI
@@ -92,68 +71,49 @@ struct private_proposal_t {
 /**
  * Struct used to store different kinds of algorithms.
  */
-struct algorithm_t {
-       /**
-        * Value from an encryption_algorithm_t/integrity_algorithm_t/...
-        */
-       u_int16_t algorithm;
-
-       /**
-        * the associated key size in bits, or zero if not needed
-        */
+typedef struct {
+       /** Type of the transform */
+       transform_type_t type;
+       /** algorithm identifier */
+       u_int16_t alg;
+       /** key size in bits, or zero if not needed */
        u_int16_t key_size;
-};
-
-/**
- * Add algorithm/keysize to a algorithm list
- */
-static void add_algo(linked_list_t *list, u_int16_t algo, u_int16_t key_size)
-{
-       algorithm_t *algo_key;
-
-       algo_key = malloc_thing(algorithm_t);
-       algo_key->algorithm = algo;
-       algo_key->key_size = key_size;
-       list->insert_last(list, (void*)algo_key);
-}
+} entry_t;
 
 METHOD(proposal_t, add_algorithm, void,
        private_proposal_t *this, transform_type_t type,
-       u_int16_t algo, u_int16_t key_size)
+       u_int16_t alg, u_int16_t key_size)
 {
-       switch (type)
-       {
-               case ENCRYPTION_ALGORITHM:
-                       add_algo(this->encryption_algos, algo, key_size);
-                       break;
-               case INTEGRITY_ALGORITHM:
-                       add_algo(this->integrity_algos, algo, key_size);
-                       break;
-               case PSEUDO_RANDOM_FUNCTION:
-                       add_algo(this->prf_algos, algo, key_size);
-                       break;
-               case DIFFIE_HELLMAN_GROUP:
-                       add_algo(this->dh_groups, algo, 0);
-                       break;
-               case EXTENDED_SEQUENCE_NUMBERS:
-                       add_algo(this->esns, algo, 0);
-                       break;
-               default:
-                       break;
-       }
+       entry_t *entry;
+
+       INIT(entry,
+               .type = type,
+               .alg = alg,
+               .key_size = key_size,
+       );
+
+       this->transforms->insert_last(this->transforms, entry);
 }
 
 /**
  * filter function for peer configs
  */
-static bool alg_filter(void *null, algorithm_t **in, u_int16_t *alg,
+static bool alg_filter(uintptr_t type, entry_t **in, u_int16_t *alg,
                                           void **unused, u_int16_t *key_size)
 {
-       algorithm_t *algo = *in;
-       *alg = algo->algorithm;
+       entry_t *entry = *in;
+
+       if (entry->type != type)
+       {
+               return FALSE;
+       }
+       if (alg)
+       {
+               *alg = entry->alg;
+       }
        if (key_size)
        {
-               *key_size = algo->key_size;
+               *key_size = entry->key_size;
        }
        return TRUE;
 }
@@ -161,30 +121,9 @@ static bool alg_filter(void *null, algorithm_t **in, u_int16_t *alg,
 METHOD(proposal_t, create_enumerator, enumerator_t*,
        private_proposal_t *this, transform_type_t type)
 {
-       linked_list_t *list;
-
-       switch (type)
-       {
-               case ENCRYPTION_ALGORITHM:
-                       list = this->encryption_algos;
-                       break;
-               case INTEGRITY_ALGORITHM:
-                       list = this->integrity_algos;
-                       break;
-               case PSEUDO_RANDOM_FUNCTION:
-                       list = this->prf_algos;
-                       break;
-               case DIFFIE_HELLMAN_GROUP:
-                       list = this->dh_groups;
-                       break;
-               case EXTENDED_SEQUENCE_NUMBERS:
-                       list = this->esns;
-                       break;
-               default:
-                       return NULL;
-       }
-       return enumerator_create_filter(list->create_enumerator(list),
-                                                                       (void*)alg_filter, NULL, NULL);
+       return enumerator_create_filter(
+                                               this->transforms->create_enumerator(this->transforms),
+                                               (void*)alg_filter, (void*)(uintptr_t)type, NULL);
 }
 
 METHOD(proposal_t, get_algorithm, bool,
@@ -200,84 +139,92 @@ METHOD(proposal_t, get_algorithm, bool,
                found = TRUE;
        }
        enumerator->destroy(enumerator);
+
        return found;
 }
 
 METHOD(proposal_t, has_dh_group, bool,
        private_proposal_t *this, diffie_hellman_group_t group)
 {
-       bool result = FALSE;
+       bool found = FALSE, any = FALSE;
+       enumerator_t *enumerator;
+       u_int16_t current;
 
-       if (this->dh_groups->get_count(this->dh_groups))
+       enumerator = create_enumerator(this, DIFFIE_HELLMAN_GROUP);
+       while (enumerator->enumerate(enumerator, &current, NULL))
        {
-               algorithm_t *current;
-               enumerator_t *enumerator;
-
-               enumerator = this->dh_groups->create_enumerator(this->dh_groups);
-               while (enumerator->enumerate(enumerator, (void**)&current))
+               any = TRUE;
+               if (current == group)
                {
-                       if (current->algorithm == group)
-                       {
-                               result = TRUE;
-                               break;
-                       }
+                       found = TRUE;
+                       break;
                }
-               enumerator->destroy(enumerator);
        }
-       else if (group == MODP_NONE)
+       enumerator->destroy(enumerator);
+
+       if (!any && group == MODP_NONE)
        {
-               result = TRUE;
+               found = TRUE;
        }
-       return result;
+       return found;
 }
 
 METHOD(proposal_t, strip_dh, void,
        private_proposal_t *this, diffie_hellman_group_t keep)
 {
        enumerator_t *enumerator;
-       algorithm_t *alg;
+       entry_t *entry;
 
-       enumerator = this->dh_groups->create_enumerator(this->dh_groups);
-       while (enumerator->enumerate(enumerator, (void**)&alg))
+       enumerator = this->transforms->create_enumerator(this->transforms);
+       while (enumerator->enumerate(enumerator, &entry))
        {
-               if (alg->algorithm != keep)
+               if (entry->type == DIFFIE_HELLMAN_GROUP &&
+                       entry->alg != keep)
                {
-                       this->dh_groups->remove_at(this->dh_groups, enumerator);
-                       free(alg);
+                       this->transforms->remove_at(this->transforms, enumerator);
+                       free(entry);
                }
        }
        enumerator->destroy(enumerator);
 }
 
 /**
- * Find a matching alg/keysize in two linked lists
+ * Select a matching proposal from this and other, insert into selected.
  */
-static bool select_algo(linked_list_t *first, linked_list_t *second, bool priv,
-                                               bool *add, u_int16_t *alg, size_t *key_size)
+static bool select_algo(private_proposal_t *this, proposal_t *other,
+                                               proposal_t *selected, transform_type_t type, bool priv)
 {
        enumerator_t *e1, *e2;
-       algorithm_t *alg1, *alg2;
+       u_int16_t alg1, alg2, ks1, ks2;
+       bool found = FALSE;
 
-       /* if in both are zero algorithms specified, we HAVE a match */
-       if (first->get_count(first) == 0 && second->get_count(second) == 0)
+       if (type == INTEGRITY_ALGORITHM &&
+               selected->get_algorithm(selected, ENCRYPTION_ALGORITHM, &alg1, NULL) &&
+               encryption_algorithm_is_aead(alg1))
        {
-               *add = FALSE;
+               /* no integrity algorithm required, we have an AEAD */
                return TRUE;
        }
 
-       e1 = first->create_enumerator(first);
-       e2 = second->create_enumerator(second);
+       e1 = create_enumerator(this, type);
+       e2 = other->create_enumerator(other, type);
+       if (!e1->enumerate(e1, NULL, NULL) && !e2->enumerate(e2, NULL, NULL))
+       {
+               found = TRUE;
+       }
+
+       e1->destroy(e1);
+       e1 = create_enumerator(this, type);
        /* compare algs, order of algs in "first" is preferred */
-       while (e1->enumerate(e1, &alg1))
+       while (!found && e1->enumerate(e1, &alg1, &ks1))
        {
                e2->destroy(e2);
-               e2 = second->create_enumerator(second);
-               while (e2->enumerate(e2, &alg2))
+               e2 = other->create_enumerator(other, type);
+               while (e2->enumerate(e2, &alg2, &ks2))
                {
-                       if (alg1->algorithm == alg2->algorithm &&
-                               alg1->key_size == alg2->key_size)
+                       if (alg1 == alg2 && ks1 == ks2)
                        {
-                               if (!priv && alg1->algorithm >= 1024)
+                               if (!priv && alg1 >= 1024)
                                {
                                        /* accept private use algorithms only if requested */
                                        DBG1(DBG_CFG, "an algorithm from private space would match, "
@@ -285,132 +232,52 @@ static bool select_algo(linked_list_t *first, linked_list_t *second, bool priv,
                                        continue;
                                }
                                /* ok, we have an algorithm */
-                               *alg = alg1->algorithm;
-                               *key_size = alg1->key_size;
-                               *add = TRUE;
-                               e1->destroy(e1);
-                               e2->destroy(e2);
-                               return TRUE;
+                               selected->add_algorithm(selected, type, alg1, ks1);
+                               found = TRUE;
+                               break;
                        }
                }
        }
        /* no match in all comparisons */
        e1->destroy(e1);
        e2->destroy(e2);
-       return FALSE;
+
+       if (!found)
+       {
+               DBG2(DBG_CFG, "  no acceptable %N found", transform_type_names, type);
+       }
+       return found;
 }
 
 METHOD(proposal_t, select_proposal, proposal_t*,
-       private_proposal_t *this, proposal_t *other_pub, bool private)
+       private_proposal_t *this, proposal_t *other, bool private)
 {
-       private_proposal_t *other = (private_proposal_t*)other_pub;
        proposal_t *selected;
-       u_int16_t algo;
-       size_t key_size;
-       bool add;
 
        DBG2(DBG_CFG, "selecting proposal:");
 
-       /* check protocol */
-       if (this->protocol != other->protocol)
+       if (this->protocol != other->get_protocol(other))
        {
                DBG2(DBG_CFG, "  protocol mismatch, skipping");
                return NULL;
        }
 
-       selected = proposal_create(this->protocol, other->number);
+       selected = proposal_create(this->protocol, other->get_number(other));
 
-       /* select encryption algorithm */
-       if (select_algo(this->encryption_algos, other->encryption_algos, private,
-                                       &add, &algo, &key_size))
-       {
-               if (add)
-               {
-                       selected->add_algorithm(selected, ENCRYPTION_ALGORITHM,
-                                                                       algo, key_size);
-               }
-       }
-       else
-       {
-               selected->destroy(selected);
-               DBG2(DBG_CFG, "  no acceptable %N found",
-                        transform_type_names, ENCRYPTION_ALGORITHM);
-               return NULL;
-       }
-       /* select integrity algorithm */
-       if (!encryption_algorithm_is_aead(algo))
-       {
-               if (select_algo(this->integrity_algos, other->integrity_algos, private,
-                                               &add, &algo, &key_size))
-               {
-                       if (add)
-                       {
-                               selected->add_algorithm(selected, INTEGRITY_ALGORITHM,
-                                                                               algo, key_size);
-                       }
-               }
-               else
-               {
-                       selected->destroy(selected);
-                       DBG2(DBG_CFG, "  no acceptable %N found",
-                                transform_type_names, INTEGRITY_ALGORITHM);
-                       return NULL;
-               }
-       }
-       /* select prf algorithm */
-       if (select_algo(this->prf_algos, other->prf_algos, private,
-                                       &add, &algo, &key_size))
-       {
-               if (add)
-               {
-                       selected->add_algorithm(selected, PSEUDO_RANDOM_FUNCTION,
-                                                                       algo, key_size);
-               }
-       }
-       else
-       {
-               selected->destroy(selected);
-               DBG2(DBG_CFG, "  no acceptable %N found",
-                        transform_type_names, PSEUDO_RANDOM_FUNCTION);
-               return NULL;
-       }
-       /* select a DH-group */
-       if (select_algo(this->dh_groups, other->dh_groups, private,
-                                       &add, &algo, &key_size))
-       {
-               if (add)
-               {
-                       selected->add_algorithm(selected, DIFFIE_HELLMAN_GROUP, algo, 0);
-               }
-       }
-       else
+       if (!select_algo(this, other, selected, ENCRYPTION_ALGORITHM, private) ||
+               !select_algo(this, other, selected, PSEUDO_RANDOM_FUNCTION, private) ||
+               !select_algo(this, other, selected, INTEGRITY_ALGORITHM, private) ||
+               !select_algo(this, other, selected, DIFFIE_HELLMAN_GROUP, private) ||
+               !select_algo(this, other, selected, EXTENDED_SEQUENCE_NUMBERS, private))
        {
                selected->destroy(selected);
-               DBG2(DBG_CFG, "  no acceptable %N found",
-                        transform_type_names, DIFFIE_HELLMAN_GROUP);
-               return NULL;
-       }
-       /* select if we use ESNs (has no private use space) */
-       if (select_algo(this->esns, other->esns, TRUE, &add, &algo, &key_size))
-       {
-               if (add)
-               {
-                       selected->add_algorithm(selected, EXTENDED_SEQUENCE_NUMBERS, algo, 0);
-               }
-       }
-       else
-       {
-               selected->destroy(selected);
-               DBG2(DBG_CFG, "  no acceptable %N found",
-                        transform_type_names, EXTENDED_SEQUENCE_NUMBERS);
                return NULL;
        }
+
        DBG2(DBG_CFG, "  proposal matches");
 
-       /* apply SPI from "other" */
-       selected->set_spi(selected, other->spi);
+       selected->set_spi(selected, other->get_spi(other));
 
-       /* everything matched, return new proposal */
        return selected;
 }
 
@@ -433,50 +300,39 @@ METHOD(proposal_t, get_spi, u_int64_t,
 }
 
 /**
- * Clone a algorithm list
- */
-static void clone_algo_list(linked_list_t *list, linked_list_t *clone_list)
-{
-       algorithm_t *algo, *clone_algo;
-       enumerator_t *enumerator;
-
-       enumerator = list->create_enumerator(list);
-       while (enumerator->enumerate(enumerator, &algo))
-       {
-               clone_algo = malloc_thing(algorithm_t);
-               memcpy(clone_algo, algo, sizeof(algorithm_t));
-               clone_list->insert_last(clone_list, (void*)clone_algo);
-       }
-       enumerator->destroy(enumerator);
-}
-
-/**
- * check if an algorithm list equals
+ * Check if two proposals have the same algorithms for a given transform type
  */
-static bool algo_list_equals(linked_list_t *l1, linked_list_t *l2)
+static bool algo_list_equals(private_proposal_t *this, proposal_t *other,
+                                                        transform_type_t type)
 {
        enumerator_t *e1, *e2;
-       algorithm_t *alg1, *alg2;
+       u_int16_t alg1, alg2, ks1, ks2;
        bool equals = TRUE;
 
-       if (l1->get_count(l1) != l2->get_count(l2))
-       {
-               return FALSE;
-       }
-
-       e1 = l1->create_enumerator(l1);
-       e2 = l2->create_enumerator(l2);
-       while (e1->enumerate(e1, &alg1) && e2->enumerate(e2, &alg2))
+       e1 = create_enumerator(this, type);
+       e2 = other->create_enumerator(other, type);
+       while (e1->enumerate(e1, &alg1, &ks1))
        {
-               if (alg1->algorithm != alg2->algorithm ||
-                       alg1->key_size != alg2->key_size)
+               if (!e2->enumerate(e2, &alg2, &ks2))
+               {
+                       /* this has more algs */
+                       equals = FALSE;
+                       break;
+               }
+               if (alg1 != alg2 || ks1 != ks2)
                {
                        equals = FALSE;
                        break;
                }
        }
+       if (e2->enumerate(e2, &alg2, ks2))
+       {
+               /* other has more algs */
+               equals = FALSE;
+       }
        e1->destroy(e1);
        e2->destroy(e2);
+
        return equals;
 }
 
@@ -487,33 +343,40 @@ METHOD(proposal_t, get_number, u_int,
 }
 
 METHOD(proposal_t, equals, bool,
-       private_proposal_t *this, proposal_t *other_pub)
+       private_proposal_t *this, proposal_t *other)
 {
-       private_proposal_t *other = (private_proposal_t*)other_pub;
-
-       if (this == other)
+       if (&this->public == other)
        {
                return TRUE;
        }
        return (
-               algo_list_equals(this->encryption_algos, other->encryption_algos) &&
-               algo_list_equals(this->integrity_algos, other->integrity_algos) &&
-               algo_list_equals(this->prf_algos, other->prf_algos) &&
-               algo_list_equals(this->dh_groups, other->dh_groups) &&
-               algo_list_equals(this->esns, other->esns));
+               algo_list_equals(this, other, ENCRYPTION_ALGORITHM) &&
+               algo_list_equals(this, other, INTEGRITY_ALGORITHM) &&
+               algo_list_equals(this, other, PSEUDO_RANDOM_FUNCTION) &&
+               algo_list_equals(this, other, DIFFIE_HELLMAN_GROUP) &&
+               algo_list_equals(this, other, EXTENDED_SEQUENCE_NUMBERS));
 }
 
 METHOD(proposal_t, clone_, proposal_t*,
        private_proposal_t *this)
 {
        private_proposal_t *clone;
+       enumerator_t *enumerator;
+       entry_t *current, *entry;
 
        clone = (private_proposal_t*)proposal_create(this->protocol, 0);
-       clone_algo_list(this->encryption_algos, clone->encryption_algos);
-       clone_algo_list(this->integrity_algos, clone->integrity_algos);
-       clone_algo_list(this->prf_algos, clone->prf_algos);
-       clone_algo_list(this->dh_groups, clone->dh_groups);
-       clone_algo_list(this->esns, clone->esns);
+
+       enumerator = this->transforms->create_enumerator(this->transforms);
+       while (enumerator->enumerate(enumerator, &current))
+       {
+               INIT(entry,
+                       .type = current->type,
+                       .alg = current->alg,
+                       .key_size = current->key_size,
+               );
+               clone->transforms->insert_last(clone->transforms, entry);
+       }
+       enumerator->destroy(enumerator);
 
        clone->spi = this->spi;
        clone->number = this->number;
@@ -544,34 +407,40 @@ static const struct {
 static void check_proposal(private_proposal_t *this)
 {
        enumerator_t *e;
-       algorithm_t *alg;
+       entry_t *entry;
+       u_int16_t alg, ks;
        bool all_aead = TRUE;
        int i;
 
-       if (this->protocol == PROTO_IKE &&
-               this->prf_algos->get_count(this->prf_algos) == 0)
-       {       /* No explicit PRF found. We assume the same algorithm as used
-                * for integrity checking */
-               e = this->integrity_algos->create_enumerator(this->integrity_algos);
-               while (e->enumerate(e, &alg))
+       if (this->protocol == PROTO_IKE)
+       {
+               e = create_enumerator(this, PSEUDO_RANDOM_FUNCTION);
+               if (!e->enumerate(e, &alg, &ks))
                {
-                       for (i = 0; i < countof(integ_prf_map); i++)
+                       /* No explicit PRF found. We assume the same algorithm as used
+                        * for integrity checking */
+                       e->destroy(e);
+                       e = create_enumerator(this, INTEGRITY_ALGORITHM);
+                       while (e->enumerate(e, &alg, &ks))
                        {
-                               if (alg->algorithm == integ_prf_map[i].integ)
+                               for (i = 0; i < countof(integ_prf_map); i++)
                                {
-                                       add_algorithm(this, PSEUDO_RANDOM_FUNCTION,
-                                                                 integ_prf_map[i].prf, 0);
-                                       break;
+                                       if (alg == integ_prf_map[i].integ)
+                                       {
+                                               add_algorithm(this, PSEUDO_RANDOM_FUNCTION,
+                                                                         integ_prf_map[i].prf, 0);
+                                               break;
+                                       }
                                }
                        }
                }
                e->destroy(e);
        }
 
-       e = this->encryption_algos->create_enumerator(this->encryption_algos);
-       while (e->enumerate(e, &alg))
+       e = create_enumerator(this, ENCRYPTION_ALGORITHM);
+       while (e->enumerate(e, &alg, &ks))
        {
-               if (!encryption_algorithm_is_aead(alg->algorithm))
+               if (!encryption_algorithm_is_aead(alg))
                {
                        all_aead = FALSE;
                        break;
@@ -581,19 +450,24 @@ static void check_proposal(private_proposal_t *this)
 
        if (all_aead)
        {
-               /* if all encryption algorithms in the proposal are authenticated encryption
-                * algorithms we MUST NOT propose any integrity algorithms */
-               while (this->integrity_algos->remove_last(this->integrity_algos,
-                                                                                                 (void**)&alg) == SUCCESS)
+               /* if all encryption algorithms in the proposal are AEADs,
+                * we MUST NOT propose any integrity algorithms */
+               e = this->transforms->create_enumerator(this->transforms);
+               while (e->enumerate(e, &entry))
                {
-                       free(alg);
+                       if (entry->type == INTEGRITY_ALGORITHM)
+                       {
+                               this->transforms->remove_at(this->transforms, e);
+                               free(entry);
+                       }
                }
+               e->destroy(e);
        }
 
        if (this->protocol == PROTO_AH || this->protocol == PROTO_ESP)
        {
-               e = this->esns->create_enumerator(this->esns);
-               if (!e->enumerate(e, &alg))
+               e = create_enumerator(this, EXTENDED_SEQUENCE_NUMBERS);
+               if (!e->enumerate(e, NULL, NULL))
                {       /* ESN not specified, assume not supported */
                        add_algorithm(this, EXTENDED_SEQUENCE_NUMBERS, NO_EXT_SEQ_NUMBERS, 0);
                }
@@ -704,11 +578,7 @@ int proposal_printf_hook(printf_hook_data_t *data, printf_hook_spec_t *spec,
 METHOD(proposal_t, destroy, void,
        private_proposal_t *this)
 {
-       this->encryption_algos->destroy_function(this->encryption_algos, free);
-       this->integrity_algos->destroy_function(this->integrity_algos, free);
-       this->prf_algos->destroy_function(this->prf_algos, free);
-       this->dh_groups->destroy_function(this->dh_groups, free);
-       this->esns->destroy_function(this->esns, free);
+       this->transforms->destroy_function(this->transforms, free);
        free(this);
 }
 
@@ -737,11 +607,7 @@ proposal_t *proposal_create(protocol_id_t protocol, u_int number)
                },
                .protocol = protocol,
                .number = number,
-               .encryption_algos = linked_list_create(),
-               .integrity_algos = linked_list_create(),
-               .prf_algos = linked_list_create(),
-               .dh_groups = linked_list_create(),
-               .esns = linked_list_create(),
+               .transforms = linked_list_create(),
        );
 
        return &this->public;