Implemented TLS Alert handling
[strongswan.git] / src / libtls / tls_fragmentation.c
index 06e1bcb..d69ef39 100644 (file)
 typedef struct private_tls_fragmentation_t private_tls_fragmentation_t;
 
 /**
+ * Alert state
+ */
+typedef enum {
+       /* no alert received/sent */
+       ALERT_NONE,
+       /* currently sending an alert */
+       ALERT_SENDING,
+       /* alert sent and out */
+       ALERT_SENT,
+} alert_state_t;
+
+/**
  * Private data of an tls_fragmentation_t object.
  */
 struct private_tls_fragmentation_t {
@@ -37,6 +49,16 @@ struct private_tls_fragmentation_t {
        tls_handshake_t *handshake;
 
        /**
+        * TLS alert handler
+        */
+       tls_alert_t *alert;
+
+       /**
+        * State of alert handling
+        */
+       alert_state_t state;
+
+       /**
         * Handshake input buffer
         */
        chunk_t input;
@@ -73,6 +95,23 @@ struct private_tls_fragmentation_t {
 #define MAX_TLS_HANDSHAKE_LEN 65536
 
 /**
+ * Process a TLS alert
+ */
+static status_t process_alert(private_tls_fragmentation_t *this,
+                                                         tls_reader_t *reader)
+{
+       u_int8_t level, description;
+
+       if (!reader->read_uint8(reader, &level) ||
+               !reader->read_uint8(reader, &description))
+       {
+               this->alert->add(this->alert, TLS_FATAL, TLS_DECODE_ERROR);
+               return NEED_MORE;
+       }
+       return this->alert->process(this->alert, level, description);
+}
+
+/**
  * Process TLS handshake protocol data
  */
 static status_t process_handshake(private_tls_fragmentation_t *this,
@@ -89,7 +128,8 @@ static status_t process_handshake(private_tls_fragmentation_t *this,
                if (reader->remaining(reader) > MAX_TLS_FRAGMENT_LEN)
                {
                        DBG1(DBG_TLS, "TLS fragment has invalid length");
-                       return FAILED;
+                       this->alert->add(this->alert, TLS_FATAL, TLS_DECODE_ERROR);
+                       return NEED_MORE;
                }
 
                if (this->input.len == 0)
@@ -97,13 +137,16 @@ static status_t process_handshake(private_tls_fragmentation_t *this,
                        if (!reader->read_uint8(reader, &type) ||
                                !reader->read_uint24(reader, &len))
                        {
-                               return FAILED;
+                               DBG1(DBG_TLS, "TLS handshake header invalid");
+                               this->alert->add(this->alert, TLS_FATAL, TLS_DECODE_ERROR);
+                               return NEED_MORE;
                        }
                        this->type = type;
                        if (len > MAX_TLS_HANDSHAKE_LEN)
                        {
                                DBG1(DBG_TLS, "TLS handshake message exceeds maximum length");
-                               return FAILED;
+                               this->alert->add(this->alert, TLS_FATAL, TLS_DECODE_ERROR);
+                               return NEED_MORE;
                        }
                        chunk_free(&this->input);
                        this->inpos = 0;
@@ -116,7 +159,9 @@ static status_t process_handshake(private_tls_fragmentation_t *this,
                len = min(this->input.len - this->inpos, reader->remaining(reader));
                if (!reader->read_data(reader, len, &data))
                {
-                       return FAILED;
+                       DBG1(DBG_TLS, "TLS fragment has invalid length");
+                       this->alert->add(this->alert, TLS_FATAL, TLS_DECODE_ERROR);
+                       return NEED_MORE;
                }
                memcpy(this->input.ptr + this->inpos, data.ptr, len);
                this->inpos += len;
@@ -151,12 +196,14 @@ static status_t process_application(private_tls_fragmentation_t *this,
                if (reader->remaining(reader) > MAX_TLS_FRAGMENT_LEN)
                {
                        DBG1(DBG_TLS, "TLS fragment has invalid length");
-                       return FAILED;
+                       this->alert->add(this->alert, TLS_FATAL, TLS_DECODE_ERROR);
+                       return NEED_MORE;
                }
                status = this->application->process(this->application, reader);
                if (status != NEED_MORE)
                {
-                       return status;
+                       this->alert->add(this->alert, TLS_FATAL, TLS_CLOSE_NOTIFY);
+                       return NEED_MORE;
                }
        }
        return NEED_MORE;
@@ -168,6 +215,15 @@ METHOD(tls_fragmentation_t, process, status_t,
        tls_reader_t *reader;
        status_t status;
 
+       switch (this->state)
+       {
+               case ALERT_SENDING:
+               case ALERT_SENT:
+                       /* don't accept more input, fatal error ocurred */
+                       return NEED_MORE;
+               case ALERT_NONE:
+                       break;
+       }
        reader = tls_reader_create(data);
        switch (type)
        {
@@ -180,8 +236,7 @@ METHOD(tls_fragmentation_t, process, status_t,
                        status = FAILED;
                        break;
                case TLS_ALERT:
-                       /* TODO: handle Alert */
-                       status = FAILED;
+                       status = process_alert(this, reader);
                        break;
                case TLS_HANDSHAKE:
                        status = process_handshake(this, reader);
@@ -198,6 +253,29 @@ METHOD(tls_fragmentation_t, process, status_t,
        return status;
 }
 
+/**
+ * Check if alerts are pending
+ */
+static bool check_alerts(private_tls_fragmentation_t *this, chunk_t *data)
+{
+       tls_alert_level_t level;
+       tls_alert_desc_t desc;
+       tls_writer_t *writer;
+
+       if (this->alert->get(this->alert, &level, &desc))
+       {
+               writer = tls_writer_create(2);
+
+               writer->write_uint8(writer, level);
+               writer->write_uint8(writer, desc);
+
+               *data = chunk_clone(writer->get_buf(writer));
+               writer->destroy(writer);
+               return TRUE;
+       }
+       return FALSE;
+}
+
 METHOD(tls_fragmentation_t, build, status_t,
        private_tls_fragmentation_t *this, tls_content_type_t *type, chunk_t *data)
 {
@@ -206,6 +284,22 @@ METHOD(tls_fragmentation_t, build, status_t,
        tls_writer_t *writer, *msg;
        status_t status = INVALID_STATE;
 
+       switch (this->state)
+       {
+               case ALERT_SENDING:
+                       this->state = ALERT_SENT;
+                       return INVALID_STATE;
+               case ALERT_SENT:
+                       return FAILED;
+               case ALERT_NONE:
+                       break;
+       }
+       if (check_alerts(this, data))
+       {
+               this->state = ALERT_SENDING;
+               *type = TLS_ALERT;
+               return NEED_MORE;
+       }
        if (this->handshake->cipherspec_changed(this->handshake))
        {
                *type = TLS_CHANGE_CIPHER_SPEC;
@@ -227,6 +321,16 @@ METHOD(tls_fragmentation_t, build, status_t,
                                        *type = TLS_APPLICATION_DATA;
                                        this->output = chunk_clone(msg->get_buf(msg));
                                }
+                               else if (status != NEED_MORE)
+                               {
+                                       this->alert->add(this->alert, TLS_FATAL, TLS_CLOSE_NOTIFY);
+                                       if (check_alerts(this, data))
+                                       {
+                                               this->state = ALERT_SENDING;
+                                               *type = TLS_ALERT;
+                                               return NEED_MORE;
+                                       }
+                               }
                        }
                }
                else
@@ -290,7 +394,7 @@ METHOD(tls_fragmentation_t, destroy, void,
  * See header
  */
 tls_fragmentation_t *tls_fragmentation_create(tls_handshake_t *handshake,
-                                                                                         tls_application_t *application)
+                                                       tls_alert_t *alert, tls_application_t *application)
 {
        private_tls_fragmentation_t *this;
 
@@ -301,6 +405,8 @@ tls_fragmentation_t *tls_fragmentation_create(tls_handshake_t *handshake,
                        .destroy = _destroy,
                },
                .handshake = handshake,
+               .alert = alert,
+               .state = ALERT_NONE,
                .application = application,
        );