some tls_eap optimizations
[strongswan.git] / src / libtls / tls_eap.c
index 226422c..6134318 100644 (file)
@@ -1,3 +1,4 @@
+
 /*
  * Copyright (C) 2010 Martin Willi
  * Copyright (C) 2010 revosec AG
 /*
  * Copyright (C) 2010 Martin Willi
  * Copyright (C) 2010 revosec AG
@@ -56,6 +57,13 @@ struct private_tls_eap_t {
        bool is_server;
 
        /**
        bool is_server;
 
        /**
+        * If FALSE include the total length of an EAP message
+        * in the first fragment of fragmented messages only.
+        * If TRUE also include the length in non-fragmented messages.
+        */
+       bool include_length;
+
+       /**
         * First fragment of a multi-fragment record?
         */
        bool first_fragment;
         * First fragment of a multi-fragment record?
         */
        bool first_fragment;
@@ -128,8 +136,10 @@ METHOD(tls_eap_t, initiate, status_t,
                htoun16(&pkt.length, sizeof(eap_tls_packet_t));
                pkt.identifier = this->identifier;
 
                htoun16(&pkt.length, sizeof(eap_tls_packet_t));
                pkt.identifier = this->identifier;
 
-               DBG2(DBG_IKE, "sending %N start packet", eap_type_names, this->type);
                *out = chunk_clone(chunk_from_thing(pkt));
                *out = chunk_clone(chunk_from_thing(pkt));
+               DBG2(DBG_TLS, "sending %N start packet (%u bytes)",
+                                          eap_type_names, this->type, sizeof(eap_tls_packet_t));
+               DBG3(DBG_TLS, "%B", out);
                return NEED_MORE;
        }
        return FAILED;
                return NEED_MORE;
        }
        return FAILED;
@@ -140,10 +150,12 @@ METHOD(tls_eap_t, initiate, status_t,
  */
 static status_t process_pkt(private_tls_eap_t *this, eap_tls_packet_t *pkt)
 {
  */
 static status_t process_pkt(private_tls_eap_t *this, eap_tls_packet_t *pkt)
 {
-       u_int32_t msg_len;
        u_int16_t pkt_len;
        u_int16_t pkt_len;
+       u_int32_t msg_len;
+       size_t msg_len_offset = 0;
 
        pkt_len = untoh16(&pkt->length);
 
        pkt_len = untoh16(&pkt->length);
+
        if (pkt->flags & EAP_TLS_LENGTH)
        {
                if (pkt_len < sizeof(eap_tls_packet_t) + sizeof(msg_len))
        if (pkt->flags & EAP_TLS_LENGTH)
        {
                if (pkt_len < sizeof(eap_tls_packet_t) + sizeof(msg_len))
@@ -155,14 +167,15 @@ static status_t process_pkt(private_tls_eap_t *this, eap_tls_packet_t *pkt)
                if (msg_len < pkt_len - sizeof(eap_tls_packet_t) - sizeof(msg_len) ||
                        msg_len > MAX_TLS_MESSAGE_LEN)
                {
                if (msg_len < pkt_len - sizeof(eap_tls_packet_t) - sizeof(msg_len) ||
                        msg_len > MAX_TLS_MESSAGE_LEN)
                {
-                       DBG1(DBG_TLS, "invalid %N packet length", eap_type_names, this->type);
+                       DBG1(DBG_TLS, "invalid %N packet length (%u bytes)", eap_type_names,
+                                this->type, msg_len);
                        return FAILED;
                }
                        return FAILED;
                }
-               return this->tls->process(this->tls, (char*)(pkt + 1) + sizeof(msg_len),
-                                               pkt_len - sizeof(eap_tls_packet_t) - sizeof(msg_len));
+               msg_len_offset = sizeof(msg_len);
        }
        }
-       return this->tls->process(this->tls, (char*)(pkt + 1),
-                                                         pkt_len - sizeof(eap_tls_packet_t));
+
+       return this->tls->process(this->tls, (char*)(pkt + 1) + msg_len_offset,
+                                          pkt_len - sizeof(eap_tls_packet_t) - msg_len_offset);
 }
 
 /**
 }
 
 /**
@@ -172,7 +185,7 @@ static status_t build_pkt(private_tls_eap_t *this, chunk_t *out)
 {
        char buf[this->frag_size];
        eap_tls_packet_t *pkt;
 {
        char buf[this->frag_size];
        eap_tls_packet_t *pkt;
-       size_t len, reclen;
+       size_t len, reclen, msg_len_offset;
        status_t status;
        char *kind;
 
        status_t status;
        char *kind;
 
@@ -203,17 +216,17 @@ static status_t build_pkt(private_tls_eap_t *this, chunk_t *out)
 
        if (this->first_fragment)
        {
 
        if (this->first_fragment)
        {
-               pkt->flags |= EAP_TLS_LENGTH;
                len = sizeof(buf) - sizeof(eap_tls_packet_t) - sizeof(u_int32_t);
                len = sizeof(buf) - sizeof(eap_tls_packet_t) - sizeof(u_int32_t);
-               status = this->tls->build(this->tls, buf + sizeof(eap_tls_packet_t) +
-                                                                 sizeof(u_int32_t), &len, &reclen);
+               msg_len_offset = sizeof(u_int32_t);
        }
        else
        {
                len = sizeof(buf) - sizeof(eap_tls_packet_t);
        }
        else
        {
                len = sizeof(buf) - sizeof(eap_tls_packet_t);
-               status = this->tls->build(this->tls, buf + sizeof(eap_tls_packet_t),
-                                                                 &len, &reclen);
+               msg_len_offset = 0;
        }
        }
+       status = this->tls->build(this->tls, buf + sizeof(eap_tls_packet_t) +
+                                                                                msg_len_offset, &len, &reclen);
+
        switch (status)
        {
                case NEED_MORE:
        switch (status)
        {
                case NEED_MORE:
@@ -221,13 +234,21 @@ static status_t build_pkt(private_tls_eap_t *this, chunk_t *out)
                        kind = "further fragment";
                        if (this->first_fragment)
                        {
                        kind = "further fragment";
                        if (this->first_fragment)
                        {
+                       pkt->flags |= EAP_TLS_LENGTH;
                                this->first_fragment = FALSE;
                                kind = "first fragment";
                        }
                        break;
                case ALREADY_DONE:
                                this->first_fragment = FALSE;
                                kind = "first fragment";
                        }
                        break;
                case ALREADY_DONE:
-                       kind = "packet";
-                       if (!this->first_fragment)
+                       if (this->first_fragment)
+                       {
+                               if (this->include_length)
+                               {
+                                       pkt->flags |= EAP_TLS_LENGTH;
+                               }
+                               kind = "packet";
+                       }
+                       else
                        {
                                this->first_fragment = TRUE;
                                kind = "final fragment";
                        {
                                this->first_fragment = TRUE;
                                kind = "final fragment";
@@ -236,17 +257,27 @@ static status_t build_pkt(private_tls_eap_t *this, chunk_t *out)
                default:
                        return status;
        }
                default:
                        return status;
        }
-       DBG2(DBG_TLS, "sending %N %s (%u bytes)",
-                eap_type_names, this->type, kind, len);
        if (reclen)
        {
        if (reclen)
        {
-               htoun32(pkt + 1, reclen);
-               len += sizeof(u_int32_t);
-               pkt->flags |= EAP_TLS_LENGTH;
+               if (pkt->flags & EAP_TLS_LENGTH)
+               { 
+                       htoun32(pkt + 1, reclen);
+                       len += sizeof(u_int32_t);
+                       pkt->flags |= EAP_TLS_LENGTH;
+               }
+               else
+               {
+                       /* get rid of the reserved length field */
+                       memcpy(buf+sizeof(eap_packet_t),
+                                  buf+sizeof(eap_packet_t)+sizeof(u_int32_t), len);    
+               }
        }
        len += sizeof(eap_tls_packet_t);
        htoun16(&pkt->length, len);
        *out = chunk_clone(chunk_create(buf, len));
        }
        len += sizeof(eap_tls_packet_t);
        htoun16(&pkt->length, len);
        *out = chunk_clone(chunk_create(buf, len));
+       DBG2(DBG_TLS, "sending %N %s (%u bytes)",
+                                  eap_type_names, this->type, kind, len);
+       DBG3(DBG_TLS, "%B", out);
        return NEED_MORE;
 }
 
        return NEED_MORE;
 }
 
@@ -292,9 +323,9 @@ METHOD(tls_eap_t, process, status_t,
        eap_tls_packet_t *pkt;
        status_t status;
 
        eap_tls_packet_t *pkt;
        status_t status;
 
-       if (++this->processed > this->max_msg_count)
+       if (this->max_msg_count && ++this->processed > this->max_msg_count)
        {
        {
-               DBG1(DBG_IKE, "%N packet count exceeded (%d > %d)",
+               DBG1(DBG_TLS, "%N packet count exceeded (%d > %d)",
                         eap_type_names, this->type,
                         this->processed, this->max_msg_count);
                return FAILED;
                         eap_type_names, this->type,
                         this->processed, this->max_msg_count);
                return FAILED;
@@ -303,12 +334,16 @@ METHOD(tls_eap_t, process, status_t,
        pkt = (eap_tls_packet_t*)in.ptr;
        if (in.len < sizeof(eap_tls_packet_t) || untoh16(&pkt->length) != in.len)
        {
        pkt = (eap_tls_packet_t*)in.ptr;
        if (in.len < sizeof(eap_tls_packet_t) || untoh16(&pkt->length) != in.len)
        {
-               DBG1(DBG_IKE, "invalid %N packet length", eap_type_names, this->type);
+               DBG1(DBG_TLS, "invalid %N packet length", eap_type_names, this->type);
                return FAILED;
        }
 
        /* update EAP identifier */
                return FAILED;
        }
 
        /* update EAP identifier */
-       this->identifier = pkt->identifier;
+       if (!this->is_server)
+       {
+               this->identifier = pkt->identifier;
+       }
+       DBG3(DBG_TLS, "%N payload %B", eap_type_names, this->type, &in);
 
        if (pkt->flags & EAP_TLS_START)
        {
 
        if (pkt->flags & EAP_TLS_START)
        {
@@ -390,7 +425,7 @@ METHOD(tls_eap_t, destroy, void,
  * See header
  */
 tls_eap_t *tls_eap_create(eap_type_t type, tls_t *tls, size_t frag_size,
  * See header
  */
 tls_eap_t *tls_eap_create(eap_type_t type, tls_t *tls, size_t frag_size,
-                                                 int max_msg_count)
+                                                 int max_msg_count, bool include_length)
 {
        private_tls_eap_t *this;
 
 {
        private_tls_eap_t *this;
 
@@ -413,6 +448,7 @@ tls_eap_t *tls_eap_create(eap_type_t type, tls_t *tls, size_t frag_size,
                .first_fragment = TRUE,
                .frag_size = frag_size,
                .max_msg_count = max_msg_count,
                .first_fragment = TRUE,
                .frag_size = frag_size,
                .max_msg_count = max_msg_count,
+               .include_length = include_length,
                .tls = tls,
        );
 
                .tls = tls,
        );