eap-radius: Remove cache entries for expired SAs during ike/child_rekey
[strongswan.git] / src / libcharon / plugins / eap_radius / eap_radius_accounting.c
index 8c780e7..b486ba4 100644 (file)
@@ -1,4 +1,7 @@
 /*
 /*
+ * Copyright (C) 2015 Tobias Brunner
+ * Hochschule fuer Technik Rapperswil
+ *
  * Copyright (C) 2012 Martin Willi
  * Copyright (C) 2012 revosec AG
  *
  * Copyright (C) 2012 Martin Willi
  * Copyright (C) 2012 revosec AG
  *
@@ -21,6 +24,7 @@
 #include <radius_message.h>
 #include <radius_client.h>
 #include <daemon.h>
 #include <radius_message.h>
 #include <radius_client.h>
 #include <daemon.h>
+#include <collections/array.h>
 #include <collections/hashtable.h>
 #include <threading/mutex.h>
 #include <processing/jobs/callback_job.h>
 #include <collections/hashtable.h>
 #include <threading/mutex.h>
 #include <processing/jobs/callback_job.h>
@@ -93,18 +97,33 @@ typedef enum {
 } radius_acct_terminate_cause_t;
 
 /**
 } radius_acct_terminate_cause_t;
 
 /**
+ * Usage stats for a cached SAs
+ */
+typedef struct {
+       /** unique CHILD_SA identifier */
+       u_int32_t id;
+       /** usage stats for this SA */
+       struct {
+               u_int64_t sent;
+               u_int64_t received;
+       } bytes, packets;
+} sa_entry_t;
+
+/**
  * Hashtable entry with usage stats
  */
 typedef struct {
        /** IKE_SA identifier this entry is stored under */
        ike_sa_id_t *id;
        /** RADIUS accounting session ID */
  * Hashtable entry with usage stats
  */
 typedef struct {
        /** IKE_SA identifier this entry is stored under */
        ike_sa_id_t *id;
        /** RADIUS accounting session ID */
-       char sid[16];
-       /** number of sent/received octets/packets */
+       char sid[24];
+       /** number of sent/received octets/packets for expired SAs */
        struct {
                u_int64_t sent;
                u_int64_t received;
        } bytes, packets;
        struct {
                u_int64_t sent;
                u_int64_t received;
        } bytes, packets;
+       /** list of cached SAs, sa_entry_t (sorted by their unique ID) */
+       array_t *cached;
        /** session creation time */
        time_t created;
        /** terminate cause */
        /** session creation time */
        time_t created;
        /** terminate cause */
@@ -123,6 +142,7 @@ typedef struct {
  */
 static void destroy_entry(entry_t *this)
 {
  */
 static void destroy_entry(entry_t *this)
 {
+       array_destroy_function(this->cached, (void*)free, NULL);
        this->id->destroy(this->id);
        free(this);
 }
        this->id->destroy(this->id);
        free(this);
 }
@@ -155,6 +175,23 @@ static bool equals(ike_sa_id_t *a, ike_sa_id_t *b)
 }
 
 /**
 }
 
 /**
+ * Sort cached SAs
+ */
+static int sa_sort(const void *a, const void *b, void *user)
+{
+       const sa_entry_t *ra = a, *rb = b;
+       return ra->id - rb->id;
+}
+
+/**
+ * Find a cached SA
+ */
+static int sa_find(const void *a, const void *b)
+{
+       return sa_sort(a, b, NULL);
+}
+
+/**
  * Update usage counter when a CHILD_SA rekeys/goes down
  */
 static void update_usage(private_eap_radius_accounting_t *this,
  * Update usage counter when a CHILD_SA rekeys/goes down
  */
 static void update_usage(private_eap_radius_accounting_t *this,
@@ -162,6 +199,7 @@ static void update_usage(private_eap_radius_accounting_t *this,
 {
        u_int64_t bytes_in, bytes_out, packets_in, packets_out;
        entry_t *entry;
 {
        u_int64_t bytes_in, bytes_out, packets_in, packets_out;
        entry_t *entry;
+       sa_entry_t *sa, lookup;
 
        child_sa->get_usestats(child_sa, FALSE, NULL, &bytes_out, &packets_out);
        child_sa->get_usestats(child_sa, TRUE, NULL, &bytes_in, &packets_in);
 
        child_sa->get_usestats(child_sa, FALSE, NULL, &bytes_out, &packets_out);
        child_sa->get_usestats(child_sa, TRUE, NULL, &bytes_in, &packets_in);
@@ -170,15 +208,66 @@ static void update_usage(private_eap_radius_accounting_t *this,
        entry = this->sessions->get(this->sessions, ike_sa->get_id(ike_sa));
        if (entry)
        {
        entry = this->sessions->get(this->sessions, ike_sa->get_id(ike_sa));
        if (entry)
        {
-               entry->bytes.sent += bytes_out;
-               entry->bytes.received += bytes_in;
-               entry->packets.sent += packets_out;
-               entry->packets.received += packets_in;
+               lookup.id = child_sa->get_unique_id(child_sa);
+               if (array_bsearch(entry->cached, &lookup, sa_find, &sa) == -1)
+               {
+                       INIT(sa,
+                               .id = lookup.id,
+                       );
+                       array_insert_create(&entry->cached, ARRAY_TAIL, sa);
+                       array_sort(entry->cached, sa_sort, NULL);
+               }
+               sa->bytes.sent = bytes_out;
+               sa->bytes.received = bytes_in;
+               sa->packets.sent = packets_out;
+               sa->packets.received = packets_in;
        }
        this->mutex->unlock(this->mutex);
 }
 
 /**
        }
        this->mutex->unlock(this->mutex);
 }
 
 /**
+ * Cleanup cached SAs
+ */
+static void cleanup_sas(private_eap_radius_accounting_t *this, ike_sa_t *ike_sa,
+                                               entry_t *entry)
+{
+       enumerator_t *enumerator;
+       child_sa_t *child_sa;
+       sa_entry_t *sa, *found;
+       array_t *sas;
+
+       sas = array_create(0, 0);
+       enumerator = ike_sa->create_child_sa_enumerator(ike_sa);
+       while (enumerator->enumerate(enumerator, &child_sa))
+       {
+               INIT(sa,
+                       .id = child_sa->get_unique_id(child_sa),
+               );
+               array_insert(sas, ARRAY_TAIL, sa);
+               array_sort(sas, sa_sort, NULL);
+       }
+       enumerator->destroy(enumerator);
+
+       enumerator = array_create_enumerator(entry->cached);
+       while (enumerator->enumerate(enumerator, &sa))
+       {
+               if (array_bsearch(sas, sa, sa_find, &found) == -1)
+               {
+                       /* SA is gone, add its latest stats to the total for this IKE_SA
+                        * and remove the cache entry */
+                       entry->bytes.sent += sa->bytes.sent;
+                       entry->bytes.received += sa->bytes.received;
+                       entry->packets.sent += sa->packets.sent;
+                       entry->packets.received += sa->packets.received;
+                       array_remove_at(entry->cached, enumerator);
+                       free(sa);
+               }
+       }
+       enumerator->destroy(enumerator);
+       array_destroy_function(sas, (void*)free, NULL);
+}
+
+/**
  * Send a RADIUS message, wait for response
  */
 static bool send_message(private_eap_radius_accounting_t *this,
  * Send a RADIUS message, wait for response
  */
 static bool send_message(private_eap_radius_accounting_t *this,
@@ -210,7 +299,7 @@ static void add_ike_sa_parameters(private_eap_radius_accounting_t *this,
 {
        enumerator_t *enumerator;
        host_t *vip, *host;
 {
        enumerator_t *enumerator;
        host_t *vip, *host;
-       char buf[128];
+       char buf[MAX_RADIUS_ATTRIBUTE_SIZE + 1];
        chunk_t data;
        u_int32_t value;
 
        chunk_t data;
        u_int32_t value;
 
@@ -338,19 +427,32 @@ static job_requeue_t send_interim(interim_data_t *data)
        ike_sa_t *ike_sa;
        entry_t *entry;
        u_int32_t value;
        ike_sa_t *ike_sa;
        entry_t *entry;
        u_int32_t value;
+       array_t *stats;
+       sa_entry_t *sa, *found;
 
        ike_sa = charon->ike_sa_manager->checkout(charon->ike_sa_manager, data->id);
        if (!ike_sa)
        {
                return JOB_REQUEUE_NONE;
        }
 
        ike_sa = charon->ike_sa_manager->checkout(charon->ike_sa_manager, data->id);
        if (!ike_sa)
        {
                return JOB_REQUEUE_NONE;
        }
+       stats = array_create(0, 0);
        enumerator = ike_sa->create_child_sa_enumerator(ike_sa);
        while (enumerator->enumerate(enumerator, &child_sa))
        {
        enumerator = ike_sa->create_child_sa_enumerator(ike_sa);
        while (enumerator->enumerate(enumerator, &child_sa))
        {
+               INIT(sa,
+                       .id = child_sa->get_unique_id(child_sa),
+               );
+               array_insert(stats, ARRAY_TAIL, sa);
+               array_sort(stats, sa_sort, NULL);
+
                child_sa->get_usestats(child_sa, FALSE, NULL, &bytes, &packets);
                child_sa->get_usestats(child_sa, FALSE, NULL, &bytes, &packets);
+               sa->bytes.sent = bytes;
+               sa->packets.sent = packets;
                bytes_out += bytes;
                packets_out += packets;
                child_sa->get_usestats(child_sa, TRUE, NULL, &bytes, &packets);
                bytes_out += bytes;
                packets_out += packets;
                child_sa->get_usestats(child_sa, TRUE, NULL, &bytes, &packets);
+               sa->bytes.received = bytes;
+               sa->packets.received = packets;
                bytes_in += bytes;
                packets_in += packets;
        }
                bytes_in += bytes;
                packets_in += packets;
        }
@@ -365,6 +467,30 @@ static job_requeue_t send_interim(interim_data_t *data)
        {
                entry->interim.last = time_monotonic(NULL);
 
        {
                entry->interim.last = time_monotonic(NULL);
 
+               enumerator = array_create_enumerator(entry->cached);
+               while (enumerator->enumerate(enumerator, &sa))
+               {
+                       if (array_bsearch(stats, sa, sa_find, &found) != -1)
+                       {
+                               /* SA is still around, update stats (e.g. for IKEv1 where
+                                * SA might get used even after rekeying) */
+                               sa->bytes = found->bytes;
+                               sa->packets = found->packets;
+                       }
+                       else
+                       {
+                               /* SA is gone, add its latest stats to the total for this IKE_SA
+                                * and remove the cache entry */
+                               entry->bytes.sent += sa->bytes.sent;
+                               entry->bytes.received += sa->bytes.received;
+                               entry->packets.sent += sa->packets.sent;
+                               entry->packets.received += sa->packets.received;
+                               array_remove_at(entry->cached, enumerator);
+                               free(sa);
+                       }
+               }
+               enumerator->destroy(enumerator);
+
                bytes_in += entry->bytes.received;
                bytes_out += entry->bytes.sent;
                packets_in += entry->packets.received;
                bytes_in += entry->bytes.received;
                bytes_out += entry->bytes.sent;
                packets_in += entry->packets.received;
@@ -405,12 +531,18 @@ static job_requeue_t send_interim(interim_data_t *data)
                schedule_interim(this, entry);
        }
        this->mutex->unlock(this->mutex);
                schedule_interim(this, entry);
        }
        this->mutex->unlock(this->mutex);
+       array_destroy_function(stats, (void*)free, NULL);
 
        if (message)
        {
                if (!send_message(this, message))
                {
 
        if (message)
        {
                if (!send_message(this, message))
                {
-                       eap_radius_handle_timeout(data->id);
+                       if (lib->settings->get_bool(lib->settings,
+                                                       "%s.plugins.eap-radius.accounting_close_on_timeout",
+                                                       TRUE, lib->ns))
+                       {
+                               eap_radius_handle_timeout(data->id);
+                       }
                }
                message->destroy(message);
        }
                }
                message->destroy(message);
        }
@@ -483,6 +615,16 @@ static void send_start(private_eap_radius_accounting_t *this, ike_sa_t *ike_sa)
        message->add(message, RAT_ACCT_SESSION_ID,
                                 chunk_create(entry->sid, strlen(entry->sid)));
 
        message->add(message, RAT_ACCT_SESSION_ID,
                                 chunk_create(entry->sid, strlen(entry->sid)));
 
+       if (!entry->interim.interval)
+       {
+               entry->interim.interval = lib->settings->get_time(lib->settings,
+                                       "%s.plugins.eap-radius.accounting_interval", 0, lib->ns);
+               if (entry->interim.interval)
+               {
+                       DBG1(DBG_CFG, "scheduling RADIUS Interim-Updates every %us",
+                                entry->interim.interval);
+               }
+       }
        schedule_interim(this, entry);
        this->mutex->unlock(this->mutex);
 
        schedule_interim(this, entry);
        this->mutex->unlock(this->mutex);
 
@@ -500,7 +642,9 @@ static void send_start(private_eap_radius_accounting_t *this, ike_sa_t *ike_sa)
 static void send_stop(private_eap_radius_accounting_t *this, ike_sa_t *ike_sa)
 {
        radius_message_t *message;
 static void send_stop(private_eap_radius_accounting_t *this, ike_sa_t *ike_sa)
 {
        radius_message_t *message;
+       enumerator_t *enumerator;
        entry_t *entry;
        entry_t *entry;
+       sa_entry_t *sa;
        u_int32_t value;
 
        this->mutex->lock(this->mutex);
        u_int32_t value;
 
        this->mutex->lock(this->mutex);
@@ -513,6 +657,16 @@ static void send_stop(private_eap_radius_accounting_t *this, ike_sa_t *ike_sa)
                        destroy_entry(entry);
                        return;
                }
                        destroy_entry(entry);
                        return;
                }
+               enumerator = array_create_enumerator(entry->cached);
+               while (enumerator->enumerate(enumerator, &sa))
+               {
+                       entry->bytes.sent += sa->bytes.sent;
+                       entry->bytes.received += sa->bytes.received;
+                       entry->packets.sent += sa->packets.sent;
+                       entry->packets.received += sa->packets.received;
+               }
+               enumerator->destroy(enumerator);
+
                message = radius_message_create(RMC_ACCOUNTING_REQUEST);
                value = htonl(ACCT_STATUS_STOP);
                message->add(message, RAT_ACCT_STATUS_TYPE, chunk_from_thing(value));
                message = radius_message_create(RMC_ACCOUNTING_REQUEST);
                value = htonl(ACCT_STATUS_STOP);
                message->add(message, RAT_ACCT_STATUS_TYPE, chunk_from_thing(value));
@@ -645,6 +799,8 @@ METHOD(listener_t, ike_rekey, bool,
                /* fire new interim update job, old gets invalid */
                schedule_interim(this, entry);
 
                /* fire new interim update job, old gets invalid */
                schedule_interim(this, entry);
 
+               cleanup_sas(this, new, entry);
+
                entry = this->sessions->put(this->sessions, entry->id, entry);
                if (entry)
                {
                entry = this->sessions->put(this->sessions, entry->id, entry);
                if (entry)
                {
@@ -660,8 +816,16 @@ METHOD(listener_t, child_rekey, bool,
        private_eap_radius_accounting_t *this, ike_sa_t *ike_sa,
        child_sa_t *old, child_sa_t *new)
 {
        private_eap_radius_accounting_t *this, ike_sa_t *ike_sa,
        child_sa_t *old, child_sa_t *new)
 {
-       update_usage(this, ike_sa, old);
+       entry_t *entry;
 
 
+       update_usage(this, ike_sa, old);
+       this->mutex->lock(this->mutex);
+       entry = this->sessions->get(this->sessions, ike_sa->get_id(ike_sa));
+       if (entry)
+       {
+               cleanup_sas(this, ike_sa, entry);
+       }
+       this->mutex->unlock(this->mutex);
        return TRUE;
 }
 
        return TRUE;
 }