socket-default: Add an option to force the sending interface via IP_PKTINFO
[strongswan.git] / src / libcharon / plugins / socket_default / socket_default_socket.c
index ba22b0c..109b3fe 100644 (file)
@@ -142,6 +142,11 @@ struct private_socket_default_socket_t {
        bool set_source;
 
        /**
+        * TRUE to force sending source interface on outbound packetrs
+        */
+       bool set_sourceif;
+
+       /**
         * A counter to implement round-robin selection of read sockets
         */
        u_int rr_counter;
@@ -362,12 +367,33 @@ static ssize_t send_msg_generic(int skt, struct msghdr *msg)
        return sendmsg(skt, msg, 0);
 }
 
+#if defined(IP_PKTINFO) || defined(HAVE_IN6_PKTINFO)
+
+/**
+ * Find the interface index a source address is installed on
+ */
+static int find_srcif(host_t *src)
+{
+       char *ifname;
+       int idx = 0;
+
+       if (charon->kernel->get_interface(charon->kernel, src, &ifname))
+       {
+               idx = if_nametoindex(ifname);
+               free(ifname);
+       }
+       return idx;
+}
+
+#endif /* IP_PKTINFO || HAVE_IN6_PKTINFO */
+
 /**
  * Send a message with the IPv4 source address set, if possible.
  */
 #ifdef IP_PKTINFO
 
-static ssize_t send_msg_v4(int skt, struct msghdr *msg, host_t *src)
+static ssize_t send_msg_v4(private_socket_default_socket_t *this, int skt,
+                                                  struct msghdr *msg, host_t *src)
 {
        char buf[CMSG_SPACE(sizeof(struct in_pktinfo))] = {};
        struct cmsghdr *cmsg;
@@ -383,6 +409,10 @@ static ssize_t send_msg_v4(int skt, struct msghdr *msg, host_t *src)
        cmsg->cmsg_len = CMSG_LEN(sizeof(struct in_pktinfo));
 
        pktinfo = (struct in_pktinfo*)CMSG_DATA(cmsg);
+       if (this->set_sourceif)
+       {
+               pktinfo->ipi_ifindex = find_srcif(src);
+       }
        addr = &pktinfo->ipi_spec_dst;
 
        sin = (struct sockaddr_in*)src->get_sockaddr(src);
@@ -392,7 +422,8 @@ static ssize_t send_msg_v4(int skt, struct msghdr *msg, host_t *src)
 
 #elif defined(IP_SENDSRCADDR)
 
-static ssize_t send_msg_v4(int skt, struct msghdr *msg, host_t *src)
+static ssize_t send_msg_v4(private_socket_default_socket_t *this, int skt,
+                                                  struct msghdr *msg, host_t *src)
 {
        char buf[CMSG_SPACE(sizeof(struct in_addr))] = {};
        struct cmsghdr *cmsg;
@@ -415,7 +446,8 @@ static ssize_t send_msg_v4(int skt, struct msghdr *msg, host_t *src)
 
 #else /* IP_PKTINFO || IP_RECVDSTADDR */
 
-static ssize_t send_msg_v4(int skt, struct msghdr *msg, host_t *src)
+static ssize_t send_msg_v4(private_socket_default_socket_t *this,
+                                                  int skt, struct msghdr *msg, host_t *src)
 {
        return send_msg_generic(skt, msg);
 }
@@ -427,7 +459,8 @@ static ssize_t send_msg_v4(int skt, struct msghdr *msg, host_t *src)
  */
 #ifdef HAVE_IN6_PKTINFO
 
-static ssize_t send_msg_v6(int skt, struct msghdr *msg, host_t *src)
+static ssize_t send_msg_v6(private_socket_default_socket_t *this, int skt,
+                                                  struct msghdr *msg, host_t *src)
 {
        char buf[CMSG_SPACE(sizeof(struct in6_pktinfo))] = {};
        struct cmsghdr *cmsg;
@@ -441,6 +474,10 @@ static ssize_t send_msg_v6(int skt, struct msghdr *msg, host_t *src)
        cmsg->cmsg_type = IPV6_PKTINFO;
        cmsg->cmsg_len = CMSG_LEN(sizeof(struct in6_pktinfo));
        pktinfo = (struct in6_pktinfo*)CMSG_DATA(cmsg);
+       if (this->set_sourceif)
+       {
+               pktinfo->ipi6_ifindex = find_srcif(src);
+       }
        sin = (struct sockaddr_in6*)src->get_sockaddr(src);
        memcpy(&pktinfo->ipi6_addr, &sin->sin6_addr, sizeof(struct in6_addr));
        return send_msg_generic(skt, msg);
@@ -448,7 +485,8 @@ static ssize_t send_msg_v6(int skt, struct msghdr *msg, host_t *src)
 
 #else /* HAVE_IN6_PKTINFO */
 
-static ssize_t send_msg_v6(int skt, struct msghdr *msg, host_t *src)
+static ssize_t send_msg_v6(private_socket_default_socket_t *this,
+                                                  int skt, struct msghdr *msg, host_t *src)
 {
        return send_msg_generic(skt, msg);
 }
@@ -564,11 +602,11 @@ METHOD(socket_t, sender, status_t,
        {
                if (family == AF_INET)
                {
-                       bytes_sent = send_msg_v4(skt, &msg, src);
+                       bytes_sent = send_msg_v4(this, skt, &msg, src);
                }
                else
                {
-                       bytes_sent = send_msg_v6(skt, &msg, src);
+                       bytes_sent = send_msg_v6(this, skt, &msg, src);
                }
        }
        else
@@ -831,6 +869,9 @@ socket_default_socket_t *socket_default_socket_create()
                .set_source = lib->settings->get_bool(lib->settings,
                                                        "%s.plugins.socket-default.set_source", TRUE,
                                                        lib->ns),
+               .set_sourceif = lib->settings->get_bool(lib->settings,
+                                                       "%s.plugins.socket-default.set_sourceif", FALSE,
+                                                       lib->ns),
        );
 
        if (this->port && this->port == this->natt)