implemented get_iface() using RTM_GETADDR
authorMartin Willi <martin@strongswan.org>
Fri, 2 Mar 2007 17:08:38 +0000 (17:08 -0000)
committerMartin Willi <martin@strongswan.org>
Fri, 2 Mar 2007 17:08:38 +0000 (17:08 -0000)
added support for multi-header netlink messages
really ugly now, need a lot of refactoring

src/charon/sa/ike_sa.c
src/charon/threads/kernel_interface.c

index aac786c..3bee151 100644 (file)
@@ -1620,11 +1620,11 @@ static void set_virtual_ip(private_ike_sa_t *this, bool local, host_t *ip)
                        DBG1(DBG_IKE, "removing old virtual IP %H", this->my_virtual_ip);
                        charon->kernel_interface->del_ip(charon->kernel_interface,
                                                                                         this->my_virtual_ip,
-                                                                                        this->other_host);
+                                                                                        this->my_host);
                        this->my_virtual_ip->destroy(this->my_virtual_ip);
                }
                if (charon->kernel_interface->add_ip(charon->kernel_interface, ip,
-                                                                                        this->other_host) == SUCCESS)
+                                                                                        this->my_host) == SUCCESS)
                {
                        this->my_virtual_ip = ip->clone(ip);
                }
@@ -1848,7 +1848,7 @@ static void destroy(private_ike_sa_t *this)
        if (this->my_virtual_ip)
        {
                charon->kernel_interface->del_ip(charon->kernel_interface,
-                                                                                this->my_virtual_ip, this->other_host);
+                                                                                this->my_virtual_ip, this->my_host);
                this->my_virtual_ip->destroy(this->my_virtual_ip);
        }
        DESTROY_IF(this->other_virtual_ip);
index 1da82c0..b7460ed 100644 (file)
@@ -308,30 +308,24 @@ static status_t send_message(private_kernel_interface_t *this,
        addr.nl_pid = 0;
        addr.nl_groups = 0;
 
-       /*
-       // set timeout to 10 secs
-       struct timespec tm;
-       tm.tv_sec = 10;
-        */
-       
        length = sendto(socket,(void *)request, request->nlmsg_len, 0, 
                                        (struct sockaddr *)&addr, sizeof(addr));
-       DBG2(DBG_IKE, "%d bytes sent to kernel", length);
        
        if (length < 0)
        {
-               DBG1(DBG_IKE,"0 byte could be sent");
                return FAILED;
        }
        else if (length != request->nlmsg_len)
        {
-               DBG1(DBG_IKE,"Request length %d does not match the sent bytes %d",
-                               request->nlmsg_len, length);
+               DBG1(DBG_KNL, "error sending to netlink socket: %m");
                return FAILED;
        }
        
        pthread_mutex_lock(&(this->rep_mutex));
        
+       DBG3(DBG_KNL, "waiting for netlink message with seq: %d",
+                        request->nlmsg_seq);
+       
        while (TRUE)
        {
                iterator_t *iterator;
@@ -358,8 +352,6 @@ static status_t send_message(private_kernel_interface_t *this,
                        break;
                }
                /* TODO: we should time out, if something goes wrong!??? */
-               //if(pthread_cond_timedwait(&(this->condvar), &(this->rep_mutex), &tm) == ETIMEDOUT)
-               //      return FAILED;
                pthread_cond_wait(&(this->condvar), &(this->rep_mutex));
        }
        
@@ -368,6 +360,8 @@ static status_t send_message(private_kernel_interface_t *this,
        return SUCCESS;
 }
 
+static int supersocket;
+
 /**
  * Reads from a netlink socket and returns the message in a buffer.
  */
@@ -379,13 +373,14 @@ static void netlink_package_receiver(int socket, unsigned char *response, int re
                socklen_t addr_length;
                size_t length;
                addr_length = sizeof(addr);
-               
+       
                length = recvfrom(socket, response, response_size, 0, (struct sockaddr*)&addr, &addr_length);
                if (length < 0)
                {
                        if (errno == EINTR)
                        {
                                /* interrupted, try again */
+                               DBG1(DBG_IKE, "wtf1");
                                continue;
                        }
                        charon->kill(charon, "receiving from netlink socket failed\n");
@@ -498,9 +493,9 @@ static void receive_xfrm_messages(private_kernel_interface_t *this)
                 * updating SAs.
                 * XFRM_MSG_NEWPOLICY is returned when we query a policy.
                 */
-               else if (hdr->nlmsg_type == NLMSG_ERROR
-                                       || hdr->nlmsg_type == XFRM_MSG_NEWSA
-                                       || hdr->nlmsg_type == XFRM_MSG_NEWPOLICY)
+               else if (hdr->nlmsg_type == NLMSG_ERROR || 
+                                hdr->nlmsg_type == XFRM_MSG_NEWSA || 
+                                hdr->nlmsg_type == XFRM_MSG_NEWPOLICY)
                {
                        add_to_package_list(this, response);
                }
@@ -518,16 +513,19 @@ static void receive_rt_messages(private_kernel_interface_t *this)
 {
        while(TRUE)
        {
-               unsigned char response[BUFFER_SIZE];
+               unsigned char response[BUFFER_SIZE*3];
                struct nlmsghdr *hdr;
-               netlink_package_receiver(this->rt_socket,response,BUFFER_SIZE);
+               supersocket = this->rt_socket;
+               netlink_package_receiver(this->rt_socket,response, sizeof(response));
                
                hdr = (struct nlmsghdr*)response;
                /* NLMSG_ERROR is sent back for acknowledge (or on error).
                 * RTM_NEWROUTE is returned when we add a route.
                 */
                if (hdr->nlmsg_type == NLMSG_ERROR ||
-                       hdr->nlmsg_type == RTM_NEWROUTE)
+                       hdr->nlmsg_type == RTM_NEWROUTE ||
+                       hdr->nlmsg_type == RTM_NEWLINK ||
+                       hdr->nlmsg_type == RTM_NEWADDR)
                {
                        add_to_package_list(this, response);
                }
@@ -1338,7 +1336,7 @@ static status_t add_policy(private_kernel_interface_t *this,
                policy->route = malloc_thing(rt_refcount_t);
                if (find_addr_by_ts(dst_ts, &policy->route->src_ip) == SUCCESS)
                {
-                       policy->route->if_index = get_iface(this, src);
+                       policy->route->if_index = get_iface(this, dst);
                        policy->route->dst_net = chunk_alloc(policy->sel.family == AF_INET ? 4 : 16);
                        memcpy(policy->route->dst_net.ptr, &policy->sel.saddr, policy->route->dst_net.len);
                        policy->route->prefixlen = policy->sel.prefixlen_s;
@@ -1610,65 +1608,173 @@ static status_t manage_ipaddr(private_kernel_interface_t *this, int nlmsg_type,
        return send_rtrequest(this, hdr);
 }
 
-static int get_iface(private_kernel_interface_t *this, host_t* ip)
+/**
+ * send a netlink message and wait for a reply
+ */
+static status_t netlink_send(int socket, struct nlmsghdr *in,
+                                                        struct nlmsghdr **out, size_t *out_len)
 {
-       unsigned char request[BUFFER_SIZE];
-       struct nlmsghdr *hdr;
-       struct rtmsg *msg;
-       struct rtattr* rta;
-       chunk_t chunk;
-       int ifindex = 0;
-
-       DBG2(DBG_KNL, "getting interface for %H", ip);
+       int len, addr_len;
+       struct sockaddr_nl addr;
+       chunk_t result = chunk_empty, tmp;
+       struct nlmsghdr *msg, peek;
        
-       memset(&request, 0, sizeof(request));
+       static int seq = 200;
+       static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
        
-       chunk = ip->get_address(ip);
-    
-    hdr = (struct nlmsghdr*)request;
-       hdr->nlmsg_flags = NLM_F_REQUEST;
-       hdr->nlmsg_type = RTM_GETROUTE; 
-       hdr->nlmsg_len = NLMSG_LENGTH(sizeof(struct rtmsg));
        
-       msg = (struct rtmsg*)NLMSG_DATA(hdr);
-       msg->rtm_family = ip->get_family(ip);
-       msg->rtm_table = 0;
-       msg->rtm_protocol = 0;
-       msg->rtm_scope = 0;
-       msg->rtm_type = 0;
-       msg->rtm_src_len = 0;
-       msg->rtm_dst_len = 8 * chunk.len;
-       msg->rtm_tos = 0;
-       msg->rtm_flags = RT_TABLE_UNSPEC | RTPROT_UNSPEC;
+       pthread_mutex_lock(&mutex);
+       
+       in->nlmsg_seq = ++seq;
+       in->nlmsg_pid = getpid();
+       
+       memset(&addr, 0, sizeof(addr));
+       addr.nl_family = AF_NETLINK;
+       addr.nl_pid = 0;
+       addr.nl_groups = 0;
 
-       if (add_rtattr(hdr, sizeof(request), RTA_DST,
-                                  chunk.ptr, chunk.len) != SUCCESS)
+       while (TRUE)
        {
-               return 0;
+               len = sendto(socket, in, in->nlmsg_len, 0, 
+                                        (struct sockaddr*)&addr, sizeof(addr));
+               
+               if (len != in->nlmsg_len)
+               {       
+                       if (errno == EINTR)
+                       {
+                               /* interrupted, try again */
+                               continue;
+                       }
+                       pthread_mutex_unlock(&mutex);
+                       DBG1(DBG_KNL, "error sending to netlink socket: %m");
+                       return FAILED;
+               }
+               break;
        }
-
-       if(send_message(this, hdr, &hdr, this->rt_socket) != SUCCESS)
-       {
-               return 0;
+       
+       for(;;)
+       {       
+               tmp = chunk_alloca(2048);
+               msg = (struct nlmsghdr*)tmp.ptr;
+       
+               len = recvfrom(socket, tmp.ptr, tmp.len, 0,
+                                          (struct sockaddr*)&addr, &addr_len);
+               if (len < 0)
+               {
+                       if (errno == EINTR)
+                       {
+                               /* interrupted, try again */
+                               continue;
+                       }
+                       DBG1(DBG_IKE, "error reading from netlink socket: %m");
+                       pthread_mutex_unlock(&mutex);
+                       return FAILED;
+               }
+               if (!NLMSG_OK(msg, len))
+               {
+                       DBG1(DBG_IKE, "received corrupted netlink message");
+                       pthread_mutex_unlock(&mutex);
+                       return FAILED;
+               }
+               if (msg->nlmsg_seq != seq)
+               {
+                       DBG1(DBG_IKE, "received invalid netlink sequence number");
+                       if (msg->nlmsg_seq < seq)
+                       {
+                               continue;
+                       }
+                       pthread_mutex_unlock(&mutex);
+                       return FAILED;
+               }
+               
+               tmp.len = len;
+               result = chunk_cata("cc", result, tmp);
+               
+               /* NLM_F_MULTI flag does not seem to be set correctly, we use sequence
+                * numbers to detect multi header messages */
+               len = recvfrom(socket, &peek, sizeof(peek), MSG_PEEK | MSG_DONTWAIT,
+                                          (struct sockaddr*)&addr, &addr_len);
+               
+               if (len == sizeof(peek) && peek.nlmsg_seq == seq)
+               {
+                       /* seems to be multipart */
+                       continue;
+               }
+               break;
        }
-       rta = (struct rtattr*)(NLMSG_DATA(hdr) + NLMSG_LENGTH(sizeof(struct rtmsg)));
+       
+       *out_len = result.len;
+       *out = (struct nlmsghdr*)clalloc(result.ptr, result.len);
+       
+       pthread_mutex_unlock(&mutex);
+       
+       return SUCCESS;
+}
+
+static int get_iface(private_kernel_interface_t *this, host_t* ip)
+{
+       unsigned char request[BUFFER_SIZE];
+       struct nlmsghdr *hdr, *tofree;
+       struct rtgenmsg *msg;
+       int ifindex = 0;
+       size_t len;
+       chunk_t target, current;
+       
+       memset(&request, 0, sizeof(request));
 
-       DBG1(DBG_KNL, "listing attributes:");
-       while(RTA_OK(rta, hdr->nlmsg_len))
+    hdr = (struct nlmsghdr*)request;
+       hdr->nlmsg_len = NLMSG_LENGTH(sizeof(struct rtgenmsg));
+       hdr->nlmsg_type = RTM_GETADDR;
+       hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_MATCH | NLM_F_ROOT;
+       
+       msg = (struct rtgenmsg*)NLMSG_DATA(hdr);
+       msg->rtgen_family = AF_UNSPEC;
+       
+       target = ip->get_address(ip);
+               
+       if (netlink_send(this->rt_socket, hdr, &hdr, &len) == SUCCESS)
        {
-               DBG1(DBG_KNL, "  found rtattr: %d, data %b", rta->rta_type,
-                        RTA_DATA(rta), rta->rta_len - 4);
-               if(rta->rta_type == RTA_OIF)
+               tofree = hdr;
+               while (NLMSG_OK(hdr, len))
                {
-                       ifindex = *((int*)RTA_DATA(rta));
+                       switch (hdr->nlmsg_type)
+                       {
+                               case RTM_NEWADDR:
+                               {
+                                       struct ifaddrmsg* msg = (struct ifaddrmsg*)(NLMSG_DATA(hdr));
+                               struct rtattr *rta = IFA_RTA(msg);
+                               size_t rtasize = IFA_PAYLOAD (hdr);
+                               
+                                       while(RTA_OK(rta, rtasize))
+                                       {
+                                               if (rta->rta_type == IFA_ADDRESS)
+                                               {
+                                                       current.len = rta->rta_len - 4;
+                                                       current.ptr = RTA_DATA(rta);
+                                                       if (chunk_equals(current, target))
+                                                       {
+                                                               ifindex = msg->ifa_index;
+                                                               break;
+                                                       }
+                                               }
+                                               rta = RTA_NEXT(rta, rtasize);
+                                       }
+                                       hdr = NLMSG_NEXT(hdr, len);
+                                       continue;
+                               }
+                               default:
+                                       hdr = NLMSG_NEXT(hdr, len);
+                                       continue;
+                               case NLMSG_DONE:
+                                       break;
+                       }
                        break;
                }
-               rta = RTA_NEXT(rta, hdr->nlmsg_len);
+               free(tofree);
        }
-       free(hdr);
-       if (ifindex == 0)
+       else
        {
-               DBG1(DBG_KNL, "address %H not reachable, unable to get interface", ip);
+               DBG1(DBG_IKE, "unable to get interface address for %H", ip);
        }
        return ifindex;
 }
@@ -1765,7 +1871,7 @@ static void vip_refcount_destroy(vip_refcount_t *this)
  * Implementation of kernel_interface_t.add_ip.
  */
 static status_t add_ip(private_kernel_interface_t *this, 
-                                               host_t *virtual_ip, host_t *dst_ip)
+                                               host_t *virtual_ip, host_t *iface_ip)
 {
        int targetif;
        vip_refcount_t *listed;
@@ -1773,7 +1879,7 @@ static status_t add_ip(private_kernel_interface_t *this,
 
        DBG2(DBG_KNL, "adding ip addr: %H", virtual_ip);
 
-       targetif = get_iface(this, dst_ip);
+       targetif = get_iface(this, iface_ip);
        if (targetif == 0)
        {
                return FAILED;
@@ -1811,7 +1917,7 @@ static status_t add_ip(private_kernel_interface_t *this,
  * Implementation of kernel_interface_t.del_ip.
  */
 static status_t del_ip(private_kernel_interface_t *this,
-                                               host_t *virtual_ip, host_t *dst_ip)
+                                               host_t *virtual_ip, host_t *iface_ip)
 {
        int targetif;
        vip_refcount_t *listed;
@@ -1819,7 +1925,7 @@ static status_t del_ip(private_kernel_interface_t *this,
 
        DBG2(DBG_KNL, "deleting ip addr: %H", virtual_ip);
 
-       targetif = get_iface(this, dst_ip);
+       targetif = get_iface(this, iface_ip);
        if (targetif == 0)
        {
                return FAILED;
@@ -1909,7 +2015,7 @@ kernel_interface_t *kernel_interface_create()
        
        /* bind the xfrm socket and reqister for ACQUIRE & EXPIRE */
        addr_xfrm.nl_family = AF_NETLINK;
-       addr_xfrm.nl_pid = getpid();
+       addr_xfrm.nl_pid = 0;
        addr_xfrm.nl_groups = XFRMGRP_ACQUIRE | XFRMGRP_EXPIRE;
        if (bind(this->xfrm_socket, (struct sockaddr*)&addr_xfrm, sizeof(addr_xfrm)))
        {
@@ -1934,7 +2040,7 @@ kernel_interface_t *kernel_interface_create()
 
        /* bind the socket_rt */
        addr_rt.nl_family = AF_NETLINK;
-       addr_rt.nl_pid = getpid();
+       addr_rt.nl_pid = 0;
        addr_rt.nl_groups = 0;
        if (bind(this->rt_socket, (struct sockaddr*)&addr_rt, sizeof(addr_rt)))
        {
@@ -1942,7 +2048,7 @@ kernel_interface_t *kernel_interface_create()
                goto kill_rt;
        }
        
-       if (pthread_create(&this->rt_thread, NULL, 
+       if (pthread_create(&this->rt_thread, NULL,
                                           (void*(*)(void*))receive_rt_messages, this))
        {
                DBG1(DBG_KNL, "Unable to create rt netlink thread");