Ignore SQL-based IP address pools if their address family does not match
authorTobias Brunner <tobias@strongswan.org>
Mon, 18 Mar 2013 17:45:29 +0000 (18:45 +0100)
committerTobias Brunner <tobias@strongswan.org>
Tue, 19 Mar 2013 15:33:07 +0000 (16:33 +0100)
src/libhydra/plugins/attr_sql/sql_attribute.c

index 1a4ee7a..e91e1ed 100644 (file)
@@ -94,19 +94,26 @@ static u_int get_attr_pool(private_sql_attribute_t *this, char *name)
 }
 
 /**
- * Lookup pool by name
+ * Lookup pool by name and address family
  */
-static u_int get_pool(private_sql_attribute_t *this, char *name, u_int *timeout)
+static u_int get_pool(private_sql_attribute_t *this, char *name, int family,
+                                         u_int *timeout)
 {
        enumerator_t *e;
+       chunk_t start;
        u_int pool;
 
-       e = this->db->query(this->db, "SELECT id, timeout FROM pools WHERE name = ?",
-                                               DB_TEXT, name, DB_UINT, DB_UINT);
-       if (e && e->enumerate(e, &pool, timeout))
+       e = this->db->query(this->db,
+                                               "SELECT id, start, timeout FROM pools WHERE name = ?",
+                                               DB_TEXT, name, DB_UINT, DB_BLOB, DB_UINT);
+       if (e && e->enumerate(e, &pool, &start, timeout))
        {
-               e->destroy(e);
-               return pool;
+               if ((family == AF_INET  && start.len == 4) ||
+                       (family == AF_INET6 && start.len == 16))
+               {
+                       e->destroy(e);
+                       return pool;
+               }
        }
        DESTROY_IF(e);
        return 0;
@@ -240,15 +247,17 @@ METHOD(attribute_provider_t, acquire_address, host_t*,
        host_t *address = NULL;
        u_int identity, pool, timeout;
        char *name;
+       int family;
 
        identity = get_identity(this, id);
        if (identity)
        {
+               family = requested->get_family(requested);
                /* check for an existing lease in all pools */
                enumerator = pools->create_enumerator(pools);
                while (enumerator->enumerate(enumerator, &name))
                {
-                       pool = get_pool(this, name, &timeout);
+                       pool = get_pool(this, name, family, &timeout);
                        if (pool)
                        {
                                address = check_lease(this, name, pool, identity);
@@ -266,7 +275,7 @@ METHOD(attribute_provider_t, acquire_address, host_t*,
                        enumerator = pools->create_enumerator(pools);
                        while (enumerator->enumerate(enumerator, &name))
                        {
-                               pool = get_pool(this, name, &timeout);
+                               pool = get_pool(this, name, family, &timeout);
                                if (pool)
                                {
                                        address = get_lease(this, name, pool, timeout, identity);
@@ -291,11 +300,13 @@ METHOD(attribute_provider_t, release_address, bool,
        time_t now = time(NULL);
        bool found = FALSE;
        char *name;
+       int family;
 
+       family = address->get_family(address);
        enumerator = pools->create_enumerator(pools);
        while (enumerator->enumerate(enumerator, &name))
        {
-               pool = get_pool(this, name, &timeout);
+               pool = get_pool(this, name, family, &timeout);
                if (!pool)
                {
                        continue;