some tls_eap optimizations
[strongswan.git] / src / libtls / tls_eap.c
index f74030b..6134318 100644 (file)
@@ -1,3 +1,4 @@
+
 /*
  * Copyright (C) 2010 Martin Willi
  * Copyright (C) 2010 revosec AG
@@ -36,11 +37,16 @@ struct private_tls_eap_t {
        tls_eap_t public;
 
        /**
-        * Type of EAP method, EAP-TLS or EAP-TTLS
+        * Type of EAP method, EAP-TLS, EAP-TTLS, or EAP-TNC
         */
        eap_type_t type;
 
        /**
+        * Current value of EAP identifier
+        */
+       u_int8_t identifier;
+
+       /**
         * TLS stack
         */
        tls_t *tls;
@@ -51,6 +57,13 @@ struct private_tls_eap_t {
        bool is_server;
 
        /**
+        * If FALSE include the total length of an EAP message
+        * in the first fragment of fragmented messages only.
+        * If TRUE also include the length in non-fragmented messages.
+        */
+       bool include_length;
+
+       /**
         * First fragment of a multi-fragment record?
         */
        bool first_fragment;
@@ -59,20 +72,31 @@ struct private_tls_eap_t {
         * Maximum size of an outgoing EAP-TLS fragment
         */
        size_t frag_size;
+
+       /**
+        * Number of EAP messages/fragments processed so far
+        */
+       int processed;
+
+       /**
+        * Maximum number of processed EAP messages/fragments 
+        */
+       int max_msg_count;
 };
 
 /**
- * Flags of an EAP-TLS/TTLS message
+ * Flags of an EAP-TLS/TTLS/TNC message
  */
 typedef enum {
-       EAP_TLS_LENGTH = (1<<7),
-       EAP_TLS_MORE_FRAGS = (1<<6),
-       EAP_TLS_START = (1<<5),
-       EAP_TTLS_VERSION = (0x07),
+       EAP_TLS_LENGTH = (1<<7),                /* shared with EAP-TTLS/TNC/PEAP */
+       EAP_TLS_MORE_FRAGS = (1<<6),    /* shared with EAP-TTLS/TNC/PEAP */
+       EAP_TLS_START = (1<<5),                 /* shared with EAP-TTLS/TNC/PEAP */
+       EAP_TTLS_VERSION = (0x07),              /* shared with EAP-TNC/PEAP      */
 } eap_tls_flags_t;
 
 #define EAP_TTLS_SUPPORTED_VERSION     0
 #define EAP_TNC_SUPPORTED_VERSION      1
+#define EAP_PEAP_SUPPORTED_VERSION     0
 
 /**
  * EAP-TLS/TTLS packet format
@@ -103,18 +127,19 @@ METHOD(tls_eap_t, initiate, status_t,
                        case EAP_TNC:
                                pkt.flags |= EAP_TNC_SUPPORTED_VERSION;
                                break;
+                       case EAP_PEAP:
+                               pkt.flags |= EAP_PEAP_SUPPORTED_VERSION;
+                               break;
                        default:
                                break;
                }
                htoun16(&pkt.length, sizeof(eap_tls_packet_t));
-               do
-               {       /* start with non-zero random identifier */
-                       pkt.identifier = random();
-               }
-               while (!pkt.identifier);
+               pkt.identifier = this->identifier;
 
-               DBG2(DBG_IKE, "sending %N start packet", eap_type_names, this->type);
                *out = chunk_clone(chunk_from_thing(pkt));
+               DBG2(DBG_TLS, "sending %N start packet (%u bytes)",
+                                          eap_type_names, this->type, sizeof(eap_tls_packet_t));
+               DBG3(DBG_TLS, "%B", out);
                return NEED_MORE;
        }
        return FAILED;
@@ -125,10 +150,12 @@ METHOD(tls_eap_t, initiate, status_t,
  */
 static status_t process_pkt(private_tls_eap_t *this, eap_tls_packet_t *pkt)
 {
-       u_int32_t msg_len;
        u_int16_t pkt_len;
+       u_int32_t msg_len;
+       size_t msg_len_offset = 0;
 
        pkt_len = untoh16(&pkt->length);
+
        if (pkt->flags & EAP_TLS_LENGTH)
        {
                if (pkt_len < sizeof(eap_tls_packet_t) + sizeof(msg_len))
@@ -140,31 +167,35 @@ static status_t process_pkt(private_tls_eap_t *this, eap_tls_packet_t *pkt)
                if (msg_len < pkt_len - sizeof(eap_tls_packet_t) - sizeof(msg_len) ||
                        msg_len > MAX_TLS_MESSAGE_LEN)
                {
-                       DBG1(DBG_TLS, "invalid %N packet length", eap_type_names, this->type);
+                       DBG1(DBG_TLS, "invalid %N packet length (%u bytes)", eap_type_names,
+                                this->type, msg_len);
                        return FAILED;
                }
-               return this->tls->process(this->tls, (char*)(pkt + 1) + sizeof(msg_len),
-                                               pkt_len - sizeof(eap_tls_packet_t) - sizeof(msg_len));
+               msg_len_offset = sizeof(msg_len);
        }
-       return this->tls->process(this->tls, (char*)(pkt + 1),
-                                                         pkt_len - sizeof(eap_tls_packet_t));
+
+       return this->tls->process(this->tls, (char*)(pkt + 1) + msg_len_offset,
+                                          pkt_len - sizeof(eap_tls_packet_t) - msg_len_offset);
 }
 
 /**
  * Build a packet to send
  */
-static status_t build_pkt(private_tls_eap_t *this,
-                                                 u_int8_t identifier, chunk_t *out)
+static status_t build_pkt(private_tls_eap_t *this, chunk_t *out)
 {
        char buf[this->frag_size];
        eap_tls_packet_t *pkt;
-       size_t len, reclen;
+       size_t len, reclen, msg_len_offset;
        status_t status;
        char *kind;
 
+       if (this->is_server)
+       {
+               this->identifier++;
+       }
        pkt = (eap_tls_packet_t*)buf;
        pkt->code = this->is_server ? EAP_REQUEST : EAP_RESPONSE;
-       pkt->identifier = this->is_server ? identifier + 1 : identifier;
+       pkt->identifier = this->identifier;
        pkt->type = this->type;
        pkt->flags = 0;
 
@@ -176,23 +207,26 @@ static status_t build_pkt(private_tls_eap_t *this,
                case EAP_TNC:
                        pkt->flags |= EAP_TNC_SUPPORTED_VERSION;
                        break;
+               case EAP_PEAP:
+                       pkt->flags |= EAP_PEAP_SUPPORTED_VERSION;
+                       break;
                default:
                        break;
        }
 
        if (this->first_fragment)
        {
-               pkt->flags |= EAP_TLS_LENGTH;
                len = sizeof(buf) - sizeof(eap_tls_packet_t) - sizeof(u_int32_t);
-               status = this->tls->build(this->tls, buf + sizeof(eap_tls_packet_t) +
-                                                                 sizeof(u_int32_t), &len, &reclen);
+               msg_len_offset = sizeof(u_int32_t);
        }
        else
        {
                len = sizeof(buf) - sizeof(eap_tls_packet_t);
-               status = this->tls->build(this->tls, buf + sizeof(eap_tls_packet_t),
-                                                                 &len, &reclen);
+               msg_len_offset = 0;
        }
+       status = this->tls->build(this->tls, buf + sizeof(eap_tls_packet_t) +
+                                                                                msg_len_offset, &len, &reclen);
+
        switch (status)
        {
                case NEED_MORE:
@@ -200,13 +234,21 @@ static status_t build_pkt(private_tls_eap_t *this,
                        kind = "further fragment";
                        if (this->first_fragment)
                        {
+                       pkt->flags |= EAP_TLS_LENGTH;
                                this->first_fragment = FALSE;
                                kind = "first fragment";
                        }
                        break;
                case ALREADY_DONE:
-                       kind = "packet";
-                       if (!this->first_fragment)
+                       if (this->first_fragment)
+                       {
+                               if (this->include_length)
+                               {
+                                       pkt->flags |= EAP_TLS_LENGTH;
+                               }
+                               kind = "packet";
+                       }
+                       else
                        {
                                this->first_fragment = TRUE;
                                kind = "final fragment";
@@ -215,31 +257,61 @@ static status_t build_pkt(private_tls_eap_t *this,
                default:
                        return status;
        }
-       DBG2(DBG_TLS, "sending %N %s (%u bytes)",
-                eap_type_names, this->type, kind, len);
        if (reclen)
        {
-               htoun32(pkt + 1, reclen);
-               len += sizeof(u_int32_t);
-               pkt->flags |= EAP_TLS_LENGTH;
+               if (pkt->flags & EAP_TLS_LENGTH)
+               { 
+                       htoun32(pkt + 1, reclen);
+                       len += sizeof(u_int32_t);
+                       pkt->flags |= EAP_TLS_LENGTH;
+               }
+               else
+               {
+                       /* get rid of the reserved length field */
+                       memcpy(buf+sizeof(eap_packet_t),
+                                  buf+sizeof(eap_packet_t)+sizeof(u_int32_t), len);    
+               }
        }
        len += sizeof(eap_tls_packet_t);
        htoun16(&pkt->length, len);
        *out = chunk_clone(chunk_create(buf, len));
+       DBG2(DBG_TLS, "sending %N %s (%u bytes)",
+                                  eap_type_names, this->type, kind, len);
+       DBG3(DBG_TLS, "%B", out);
        return NEED_MORE;
 }
 
 /**
  * Send an ack to request next fragment
  */
-static chunk_t create_ack(private_tls_eap_t *this, u_int8_t identifier)
+static chunk_t create_ack(private_tls_eap_t *this)
 {
        eap_tls_packet_t pkt = {
                .code = this->is_server ? EAP_REQUEST : EAP_RESPONSE,
-               .identifier = this->is_server ? identifier + 1 : identifier,
                .type = this->type,
        };
+
+       if (this->is_server)
+       {
+               this->identifier++;
+       }
+       pkt.identifier = this->identifier;
        htoun16(&pkt.length, sizeof(pkt));
+
+       switch (this->type)
+       {
+               case EAP_TTLS:
+                       pkt.flags |= EAP_TTLS_SUPPORTED_VERSION;
+                       break;
+               case EAP_TNC:
+                       pkt.flags |= EAP_TNC_SUPPORTED_VERSION;
+                       break;
+               case EAP_PEAP:
+                       pkt.flags |= EAP_PEAP_SUPPORTED_VERSION;
+                       break;
+               default:
+                       break;
+       }
        DBG2(DBG_TLS, "sending %N acknowledgement packet",
                 eap_type_names, this->type);
        return chunk_clone(chunk_from_thing(pkt));
@@ -251,17 +323,32 @@ METHOD(tls_eap_t, process, status_t,
        eap_tls_packet_t *pkt;
        status_t status;
 
+       if (this->max_msg_count && ++this->processed > this->max_msg_count)
+       {
+               DBG1(DBG_TLS, "%N packet count exceeded (%d > %d)",
+                        eap_type_names, this->type,
+                        this->processed, this->max_msg_count);
+               return FAILED;
+       }
+
        pkt = (eap_tls_packet_t*)in.ptr;
-       if (in.len < sizeof(eap_tls_packet_t) ||
-               untoh16(&pkt->length) != in.len)
+       if (in.len < sizeof(eap_tls_packet_t) || untoh16(&pkt->length) != in.len)
        {
-               DBG1(DBG_IKE, "invalid %N packet length",
-                        eap_type_names, this->type);
+               DBG1(DBG_TLS, "invalid %N packet length", eap_type_names, this->type);
                return FAILED;
        }
+
+       /* update EAP identifier */
+       if (!this->is_server)
+       {
+               this->identifier = pkt->identifier;
+       }
+       DBG3(DBG_TLS, "%N payload %B", eap_type_names, this->type, &in);
+
        if (pkt->flags & EAP_TLS_START)
        {
-               if (this->type == EAP_TTLS || this->type == EAP_TNC)
+               if (this->type == EAP_TTLS || this->type == EAP_TNC ||
+                       this->type == EAP_PEAP)
                {
                        DBG1(DBG_TLS, "%N version is v%u", eap_type_names, this->type,
                                 pkt->flags & EAP_TTLS_VERSION);
@@ -273,30 +360,34 @@ METHOD(tls_eap_t, process, status_t,
                {
                        DBG2(DBG_TLS, "received %N acknowledgement packet",
                                 eap_type_names, this->type);
-                       status = build_pkt(this, pkt->identifier, out);
-                       if (status == INVALID_STATE &&
-                               this->tls->is_complete(this->tls))
+                       status = build_pkt(this, out);
+                       if (status == INVALID_STATE && this->tls->is_complete(this->tls))
                        {
                                return SUCCESS;
                        }
                        return status;
                }
                status = process_pkt(this, pkt);
-               if (status != NEED_MORE)
+               switch (status)
                {
-                       return status;
+                       case NEED_MORE:
+                               break;
+                       case SUCCESS:
+                               return this->tls->is_complete(this->tls) ? SUCCESS : FAILED;
+                       default:
+                               return status;
                }
        }
-       status = build_pkt(this, pkt->identifier, out);
+       status = build_pkt(this, out);
        switch (status)
        {
                case INVALID_STATE:
-                       *out = create_ack(this, pkt->identifier);
+                       *out = create_ack(this);
                        return NEED_MORE;
                case FAILED:
                        if (!this->is_server)
                        {
-                               *out = create_ack(this, pkt->identifier);
+                               *out = create_ack(this);
                                return NEED_MORE;
                        }
                        return FAILED;
@@ -311,6 +402,18 @@ METHOD(tls_eap_t, get_msk, chunk_t,
        return this->tls->get_eap_msk(this->tls);
 }
 
+METHOD(tls_eap_t, get_identifier, u_int8_t,
+       private_tls_eap_t *this)
+{
+       return this->identifier;
+}
+
+METHOD(tls_eap_t, set_identifier, void,
+       private_tls_eap_t *this, u_int8_t identifier)
+{
+       this->identifier = identifier;
+}
+
 METHOD(tls_eap_t, destroy, void,
        private_tls_eap_t *this)
 {
@@ -321,23 +424,42 @@ METHOD(tls_eap_t, destroy, void,
 /**
  * See header
  */
-tls_eap_t *tls_eap_create(eap_type_t type, tls_t *tls, size_t frag_size)
+tls_eap_t *tls_eap_create(eap_type_t type, tls_t *tls, size_t frag_size,
+                                                 int max_msg_count, bool include_length)
 {
        private_tls_eap_t *this;
 
+       if (!tls)
+       {
+               return NULL;
+       }
+
        INIT(this,
                .public = {
                        .initiate = _initiate,
                        .process = _process,
                        .get_msk = _get_msk,
+                       .get_identifier = _get_identifier,
+                       .set_identifier = _set_identifier,
                        .destroy = _destroy,
                },
                .type = type,
                .is_server = tls->is_server(tls),
                .first_fragment = TRUE,
                .frag_size = frag_size,
+               .max_msg_count = max_msg_count,
+               .include_length = include_length,
                .tls = tls,
        );
 
+       if (this->is_server)
+       {
+               do
+               {       /* start with non-zero random identifier */
+                       this->identifier = random();
+               }
+               while (!this->identifier);
+       }
+
        return &this->public;
 }