ikev1: keep vendor ID task alive during full Main/Aggressive Mode
authorMartin Willi <martin@revosec.ch>
Thu, 6 Jun 2013 13:38:38 +0000 (15:38 +0200)
committerMartin Willi <martin@revosec.ch>
Tue, 11 Jun 2013 13:54:27 +0000 (15:54 +0200)
Fixes DPD with Cisco IOS sending the DPD vendor ID not in the first message.

src/libcharon/sa/ikev1/tasks/isakmp_vendor.c

index 2ff2b55..11155b2 100644 (file)
@@ -67,6 +67,11 @@ struct private_isakmp_vendor_t {
         * Index of best nat traversal VID found
         */
        int best_natt_ext;
+
+       /**
+        * Number of times we have been invoked
+        */
+       int count;
 };
 
 /**
@@ -175,8 +180,10 @@ static bool fragmentation_supported(chunk_t data, int i)
        return FALSE;
 }
 
-METHOD(task_t, build, status_t,
-       private_isakmp_vendor_t *this, message_t *message)
+/**
+ * Add supported vendor ID payloads
+ */
+static void build(private_isakmp_vendor_t *this, message_t *message)
 {
        vendor_id_payload_t *vid_payload;
        bool strongswan, cisco_unity, fragmentation;
@@ -219,11 +226,12 @@ METHOD(task_t, build, status_t,
                        message->add_payload(message, &vid_payload->payload_interface);
                }
        }
-       return this->initiator ? NEED_MORE : SUCCESS;
 }
 
-METHOD(task_t, process, status_t,
-       private_isakmp_vendor_t *this, message_t *message)
+/**
+ * Process vendor ID payloads
+ */
+static void process(private_isakmp_vendor_t *this, message_t *message)
 {
        enumerator_t *enumerator;
        payload_t *payload;
@@ -289,14 +297,64 @@ METHOD(task_t, process, status_t,
                this->ike_sa->enable_extension(this->ike_sa,
                                                                vendor_natt_ids[this->best_natt_ext].extension);
        }
+}
 
-       return this->initiator ? SUCCESS : NEED_MORE;
+METHOD(task_t, build_i, status_t,
+       private_isakmp_vendor_t *this, message_t *message)
+{
+       if (this->count++ == 0)
+       {
+               build(this, message);
+       }
+       if (message->get_exchange_type(message) == AGGRESSIVE && this->count > 1)
+       {
+               return SUCCESS;
+       }
+       return NEED_MORE;
+}
+
+METHOD(task_t, process_r, status_t,
+       private_isakmp_vendor_t *this, message_t *message)
+{
+       this->count++;
+       process(this, message);
+       if (message->get_exchange_type(message) == AGGRESSIVE && this->count > 1)
+       {
+               return SUCCESS;
+       }
+       return NEED_MORE;
+}
+
+METHOD(task_t, build_r, status_t,
+       private_isakmp_vendor_t *this, message_t *message)
+{
+       if (this->count == 1)
+       {
+               build(this, message);
+       }
+       if (message->get_exchange_type(message) == ID_PROT && this->count > 2)
+       {
+               return SUCCESS;
+       }
+       return NEED_MORE;
+}
+
+METHOD(task_t, process_i, status_t,
+       private_isakmp_vendor_t *this, message_t *message)
+{
+       process(this, message);
+       if (message->get_exchange_type(message) == ID_PROT && this->count > 2)
+       {
+               return SUCCESS;
+       }
+       return NEED_MORE;
 }
 
 METHOD(task_t, migrate, void,
        private_isakmp_vendor_t *this, ike_sa_t *ike_sa)
 {
        this->ike_sa = ike_sa;
+       this->count = 0;
 }
 
 METHOD(task_t, get_type, task_type_t,
@@ -321,8 +379,6 @@ isakmp_vendor_t *isakmp_vendor_create(ike_sa_t *ike_sa, bool initiator)
        INIT(this,
                .public = {
                        .task = {
-                               .build = _build,
-                               .process = _process,
                                .migrate = _migrate,
                                .get_type = _get_type,
                                .destroy = _destroy,
@@ -333,5 +389,16 @@ isakmp_vendor_t *isakmp_vendor_create(ike_sa_t *ike_sa, bool initiator)
                .best_natt_ext = -1,
        );
 
+       if (initiator)
+       {
+               this->public.task.build = _build_i;
+               this->public.task.process = _process_i;
+       }
+       else
+       {
+               this->public.task.build = _build_r;
+               this->public.task.process = _process_r;
+       }
+
        return &this->public;
 }