Properly initialize cached address map in kernel-pfroute plugin
[strongswan.git] / src / libhydra / plugins / kernel_pfroute / kernel_pfroute_net.c
index 0095c66..16a46bb 100644 (file)
@@ -28,6 +28,8 @@
 #include <utils/host.h>
 #include <threading/thread.h>
 #include <threading/mutex.h>
+#include <threading/rwlock.h>
+#include <utils/hashtable.h>
 #include <utils/linked_list.h>
 #include <processing/jobs/callback_job.h>
 
@@ -100,13 +102,72 @@ static void iface_entry_destroy(iface_entry_t *this)
 }
 
 /**
+ * check if an interface is up
+ */
+static inline bool iface_entry_up(iface_entry_t *iface)
+{
+       return (iface->flags & IFF_UP) == IFF_UP;
+}
+
+/**
  * check if an interface is up and usable
  */
 static inline bool iface_entry_up_and_usable(iface_entry_t *iface)
 {
-       return iface->usable && (iface->flags & IFF_UP) == IFF_UP;
+       return iface->usable && iface_entry_up(iface);
+}
+
+typedef struct addr_map_entry_t addr_map_entry_t;
+
+/**
+ * Entry that maps an IP address to an interface entry
+ */
+struct addr_map_entry_t {
+       /** The IP address */
+       host_t *ip;
+
+       /** The interface this address is installed on */
+       iface_entry_t *iface;
+};
+
+/**
+ * Hash a addr_map_entry_t object, all entries with the same IP address
+ * are stored in the same bucket
+ */
+static u_int addr_map_entry_hash(addr_map_entry_t *this)
+{
+       return chunk_hash(this->ip->get_address(this->ip));
+}
+
+/**
+ * Compare two addr_map_entry_t objects, two entries are equal if they are
+ * installed on the same interface
+ */
+static bool addr_map_entry_equals(addr_map_entry_t *a, addr_map_entry_t *b)
+{
+       return a->iface->ifindex == b->iface->ifindex &&
+                  a->ip->ip_equals(a->ip, b->ip);
 }
 
+/**
+ * Used with get_match this finds an address entry if it is installed on
+ * an up and usable interface
+ */
+static bool addr_map_entry_match_up_and_usable(addr_map_entry_t *a,
+                                                                                          addr_map_entry_t *b)
+{
+       return iface_entry_up_and_usable(b->iface) &&
+                  a->ip->ip_equals(a->ip, b->ip);
+}
+
+/**
+ * Used with get_match this finds an address entry if it is installed on
+ * any active local interface
+ */
+static bool addr_map_entry_match_up(addr_map_entry_t *a, addr_map_entry_t *b)
+{
+       return iface_entry_up(b->iface) && a->ip->ip_equals(a->ip, b->ip);
+}
 
 typedef struct private_kernel_pfroute_net_t private_kernel_pfroute_net_t;
 
@@ -121,9 +182,9 @@ struct private_kernel_pfroute_net_t
        kernel_pfroute_net_t public;
 
        /**
-        * mutex to lock access to various lists
+        * lock to access lists and maps
         */
-       mutex_t *mutex;
+       rwlock_t *lock;
 
        /**
         * Cached list of interfaces and their addresses (iface_entry_t)
@@ -131,6 +192,11 @@ struct private_kernel_pfroute_net_t
        linked_list_t *ifaces;
 
        /**
+        * Map for IP addresses to iface_entry_t objects (addr_map_entry_t)
+        */
+       hashtable_t *addrs;
+
+       /**
         * mutex to lock access to the PF_ROUTE socket
         */
        mutex_t *mutex_pfroute;
@@ -157,6 +223,48 @@ struct private_kernel_pfroute_net_t
 };
 
 /**
+ * Add an address map entry
+ */
+static void addr_map_entry_add(private_kernel_pfroute_net_t *this,
+                                                          addr_entry_t *addr, iface_entry_t *iface)
+{
+       addr_map_entry_t *entry;
+
+       if (addr->virtual)
+       {       /* don't map virtual IPs */
+               return;
+       }
+
+       INIT(entry,
+               .ip = addr->ip,
+               .iface = iface,
+       );
+       entry = this->addrs->put(this->addrs, entry, entry);
+       free(entry);
+}
+
+/**
+ * Remove an address map entry (the argument order is a bit strange because
+ * it is also used with linked_list_t.invoke_function)
+ */
+static void addr_map_entry_remove(addr_entry_t *addr, iface_entry_t *iface,
+                                                                 private_kernel_pfroute_net_t *this)
+{
+       addr_map_entry_t *entry, lookup = {
+               .ip = addr->ip,
+               .iface = iface,
+       };
+
+       if (addr->virtual)
+       {       /* these are never mapped, but this check avoid problems if a virtual IP
+                * equals a regular one */
+               return;
+       }
+       entry = this->addrs->remove(this->addrs, &lookup);
+       free(entry);
+}
+
+/**
  * callback function that raises the delayed roam event
  */
 static job_requeue_t roam_event(uintptr_t address)
@@ -225,7 +333,7 @@ static void process_addr(private_kernel_pfroute_net_t *this,
                return;
        }
 
-       this->mutex->lock(this->mutex);
+       this->lock->write_lock(this->lock);
        ifaces = this->ifaces->create_enumerator(this->ifaces);
        while (ifaces->enumerate(ifaces, &iface))
        {
@@ -246,6 +354,7 @@ static void process_addr(private_kernel_pfroute_net_t *this,
                                                        DBG1(DBG_KNL, "%H disappeared from %s",
                                                                 host, iface->ifname);
                                                }
+                                               addr_map_entry_remove(addr, iface, this);
                                                addr_entry_destroy(addr);
                                        }
                                        else if (ifa->ifam_type == RTM_NEWADDR && addr->virtual)
@@ -264,6 +373,7 @@ static void process_addr(private_kernel_pfroute_net_t *this,
                                addr->virtual = FALSE;
                                addr->refcount = 1;
                                iface->addrs->insert_last(iface->addrs, addr);
+                               addr_map_entry_add(this, addr, iface);
                                if (iface->usable)
                                {
                                        DBG1(DBG_KNL, "%H appeared on %s", host, iface->ifname);
@@ -278,7 +388,7 @@ static void process_addr(private_kernel_pfroute_net_t *this,
                }
        }
        ifaces->destroy(ifaces);
-       this->mutex->unlock(this->mutex);
+       this->lock->unlock(this->lock);
        host->destroy(host);
 
        if (roam)
@@ -298,7 +408,7 @@ static void process_link(private_kernel_pfroute_net_t *this,
        iface_entry_t *iface;
        bool roam = FALSE;
 
-       this->mutex->lock(this->mutex);
+       this->lock->write_lock(this->lock);
        enumerator = this->ifaces->create_enumerator(this->ifaces);
        while (enumerator->enumerate(enumerator, &iface))
        {
@@ -322,7 +432,7 @@ static void process_link(private_kernel_pfroute_net_t *this,
                }
        }
        enumerator->destroy(enumerator);
-       this->mutex->unlock(this->mutex);
+       this->lock->unlock(this->lock);
 
        if (roam)
        {
@@ -401,12 +511,8 @@ static job_requeue_t receive_events(private_kernel_pfroute_net_t *this)
 /** enumerator over addresses */
 typedef struct {
        private_kernel_pfroute_net_t* this;
-       /** whether to enumerate down interfaces */
-       bool include_down_ifaces;
-       /** whether to enumerate virtual ip addresses */
-       bool include_virtual_ips;
-       /** whether to enumerate loopback interfaces */
-       bool include_loopback;
+       /** which addresses to enumerate */
+       kernel_address_type_t which;
 } address_enumerator_t;
 
 /**
@@ -414,7 +520,7 @@ typedef struct {
  */
 static void address_enumerator_destroy(address_enumerator_t *data)
 {
-       data->this->mutex->unlock(data->this->mutex);
+       data->this->lock->unlock(data->this->lock);
        free(data);
 }
 
@@ -425,7 +531,7 @@ static bool filter_addresses(address_enumerator_t *data,
                                                         addr_entry_t** in, host_t** out)
 {
        host_t *ip;
-       if (!data->include_virtual_ips && (*in)->virtual)
+       if (!(data->which & ADDR_TYPE_VIRTUAL) && (*in)->virtual)
        {   /* skip virtual interfaces added by us */
                return FALSE;
        }
@@ -458,16 +564,16 @@ static enumerator_t *create_iface_enumerator(iface_entry_t *iface,
 static bool filter_interfaces(address_enumerator_t *data, iface_entry_t** in,
                                                          iface_entry_t** out)
 {
-       if (!(*in)->usable)
+       if (!(data->which & ADDR_TYPE_IGNORED) && !(*in)->usable)
        {       /* skip interfaces excluded by config */
                return FALSE;
        }
-       if (!data->include_loopback && ((*in)->flags & IFF_LOOPBACK))
+       if (!(data->which & ADDR_TYPE_LOOPBACK) && ((*in)->flags & IFF_LOOPBACK))
        {       /* ignore loopback devices */
                return FALSE;
        }
-       if (!data->include_down_ifaces && !((*in)->flags & IFF_UP))
-       {   /* skip interfaces not up */
+       if (!(data->which & ADDR_TYPE_DOWN) && !((*in)->flags & IFF_UP))
+       {       /* skip interfaces not up */
                return FALSE;
        }
        *out = *in;
@@ -475,16 +581,13 @@ static bool filter_interfaces(address_enumerator_t *data, iface_entry_t** in,
 }
 
 METHOD(kernel_net_t, create_address_enumerator, enumerator_t*,
-       private_kernel_pfroute_net_t *this,
-       bool include_down_ifaces, bool include_virtual_ips, bool include_loopback)
+       private_kernel_pfroute_net_t *this, kernel_address_type_t which)
 {
        address_enumerator_t *data = malloc_thing(address_enumerator_t);
        data->this = this;
-       data->include_down_ifaces = include_down_ifaces;
-       data->include_virtual_ips = include_virtual_ips;
-       data->include_loopback = include_loopback;
+       data->which = which;
 
-       this->mutex->lock(this->mutex);
+       this->lock->read_lock(this->lock);
        return enumerator_create_nested(
                                enumerator_create_filter(
                                        this->ifaces->create_enumerator(this->ifaces),
@@ -496,59 +599,37 @@ METHOD(kernel_net_t, create_address_enumerator, enumerator_t*,
 METHOD(kernel_net_t, get_interface_name, bool,
        private_kernel_pfroute_net_t *this, host_t* ip, char **name)
 {
-       enumerator_t *ifaces, *addrs;
-       iface_entry_t *iface;
-       addr_entry_t *addr;
-       bool found = FALSE, ignored = FALSE;
+       addr_map_entry_t *entry, lookup = {
+               .ip = ip,
+       };
 
        if (ip->is_anyaddr(ip))
        {
                return FALSE;
        }
-
-       this->mutex->lock(this->mutex);
-       ifaces = this->ifaces->create_enumerator(this->ifaces);
-       while (ifaces->enumerate(ifaces, &iface))
+       this->lock->read_lock(this->lock);
+       /* first try to find it on an up and usable interface */
+       entry = this->addrs->get_match(this->addrs, &lookup,
+                                                                 (void*)addr_map_entry_match_up_and_usable);
+       if (entry)
        {
-               addrs = iface->addrs->create_enumerator(iface->addrs);
-               while (addrs->enumerate(addrs, &addr))
-               {
-                       if (ip->ip_equals(ip, addr->ip))
-                       {
-                               found = TRUE;
-                               if (!iface->usable)
-                               {
-                                       ignored = TRUE;
-                                       break;
-                               }
-                               if (name)
-                               {
-                                       *name = strdup(iface->ifname);
-                               }
-                               break;
-                       }
-               }
-               addrs->destroy(addrs);
-               if (found)
-               {
-                       break;
-               }
-       }
-       ifaces->destroy(ifaces);
-       this->mutex->unlock(this->mutex);
-
-       if (!ignored)
-       {
-               if (!found)
-               {
-                       DBG2(DBG_KNL, "%H is not a local address", ip);
-               }
-               else if (name)
+               if (name)
                {
+                       *name = strdup(entry->iface->ifname);
                        DBG2(DBG_KNL, "%H is on interface %s", ip, *name);
                }
+               this->lock->unlock(this->lock);
+               return TRUE;
        }
-       return found && !ignored;
+       /* maybe it is installed on an ignored interface */
+       entry = this->addrs->get_match(this->addrs, &lookup,
+                                                                 (void*)addr_map_entry_match_up);
+       if (!entry)
+       {       /* the address does not exist, is on a down interface */
+               DBG2(DBG_KNL, "%H is not a local address or the interface is down", ip);
+       }
+       this->lock->unlock(this->lock);
+       return FALSE;
 }
 
 METHOD(kernel_net_t, get_source_addr, host_t*,
@@ -650,6 +731,7 @@ static status_t init_address_list(private_kernel_pfroute_net_t *this)
                                        addr->virtual = FALSE;
                                        addr->refcount = 1;
                                        iface->addrs->insert_last(iface->addrs, addr);
+                                       addr_map_entry_add(this, addr, iface);
                                }
                        }
                }
@@ -678,6 +760,9 @@ static status_t init_address_list(private_kernel_pfroute_net_t *this)
 METHOD(kernel_net_t, destroy, void,
        private_kernel_pfroute_net_t *this)
 {
+       enumerator_t *enumerator;
+       addr_entry_t *addr;
+
        if (this->socket > 0)
        {
                close(this->socket);
@@ -686,8 +771,15 @@ METHOD(kernel_net_t, destroy, void,
        {
                close(this->socket_events);
        }
+       enumerator = this->addrs->create_enumerator(this->addrs);
+       while (enumerator->enumerate(enumerator, NULL, (void**)&addr))
+       {
+               free(addr);
+       }
+       enumerator->destroy(enumerator);
+       this->addrs->destroy(this->addrs);
        this->ifaces->destroy_function(this->ifaces, (void*)iface_entry_destroy);
-       this->mutex->destroy(this->mutex);
+       this->lock->destroy(this->lock);
        this->mutex_pfroute->destroy(this->mutex_pfroute);
        free(this);
 }
@@ -715,7 +807,10 @@ kernel_pfroute_net_t *kernel_pfroute_net_create()
                        },
                },
                .ifaces = linked_list_create(),
-               .mutex = mutex_create(MUTEX_TYPE_DEFAULT),
+               .addrs = hashtable_create(
+                                                               (hashtable_hash_t)addr_map_entry_hash,
+                                                               (hashtable_equals_t)addr_map_entry_equals, 16),
+               .lock = rwlock_create(RWLOCK_TYPE_DEFAULT),
                .mutex_pfroute = mutex_create(MUTEX_TYPE_DEFAULT),
        );