proposal: Keep track of contained transform types
authorTobias Brunner <tobias@strongswan.org>
Fri, 23 Feb 2018 07:36:33 +0000 (08:36 +0100)
committerTobias Brunner <tobias@strongswan.org>
Mon, 5 Mar 2018 11:05:36 +0000 (12:05 +0100)
src/libstrongswan/crypto/proposal/proposal.c

index bb0a02b..be519c2 100644 (file)
@@ -58,6 +58,11 @@ struct private_proposal_t {
        array_t *transforms;
 
        /**
+        * Types of transforms contained, as transform_type_t
+        */
+       array_t *types;
+
+       /**
         * senders SPI
         */
        uint64_t spi;
@@ -69,6 +74,57 @@ struct private_proposal_t {
 };
 
 /**
+ * Sort transform types
+ */
+static int type_sort(const void *a, const void *b, void *user)
+{
+       const transform_type_t *ta = a, *tb = b;
+       return *ta - *tb;
+}
+
+/**
+ * Find a transform type
+ */
+static int type_find(const void *a, const void *b)
+{
+       return type_sort(a, b, NULL);
+}
+
+/**
+ * Check if the given transform type is already in the set
+ */
+static bool contains_type(private_proposal_t *this, transform_type_t type)
+{
+       return array_bsearch(this->types, &type, type_find, NULL) != -1;
+}
+
+/**
+ * Add the given transform type to the set
+ */
+static void add_type(private_proposal_t *this, transform_type_t type)
+{
+       if (!contains_type(this, type))
+       {
+               array_insert(this->types, ARRAY_TAIL, &type);
+               array_sort(this->types, type_sort, NULL);
+       }
+}
+
+/**
+ * Remove the given transform type from the set
+ */
+static void remove_type(private_proposal_t *this, transform_type_t type)
+{
+       int i;
+
+       i = array_bsearch(this->types, &type, type_find, NULL);
+       if (i >= 0)
+       {
+               array_remove(this->types, i, NULL);
+       }
+}
+
+/**
  * Struct used to store different kinds of algorithms.
  */
 typedef struct {
@@ -91,6 +147,7 @@ METHOD(proposal_t, add_algorithm, void,
        };
 
        array_insert(this->transforms, ARRAY_TAIL, &entry);
+       add_type(this, type);
 }
 
 CALLBACK(alg_filter, bool,
@@ -206,17 +263,29 @@ METHOD(proposal_t, strip_dh, void,
 {
        enumerator_t *enumerator;
        entry_t *entry;
+       bool found = FALSE;
 
        enumerator = array_create_enumerator(this->transforms);
        while (enumerator->enumerate(enumerator, &entry))
        {
-               if (entry->type == DIFFIE_HELLMAN_GROUP &&
-                       entry->alg != keep)
+               if (entry->type == DIFFIE_HELLMAN_GROUP)
                {
-                       array_remove_at(this->transforms, enumerator);
+                       if (entry->alg != keep)
+                       {
+                               array_remove_at(this->transforms, enumerator);
+                       }
+                       else
+                       {
+                               found = TRUE;
+                       }
                }
        }
        enumerator->destroy(enumerator);
+
+       if (keep == MODP_NONE || !found)
+       {
+               remove_type(this, DIFFIE_HELLMAN_GROUP);
+       }
 }
 
 /**
@@ -427,6 +496,7 @@ METHOD(proposal_t, clone_, proposal_t*,
        private_proposal_t *clone;
        enumerator_t *enumerator;
        entry_t *entry;
+       transform_type_t *type;
 
        clone = (private_proposal_t*)proposal_create(this->protocol, 0);
 
@@ -436,6 +506,12 @@ METHOD(proposal_t, clone_, proposal_t*,
                array_insert(clone->transforms, ARRAY_TAIL, entry);
        }
        enumerator->destroy(enumerator);
+       enumerator = array_create_enumerator(this->types);
+       while (enumerator->enumerate(enumerator, &type))
+       {
+               array_insert(clone->types, ARRAY_TAIL, type);
+       }
+       enumerator->destroy(enumerator);
 
        clone->spi = this->spi;
        clone->number = this->number;
@@ -479,6 +555,7 @@ static void remove_transform(private_proposal_t *this, transform_type_t type)
                }
        }
        e->destroy(e);
+       remove_type(this, type);
 }
 
 /**
@@ -605,6 +682,7 @@ static bool check_proposal(private_proposal_t *this)
                        }
                }
                e->destroy(e);
+               remove_type(this, ENCRYPTION_ALGORITHM);
 
                if (!get_algorithm(this, INTEGRITY_ALGORITHM, NULL, NULL))
                {
@@ -730,6 +808,7 @@ METHOD(proposal_t, destroy, void,
        private_proposal_t *this)
 {
        array_destroy(this->transforms);
+       array_destroy(this->types);
        free(this);
 }
 
@@ -760,6 +839,7 @@ proposal_t *proposal_create(protocol_id_t protocol, u_int number)
                .protocol = protocol,
                .number = number,
                .transforms = array_create(sizeof(entry_t), 0),
+               .types = array_create(sizeof(transform_type_t), 0),
        );
 
        return &this->public;