Slightly refactor traffic_selector_t.get_subset()
[strongswan.git] / src / libstrongswan / selectors / traffic_selector.c
index 8b862a8..964e5a5 100644 (file)
@@ -22,9 +22,9 @@
 
 #include "traffic_selector.h"
 
-#include <utils/linked_list.h>
+#include <collections/linked_list.h>
 #include <utils/identification.h>
-#include <debug.h>
+#include <utils/debug.h>
 
 #define NON_SUBNET_ADDRESS_RANGE       255
 
@@ -174,13 +174,14 @@ static u_int8_t calc_netbits(private_traffic_selector_t *this)
 /**
  * internal generic constructor
  */
-static private_traffic_selector_t *traffic_selector_create(u_int8_t protocol, ts_type_t type, u_int16_t from_port, u_int16_t to_port);
+static private_traffic_selector_t *traffic_selector_create(u_int8_t protocol,
+                                               ts_type_t type, u_int16_t from_port, u_int16_t to_port);
 
 /**
  * Described in header.
  */
-int traffic_selector_printf_hook(char *dst, size_t len, printf_hook_spec_t *spec,
-                                                                const void *const *args)
+int traffic_selector_printf_hook(printf_hook_data_t *data,
+                                                       printf_hook_spec_t *spec, const void *const *args)
 {
        private_traffic_selector_t *this = *((private_traffic_selector_t**)(args[0]));
        linked_list_t *list = *((linked_list_t**)(args[0]));
@@ -195,7 +196,7 @@ int traffic_selector_printf_hook(char *dst, size_t len, printf_hook_spec_t *spec
 
        if (this == NULL)
        {
-               return print_in_hook(dst, len, "(null)");
+               return print_in_hook(data, "(null)");
        }
 
        if (spec->hash)
@@ -204,7 +205,7 @@ int traffic_selector_printf_hook(char *dst, size_t len, printf_hook_spec_t *spec
                while (enumerator->enumerate(enumerator, (void**)&this))
                {
                        /* call recursivly */
-                       written += print_in_hook(dst, len, "%R ", this);
+                       written += print_in_hook(data, "%R ", this);
                }
                enumerator->destroy(enumerator);
                return written;
@@ -216,7 +217,7 @@ int traffic_selector_printf_hook(char *dst, size_t len, printf_hook_spec_t *spec
                memeq(this->from, from, this->type == TS_IPV4_ADDR_RANGE ? 4 : 16) &&
                memeq(this->to, to, this->type == TS_IPV4_ADDR_RANGE ? 4 : 16))
        {
-               written += print_in_hook(dst, len, "dynamic");
+               written += print_in_hook(data, "dynamic");
        }
        else
        {
@@ -238,11 +239,11 @@ int traffic_selector_printf_hook(char *dst, size_t len, printf_hook_spec_t *spec
                        {
                                inet_ntop(AF_INET6, &this->to6, to_str, sizeof(to_str));
                        }
-                       written += print_in_hook(dst, len, "%s..%s", from_str, to_str);
+                       written += print_in_hook(data, "%s..%s", from_str, to_str);
                }
                else
                {
-                       written += print_in_hook(dst, len, "%s/%d", from_str, this->netbits);
+                       written += print_in_hook(data, "%s/%d", from_str, this->netbits);
                }
        }
 
@@ -255,7 +256,7 @@ int traffic_selector_printf_hook(char *dst, size_t len, printf_hook_spec_t *spec
                return written;
        }
 
-       written += print_in_hook(dst, len, "[");
+       written += print_in_hook(data, "[");
 
        /* build protocol string */
        if (has_proto)
@@ -264,18 +265,18 @@ int traffic_selector_printf_hook(char *dst, size_t len, printf_hook_spec_t *spec
 
                if (proto)
                {
-                       written += print_in_hook(dst, len, "%s", proto->p_name);
+                       written += print_in_hook(data, "%s", proto->p_name);
                        serv_proto = proto->p_name;
                }
                else
                {
-                       written += print_in_hook(dst, len, "%d", this->protocol);
+                       written += print_in_hook(data, "%d", this->protocol);
                }
        }
 
        if (has_proto && has_ports)
        {
-               written += print_in_hook(dst, len, "/");
+               written += print_in_hook(data, "/");
        }
 
        /* build port string */
@@ -283,104 +284,116 @@ int traffic_selector_printf_hook(char *dst, size_t len, printf_hook_spec_t *spec
        {
                if (this->from_port == this->to_port)
                {
-                       struct servent *serv = getservbyport(htons(this->from_port), serv_proto);
+                       struct servent *serv;
 
+                       serv = getservbyport(htons(this->from_port), serv_proto);
                        if (serv)
                        {
-                               written += print_in_hook(dst, len, "%s", serv->s_name);
+                               written += print_in_hook(data, "%s", serv->s_name);
                        }
                        else
                        {
-                               written += print_in_hook(dst, len, "%d", this->from_port);
+                               written += print_in_hook(data, "%d", this->from_port);
                        }
                }
                else
                {
-                       written += print_in_hook(dst, len, "%d-%d", this->from_port, this->to_port);
+                       written += print_in_hook(data, "%d-%d",
+                                                                        this->from_port, this->to_port);
                }
        }
 
-       written += print_in_hook(dst, len, "]");
+       written += print_in_hook(data, "]");
 
        return written;
 }
 
-/**
- * Implements traffic_selector_t.get_subset
- */
-static traffic_selector_t *get_subset(private_traffic_selector_t *this, private_traffic_selector_t *other)
+METHOD(traffic_selector_t, get_subset, traffic_selector_t*,
+       private_traffic_selector_t *this, traffic_selector_t *other_public)
 {
-       if (this->type == other->type && (this->protocol == other->protocol ||
-                                                               this->protocol == 0 || other->protocol == 0))
-       {
-               u_int16_t from_port, to_port;
-               u_char *from, *to;
-               u_int8_t protocol;
-               size_t size;
-               private_traffic_selector_t *new_ts;
-
-               /* calculate the maximum port range allowed for both */
-               from_port = max(this->from_port, other->from_port);
-               to_port = min(this->to_port, other->to_port);
-               if (from_port > to_port)
-               {
-                       return NULL;
-               }
-               /* select protocol, which is not zero */
-               protocol = max(this->protocol, other->protocol);
+       private_traffic_selector_t *other, *subset;
+       u_int16_t from_port, to_port;
+       u_char *from, *to;
+       u_int8_t protocol;
+       size_t size;
 
-               switch (this->type)
-               {
-                       case TS_IPV4_ADDR_RANGE:
-                               size = sizeof(this->from4);
-                               break;
-                       case TS_IPV6_ADDR_RANGE:
-                               size = sizeof(this->from6);
-                               break;
-                       default:
-                               return NULL;
-               }
+       other = (private_traffic_selector_t*)other_public;
 
-               /* get higher from-address */
-               if (memcmp(this->from, other->from, size) > 0)
-               {
-                       from = this->from;
-               }
-               else
-               {
-                       from = other->from;
-               }
-               /* get lower to-address */
-               if (memcmp(this->to, other->to, size) > 0)
-               {
-                       to = other->to;
-               }
-               else
-               {
-                       to = this->to;
-               }
-               /* if "from" > "to", we don't have a match */
-               if (memcmp(from, to, size) > 0)
-               {
+       if (this->dynamic || other->dynamic)
+       {       /* no set_address() applied, TS has no subset */
+               return NULL;
+       }
+
+       if (this->type != other->type)
+       {
+               return NULL;
+       }
+       switch (this->type)
+       {
+               case TS_IPV4_ADDR_RANGE:
+                       size = sizeof(this->from4);
+                       break;
+               case TS_IPV6_ADDR_RANGE:
+                       size = sizeof(this->from6);
+                       break;
+               default:
                        return NULL;
-               }
+       }
+
+       if (this->protocol != other->protocol &&
+               this->protocol != 0 && other->protocol != 0)
+       {
+               return NULL;
+       }
+       /* select protocol, which is not zero */
+       protocol = max(this->protocol, other->protocol);
 
-               /* we have a match in protocol, port, and address: return it... */
-               new_ts = traffic_selector_create(protocol, this->type, from_port, to_port);
-               new_ts->dynamic = this->dynamic || other->dynamic;
-               memcpy(new_ts->from, from, size);
-               memcpy(new_ts->to, to, size);
-               calc_netbits(new_ts);
-               return &new_ts->public;
+       /* calculate the maximum port range allowed for both */
+       from_port = max(this->from_port, other->from_port);
+       to_port = min(this->to_port, other->to_port);
+       if (from_port > to_port)
+       {
+               return NULL;
        }
-       return NULL;
+       /* get higher from-address */
+       if (memcmp(this->from, other->from, size) > 0)
+       {
+               from = this->from;
+       }
+       else
+       {
+               from = other->from;
+       }
+       /* get lower to-address */
+       if (memcmp(this->to, other->to, size) > 0)
+       {
+               to = other->to;
+       }
+       else
+       {
+               to = this->to;
+       }
+       /* if "from" > "to", we don't have a match */
+       if (memcmp(from, to, size) > 0)
+       {
+               return NULL;
+       }
+
+       /* we have a match in protocol, port, and address: return it... */
+       subset = traffic_selector_create(protocol, this->type, from_port, to_port);
+       memcpy(subset->from, from, size);
+       memcpy(subset->to, to, size);
+       calc_netbits(subset);
+
+       return &subset->public;
 }
 
-/**
- * Implements traffic_selector_t.equals
- */
-static bool equals(private_traffic_selector_t *this, private_traffic_selector_t *other)
+METHOD(traffic_selector_t, equals, bool,
+       private_traffic_selector_t *this, traffic_selector_t *other_public)
 {
+       private_traffic_selector_t *other;
+
+       other = (private_traffic_selector_t*)other_public;
        if (this->type != other->type)
        {
                return FALSE;
@@ -510,7 +523,7 @@ METHOD(traffic_selector_t, is_dynamic, bool,
 METHOD(traffic_selector_t, set_address, void,
        private_traffic_selector_t *this, host_t *host)
 {
-       if (this->dynamic)
+       if (is_host(this, NULL))
        {
                this->type = host->get_family(host) == AF_INET ?
                                TS_IPV4_ADDR_RANGE : TS_IPV6_ADDR_RANGE;
@@ -528,14 +541,12 @@ METHOD(traffic_selector_t, set_address, void,
                        memcpy(this->to, from.ptr, from.len);
                        this->netbits = from.len * 8;
                }
+               this->dynamic = FALSE;
        }
 }
 
-/**
- * Implements traffic_selector_t.is_contained_in.
- */
-static bool is_contained_in(private_traffic_selector_t *this,
-                                                       private_traffic_selector_t *other)
+METHOD(traffic_selector_t, is_contained_in, bool,
+       private_traffic_selector_t *this, traffic_selector_t *other)
 {
        private_traffic_selector_t *subset;
        bool contained_in = FALSE;
@@ -544,7 +555,7 @@ static bool is_contained_in(private_traffic_selector_t *this,
 
        if (subset)
        {
-               if (equals(subset, this))
+               if (equals(subset, &this->public))
                {
                        contained_in = TRUE;
                }
@@ -737,66 +748,36 @@ traffic_selector_t *traffic_selector_create_from_rfc3779_format(ts_type_t type,
 traffic_selector_t *traffic_selector_create_from_subnet(host_t *net,
                                                        u_int8_t netbits, u_int8_t protocol, u_int16_t port)
 {
-       private_traffic_selector_t *this = traffic_selector_create(protocol, 0, 0, 65535);
+       private_traffic_selector_t *this;
+       chunk_t from;
+
+       this = traffic_selector_create(protocol, 0, 0, 65535);
 
        switch (net->get_family(net))
        {
                case AF_INET:
-               {
-                       chunk_t from;
-
                        this->type = TS_IPV4_ADDR_RANGE;
-                       from = net->get_address(net);
-                       memcpy(this->from, from.ptr, from.len);
-                       if (this->from4[0] == 0)
-                       {
-                               /* use /0 for 0.0.0.0 */
-                               this->to4[0] = ~0;
-                               this->netbits = 0;
-                       }
-                       else
-                       {
-                               calc_range(this, netbits);
-                       }
                        break;
-               }
                case AF_INET6:
-               {
-                       chunk_t from;
-
                        this->type = TS_IPV6_ADDR_RANGE;
-                       from = net->get_address(net);
-                       memcpy(this->from, from.ptr, from.len);
-                       if (this->from6[0] == 0 && this->from6[1] == 0 &&
-                               this->from6[2] == 0 && this->from6[3] == 0)
-                       {
-                               /* use /0 for ::0 */
-                               this->to6[0] = ~0;
-                               this->to6[1] = ~0;
-                               this->to6[2] = ~0;
-                               this->to6[3] = ~0;
-                               this->netbits = 0;
-                       }
-                       else
-                       {
-                               calc_range(this, netbits);
-                       }
                        break;
-               }
                default:
-               {
                        net->destroy(net);
                        free(this);
                        return NULL;
-               }
        }
+       from = net->get_address(net);
+       memcpy(this->from, from.ptr, from.len);
+       netbits = min(netbits, this->type == TS_IPV4_ADDR_RANGE ? 32 : 128);
+       calc_range(this, netbits);
        if (port)
        {
                this->from_port = port;
                this->to_port = port;
        }
        net->destroy(net);
-       return (&this->public);
+
+       return &this->public;
 }
 
 /*
@@ -844,6 +825,23 @@ traffic_selector_t *traffic_selector_create_from_string(
 /*
  * see header
  */
+traffic_selector_t *traffic_selector_create_from_cidr(char *string,
+                                                                       u_int8_t protocol, u_int16_t port)
+{
+       host_t *net;
+       int bits;
+
+       net = host_create_from_subnet(string, &bits);
+       if (net)
+       {
+               return traffic_selector_create_from_subnet(net, bits, protocol, port);
+       }
+       return NULL;
+}
+
+/*
+ * see header
+ */
 traffic_selector_t *traffic_selector_create_dynamic(u_int8_t protocol,
                                                                        u_int16_t from_port, u_int16_t to_port)
 {
@@ -868,8 +866,8 @@ static private_traffic_selector_t *traffic_selector_create(u_int8_t protocol,
 
        INIT(this,
                .public = {
-                       .get_subset = (traffic_selector_t*(*)(traffic_selector_t*,traffic_selector_t*))get_subset,
-                       .equals = (bool(*)(traffic_selector_t*,traffic_selector_t*))equals,
+                       .get_subset = _get_subset,
+                       .equals = _equals,
                        .get_from_address = _get_from_address,
                        .get_to_address = _get_to_address,
                        .get_from_port = _get_from_port,
@@ -878,7 +876,7 @@ static private_traffic_selector_t *traffic_selector_create(u_int8_t protocol,
                        .get_protocol = _get_protocol,
                        .is_host = _is_host,
                        .is_dynamic = _is_dynamic,
-                       .is_contained_in = (bool(*)(traffic_selector_t*,traffic_selector_t*))is_contained_in,
+                       .is_contained_in = _is_contained_in,
                        .includes = _includes,
                        .set_address = _set_address,
                        .to_subnet = _to_subnet,
@@ -893,4 +891,3 @@ static private_traffic_selector_t *traffic_selector_create(u_int8_t protocol,
 
        return this;
 }
-