Implemented a checkout/checkin mechanism for IPsec SAs
authorTobias Brunner <tobias@strongswan.org>
Fri, 13 Jul 2012 12:32:03 +0000 (14:32 +0200)
committerTobias Brunner <tobias@strongswan.org>
Wed, 8 Aug 2012 13:41:03 +0000 (15:41 +0200)
SAs can only be checked out by a single thread and all other threads
block until the SA is checked in again.

src/libipsec/ipsec_sa_mgr.c
src/libipsec/ipsec_sa_mgr.h

index 3a851ba..e42c77a 100644 (file)
@@ -21,6 +21,7 @@
 #include <debug.h>
 #include <library.h>
 #include <processing/jobs/callback_job.h>
+#include <threading/condvar.h>
 #include <threading/mutex.h>
 #include <utils/hashtable.h>
 #include <utils/linked_list.h>
@@ -59,6 +60,38 @@ struct private_ipsec_sa_mgr_t {
 };
 
 /**
+ * Struct to keep track of locked IPsec SAs
+ */
+typedef struct {
+
+       /**
+        * IPsec SA
+        */
+       ipsec_sa_t *sa;
+
+       /**
+        * Set if this SA is currently in use by a thread
+        */
+       bool locked;
+
+       /**
+        * Condvar used by threads to wait for this entry
+        */
+       condvar_t *condvar;
+
+       /**
+        * Number of threads waiting for this entry
+        */
+       u_int waiting_threads;
+
+       /**
+        * Set if this entry is awaiting deletion
+        */
+       bool awaits_deletion;
+
+}  ipsec_sa_entry_t;
+
+/**
  * Helper struct for expiration events
  */
 typedef struct {
@@ -69,9 +102,9 @@ typedef struct {
        private_ipsec_sa_mgr_t *manager;
 
        /**
-        * SA that expired
+        * Entry that expired
         */
-       ipsec_sa_t *sa;
+       ipsec_sa_entry_t *entry;
 
        /**
         * 0 if this is a hard expire, otherwise the offset in s (soft->hard)
@@ -94,21 +127,99 @@ static u_int spi_hash(u_int32_t *spi)
 }
 
 /**
+ * Create an SA entry
+ */
+static ipsec_sa_entry_t *create_entry(ipsec_sa_t *sa)
+{
+       ipsec_sa_entry_t *this;
+
+       INIT(this,
+               .condvar = condvar_create(CONDVAR_TYPE_DEFAULT),
+               .sa = sa,
+       );
+       return this;
+}
+
+/**
+ * Destroy an SA entry
+ */
+static void destroy_entry(ipsec_sa_entry_t *entry)
+{
+       entry->condvar->destroy(entry->condvar);
+       entry->sa->destroy(entry->sa);
+       free(entry);
+}
+
+/**
+ * Makes sure an entry is safe to remove
+ * Must be called with this->mutex held.
+ *
+ * @return                     TRUE if entry can be removed, FALSE if entry is already
+*                                      being removed by another thread
+ */
+static bool wait_remove_entry(private_ipsec_sa_mgr_t *this,
+                                                         ipsec_sa_entry_t *entry)
+{
+       if (entry->awaits_deletion)
+       {
+               /* this will be deleted by another thread already */
+               return FALSE;
+       }
+       entry->awaits_deletion = TRUE;
+       while (entry->locked)
+       {
+               entry->condvar->wait(entry->condvar, this->mutex);
+       }
+       while (entry->waiting_threads > 0)
+       {
+               entry->condvar->broadcast(entry->condvar);
+               entry->condvar->wait(entry->condvar, this->mutex);
+       }
+       return TRUE;
+}
+
+/**
+ * Waits until an is available and then locks it.
+ * Must only be called with this->mutex held
+ */
+static bool wait_for_entry(private_ipsec_sa_mgr_t *this,
+                                                  ipsec_sa_entry_t *entry)
+{
+       while (entry->locked && !entry->awaits_deletion)
+       {
+               entry->waiting_threads++;
+               entry->condvar->wait(entry->condvar, this->mutex);
+               entry->waiting_threads--;
+       }
+       if (entry->awaits_deletion)
+       {
+               /* others may still be waiting, */
+               entry->condvar->signal(entry->condvar);
+               return FALSE;
+       }
+       entry->locked = TRUE;
+       return TRUE;
+}
+
+/**
  * Flushes all entries
  * Must be called with this->mutex held.
  */
 static void flush_entries(private_ipsec_sa_mgr_t *this)
 {
+       ipsec_sa_entry_t *current;
        enumerator_t *enumerator;
-       ipsec_sa_t *current;
 
        DBG2(DBG_ESP, "flushing SAD");
 
        enumerator = this->sas->create_enumerator(this->sas);
        while (enumerator->enumerate(enumerator, (void**)&current))
        {
-               this->sas->remove_at(this->sas, enumerator);
-               current->destroy(current);
+               if (wait_remove_entry(this, current))
+               {
+                       this->sas->remove_at(this->sas, enumerator);
+                       destroy_entry(current);
+               }
        }
        enumerator->destroy(enumerator);
 }
@@ -116,21 +227,65 @@ static void flush_entries(private_ipsec_sa_mgr_t *this)
 /*
  * Different match functions to find SAs in the linked list
  */
-static bool match_entry_by_ptr(ipsec_sa_t *sa, ipsec_sa_t *other)
+static bool match_entry_by_ptr(ipsec_sa_entry_t *item, ipsec_sa_entry_t *entry)
+{
+       return item == entry;
+}
+
+static bool match_entry_by_sa_ptr(ipsec_sa_entry_t *item, ipsec_sa_t *sa)
 {
-       return sa == other;
+       return item->sa == sa;
 }
 
-static bool match_entry_by_spi_inbound(ipsec_sa_t *sa, u_int32_t spi,
+static bool match_entry_by_spi_inbound(ipsec_sa_entry_t *item, u_int32_t spi,
                                                                           bool inbound)
 {
-       return sa->get_spi(sa) == spi && sa->is_inbound(sa) == inbound;
+       return item->sa->get_spi(item->sa) == spi &&
+                  item->sa->is_inbound(item->sa) == inbound;
 }
 
-static bool match_entry_by_spi_src_dst(ipsec_sa_t *sa, u_int32_t spi,
+static bool match_entry_by_spi_src_dst(ipsec_sa_entry_t *item, u_int32_t spi,
                                                                           host_t *src, host_t *dst)
 {
-       return sa->match_by_spi_src_dst(sa, spi, src, dst);
+       return item->sa->match_by_spi_src_dst(item->sa, spi, src, dst);
+}
+
+static bool match_entry_by_reqid_inbound(ipsec_sa_entry_t *item,
+                                                                                u_int32_t reqid, bool inbound)
+{
+       return item->sa->match_by_reqid(item->sa, reqid, inbound);
+}
+
+static bool match_entry_by_spi_dst(ipsec_sa_entry_t *item, u_int32_t spi,
+                                                                  host_t *dst)
+{
+       return item->sa->match_by_spi_dst(item->sa, spi, dst);
+}
+
+/**
+ * Remove an entry
+ */
+static bool remove_entry(private_ipsec_sa_mgr_t *this, ipsec_sa_entry_t *entry)
+{
+       ipsec_sa_entry_t *current;
+       enumerator_t *enumerator;
+       bool removed = FALSE;
+
+       enumerator = this->sas->create_enumerator(this->sas);
+       while (enumerator->enumerate(enumerator, (void**)&current))
+       {
+               if (current == entry)
+               {
+                       if (wait_remove_entry(this, current))
+                       {
+                               this->sas->remove_at(this->sas, enumerator);
+                               removed = TRUE;
+                       }
+                       break;
+               }
+       }
+       enumerator->destroy(enumerator);
+       return removed;
 }
 
 /**
@@ -142,10 +297,10 @@ static job_requeue_t sa_expired(ipsec_sa_expired_t *expired)
 
        this->mutex->lock(this->mutex);
        if (this->sas->find_first(this->sas, (void*)match_entry_by_ptr,
-                                                         NULL, expired->sa) == SUCCESS)
+                                                         NULL, expired->entry) == SUCCESS)
        {
                u_int32_t hard_offset = expired->hard_offset;
-               ipsec_sa_t *sa = expired->sa;
+               ipsec_sa_t *sa = expired->entry->sa;
 
                ipsec->events->expire(ipsec->events, sa->get_reqid(sa),
                                                          sa->get_protocol(sa), sa->get_spi(sa),
@@ -157,8 +312,10 @@ static job_requeue_t sa_expired(ipsec_sa_expired_t *expired)
                        return JOB_RESCHEDULE(hard_offset);
                }
                /* hard limit reached */
-               this->sas->remove(this->sas, sa, NULL);
-               sa->destroy(sa);
+               if (remove_entry(this, expired->entry))
+               {
+                       destroy_entry(expired->entry);
+               }
        }
        this->mutex->unlock(this->mutex);
        return JOB_REQUEUE_NONE;
@@ -168,16 +325,16 @@ static job_requeue_t sa_expired(ipsec_sa_expired_t *expired)
  * Schedule a job to handle IPsec SA expiration
  */
 static void schedule_expiration(private_ipsec_sa_mgr_t *this,
-                                                               ipsec_sa_t *sa)
+                                                               ipsec_sa_entry_t *entry)
 {
-       lifetime_cfg_t *lifetime = sa->get_lifetime(sa);
+       lifetime_cfg_t *lifetime = entry->sa->get_lifetime(entry->sa);
        ipsec_sa_expired_t *expired;
        callback_job_t *job;
        u_int32_t timeout;
 
        INIT(expired,
                .manager = this,
-               .sa = sa,
+               .entry = entry,
        );
 
        /* schedule a rekey first, a hard timeout will be scheduled then, if any */
@@ -284,6 +441,7 @@ METHOD(ipsec_sa_mgr_t, add_sa, status_t,
        u_int16_t cpi, bool encap, bool esn, bool inbound,
        traffic_selector_t *src_ts, traffic_selector_t *dst_ts)
 {
+       ipsec_sa_entry_t *entry;
        ipsec_sa_t *sa_new;
 
        DBG2(DBG_ESP, "adding SAD entry with SPI %.8x and reqid {%u}",
@@ -321,8 +479,9 @@ METHOD(ipsec_sa_mgr_t, add_sa, status_t,
                return FAILED;
        }
 
-       schedule_expiration(this, sa_new);
-       this->sas->insert_last(this->sas, sa_new);
+       entry = create_entry(sa_new);
+       schedule_expiration(this, entry);
+       this->sas->insert_last(this->sas, entry);
 
        this->mutex->unlock(this->mutex);
        return SUCCESS;
@@ -332,7 +491,7 @@ METHOD(ipsec_sa_mgr_t, del_sa, status_t,
        private_ipsec_sa_mgr_t *this, host_t *src, host_t *dst, u_int32_t spi,
        u_int8_t protocol, u_int16_t cpi, mark_t mark)
 {
-       ipsec_sa_t *current, *found = NULL;
+       ipsec_sa_entry_t *current, *found = NULL;
        enumerator_t *enumerator;
 
        this->mutex->lock(this->mutex);
@@ -341,8 +500,11 @@ METHOD(ipsec_sa_mgr_t, del_sa, status_t,
        {
                if (match_entry_by_spi_src_dst(current, spi, src, dst))
                {
-                       this->sas->remove_at(this->sas, enumerator);
-                       found = current;
+                       if (wait_remove_entry(this, current))
+                       {
+                               this->sas->remove_at(this->sas, enumerator);
+                               found = current;
+                       }
                        break;
                }
        }
@@ -352,13 +514,65 @@ METHOD(ipsec_sa_mgr_t, del_sa, status_t,
        if (found)
        {
                DBG2(DBG_ESP, "deleted %sbound SAD entry with SPI %.8x",
-                        found->is_inbound(found) ? "in" : "out", ntohl(spi));
-               found->destroy(found);
+                        found->sa->is_inbound(found->sa) ? "in" : "out", ntohl(spi));
+               destroy_entry(found);
                return SUCCESS;
        }
        return FAILED;
 }
 
+METHOD(ipsec_sa_mgr_t, checkout_by_reqid, ipsec_sa_t*,
+       private_ipsec_sa_mgr_t *this, u_int32_t reqid, bool inbound)
+{
+       ipsec_sa_entry_t *entry;
+       ipsec_sa_t *sa = NULL;
+
+       this->mutex->lock(this->mutex);
+       if (this->sas->find_first(this->sas, (void*)match_entry_by_reqid_inbound,
+                                                        (void**)&entry, reqid, inbound) == SUCCESS &&
+               wait_for_entry(this, entry))
+       {
+               sa = entry->sa;
+       }
+       this->mutex->unlock(this->mutex);
+       return sa;
+}
+
+METHOD(ipsec_sa_mgr_t, checkout_by_spi, ipsec_sa_t*,
+       private_ipsec_sa_mgr_t *this, u_int32_t spi, host_t *dst)
+{
+       ipsec_sa_entry_t *entry;
+       ipsec_sa_t *sa = NULL;
+
+       this->mutex->lock(this->mutex);
+       if (this->sas->find_first(this->sas, (void*)match_entry_by_spi_dst,
+                                                        (void**)&entry, spi, dst) == SUCCESS &&
+               wait_for_entry(this, entry))
+       {
+               sa = entry->sa;
+       }
+       this->mutex->unlock(this->mutex);
+       return sa;
+}
+
+METHOD(ipsec_sa_mgr_t, checkin, void,
+       private_ipsec_sa_mgr_t *this, ipsec_sa_t *sa)
+{
+       ipsec_sa_entry_t *entry;
+
+       this->mutex->lock(this->mutex);
+       if (this->sas->find_first(this->sas, (void*)match_entry_by_sa_ptr,
+                                                        (void**)&entry, sa) == SUCCESS)
+       {
+               if (entry->locked)
+               {
+                       entry->locked = FALSE;
+                       entry->condvar->signal(entry->condvar);
+               }
+       }
+       this->mutex->unlock(this->mutex);
+}
+
 METHOD(ipsec_sa_mgr_t, flush_sas, status_t,
        private_ipsec_sa_mgr_t *this)
 {
@@ -396,6 +610,9 @@ ipsec_sa_mgr_t *ipsec_sa_mgr_create()
                        .get_spi = _get_spi,
                        .add_sa = _add_sa,
                        .del_sa = _del_sa,
+                       .checkout_by_spi = _checkout_by_spi,
+                       .checkout_by_reqid = _checkout_by_reqid,
+                       .checkin = _checkin,
                        .flush_sas = _flush_sas,
                        .destroy = _destroy,
                },
index 0acb0c1..303b36f 100644 (file)
@@ -108,6 +108,49 @@ struct ipsec_sa_mgr_t {
        status_t (*flush_sas)(ipsec_sa_mgr_t *this);
 
        /**
+        * Checkout an installed IPsec SA by SPI and destination address
+        * Can be used to find the correct SA for an inbound packet.
+        *
+        * The matching SA is locked until it is checked in using checkin().
+        * If the matching SA is already checked out, this call blocks until the
+        * SA is checked in.
+        *
+        * Since other threads may be waiting for the checked out SA, it should be
+        * checked in as soon as possible after use.
+        *
+        * @param spi                   SPI (e.g. of an inbound packet)
+        * @param dst                   destination address (e.g. of an inbound packet)
+        * @return                              the matching IPsec SA, or NULL if none is found
+        */
+       ipsec_sa_t *(*checkout_by_spi)(ipsec_sa_mgr_t *this, u_int32_t spi,
+                                                                  host_t *dst);
+
+       /**
+        * Checkout an installed IPsec SA by its reqid and inbound/outbound flag.
+        * Can be used to find the correct SA for an outbound packet.
+        *
+        * The matching SA is locked until it is checked in using checkin().
+        * If the matching SA is already checked out, this call blocks until the
+        * SA is checked in.
+        *
+        * Since other threads may be waiting for a checked out SA, it should be
+        * checked in as soon as possible after use.
+        *
+        * @param reqid                 reqid of the SA
+        * @param inbound               TRUE for an inbound SA, FALSE for an outbound SA
+        * @return                              the matching IPsec SA, or NULL if none is found
+        */
+       ipsec_sa_t *(*checkout_by_reqid)(ipsec_sa_mgr_t *this, u_int32_t reqid,
+                                                                        bool inbound);
+
+       /**
+        * Checkin an SA after use.
+        *
+        * @param sa                    checked out SA
+        */
+       void (*checkin)(ipsec_sa_mgr_t *this, ipsec_sa_t *sa);
+
+       /**
         * Destroy an ipsec_sa_mgr_t
         */
        void (*destroy)(ipsec_sa_mgr_t *this);