Refactored traffic selector handling in quick mode
authorMartin Willi <martin@revosec.ch>
Thu, 24 Nov 2011 09:20:59 +0000 (10:20 +0100)
committerMartin Willi <martin@revosec.ch>
Tue, 20 Mar 2012 16:30:51 +0000 (17:30 +0100)
src/libcharon/sa/tasks/quick_mode.c

index 70bde90..0675fe4 100644 (file)
@@ -241,6 +241,137 @@ static bool get_nonce(private_quick_mode_t *this, chunk_t *nonce,
        return TRUE;
 }
 
+/**
+ * Select a traffic selector from configuration
+ */
+static traffic_selector_t* select_ts(private_quick_mode_t *this, bool initiator)
+{
+       traffic_selector_t *ts;
+       linked_list_t *list;
+       host_t *host;
+
+       if (initiator)
+       {
+               host = this->ike_sa->get_my_host(this->ike_sa);
+       }
+       else
+       {
+               host = this->ike_sa->get_other_host(this->ike_sa);
+       }
+       list = this->config->get_traffic_selectors(this->config, initiator,
+                                                                                          NULL, host);
+       if (list->get_first(list, (void**)&ts) == SUCCESS)
+       {
+               if (list->get_count(list) > 1)
+               {
+                       DBG1(DBG_IKE, "configuration has more than one %s traffic selector,"
+                                " using first IKEv1", initiator ? "initiator" : "responder");
+               }
+               ts = ts->clone(ts);
+       }
+       else
+       {
+               DBG1(DBG_IKE, "%s traffic selector missing in configuration",
+                        initiator ? "initiator" : "responder");
+               ts = NULL;
+       }
+       list->destroy_offset(list, offsetof(traffic_selector_t, destroy));
+       return ts;
+}
+
+/**
+ * Add selected traffic selectors to message
+ */
+static void add_ts(private_quick_mode_t *this, message_t *message)
+{
+       id_payload_t *id_payload;
+
+       id_payload = id_payload_create_from_ts(this->tsi);
+       message->add_payload(message, &id_payload->payload_interface);
+       id_payload = id_payload_create_from_ts(this->tsr);
+       message->add_payload(message, &id_payload->payload_interface);
+}
+
+/**
+ * Get traffic selectors from received message
+ */
+static bool get_ts(private_quick_mode_t *this, message_t *message,
+                                  bool initiator)
+{
+       traffic_selector_t *tsi = NULL, *tsr = NULL;
+       enumerator_t *enumerator;
+       id_payload_t *id_payload;
+       payload_t *payload;
+       host_t *hsi, *hsr;
+       bool first = TRUE;
+
+       enumerator = message->create_payload_enumerator(message);
+       while (enumerator->enumerate(enumerator, &payload))
+       {
+               if (payload->get_type(payload) == ID_V1)
+               {
+                       id_payload = (id_payload_t*)payload;
+
+                       if (first)
+                       {
+                               tsi = id_payload->get_ts(id_payload);
+                               first = FALSE;
+                       }
+                       else
+                       {
+                               tsr = id_payload->get_ts(id_payload);
+                               break;
+                       }
+               }
+       }
+       enumerator->destroy(enumerator);
+
+       /* create host2host selectors if ID payloads missing */
+       if (initiator)
+       {
+               hsi = this->ike_sa->get_my_host(this->ike_sa);
+               hsr = this->ike_sa->get_other_host(this->ike_sa);
+       }
+       else
+       {
+               hsr = this->ike_sa->get_my_host(this->ike_sa);
+               hsi = this->ike_sa->get_other_host(this->ike_sa);
+       }
+       if (!tsi)
+       {
+               tsi = traffic_selector_create_from_subnet(hsi->clone(hsi),
+                                                       hsi->get_family(hsi) == AF_INET ? 32 : 128, 0, 0);
+       }
+       if (!tsr)
+       {
+               tsr = traffic_selector_create_from_subnet(hsr->clone(hsr),
+                                                       hsr->get_family(hsr) == AF_INET ? 32 : 128, 0, 0);
+       }
+       if (initiator)
+       {
+               /* check if peer selection valid */
+               if (!tsr->is_contained_in(tsr, this->tsr) ||
+                       !tsi->is_contained_in(tsi, this->tsi))
+               {
+                       DBG1(DBG_IKE, "peer selected invalid traffic selectors: ",
+                                "%R for %R, %R for %R", tsi, this->tsi, tsr, this->tsr);
+                       tsi->destroy(tsi);
+                       tsr->destroy(tsr);
+                       return FALSE;
+               }
+               this->tsi->destroy(this->tsi);
+               this->tsr->destroy(this->tsr);
+               this->tsi = tsi;
+               this->tsr = tsr;
+       }
+       else
+       {
+               this->tsi = tsi;
+               this->tsr = tsr;
+       }
+       return TRUE;
+}
+
 METHOD(task_t, build_i, status_t,
        private_quick_mode_t *this, message_t *message)
 {
@@ -250,8 +381,6 @@ METHOD(task_t, build_i, status_t,
                {
                        enumerator_t *enumerator;
                        sa_payload_t *sa_payload;
-                       id_payload_t *id_payload;
-                       traffic_selector_t *ts;
                        linked_list_t *list;
                        proposal_t *proposal;
 
@@ -284,33 +413,13 @@ METHOD(task_t, build_i, status_t,
                        {
                                return FAILED;
                        }
-
-                       list = this->config->get_traffic_selectors(this->config, TRUE, NULL,
-                                                                       this->ike_sa->get_my_host(this->ike_sa));
-                       if (list->get_first(list, (void**)&ts) != SUCCESS)
-                       {
-                               list->destroy_offset(list, offsetof(traffic_selector_t, destroy));
-                               DBG1(DBG_IKE, "traffic selector missing");
-                               return FAILED;
-                       }
-                       id_payload = id_payload_create_from_ts(ts);
-                       this->tsi = ts->clone(ts);
-                       list->destroy_offset(list, offsetof(traffic_selector_t, destroy));
-                       message->add_payload(message, &id_payload->payload_interface);
-
-                       list = this->config->get_traffic_selectors(this->config, FALSE, NULL,
-                                                                       this->ike_sa->get_other_host(this->ike_sa));
-                       if (list->get_first(list, (void**)&ts) != SUCCESS)
+                       this->tsi = select_ts(this, TRUE);
+                       this->tsr = select_ts(this, FALSE);
+                       if (!this->tsi || !this->tsr)
                        {
-                               list->destroy_offset(list, offsetof(traffic_selector_t, destroy));
-                               DBG1(DBG_IKE, "traffic selector missing");
                                return FAILED;
                        }
-                       id_payload = id_payload_create_from_ts(ts);
-                       this->tsr = ts->clone(ts);
-                       list->destroy_offset(list, offsetof(traffic_selector_t, destroy));
-                       message->add_payload(message, &id_payload->payload_interface);
-
+                       add_ts(this, message);
                        return NEED_MORE;
                }
                case QM_NEGOTIATED:
@@ -330,48 +439,14 @@ METHOD(task_t, process_r, status_t,
                case QM_INIT:
                {
                        sa_payload_t *sa_payload;
-                       id_payload_t *id_payload;
-                       payload_t *payload;
                        linked_list_t *tsi, *tsr, *list;
                        peer_cfg_t *peer_cfg;
-                       host_t *me, *other, *host;
-                       enumerator_t *enumerator;
-                       bool first = TRUE;
-
-                       enumerator = message->create_payload_enumerator(message);
-                       while (enumerator->enumerate(enumerator, &payload))
-                       {
-                               if (payload->get_type(payload) == ID_V1)
-                               {
-                                       id_payload = (id_payload_t*)payload;
-
-                                       if (first)
-                                       {
-                                               this->tsi = id_payload->get_ts(id_payload);
-                                               first = FALSE;
-                                       }
-                                       else
-                                       {
-                                               this->tsr = id_payload->get_ts(id_payload);
-                                               break;
-                                       }
-                               }
-                       }
-                       enumerator->destroy(enumerator);
+                       host_t *me, *other;
 
-                       if (!this->tsi)
+                       if (!get_ts(this, message, FALSE))
                        {
-                               host = this->ike_sa->get_other_host(this->ike_sa);
-                               this->tsi = traffic_selector_create_from_subnet(host->clone(host),
-                                               host->get_family(host) == AF_INET ? 32 : 128, 0, 0);
-                       }
-                       if (!this->tsr)
-                       {
-                               host = this->ike_sa->get_my_host(this->ike_sa);
-                               this->tsr = traffic_selector_create_from_subnet(host->clone(host),
-                                               host->get_family(host) == AF_INET ? 32 : 128, 0, 0);
+                               return FAILED;
                        }
-
                        me = this->ike_sa->get_virtual_ip(this->ike_sa, TRUE);
                        if (!me)
                        {
@@ -446,7 +521,6 @@ METHOD(task_t, build_r, status_t,
                case QM_INIT:
                {
                        sa_payload_t *sa_payload;
-                       id_payload_t *id_payload;
 
                        this->spi_r = this->child_sa->alloc_spi(this->child_sa, PROTO_ESP);
                        if (!this->spi_r)
@@ -464,11 +538,7 @@ METHOD(task_t, build_r, status_t,
                        {
                                return FAILED;
                        }
-
-                       id_payload = id_payload_create_from_ts(this->tsi);
-                       message->add_payload(message, &id_payload->payload_interface);
-                       id_payload = id_payload_create_from_ts(this->tsr);
-                       message->add_payload(message, &id_payload->payload_interface);
+                       add_ts(this, message);
 
                        this->state = QM_NEGOTIATED;
                        return NEED_MORE;
@@ -486,60 +556,7 @@ METHOD(task_t, process_i, status_t,
                case QM_INIT:
                {
                        sa_payload_t *sa_payload;
-                       id_payload_t *id_payload;
-                       payload_t *payload;
-                       traffic_selector_t *tsi = NULL, *tsr = NULL;
                        linked_list_t *list;
-                       enumerator_t *enumerator;
-                       host_t *host;
-                       bool first = TRUE;
-
-                       enumerator = message->create_payload_enumerator(message);
-                       while (enumerator->enumerate(enumerator, &payload))
-                       {
-                               if (payload->get_type(payload) == ID_V1)
-                               {
-                                       id_payload = (id_payload_t*)payload;
-
-                                       if (first)
-                                       {
-                                               tsi = id_payload->get_ts(id_payload);
-                                               first = FALSE;
-                                       }
-                                       else
-                                       {
-                                               tsr = id_payload->get_ts(id_payload);
-                                               break;
-                                       }
-                               }
-                       }
-                       enumerator->destroy(enumerator);
-
-                       if (!tsr)
-                       {
-                               host = this->ike_sa->get_other_host(this->ike_sa);
-                               tsr = traffic_selector_create_from_subnet(host->clone(host),
-                                               host->get_family(host) == AF_INET ? 32 : 128, 0, 0);
-                       }
-                       if (!tsi)
-                       {
-                               host = this->ike_sa->get_my_host(this->ike_sa);
-                               tsi = traffic_selector_create_from_subnet(host->clone(host),
-                                               host->get_family(host) == AF_INET ? 32 : 128, 0, 0);
-                       }
-
-                       if (!tsr->is_contained_in(tsr, this->tsr) ||
-                               !tsi->is_contained_in(tsi, this->tsi))
-                       {
-                               tsi->destroy(tsi);
-                               tsr->destroy(tsr);
-                               DBG1(DBG_IKE, "TS mismatch");
-                               return FAILED;
-                       }
-                       this->tsi->destroy(this->tsi);
-                       this->tsr->destroy(this->tsr);
-                       this->tsi = tsi;
-                       this->tsr = tsr;
 
                        sa_payload = (sa_payload_t*)message->get_payload(message,
                                                                                                        SECURITY_ASSOCIATION_V1);
@@ -563,6 +580,10 @@ METHOD(task_t, process_i, status_t,
                        {
                                return FAILED;
                        }
+                       if (!get_ts(this, message, TRUE))
+                       {
+                               return FAILED;
+                       }
                        if (!install(this))
                        {
                                return FAILED;