Don't parse comma separated pool names in attr-sql
[strongswan.git] / src / libhydra / plugins / attr_sql / sql_attribute.c
index 9a2108e..8055be7 100644 (file)
@@ -38,7 +38,7 @@ struct private_sql_attribute_t {
        database_t *db;
 
        /**
-        * wheter to record lease history in lease table
+        * whether to record lease history in lease table
         */
        bool history;
 };
@@ -74,6 +74,26 @@ static u_int get_identity(private_sql_attribute_t *this, identification_t *id)
 }
 
 /**
+ * Lookup an attribute pool by name
+ */
+static u_int get_attr_pool(private_sql_attribute_t *this, char *name)
+{
+       enumerator_t *e;
+       u_int row = 0;
+
+       e = this->db->query(this->db,
+                                               "SELECT id FROM attribute_pools WHERE name = ?",
+                                               DB_TEXT, name, DB_UINT);
+       if (e)
+       {
+               e->enumerate(e, &row);
+       }
+       DESTROY_IF(e);
+
+       return row;
+}
+
+/**
  * Lookup pool by name
  */
 static u_int get_pool(private_sql_attribute_t *this, char *name, u_int *timeout)
@@ -127,8 +147,8 @@ static host_t* check_lease(private_sql_attribute_t *this, char *name,
                        host = host_create_from_chunk(AF_UNSPEC, address, 0);
                        if (host)
                        {
-                               DBG1("acquired existing lease for address %H in pool '%s'",
-                                        host, name);
+                               DBG1(DBG_CFG, "acquired existing lease for address %H in"
+                                        " pool '%s'", host, name);
                                return host;
                        }
                }
@@ -202,22 +222,19 @@ static host_t* get_lease(private_sql_attribute_t *this, char *name,
                        host = host_create_from_chunk(AF_UNSPEC, address, 0);
                        if (host)
                        {
-                               DBG1("acquired new lease for address %H in pool '%s'",
+                               DBG1(DBG_CFG, "acquired new lease for address %H in pool '%s'",
                                         host, name);
                                return host;
                        }
                }
        }
-       DBG1("no available address found in pool '%s'", name);
+       DBG1(DBG_CFG, "no available address found in pool '%s'", name);
        return NULL;
 }
 
-/**
- * Implementation of attribute_provider_t.acquire_address
- */
-static host_t* acquire_address(private_sql_attribute_t *this,
-                                                          char *names, identification_t *id,
-                                                          host_t *requested)
+METHOD(attribute_provider_t, acquire_address, host_t*,
+       private_sql_attribute_t *this, char *name, identification_t *id,
+       host_t *requested)
 {
        host_t *address = NULL;
        u_int identity, pool, timeout;
@@ -225,128 +242,152 @@ static host_t* acquire_address(private_sql_attribute_t *this,
        identity = get_identity(this, id);
        if (identity)
        {
-               /* check for a single pool first (no concatenation and enumeration) */
-               if (strchr(names, ',') == NULL)
+               pool = get_pool(this, name, &timeout);
+               if (pool)
                {
-                       pool = get_pool(this, names, &timeout);
-                       if (pool)
+                       /* check for an existing lease */
+                       address = check_lease(this, name, pool, identity);
+                       if (address == NULL)
                        {
-                               /* check for an existing lease */
-                               address = check_lease(this, names, pool, identity);
-                               if (address == NULL)
-                               {
-                                       /* get an unallocated address or expired lease */
-                                       address = get_lease(this, names, pool, timeout, identity);
-                               }
+                               /* get an unallocated address or expired lease */
+                               address = get_lease(this, name, pool, timeout, identity);
                        }
                }
-               else
-               {
-                       enumerator_t *enumerator;
-                       char *name;
+       }
+       return address;
+}
 
-                       /* in a first step check for an existing lease over all pools */
-                       enumerator = enumerator_create_token(names, ",", " ");
-                       while (enumerator->enumerate(enumerator, &name))
-                       {
-                               pool = get_pool(this, name, &timeout);
-                               if (pool)
-                               {
-                                       address = check_lease(this, name, pool, identity);
-                                       if (address)
-                                       {
-                                               enumerator->destroy(enumerator);
-                                               return address;
-                                       }
-                               }
-                       }
-                       enumerator->destroy(enumerator);
+METHOD(attribute_provider_t, release_address, bool,
+       private_sql_attribute_t *this, char *name, host_t *address,
+       identification_t *id)
+{
+       u_int pool, timeout;
+       time_t now = time(NULL);
 
-                       /* in a second step get an unallocated address or expired lease */
-                       enumerator = enumerator_create_token(names, ",", " ");
-                       while (enumerator->enumerate(enumerator, &name))
-                       {
-                               pool = get_pool(this, name, &timeout);
-                               if (pool)
-                               {
-                                       address = get_lease(this, name, pool, timeout, identity);
-                                       if (address)
-                                       {
-                                               break;
-                                       }
-                               }
-                       }
-                       enumerator->destroy(enumerator);
+       pool = get_pool(this, name, &timeout);
+       if (pool)
+       {
+               if (this->history)
+               {
+                       this->db->execute(this->db, NULL,
+                               "INSERT INTO leases (address, identity, acquired, released)"
+                               " SELECT id, identity, acquired, ? FROM addresses "
+                               " WHERE pool = ? AND address = ?",
+                               DB_UINT, now, DB_UINT, pool,
+                               DB_BLOB, address->get_address(address));
+               }
+               if (this->db->execute(this->db, NULL,
+                               "UPDATE addresses SET released = ? WHERE "
+                               "pool = ? AND address = ?", DB_UINT, time(NULL),
+                               DB_UINT, pool, DB_BLOB, address->get_address(address)) > 0)
+               {
+                       return TRUE;
                }
        }
-       return address;
+       return FALSE;
 }
 
-/**
- * Implementation of attribute_provider_t.release_address
- */
-static bool release_address(private_sql_attribute_t *this,
-                                                       char *name, host_t *address, identification_t *id)
+METHOD(attribute_provider_t, create_attribute_enumerator, enumerator_t*,
+       private_sql_attribute_t *this, linked_list_t *pools, identification_t *id,
+       linked_list_t *vips)
 {
-       enumerator_t *enumerator;
-       bool found = FALSE;
-       time_t now = time(NULL);
+       enumerator_t *attr_enumerator = NULL;
 
-       enumerator = enumerator_create_token(name, ",", " ");
-       while (enumerator->enumerate(enumerator, &name))
+       if (vips->get_count(vips))
        {
-               u_int pool, timeout;
+               enumerator_t *pool_enumerator;
+               u_int count;
+               char *name;
 
-               pool = get_pool(this, name, &timeout);
-               if (pool)
+               this->db->execute(this->db, NULL, "BEGIN EXCLUSIVE TRANSACTION");
+
+               /* in a first step check for attributes that match name and id */
+               if (id)
                {
-                       if (this->history)
+                       u_int identity = get_identity(this, id);
+
+                       pool_enumerator = pools->create_enumerator(pools);
+                       while (pool_enumerator->enumerate(pool_enumerator, &name))
                        {
-                               this->db->execute(this->db, NULL,
-                                       "INSERT INTO leases (address, identity, acquired, released)"
-                                       " SELECT id, identity, acquired, ? FROM addresses "
-                                       " WHERE pool = ? AND address = ?",
-                                       DB_UINT, now, DB_UINT, pool,
-                                       DB_BLOB, address->get_address(address));
+                               u_int attr_pool = get_attr_pool(this, name);
+                               if (!attr_pool)
+                               {
+                                       continue;
+                               }
+
+                               attr_enumerator = this->db->query(this->db,
+                                                               "SELECT count(*) FROM attributes "
+                                                               "WHERE pool = ? AND identity = ?",
+                                                               DB_UINT, attr_pool, DB_UINT, identity, DB_UINT);
+
+                               if (attr_enumerator &&
+                                       attr_enumerator->enumerate(attr_enumerator, &count) &&
+                                       count != 0)
+                               {
+                                       attr_enumerator->destroy(attr_enumerator);
+                                       attr_enumerator = this->db->query(this->db,
+                                                               "SELECT type, value FROM attributes "
+                                                               "WHERE pool = ? AND identity = ?", DB_UINT,
+                                                               attr_pool, DB_UINT, identity, DB_INT, DB_BLOB);
+                                       break;
+                               }
+                               DESTROY_IF(attr_enumerator);
+                               attr_enumerator = NULL;
                        }
-                       if (this->db->execute(this->db, NULL,
-                                       "UPDATE addresses SET released = ? WHERE "
-                                       "pool = ? AND address = ?", DB_UINT, time(NULL),
-                                       DB_UINT, pool, DB_BLOB, address->get_address(address)) > 0)
+                       pool_enumerator->destroy(pool_enumerator);
+               }
+
+               /* in a second step check for attributes that match name */
+               if (!attr_enumerator)
+               {
+                       pool_enumerator = pools->create_enumerator(pools);
+                       while (pool_enumerator->enumerate(pool_enumerator, &name))
                        {
-                               found = TRUE;
-                               break;
+                               u_int attr_pool = get_attr_pool(this, name);
+                               if (!attr_pool)
+                               {
+                                       continue;
+                               }
+
+                               attr_enumerator = this->db->query(this->db,
+                                                                       "SELECT count(*) FROM attributes "
+                                                                       "WHERE pool = ? AND identity = 0",
+                                                                       DB_UINT, attr_pool, DB_UINT);
+
+                               if (attr_enumerator &&
+                                       attr_enumerator->enumerate(attr_enumerator, &count) &&
+                                       count != 0)
+                               {
+                                       attr_enumerator->destroy(attr_enumerator);
+                                       attr_enumerator = this->db->query(this->db,
+                                                                       "SELECT type, value FROM attributes "
+                                                                       "WHERE pool = ? AND identity = 0",
+                                                                       DB_UINT, attr_pool, DB_INT, DB_BLOB);
+                                       break;
+                               }
+                               DESTROY_IF(attr_enumerator);
+                               attr_enumerator = NULL;
                        }
+                       pool_enumerator->destroy(pool_enumerator);
                }
-       }
-       enumerator->destroy(enumerator);
-       return found;
-}
 
-/**
- * Implementation of sql_attribute_t.create_attribute_enumerator
- */
-static enumerator_t* create_attribute_enumerator(private_sql_attribute_t *this,
-                                                                                       identification_t *id, host_t *vip)
-{
-       if (vip)
-       {
-               enumerator_t *enumerator;
+               this->db->execute(this->db, NULL, "END TRANSACTION");
 
-               enumerator = this->db->query(this->db,
-                                               "SELECT type, value FROM attributes", DB_INT, DB_BLOB);
-               if (enumerator)
+               /* lastly try to find global attributes */
+               if (!attr_enumerator)
                {
-                       return enumerator;
+                       attr_enumerator = this->db->query(this->db,
+                                                                       "SELECT type, value FROM attributes "
+                                                                       "WHERE pool = 0 AND identity = 0",
+                                                                       DB_INT, DB_BLOB);
                }
        }
-       return enumerator_create_empty();
+
+       return (attr_enumerator ? attr_enumerator : enumerator_create_empty());
 }
 
-/**
- * Implementation of sql_attribute_t.destroy
- */
-static void destroy(private_sql_attribute_t *this)
+METHOD(sql_attribute_t, destroy, void,
+       private_sql_attribute_t *this)
 {
        free(this);
 }
@@ -356,17 +397,22 @@ static void destroy(private_sql_attribute_t *this)
  */
 sql_attribute_t *sql_attribute_create(database_t *db)
 {
-       private_sql_attribute_t *this = malloc_thing(private_sql_attribute_t);
+       private_sql_attribute_t *this;
        time_t now = time(NULL);
 
-       this->public.provider.acquire_address = (host_t*(*)(attribute_provider_t *this, char*, identification_t *, host_t *))acquire_address;
-       this->public.provider.release_address = (bool(*)(attribute_provider_t *this, char*,host_t *, identification_t*))release_address;
-       this->public.provider.create_attribute_enumerator = (enumerator_t*(*)(attribute_provider_t*, identification_t *id, host_t *host))create_attribute_enumerator;
-       this->public.destroy = (void(*)(sql_attribute_t*))destroy;
-
-       this->db = db;
-       this->history = lib->settings->get_bool(lib->settings,
-                                               "libhydra.plugins.attr-sql.lease_history", TRUE);
+       INIT(this,
+               .public = {
+                       .provider = {
+                               .acquire_address = _acquire_address,
+                               .release_address = _release_address,
+                               .create_attribute_enumerator = _create_attribute_enumerator,
+                       },
+                       .destroy = _destroy,
+               },
+               .db = db,
+               .history = lib->settings->get_bool(lib->settings,
+                                                       "libhydra.plugins.attr-sql.lease_history", TRUE),
+       );
 
        /* close any "online" leases in the case we crashed */
        if (this->history)