kernel-netlink: Consider RTA_SRC when looking for a source address
authorTobias Brunner <tobias@strongswan.org>
Tue, 23 Aug 2016 10:48:37 +0000 (12:48 +0200)
committerTobias Brunner <tobias@strongswan.org>
Wed, 5 Oct 2016 09:44:53 +0000 (11:44 +0200)
src/libcharon/plugins/kernel_netlink/kernel_netlink_net.c

index 93c2ccc..b9d3269 100644 (file)
@@ -702,6 +702,54 @@ static void addr_map_entry_remove(hashtable_t *map, addr_entry_t *addr,
 }
 
 /**
+ * Check if an address or net (addr with prefix net bits) is in
+ * subnet (net with net_len net bits)
+ */
+static bool addr_in_subnet(chunk_t addr, int prefix, chunk_t net, int net_len)
+{
+       static const u_char mask[] = { 0x00, 0x80, 0xc0, 0xe0, 0xf0, 0xf8, 0xfc, 0xfe };
+       int byte = 0;
+
+       if (net_len == 0)
+       {       /* any address matches a /0 network */
+               return TRUE;
+       }
+       if (addr.len != net.len || net_len > 8 * net.len || prefix < net_len)
+       {
+               return FALSE;
+       }
+       /* scan through all bytes in network order */
+       while (net_len > 0)
+       {
+               if (net_len < 8)
+               {
+                       return (mask[net_len] & addr.ptr[byte]) == (mask[net_len] & net.ptr[byte]);
+               }
+               else
+               {
+                       if (addr.ptr[byte] != net.ptr[byte])
+                       {
+                               return FALSE;
+                       }
+                       byte++;
+                       net_len -= 8;
+               }
+       }
+       return TRUE;
+}
+
+/**
+ * Check if the given address is in subnet (net with net_len net bits)
+ */
+static bool host_in_subnet(host_t *host, chunk_t net, int net_len)
+{
+       chunk_t addr;
+
+       addr = host->get_address(host);
+       return addr_in_subnet(addr, addr.len * 8, net, net_len);
+}
+
+/**
  * Determine the type or scope of the given unicast IP address.  This is not
  * the same thing returned in rtm_scope/ifa_scope.
  *
@@ -837,7 +885,8 @@ static bool is_address_better(private_kernel_netlink_net_t *this,
 }
 
 /**
- * Get a non-virtual IP address on the given interface.
+ * Get a non-virtual IP address on the given interfaces and optionally in a
+ * given subnet.
  *
  * If a candidate address is given, we first search for that address and if not
  * found return the address as above.
@@ -845,19 +894,21 @@ static bool is_address_better(private_kernel_netlink_net_t *this,
  *
  * this->lock must be held when calling this function.
  */
-static host_t *get_interface_address(private_kernel_netlink_net_t *this,
-                                                                        int ifindex, int family, host_t *dest,
-                                                                        host_t *candidate)
+static host_t *get_matching_address(private_kernel_netlink_net_t *this,
+                                                                       int *ifindex, int family, chunk_t net,
+                                                                       uint8_t mask, host_t *dest,
+                                                                       host_t *candidate)
 {
+       enumerator_t *ifaces, *addrs;
        iface_entry_t *iface;
-       enumerator_t *addrs;
        addr_entry_t *addr, *best = NULL;
+       bool candidate_matched = FALSE;
 
-       if (this->ifaces->find_first(this->ifaces, (void*)iface_entry_by_index,
-                                                                (void**)&iface, &ifindex) == SUCCESS)
+       ifaces = this->ifaces->create_enumerator(this->ifaces);
+       while (ifaces->enumerate(ifaces, &iface))
        {
-               if (iface->usable)
-               {       /* only use interfaces not excluded by config */
+               if (iface->usable && (!ifindex || iface->ifindex == *ifindex))
+               {       /* only use matching interfaces not excluded by config */
                        addrs = iface->addrs->create_enumerator(iface->addrs);
                        while (addrs->enumerate(addrs, &addr))
                        {
@@ -866,9 +917,14 @@ static host_t *get_interface_address(private_kernel_netlink_net_t *this,
                                {       /* ignore virtual IP addresses and ensure family matches */
                                        continue;
                                }
+                               if (net.ptr && !host_in_subnet(addr->ip, net, mask))
+                               {       /* optionally match a subnet */
+                                       continue;
+                               }
                                if (candidate && candidate->ip_equals(candidate, addr->ip))
                                {       /* stop if we find the candidate */
                                        best = addr;
+                                       candidate_matched = TRUE;
                                        break;
                                }
                                else if (!best || is_address_better(this, best, addr, dest))
@@ -877,12 +933,50 @@ static host_t *get_interface_address(private_kernel_netlink_net_t *this,
                                }
                        }
                        addrs->destroy(addrs);
+                       if (ifindex || candidate_matched)
+                       {
+                               break;
+                       }
                }
        }
+       ifaces->destroy(ifaces);
        return best ? best->ip->clone(best->ip) : NULL;
 }
 
 /**
+ * Get a non-virtual IP address on the given interface.
+ *
+ * If a candidate address is given, we first search for that address and if not
+ * found return the address as above.
+ * Returned host is a clone, has to be freed by caller.
+ *
+ * this->lock must be held when calling this function.
+ */
+static host_t *get_interface_address(private_kernel_netlink_net_t *this,
+                                                                        int ifindex, int family, host_t *dest,
+                                                                        host_t *candidate)
+{
+       return get_matching_address(this, &ifindex, family, chunk_empty, 0, dest,
+                                                               candidate);
+}
+
+/**
+ * Get a non-virtual IP address in the given subnet.
+ *
+ * If a candidate address is given, we first search for that address and if not
+ * found return the address as above.
+ * Returned host is a clone, has to be freed by caller.
+ *
+ * this->lock must be held when calling this function.
+ */
+static host_t *get_subnet_address(private_kernel_netlink_net_t *this,
+                                                                 int family, chunk_t net, uint8_t mask,
+                                                                 host_t *dest, host_t *candidate)
+{
+       return get_matching_address(this, NULL, family, net, mask, dest, candidate);
+}
+
+/**
  * callback function that raises the delayed roam event
  */
 static job_requeue_t roam_event(private_kernel_netlink_net_t *this)
@@ -1528,51 +1622,16 @@ static char *get_interface_name_by_index(private_kernel_netlink_net_t *this,
 }
 
 /**
- * check if an address or net (addr with prefix net bits) is in
- * subnet (net with net_len net bits)
- */
-static bool addr_in_subnet(chunk_t addr, int prefix, chunk_t net, int net_len)
-{
-       static const u_char mask[] = { 0x00, 0x80, 0xc0, 0xe0, 0xf0, 0xf8, 0xfc, 0xfe };
-       int byte = 0;
-
-       if (net_len == 0)
-       {       /* any address matches a /0 network */
-               return TRUE;
-       }
-       if (addr.len != net.len || net_len > 8 * net.len || prefix < net_len)
-       {
-               return FALSE;
-       }
-       /* scan through all bytes in network order */
-       while (net_len > 0)
-       {
-               if (net_len < 8)
-               {
-                       return (mask[net_len] & addr.ptr[byte]) == (mask[net_len] & net.ptr[byte]);
-               }
-               else
-               {
-                       if (addr.ptr[byte] != net.ptr[byte])
-                       {
-                               return FALSE;
-                       }
-                       byte++;
-                       net_len -= 8;
-               }
-       }
-       return TRUE;
-}
-
-/**
  * Store information about a route retrieved via RTNETLINK
  */
 typedef struct {
        chunk_t gtw;
-       chunk_t src;
+       chunk_t pref_src;
        chunk_t dst;
+       chunk_t src;
        host_t *src_host;
        uint8_t dst_len;
+       uint8_t src_len;
        uint32_t table;
        uint32_t oif;
        uint32_t priority;
@@ -1626,9 +1685,11 @@ static rt_entry_t *parse_route(struct nlmsghdr *hdr, rt_entry_t *route)
        if (route)
        {
                route->gtw = chunk_empty;
-               route->src = chunk_empty;
+               route->pref_src = chunk_empty;
                route->dst = chunk_empty;
                route->dst_len = msg->rtm_dst_len;
+               route->src = chunk_empty;
+               route->src_len = msg->rtm_src_len;
                route->table = msg->rtm_table;
                route->oif = 0;
                route->priority = 0;
@@ -1637,6 +1698,7 @@ static rt_entry_t *parse_route(struct nlmsghdr *hdr, rt_entry_t *route)
        {
                INIT(route,
                        .dst_len = msg->rtm_dst_len,
+                       .src_len = msg->rtm_src_len,
                        .table = msg->rtm_table,
                );
        }
@@ -1646,7 +1708,7 @@ static rt_entry_t *parse_route(struct nlmsghdr *hdr, rt_entry_t *route)
                switch (rta->rta_type)
                {
                        case RTA_PREFSRC:
-                               route->src = chunk_create(RTA_DATA(rta), RTA_PAYLOAD(rta));
+                               route->pref_src = chunk_create(RTA_DATA(rta), RTA_PAYLOAD(rta));
                                break;
                        case RTA_GATEWAY:
                                route->gtw = chunk_create(RTA_DATA(rta), RTA_PAYLOAD(rta));
@@ -1654,6 +1716,9 @@ static rt_entry_t *parse_route(struct nlmsghdr *hdr, rt_entry_t *route)
                        case RTA_DST:
                                route->dst = chunk_create(RTA_DATA(rta), RTA_PAYLOAD(rta));
                                break;
+                       case RTA_SRC:
+                               route->src = chunk_create(RTA_DATA(rta), RTA_PAYLOAD(rta));
+                               break;
                        case RTA_OIF:
                                if (RTA_PAYLOAD(rta) == sizeof(route->oif))
                                {
@@ -1790,10 +1855,10 @@ static host_t *get_route(private_kernel_netlink_net_t *this, host_t *dest,
                                {       /* route destination does not contain dest */
                                        continue;
                                }
-                               if (route->src.ptr)
+                               if (route->pref_src.ptr)
                                {       /* verify source address, if any */
                                        host_t *src = host_create_from_chunk(msg->rtm_family,
-                                                                                                                route->src, 0);
+                                                                                                                route->pref_src, 0);
                                        if (src && is_known_vip(this, src))
                                        {       /* ignore routes installed by us */
                                                src->destroy(src);
@@ -1863,12 +1928,29 @@ static host_t *get_route(private_kernel_netlink_net_t *this, host_t *dest,
                        best = best ?: route;
                        continue;
                }
+               if (route->src.ptr)
+               {       /* no src, but a source selector, try to find a matching address */
+                       route->src_host = get_subnet_address(this, msg->rtm_family,
+                                                                                       route->src, route->src_len, dest,
+                                                                                       candidate);
+                       if (route->src_host)
+                       {       /* we handle this address the same as the one above */
+                               if (!candidate ||
+                                        candidate->ip_equals(candidate, route->src_host))
+                               {
+                                       best = route;
+                                       break;
+                               }
+                               best = best ?: route;
+                               continue;
+                       }
+               }
                if (route->oif)
                {       /* no src, but an interface - get address from it */
                        route->src_host = get_interface_address(this, route->oif,
                                                                                        msg->rtm_family, dest, candidate);
                        if (route->src_host)
-                       {       /* we handle this address the same as the one above */
+                       {       /* more of the same */
                                if (!candidate ||
                                         candidate->ip_equals(candidate, route->src_host))
                                {