Use a helper function to add XFRM_MARK attribute
authorMartin Willi <martin@revosec.ch>
Fri, 15 Mar 2013 14:17:13 +0000 (15:17 +0100)
committerMartin Willi <martin@revosec.ch>
Fri, 15 Mar 2013 15:02:01 +0000 (16:02 +0100)
src/libhydra/plugins/kernel_netlink/kernel_netlink_ipsec.c

index 1c4b603..485a4d9 100644 (file)
@@ -1145,6 +1145,26 @@ METHOD(kernel_ipsec_t, get_cpi, status_t,
        return SUCCESS;
 }
 
+/**
+ * Add a XFRM mark to message if required
+ */
+static bool add_mark(struct nlmsghdr *hdr, int buflen, mark_t mark)
+{
+       if (mark.value)
+       {
+               struct xfrm_mark *xmrk;
+
+               xmrk = netlink_reserve(hdr, buflen, XFRMA_MARK, sizeof(*xmrk));
+               if (!xmrk)
+               {
+                       return FALSE;
+               }
+               xmrk->v = mark.value;
+               xmrk->m = mark.mask;
+       }
+       return TRUE;
+}
+
 METHOD(kernel_ipsec_t, add_sa, status_t,
        private_kernel_netlink_ipsec_t *this, host_t *src, host_t *dst,
        u_int32_t spi, u_int8_t protocol, u_int32_t reqid, mark_t mark,
@@ -1402,17 +1422,9 @@ METHOD(kernel_ipsec_t, add_sa, status_t,
                 * checks it marks them "checksum ok" so OA isn't needed. */
        }
 
-       if (mark.value)
+       if (!add_mark(hdr, sizeof(request), mark))
        {
-               struct xfrm_mark *mrk;
-
-               mrk = netlink_reserve(hdr, sizeof(request), XFRMA_MARK, sizeof(*mrk));
-               if (!mrk)
-               {
-                       goto failed;
-               }
-               mrk->v = mark.value;
-               mrk->m = mark.mask;
+               goto failed;
        }
 
        if (tfc)
@@ -1519,17 +1531,9 @@ static void get_replay_state(private_kernel_netlink_ipsec_t *this,
        aevent_id->sa_id.proto = protocol;
        aevent_id->sa_id.family = dst->get_family(dst);
 
-       if (mark.value)
+       if (!add_mark(hdr, sizeof(request), mark))
        {
-               struct xfrm_mark *mrk;
-
-               mrk = netlink_reserve(hdr, sizeof(request), XFRMA_MARK, sizeof(*mrk));
-               if (!mrk)
-               {
-                       return;
-               }
-               mrk->v = mark.value;
-               mrk->m = mark.mask;
+               return;
        }
 
        if (this->socket_xfrm->send(this->socket_xfrm, hdr, &out, &len) == SUCCESS)
@@ -1615,17 +1619,9 @@ METHOD(kernel_ipsec_t, query_sa, status_t,
        sa_id->proto = protocol;
        sa_id->family = dst->get_family(dst);
 
-       if (mark.value)
+       if (!add_mark(hdr, sizeof(request), mark))
        {
-               struct xfrm_mark *mrk;
-
-               mrk = netlink_reserve(hdr, sizeof(request), XFRMA_MARK, sizeof(*mrk));
-               if (!mrk)
-               {
-                       return FAILED;
-               }
-               mrk->v = mark.value;
-               mrk->m = mark.mask;
+               return FAILED;
        }
 
        if (this->socket_xfrm->send(this->socket_xfrm, hdr, &out, &len) == SUCCESS)
@@ -1713,17 +1709,9 @@ METHOD(kernel_ipsec_t, del_sa, status_t,
        sa_id->proto = protocol;
        sa_id->family = dst->get_family(dst);
 
-       if (mark.value)
+       if (!add_mark(hdr, sizeof(request), mark))
        {
-               struct xfrm_mark *mrk;
-
-               mrk = netlink_reserve(hdr, sizeof(request), XFRMA_MARK, sizeof(*mrk));
-               if (!mrk)
-               {
-                       return FAILED;
-               }
-               mrk->v = mark.value;
-               mrk->m = mark.mask;
+               return FAILED;
        }
 
        switch (this->socket_xfrm->send_ack(this->socket_xfrm, hdr))
@@ -1790,17 +1778,9 @@ METHOD(kernel_ipsec_t, update_sa, status_t,
        sa_id->proto = protocol;
        sa_id->family = dst->get_family(dst);
 
-       if (mark.value)
+       if (!add_mark(hdr, sizeof(request), mark))
        {
-               struct xfrm_mark *mrk;
-
-               mrk = netlink_reserve(hdr, sizeof(request), XFRMA_MARK, sizeof(*mrk));
-               if (!mrk)
-               {
-                       return FAILED;
-               }
-               mrk->v = mark.value;
-               mrk->m = mark.mask;
+               return FAILED;
        }
 
        if (this->socket_xfrm->send(this->socket_xfrm, hdr, &out, &len) == SUCCESS)
@@ -2077,18 +2057,10 @@ static status_t add_policy_internal(private_kernel_netlink_ipsec_t *this,
                }
        }
 
-       if (ipsec->mark.value)
+       if (!add_mark(hdr, sizeof(request), ipsec->mark))
        {
-               struct xfrm_mark *mrk;
-
-               mrk = netlink_reserve(hdr, sizeof(request), XFRMA_MARK, sizeof(*mrk));
-               if (!mrk)
-               {
-                       this->mutex->unlock(this->mutex);
-                       return FAILED;
-               }
-               mrk->v = ipsec->mark.value;
-               mrk->m = ipsec->mark.mask;
+               this->mutex->unlock(this->mutex);
+               return FAILED;
        }
        this->mutex->unlock(this->mutex);
 
@@ -2315,17 +2287,9 @@ METHOD(kernel_ipsec_t, query_policy, status_t,
        policy_id->sel = ts2selector(src_ts, dst_ts);
        policy_id->dir = direction;
 
-       if (mark.value)
+       if (!add_mark(hdr, sizeof(request), mark))
        {
-               struct xfrm_mark *mrk;
-
-               mrk = netlink_reserve(hdr, sizeof(request), XFRMA_MARK, sizeof(*mrk));
-               if (!mrk)
-               {
-                       return FAILED;
-               }
-               mrk->v = mark.value;
-               mrk->m = mark.mask;
+               return FAILED;
        }
 
        if (this->socket_xfrm->send(this->socket_xfrm, hdr, &out, &len) == SUCCESS)
@@ -2481,17 +2445,9 @@ METHOD(kernel_ipsec_t, del_policy, status_t,
        policy_id->sel = current->sel;
        policy_id->dir = direction;
 
-       if (mark.value)
+       if (!add_mark(hdr, sizeof(request), mark))
        {
-               struct xfrm_mark *mrk;
-
-               mrk = netlink_reserve(hdr, sizeof(request), XFRMA_MARK, sizeof(*mrk));
-               if (!mrk)
-               {
-                       return FAILED;
-               }
-               mrk->v = mark.value;
-               mrk->m = mark.mask;
+               return FAILED;
        }
 
        if (current->route)