linked-list: Remove barely used has_more() method
authorTobias Brunner <tobias@strongswan.org>
Tue, 16 Jul 2013 13:25:51 +0000 (15:25 +0200)
committerTobias Brunner <tobias@strongswan.org>
Wed, 17 Jul 2013 15:42:53 +0000 (17:42 +0200)
This required some refactoring when handling encrypted payloads.

Also changed log messages so that "encrypted payload" is logged instead
of "encryption payload" (even if we internally still call it that) as
that's the name used in RFC 5996.

src/libcharon/encoding/message.c
src/libstrongswan/collections/linked_list.c
src/libstrongswan/collections/linked_list.h
src/libstrongswan/tests/test_linked_list_enumerator.c

index 749c326..9bb8e51 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2006-2011 Tobias Brunner
+ * Copyright (C) 2006-2013 Tobias Brunner
  * Copyright (C) 2005-2010 Martin Willi
  * Copyright (C) 2010 revosec AG
  * Copyright (C) 2006 Daniel Roethlisberger
@@ -1409,7 +1409,7 @@ static encryption_payload_t* wrap_payloads(private_message_t *this)
                }
                if (encrypt || this->is_encrypted)
                {       /* encryption is forced for IKEv1 */
-                       DBG2(DBG_ENC, "insert payload %N to encryption payload",
+                       DBG2(DBG_ENC, "insert payload %N into encrypted payload",
                                 payload_type_names, type);
                        encryption->add_payload(encryption, current);
                }
@@ -1799,15 +1799,15 @@ static status_t parse_payloads(private_message_t *this)
                        return VERIFY_ERROR;
                }
 
-               DBG2(DBG_ENC, "%N payload verified. Adding to payload list",
+               DBG2(DBG_ENC, "%N payload verified, adding to payload list",
                         payload_type_names, type);
                this->payloads->insert_last(this->payloads, payload);
 
-               /* an encryption payload is the last one, so STOP here. decryption is
+               /* an encrypted payload is the last one, so STOP here. decryption is
                 * done later */
                if (type == ENCRYPTED)
                {
-                       DBG2(DBG_ENC, "%N payload found. Stop parsing",
+                       DBG2(DBG_ENC, "%N payload found, stop parsing",
                                 payload_type_names, type);
                        break;
                }
@@ -1817,17 +1817,101 @@ static status_t parse_payloads(private_message_t *this)
 }
 
 /**
+ * Decrypt an encrypted payload and extract all contained payloads.
+ */
+static status_t decrypt_and_extract(private_message_t *this, keymat_t *keymat,
+                                               payload_t *previous, encryption_payload_t *encryption)
+{
+       payload_t *encrypted;
+       payload_type_t type;
+       chunk_t chunk;
+       aead_t *aead;
+       size_t bs;
+       status_t status = SUCCESS;
+
+       if (!keymat)
+       {
+               DBG1(DBG_ENC, "found encrypted payload, but no keymat");
+               return INVALID_ARG;
+       }
+       aead = keymat->get_aead(keymat, TRUE);
+       if (!aead)
+       {
+               DBG1(DBG_ENC, "found encrypted payload, but no transform set");
+               return INVALID_ARG;
+       }
+       bs = aead->get_block_size(aead);
+       encryption->set_transform(encryption, aead);
+       chunk = this->packet->get_data(this->packet);
+       if (chunk.len < encryption->get_length(encryption) ||
+               chunk.len < bs)
+       {
+               DBG1(DBG_ENC, "invalid payload length");
+               return VERIFY_ERROR;
+       }
+       if (keymat->get_version(keymat) == IKEV1)
+       {       /* instead of associated data we provide the IV, we also update
+                * the IV with the last encrypted block */
+               keymat_v1_t *keymat_v1 = (keymat_v1_t*)keymat;
+               chunk_t iv;
+
+               if (keymat_v1->get_iv(keymat_v1, this->message_id, &iv))
+               {
+                       status = encryption->decrypt(encryption, iv);
+                       if (status == SUCCESS)
+                       {
+                               if (!keymat_v1->update_iv(keymat_v1, this->message_id,
+                                               chunk_create(chunk.ptr + chunk.len - bs, bs)))
+                               {
+                                       status = FAILED;
+                               }
+                       }
+               }
+               else
+               {
+                       status = FAILED;
+               }
+       }
+       else
+       {
+               chunk.len -= encryption->get_length(encryption);
+               status = encryption->decrypt(encryption, chunk);
+       }
+       if (status != SUCCESS)
+       {
+               return status;
+       }
+
+       while ((encrypted = encryption->remove_payload(encryption)))
+       {
+               type = encrypted->get_type(encrypted);
+               if (previous)
+               {
+                       previous->set_next_type(previous, type);
+               }
+               else
+               {
+                       this->first_payload = type;
+               }
+               DBG2(DBG_ENC, "insert decrypted payload of type %N at end of list",
+                        payload_type_names, type);
+               this->payloads->insert_last(this->payloads, encrypted);
+               previous = encrypted;
+       }
+       return SUCCESS;
+}
+
+/**
  * Decrypt payload from the encryption payload
  */
 static status_t decrypt_payloads(private_message_t *this, keymat_t *keymat)
 {
-       bool was_encrypted = FALSE;
        payload_t *payload, *previous = NULL;
        enumerator_t *enumerator;
        payload_rule_t *rule;
        payload_type_t type;
-       aead_t *aead;
        status_t status = SUCCESS;
+       bool was_encrypted = FALSE;
 
        enumerator = this->payloads->create_enumerator(this->payloads);
        while (enumerator->enumerate(enumerator, &payload))
@@ -1839,97 +1923,35 @@ static status_t decrypt_payloads(private_message_t *this, keymat_t *keymat)
                if (type == ENCRYPTED || type == ENCRYPTED_V1)
                {
                        encryption_payload_t *encryption;
-                       payload_t *encrypted;
-                       chunk_t chunk;
-                       size_t bs;
-
-                       encryption = (encryption_payload_t*)payload;
 
-                       DBG2(DBG_ENC, "found an encryption payload");
-
-                       if (this->payloads->has_more(this->payloads, enumerator))
+                       if (was_encrypted)
                        {
-                               DBG1(DBG_ENC, "encrypted payload is not last payload");
+                               DBG1(DBG_ENC, "encrypted payload can't contain other payloads "
+                                        "of type %N", payload_type_names, type);
                                status = VERIFY_ERROR;
                                break;
                        }
-                       if (!keymat)
-                       {
-                               DBG1(DBG_ENC, "found encryption payload, but no keymat");
-                               status = INVALID_ARG;
-                               break;
-                       }
-                       aead = keymat->get_aead(keymat, TRUE);
-                       if (!aead)
-                       {
-                               DBG1(DBG_ENC, "found encryption payload, but no transform set");
-                               status = INVALID_ARG;
-                               break;
-                       }
-                       bs = aead->get_block_size(aead);
-                       encryption->set_transform(encryption, aead);
-                       chunk = this->packet->get_data(this->packet);
-                       if (chunk.len < encryption->get_length(encryption) ||
-                               chunk.len < bs)
+
+                       DBG2(DBG_ENC, "found an encrypted payload");
+                       encryption = (encryption_payload_t*)payload;
+                       this->payloads->remove_at(this->payloads, enumerator);
+
+                       if (enumerator->enumerate(enumerator, NULL))
                        {
-                               DBG1(DBG_ENC, "invalid payload length");
+                               DBG1(DBG_ENC, "encrypted payload is not last payload");
+                               encryption->destroy(encryption);
                                status = VERIFY_ERROR;
                                break;
                        }
-                       if (keymat->get_version(keymat) == IKEV1)
-                       {       /* instead of associated data we provide the IV, we also update
-                                * the IV with the last encrypted block */
-                               keymat_v1_t *keymat_v1 = (keymat_v1_t*)keymat;
-                               chunk_t iv;
-
-                               if (keymat_v1->get_iv(keymat_v1, this->message_id, &iv))
-                               {
-                                       status = encryption->decrypt(encryption, iv);
-                                       if (status == SUCCESS)
-                                       {
-                                               if (!keymat_v1->update_iv(keymat_v1, this->message_id,
-                                                               chunk_create(chunk.ptr + chunk.len - bs, bs)))
-                                               {
-                                                       status = FAILED;
-                                               }
-                                       }
-                               }
-                               else
-                               {
-                                       status = FAILED;
-                               }
-                       }
-                       else
-                       {
-                               chunk.len -= encryption->get_length(encryption);
-                               status = encryption->decrypt(encryption, chunk);
-                       }
+                       status = decrypt_and_extract(this, keymat, previous, encryption);
+                       encryption->destroy(encryption);
                        if (status != SUCCESS)
                        {
                                break;
                        }
-
                        was_encrypted = TRUE;
-                       this->payloads->remove_at(this->payloads, enumerator);
-
-                       while ((encrypted = encryption->remove_payload(encryption)))
-                       {
-                               type = encrypted->get_type(encrypted);
-                               if (previous)
-                               {
-                                       previous->set_next_type(previous, type);
-                               }
-                               else
-                               {
-                                       this->first_payload = type;
-                               }
-                               DBG2(DBG_ENC, "insert decrypted payload of type "
-                                        "%N at end of list", payload_type_names, type);
-                               this->payloads->insert_last(this->payloads, encrypted);
-                               previous = encrypted;
-                       }
-                       encryption->destroy(encryption);
                }
+
                if (payload_is_known(type) && !was_encrypted &&
                        !is_connectivity_check(this, payload) &&
                        this->exchange_type != AGGRESSIVE)
index dbbc2a9..a176e5a 100644 (file)
@@ -168,16 +168,6 @@ METHOD(linked_list_t, reset_enumerator, void,
        enumerator->finished = FALSE;
 }
 
-METHOD(linked_list_t, has_more, bool,
-       private_linked_list_t *this, private_enumerator_t *enumerator)
-{
-       if (enumerator->current)
-       {
-               return enumerator->current->next != NULL;
-       }
-       return !enumerator->finished && this->first != NULL;
-}
-
 METHOD(linked_list_t, get_count, int,
        private_linked_list_t *this)
 {
@@ -500,7 +490,6 @@ linked_list_t *linked_list_create()
                        .get_count = _get_count,
                        .create_enumerator = _create_enumerator,
                        .reset_enumerator = (void*)_reset_enumerator,
-                       .has_more = (void*)_has_more,
                        .get_first = _get_first,
                        .get_last = _get_last,
                        .find_first = (void*)_find_first,
index bc77765..abc33c1 100644 (file)
@@ -78,15 +78,6 @@ struct linked_list_t {
        void (*reset_enumerator)(linked_list_t *this, enumerator_t *enumerator);
 
        /**
-        * Checks if there are more elements following after the enumerator's
-        * current position.
-        *
-        * @param enumerator    enumerator to check
-        * @return                              TRUE if more elements follow after the current item
-        */
-       bool (*has_more)(linked_list_t *this, enumerator_t *enumerator);
-
-       /**
         * Inserts a new item at the beginning of the list.
         *
         * @param item          item value to insert in list
index 93d814b..48d6f40 100644 (file)
@@ -97,50 +97,6 @@ START_TEST(test_reset_enumerator)
 }
 END_TEST
 
-START_TEST(test_has_more_empty)
-{
-       enumerator_t *enumerator;
-       intptr_t x;
-
-       list->destroy(list);
-       list = linked_list_create();
-       enumerator = list->create_enumerator(list);
-       ck_assert(!list->has_more(list, enumerator));
-       ck_assert(!enumerator->enumerate(enumerator, &x));
-       ck_assert(!list->has_more(list, enumerator));
-       enumerator->destroy(enumerator);
-}
-END_TEST
-
-START_TEST(test_has_more)
-{
-       enumerator_t *enumerator;
-       intptr_t x;
-       int round;
-
-       round = 1;
-       enumerator = list->create_enumerator(list);
-       while (enumerator->enumerate(enumerator, &x))
-       {
-               ck_assert_int_eq(round, x);
-               round++;
-               if (x == 2)
-               {
-                       break;
-               }
-       }
-       ck_assert(list->has_more(list, enumerator));
-       while (enumerator->enumerate(enumerator, &x))
-       {
-               ck_assert_int_eq(round, x);
-               round++;
-       }
-       ck_assert(!list->has_more(list, enumerator));
-       ck_assert_int_eq(round, 6);
-       enumerator->destroy(enumerator);
-}
-END_TEST
-
 /*******************************************************************************
  * insert before
  */
@@ -202,7 +158,6 @@ START_TEST(test_insert_before_ends)
        ck_assert_int_eq(list->get_count(list), 7);
        ck_assert(list->get_last(list, (void*)&x) == SUCCESS);
        ck_assert_int_eq(x, 6);
-       ck_assert(!list->has_more(list, enumerator));
        ck_assert(!enumerator->enumerate(enumerator, &x));
        enumerator->destroy(enumerator);
 }
@@ -222,10 +177,9 @@ START_TEST(test_insert_before_empty)
        ck_assert_int_eq(x, 1);
        ck_assert(list->get_last(list, (void*)&x) == SUCCESS);
        ck_assert_int_eq(x, 1);
-       ck_assert(list->has_more(list, enumerator));
        ck_assert(enumerator->enumerate(enumerator, &x));
        ck_assert_int_eq(x, 1);
-       ck_assert(!list->has_more(list, enumerator));
+       ck_assert(!enumerator->enumerate(enumerator, NULL));
        enumerator->destroy(enumerator);
 }
 END_TEST
@@ -382,8 +336,6 @@ Suite *linked_list_enumerator_suite_create()
        tcase_add_test(tc, test_enumerate);
        tcase_add_test(tc, test_enumerate_null);
        tcase_add_test(tc, test_reset_enumerator);
-       tcase_add_test(tc, test_has_more_empty);
-       tcase_add_test(tc, test_has_more);
        suite_add_tcase(s, tc);
 
        tc = tcase_create("insert_before()");