Fix init message arrival check.
[strongswan.git] / src / libcharon / sa / ike_sa_manager.c
index d695c7f..6b2d173 100644 (file)
@@ -307,72 +307,72 @@ struct private_ike_sa_manager_t {
        /**
         * Public interface of ike_sa_manager_t.
         */
-        ike_sa_manager_t public;
-
-        /**
-         * Hash table with entries for the ike_sa_t objects.
-         */
-        linked_list_t **ike_sa_table;
-
-        /**
-         * The size of the hash table.
-         */
-        u_int table_size;
-
-        /**
-         * Mask to map the hashes to table rows.
-         */
-        u_int table_mask;
-
-        /**
-         * Segments of the hash table.
-         */
-        segment_t *segments;
-
-        /**
-         * The number of segments.
-         */
-        u_int segment_count;
-
-        /**
-         * Mask to map a table row to a segment.
-         */
-        u_int segment_mask;
-
-        /**
-         * Hash table with half_open_t objects.
-         */
-        linked_list_t **half_open_table;
-
-        /**
+       ike_sa_manager_t public;
+
+       /**
+        * Hash table with entries for the ike_sa_t objects.
+        */
+       linked_list_t **ike_sa_table;
+
+       /**
+        * The size of the hash table.
+        */
+       u_int table_size;
+
+       /**
+        * Mask to map the hashes to table rows.
+        */
+       u_int table_mask;
+
+       /**
+        * Segments of the hash table.
+        */
+       segment_t *segments;
+
+       /**
+        * The number of segments.
+        */
+       u_int segment_count;
+
+       /**
+        * Mask to map a table row to a segment.
+        */
+       u_int segment_mask;
+
+       /**
+        * Hash table with half_open_t objects.
+        */
+       linked_list_t **half_open_table;
+
+       /**
          * Segments of the "half-open" hash table.
-         */
-        shareable_segment_t *half_open_segments;
+        */
+       shareable_segment_t *half_open_segments;
 
-        /**
-         * Hash table with connected_peers_t objects.
-         */
-        linked_list_t **connected_peers_table;
+       /**
+        * Hash table with connected_peers_t objects.
+        */
+       linked_list_t **connected_peers_table;
 
-        /**
-         * Segments of the "connected peers" hash table.
-         */
-        shareable_segment_t *connected_peers_segments;
+       /**
+        * Segments of the "connected peers" hash table.
+        */
+       shareable_segment_t *connected_peers_segments;
 
-        /**
-         * RNG to get random SPIs for our side
-         */
-        rng_t *rng;
+       /**
+        * RNG to get random SPIs for our side
+        */
+       rng_t *rng;
 
-        /**
-         * SHA1 hasher for IKE_SA_INIT retransmit detection
-         */
-        hasher_t *hasher;
+       /**
+        * SHA1 hasher for IKE_SA_INIT retransmit detection
+        */
+       hasher_t *hasher;
 
        /**
         * reuse existing IKE_SAs in checkout_by_config
         */
-        bool reuse_ikesa;
+       bool reuse_ikesa;
 };
 
 /**
@@ -970,6 +970,7 @@ METHOD(ike_sa_manager_t, checkout_by_message, ike_sa_t*,
        entry_t *entry;
        ike_sa_t *ike_sa = NULL;
        ike_sa_id_t *id;
+       bool is_init = FALSE;
 
        id = message->get_ike_sa_id(message);
        id = id->clone(id);
@@ -977,11 +978,29 @@ METHOD(ike_sa_manager_t, checkout_by_message, ike_sa_t*,
 
        DBG2(DBG_MGR, "checkout IKE_SA by message");
 
-       if (message->get_request(message) &&
-               message->get_exchange_type(message) == IKE_SA_INIT &&
-               this->hasher)
+       if (id->get_responder_spi(id) == 0)
+       {
+               if (message->get_major_version(message) == IKEV2_MAJOR_VERSION)
+               {
+                       if (message->get_exchange_type(message) == IKE_SA_INIT &&
+                               message->get_request(message))
+                       {
+                               is_init = TRUE;
+                       }
+               }
+               else
+               {
+                       if (message->get_exchange_type(message) == ID_PROT ||
+                               message->get_exchange_type(message) == AGGRESSIVE)
+                       {
+                               is_init = TRUE;
+                       }
+               }
+       }
+
+       if (is_init && this->hasher)
        {
-               /* IKE_SA_INIT request. Check for an IKE_SA with such a message hash. */
+               /* First request. Check for an IKE_SA with such a message hash. */
                chunk_t data, hash;
 
                data = message->get_packet_data(message);
@@ -990,7 +1009,8 @@ METHOD(ike_sa_manager_t, checkout_by_message, ike_sa_t*,
 
                if (get_entry_by_hash(this, id, hash, &entry, &segment) == SUCCESS)
                {
-                       if (entry->message_id == 0)
+                       if (message->get_exchange_type(message) == IKE_SA_INIT &&
+                               entry->message_id == 0)
                        {
                                unlock_single_segment(this, segment);
                                chunk_free(&hash);
@@ -1011,35 +1031,27 @@ METHOD(ike_sa_manager_t, checkout_by_message, ike_sa_t*,
 
                if (ike_sa == NULL)
                {
-                       if (id->get_responder_spi(id) == 0 &&
-                               message->get_exchange_type(message) == IKE_SA_INIT)
-                       {
-                               /* no IKE_SA found, create a new one */
-                               id->set_responder_spi(id, get_spi(this));
-                               entry = entry_create();
-                               entry->ike_sa = ike_sa_create(id);
-                               entry->ike_sa_id = id->clone(id);
+                       /* no IKE_SA found, create a new one */
+                       id->set_responder_spi(id, get_spi(this));
+                       entry = entry_create();
+                       entry->ike_sa = ike_sa_create(id);
+                       entry->ike_sa_id = id->clone(id);
 
-                               segment = put_entry(this, entry);
-                               entry->checked_out = TRUE;
-                               unlock_single_segment(this, segment);
+                       segment = put_entry(this, entry);
+                       entry->checked_out = TRUE;
+                       unlock_single_segment(this, segment);
 
-                               entry->message_id = message->get_message_id(message);
-                               entry->init_hash = hash;
-                               ike_sa = entry->ike_sa;
+                       entry->message_id = message->get_message_id(message);
+                       entry->init_hash = hash;
+                       ike_sa = entry->ike_sa;
 
-                               DBG2(DBG_MGR, "created IKE_SA %s[%u]",
-                                               ike_sa->get_name(ike_sa), ike_sa->get_unique_id(ike_sa));
-                       }
-                       else
-                       {
-                               chunk_free(&hash);
-                               DBG1(DBG_MGR, "ignoring message, no such IKE_SA");
-                       }
+                       DBG2(DBG_MGR, "created IKE_SA %s[%u]",
+                                       ike_sa->get_name(ike_sa), ike_sa->get_unique_id(ike_sa));
                }
                else
                {
                        chunk_free(&hash);
+                       DBG1(DBG_MGR, "ignoring message, no such IKE_SA");
                }
                id->destroy(id);
                charon->bus->set_sa(charon->bus, ike_sa);
@@ -1048,7 +1060,7 @@ METHOD(ike_sa_manager_t, checkout_by_message, ike_sa_t*,
 
        if (get_entry_by_id(this, id, &entry, &segment) == SUCCESS)
        {
-               /* only check out if we are not processing this request */
+               /* only check out in IKEv2 if we are not already processing it */
                if (message->get_request(message) &&
                        message->get_message_id(message) == entry->message_id)
                {
@@ -1057,7 +1069,9 @@ METHOD(ike_sa_manager_t, checkout_by_message, ike_sa_t*,
                }
                else if (wait_for_entry(this, entry, segment))
                {
-                       ike_sa_id_t *ike_id = entry->ike_sa->get_id(entry->ike_sa);
+                       ike_sa_id_t *ike_id;
+
+                       ike_id = entry->ike_sa->get_id(entry->ike_sa);
                        entry->checked_out = TRUE;
                        entry->message_id = message->get_message_id(message);
                        if (ike_id->get_responder_spi(ike_id) == 0)
@@ -1134,8 +1148,7 @@ METHOD(ike_sa_manager_t, checkout_by_config, ike_sa_t*,
 METHOD(ike_sa_manager_t, checkout_by_id, ike_sa_t*,
        private_ike_sa_manager_t *this, u_int32_t id, bool child)
 {
-       enumerator_t *enumerator;
-       iterator_t *children;
+       enumerator_t *enumerator, *children;
        entry_t *entry;
        ike_sa_t *ike_sa = NULL;
        child_sa_t *child_sa;
@@ -1151,8 +1164,8 @@ METHOD(ike_sa_manager_t, checkout_by_id, ike_sa_t*,
                        /* look for a child with such a reqid ... */
                        if (child)
                        {
-                               children = entry->ike_sa->create_child_sa_iterator(entry->ike_sa);
-                               while (children->iterate(children, (void**)&child_sa))
+                               children = entry->ike_sa->create_child_sa_enumerator(entry->ike_sa);
+                               while (children->enumerate(children, (void**)&child_sa))
                                {
                                        if (child_sa->get_reqid(child_sa) == id)
                                        {
@@ -1188,8 +1201,7 @@ METHOD(ike_sa_manager_t, checkout_by_id, ike_sa_t*,
 METHOD(ike_sa_manager_t, checkout_by_name, ike_sa_t*,
        private_ike_sa_manager_t *this, char *name, bool child)
 {
-       enumerator_t *enumerator;
-       iterator_t *children;
+       enumerator_t *enumerator, *children;
        entry_t *entry;
        ike_sa_t *ike_sa = NULL;
        child_sa_t *child_sa;
@@ -1203,8 +1215,8 @@ METHOD(ike_sa_manager_t, checkout_by_name, ike_sa_t*,
                        /* look for a child with such a policy name ... */
                        if (child)
                        {
-                               children = entry->ike_sa->create_child_sa_iterator(entry->ike_sa);
-                               while (children->iterate(children, (void**)&child_sa))
+                               children = entry->ike_sa->create_child_sa_enumerator(entry->ike_sa);
+                               while (children->enumerate(children, (void**)&child_sa))
                                {
                                        if (streq(child_sa->get_name(child_sa), name))
                                        {
@@ -1238,10 +1250,10 @@ METHOD(ike_sa_manager_t, checkout_by_name, ike_sa_t*,
 }
 
 /**
- * enumerator filter function
+ * enumerator filter function, waiting variant
  */
-static bool enumerator_filter(private_ike_sa_manager_t *this,
-                                                         entry_t **in, ike_sa_t **out, u_int *segment)
+static bool enumerator_filter_wait(private_ike_sa_manager_t *this,
+                                                                  entry_t **in, ike_sa_t **out, u_int *segment)
 {
        if (wait_for_entry(this, *in, *segment))
        {
@@ -1251,11 +1263,28 @@ static bool enumerator_filter(private_ike_sa_manager_t *this,
        return FALSE;
 }
 
+/**
+ * enumerator filter function, skipping variant
+ */
+static bool enumerator_filter_skip(private_ike_sa_manager_t *this,
+                                                                  entry_t **in, ike_sa_t **out, u_int *segment)
+{
+       if (!(*in)->driveout_new_threads &&
+               !(*in)->driveout_waiting_threads &&
+               !(*in)->checked_out)
+       {
+               *out = (*in)->ike_sa;
+               return TRUE;
+       }
+       return FALSE;
+}
+
 METHOD(ike_sa_manager_t, create_enumerator, enumerator_t*,
-       private_ike_sa_manager_t* this)
+       private_ike_sa_manager_t* this, bool wait)
 {
        return enumerator_create_filter(create_table_enumerator(this),
-                                                                       (void*)enumerator_filter, this, NULL);
+                       wait ? (void*)enumerator_filter_wait : (void*)enumerator_filter_skip,
+                       this, NULL);
 }
 
 METHOD(ike_sa_manager_t, checkin, void,
@@ -1539,14 +1568,30 @@ METHOD(ike_sa_manager_t, has_contact, bool,
        return found;
 }
 
-METHOD(ike_sa_manager_t, get_half_open_count, int,
+METHOD(ike_sa_manager_t, get_count, u_int,
+       private_ike_sa_manager_t *this)
+{
+       u_int segment, count = 0;
+       mutex_t *mutex;
+
+       for (segment = 0; segment < this->segment_count; segment++)
+       {
+               mutex = this->segments[segment & this->segment_mask].mutex;
+               mutex->lock(mutex);
+               count += this->segments[segment].count;
+               mutex->unlock(mutex);
+       }
+       return count;
+}
+
+METHOD(ike_sa_manager_t, get_half_open_count, u_int,
        private_ike_sa_manager_t *this, host_t *ip)
 {
        linked_list_t *list;
        u_int segment, row;
        rwlock_t *lock;
        chunk_t addr;
-       int count = 0;
+       u_int count = 0;
 
        if (ip)
        {
@@ -1728,6 +1773,7 @@ ike_sa_manager_t *ike_sa_manager_create()
                        .create_enumerator = _create_enumerator,
                        .checkin = _checkin,
                        .checkin_and_destroy = _checkin_and_destroy,
+                       .get_count = _get_count,
                        .get_half_open_count = _get_half_open_count,
                        .flush = _flush,
                        .destroy = _destroy,