Properly initialize cached address map in kernel-pfroute plugin
[strongswan.git] / src / libhydra / plugins / kernel_pfroute / kernel_pfroute_net.c
index ae13709..16a46bb 100644 (file)
@@ -28,6 +28,8 @@
 #include <utils/host.h>
 #include <threading/thread.h>
 #include <threading/mutex.h>
+#include <threading/rwlock.h>
+#include <utils/hashtable.h>
 #include <utils/linked_list.h>
 #include <processing/jobs/callback_job.h>
 
@@ -180,9 +182,9 @@ struct private_kernel_pfroute_net_t
        kernel_pfroute_net_t public;
 
        /**
-        * mutex to lock access to various lists
+        * lock to access lists and maps
         */
-       mutex_t *mutex;
+       rwlock_t *lock;
 
        /**
         * Cached list of interfaces and their addresses (iface_entry_t)
@@ -223,7 +225,7 @@ struct private_kernel_pfroute_net_t
 /**
  * Add an address map entry
  */
-static void addr_map_entry_add(private_kernel_netlink_net_t *this,
+static void addr_map_entry_add(private_kernel_pfroute_net_t *this,
                                                           addr_entry_t *addr, iface_entry_t *iface)
 {
        addr_map_entry_t *entry;
@@ -246,7 +248,7 @@ static void addr_map_entry_add(private_kernel_netlink_net_t *this,
  * it is also used with linked_list_t.invoke_function)
  */
 static void addr_map_entry_remove(addr_entry_t *addr, iface_entry_t *iface,
-                                                                 private_kernel_netlink_net_t *this)
+                                                                 private_kernel_pfroute_net_t *this)
 {
        addr_map_entry_t *entry, lookup = {
                .ip = addr->ip,
@@ -331,7 +333,7 @@ static void process_addr(private_kernel_pfroute_net_t *this,
                return;
        }
 
-       this->mutex->lock(this->mutex);
+       this->lock->write_lock(this->lock);
        ifaces = this->ifaces->create_enumerator(this->ifaces);
        while (ifaces->enumerate(ifaces, &iface))
        {
@@ -386,7 +388,7 @@ static void process_addr(private_kernel_pfroute_net_t *this,
                }
        }
        ifaces->destroy(ifaces);
-       this->mutex->unlock(this->mutex);
+       this->lock->unlock(this->lock);
        host->destroy(host);
 
        if (roam)
@@ -406,7 +408,7 @@ static void process_link(private_kernel_pfroute_net_t *this,
        iface_entry_t *iface;
        bool roam = FALSE;
 
-       this->mutex->lock(this->mutex);
+       this->lock->write_lock(this->lock);
        enumerator = this->ifaces->create_enumerator(this->ifaces);
        while (enumerator->enumerate(enumerator, &iface))
        {
@@ -430,7 +432,7 @@ static void process_link(private_kernel_pfroute_net_t *this,
                }
        }
        enumerator->destroy(enumerator);
-       this->mutex->unlock(this->mutex);
+       this->lock->unlock(this->lock);
 
        if (roam)
        {
@@ -509,12 +511,8 @@ static job_requeue_t receive_events(private_kernel_pfroute_net_t *this)
 /** enumerator over addresses */
 typedef struct {
        private_kernel_pfroute_net_t* this;
-       /** whether to enumerate down interfaces */
-       bool include_down_ifaces;
-       /** whether to enumerate virtual ip addresses */
-       bool include_virtual_ips;
-       /** whether to enumerate loopback interfaces */
-       bool include_loopback;
+       /** which addresses to enumerate */
+       kernel_address_type_t which;
 } address_enumerator_t;
 
 /**
@@ -522,7 +520,7 @@ typedef struct {
  */
 static void address_enumerator_destroy(address_enumerator_t *data)
 {
-       data->this->mutex->unlock(data->this->mutex);
+       data->this->lock->unlock(data->this->lock);
        free(data);
 }
 
@@ -533,7 +531,7 @@ static bool filter_addresses(address_enumerator_t *data,
                                                         addr_entry_t** in, host_t** out)
 {
        host_t *ip;
-       if (!data->include_virtual_ips && (*in)->virtual)
+       if (!(data->which & ADDR_TYPE_VIRTUAL) && (*in)->virtual)
        {   /* skip virtual interfaces added by us */
                return FALSE;
        }
@@ -566,16 +564,16 @@ static enumerator_t *create_iface_enumerator(iface_entry_t *iface,
 static bool filter_interfaces(address_enumerator_t *data, iface_entry_t** in,
                                                          iface_entry_t** out)
 {
-       if (!(*in)->usable)
+       if (!(data->which & ADDR_TYPE_IGNORED) && !(*in)->usable)
        {       /* skip interfaces excluded by config */
                return FALSE;
        }
-       if (!data->include_loopback && ((*in)->flags & IFF_LOOPBACK))
+       if (!(data->which & ADDR_TYPE_LOOPBACK) && ((*in)->flags & IFF_LOOPBACK))
        {       /* ignore loopback devices */
                return FALSE;
        }
-       if (!data->include_down_ifaces && !((*in)->flags & IFF_UP))
-       {   /* skip interfaces not up */
+       if (!(data->which & ADDR_TYPE_DOWN) && !((*in)->flags & IFF_UP))
+       {       /* skip interfaces not up */
                return FALSE;
        }
        *out = *in;
@@ -583,16 +581,13 @@ static bool filter_interfaces(address_enumerator_t *data, iface_entry_t** in,
 }
 
 METHOD(kernel_net_t, create_address_enumerator, enumerator_t*,
-       private_kernel_pfroute_net_t *this,
-       bool include_down_ifaces, bool include_virtual_ips, bool include_loopback)
+       private_kernel_pfroute_net_t *this, kernel_address_type_t which)
 {
        address_enumerator_t *data = malloc_thing(address_enumerator_t);
        data->this = this;
-       data->include_down_ifaces = include_down_ifaces;
-       data->include_virtual_ips = include_virtual_ips;
-       data->include_loopback = include_loopback;
+       data->which = which;
 
-       this->mutex->lock(this->mutex);
+       this->lock->read_lock(this->lock);
        return enumerator_create_nested(
                                enumerator_create_filter(
                                        this->ifaces->create_enumerator(this->ifaces),
@@ -612,7 +607,7 @@ METHOD(kernel_net_t, get_interface_name, bool,
        {
                return FALSE;
        }
-       this->mutex->lock(this->mutex);
+       this->lock->read_lock(this->lock);
        /* first try to find it on an up and usable interface */
        entry = this->addrs->get_match(this->addrs, &lookup,
                                                                  (void*)addr_map_entry_match_up_and_usable);
@@ -623,7 +618,7 @@ METHOD(kernel_net_t, get_interface_name, bool,
                        *name = strdup(entry->iface->ifname);
                        DBG2(DBG_KNL, "%H is on interface %s", ip, *name);
                }
-               this->mutex->unlock(this->mutex);
+               this->lock->unlock(this->lock);
                return TRUE;
        }
        /* maybe it is installed on an ignored interface */
@@ -633,7 +628,7 @@ METHOD(kernel_net_t, get_interface_name, bool,
        {       /* the address does not exist, is on a down interface */
                DBG2(DBG_KNL, "%H is not a local address or the interface is down", ip);
        }
-       this->mutex->unlock(this->mutex);
+       this->lock->unlock(this->lock);
        return FALSE;
 }
 
@@ -736,6 +731,7 @@ static status_t init_address_list(private_kernel_pfroute_net_t *this)
                                        addr->virtual = FALSE;
                                        addr->refcount = 1;
                                        iface->addrs->insert_last(iface->addrs, addr);
+                                       addr_map_entry_add(this, addr, iface);
                                }
                        }
                }
@@ -765,6 +761,7 @@ METHOD(kernel_net_t, destroy, void,
        private_kernel_pfroute_net_t *this)
 {
        enumerator_t *enumerator;
+       addr_entry_t *addr;
 
        if (this->socket > 0)
        {
@@ -782,7 +779,7 @@ METHOD(kernel_net_t, destroy, void,
        enumerator->destroy(enumerator);
        this->addrs->destroy(this->addrs);
        this->ifaces->destroy_function(this->ifaces, (void*)iface_entry_destroy);
-       this->mutex->destroy(this->mutex);
+       this->lock->destroy(this->lock);
        this->mutex_pfroute->destroy(this->mutex_pfroute);
        free(this);
 }
@@ -813,7 +810,7 @@ kernel_pfroute_net_t *kernel_pfroute_net_create()
                .addrs = hashtable_create(
                                                                (hashtable_hash_t)addr_map_entry_hash,
                                                                (hashtable_equals_t)addr_map_entry_equals, 16),
-               .mutex = mutex_create(MUTEX_TYPE_DEFAULT),
+               .lock = rwlock_create(RWLOCK_TYPE_DEFAULT),
                .mutex_pfroute = mutex_create(MUTEX_TYPE_DEFAULT),
        );