Merge branch 'debian-testing'
[strongswan.git] / src / conftest / hooks / rebuild_auth.c
index 30de2c7..b7e6f22 100644 (file)
@@ -15,6 +15,7 @@
 
 #include "hook.h"
 
+#include <sa/ikev2/keymat_v2.h>
 #include <encoding/generator.h>
 #include <encoding/payloads/nonce_payload.h>
 #include <encoding/payloads/auth_payload.h>
@@ -41,6 +42,11 @@ struct private_rebuild_auth_t {
         * Received NONCE, required to rebuild AUTH
         */
        chunk_t nonce;
+
+       /**
+        * ID to use for key lookup, if not from IDi
+        */
+       identification_t *id;
 };
 
 /**
@@ -52,12 +58,11 @@ static bool rebuild_auth(private_rebuild_auth_t *this, ike_sa_t *ike_sa,
        enumerator_t *enumerator;
        chunk_t octets, auth_data;
        private_key_t *private;
-       auth_cfg_t *auth;
        payload_t *payload;
        auth_payload_t *auth_payload;
        auth_method_t auth_method;
        signature_scheme_t scheme;
-       keymat_t *keymat;
+       keymat_v2_t *keymat;
        identification_t *id;
        char reserved[3];
        generator_t *generator;
@@ -85,12 +90,12 @@ static bool rebuild_auth(private_rebuild_auth_t *this, ike_sa_t *ike_sa,
        id = identification_create_from_encoding(data.ptr[4], chunk_skip(data, 8));
        generator->destroy(generator);
 
-       auth = auth_cfg_create();
-       private = lib->credmgr->get_private(lib->credmgr, KEY_ANY, id, auth);
-       auth->destroy(auth);
+       private = lib->credmgr->get_private(lib->credmgr, KEY_ANY,
+                                                                               this->id ?: id, NULL);
        if (private == NULL)
        {
-               DBG1(DBG_CFG, "no private key found for '%Y' to rebuild AUTH", id);
+               DBG1(DBG_CFG, "no private key found for '%Y' to rebuild AUTH",
+                        this->id ?: id);
                id->destroy(id);
                return FALSE;
        }
@@ -130,9 +135,14 @@ static bool rebuild_auth(private_rebuild_auth_t *this, ike_sa_t *ike_sa,
                        id->destroy(id);
                        return FALSE;
        }
-       keymat = ike_sa->get_keymat(ike_sa);
-       octets = keymat->get_auth_octets(keymat, FALSE, this->ike_init,
-                                                                        this->nonce, id, reserved);
+       keymat = (keymat_v2_t*)ike_sa->get_keymat(ike_sa);
+       if (!keymat->get_auth_octets(keymat, FALSE, this->ike_init,
+                                                                this->nonce, id, reserved, &octets))
+       {
+               private->destroy(private);
+               id->destroy(id);
+               return FALSE;
+       }
        if (!private->sign(private, scheme, octets, &auth_data))
        {
                chunk_free(&octets);
@@ -167,34 +177,37 @@ static bool rebuild_auth(private_rebuild_auth_t *this, ike_sa_t *ike_sa,
 
 METHOD(listener_t, message, bool,
        private_rebuild_auth_t *this, ike_sa_t *ike_sa, message_t *message,
-       bool incoming)
+       bool incoming, bool plain)
 {
-       if (!incoming && message->get_message_id(message) == 1)
+       if (plain)
        {
-               rebuild_auth(this, ike_sa, message);
-       }
-       if (message->get_exchange_type(message) == IKE_SA_INIT)
-       {
-               if (incoming)
+               if (!incoming && message->get_message_id(message) == 1)
                {
-                       nonce_payload_t *nonce;
-
-                       nonce = (nonce_payload_t*)message->get_payload(message, NONCE);
-                       if (nonce)
-                       {
-                               free(this->nonce.ptr);
-                               this->nonce = nonce->get_nonce(nonce);
-                       }
+                       rebuild_auth(this, ike_sa, message);
                }
-               else
+               if (message->get_exchange_type(message) == IKE_SA_INIT)
                {
-                       packet_t *packet;
-
-                       if (message->generate(message, NULL, &packet) == SUCCESS)
+                       if (incoming)
                        {
-                               free(this->ike_init.ptr);
-                               this->ike_init = chunk_clone(packet->get_data(packet));
-                               packet->destroy(packet);
+                               nonce_payload_t *nonce;
+
+                               nonce = (nonce_payload_t*)message->get_payload(message, NONCE);
+                               if (nonce)
+                               {
+                                       free(this->nonce.ptr);
+                                       this->nonce = nonce->get_nonce(nonce);
+                               }
+                       }
+                       else
+                       {
+                               packet_t *packet;
+
+                               if (message->generate(message, NULL, &packet) == SUCCESS)
+                               {
+                                       free(this->ike_init.ptr);
+                                       this->ike_init = chunk_clone(packet->get_data(packet));
+                                       packet->destroy(packet);
+                               }
                        }
                }
        }
@@ -206,6 +219,7 @@ METHOD(hook_t, destroy, void,
 {
        free(this->ike_init.ptr);
        free(this->nonce.ptr);
+       DESTROY_IF(this->id);
        free(this);
 }
 
@@ -215,6 +229,7 @@ METHOD(hook_t, destroy, void,
 hook_t *rebuild_auth_hook_create(char *name)
 {
        private_rebuild_auth_t *this;
+       char *id;
 
        INIT(this,
                .hook = {
@@ -224,6 +239,11 @@ hook_t *rebuild_auth_hook_create(char *name)
                        .destroy = _destroy,
                },
        );
+       id = conftest->test->get_str(conftest->test, "hooks.%s.key", NULL, name);
+       if (id)
+       {
+               this->id = identification_create_from_string(id);
+       }
 
        return &this->hook;
 }