Implemented TLS Alert handling
[strongswan.git] / src / libtls / tls_protection.c
index 90b30f9..574e691 100644 (file)
@@ -30,9 +30,9 @@ struct private_tls_protection_t {
        tls_protection_t public;
 
        /**
-        * TLS context
+        * negotiated TLS version
         */
-       tls_t *tls;
+       tls_version_t version;
 
        /**
         * Upper layer, TLS record compression
@@ -40,6 +40,11 @@ struct private_tls_protection_t {
        tls_compression_t *compression;
 
        /**
+        * TLS alert handler
+        */
+       tls_alert_t *alert;
+
+       /**
         * RNG if we generate IVs ourself
         */
        rng_t *rng;
@@ -106,6 +111,11 @@ static chunk_t sigheader(u_int32_t seq, u_int8_t type,
 METHOD(tls_protection_t, process, status_t,
        private_tls_protection_t *this, tls_content_type_t type, chunk_t data)
 {
+       if (this->alert->fatal(this->alert))
+       {       /* don't accept more input, fatal error ocurred */
+               return NEED_MORE;
+       }
+
        if (this->crypter_in)
        {
                chunk_t iv, next_iv = chunk_empty;
@@ -117,7 +127,8 @@ METHOD(tls_protection_t, process, status_t,
                        if (data.len < bs || data.len % bs)
                        {
                                DBG1(DBG_TLS, "encrypted TLS record length invalid");
-                               return FAILED;
+                               this->alert->add(this->alert, TLS_FATAL, TLS_BAD_RECORD_MAC);
+                               return NEED_MORE;
                        }
                        iv = this->iv_in;
                        next_iv = chunk_clone(chunk_create(data.ptr + data.len - bs, bs));
@@ -130,7 +141,8 @@ METHOD(tls_protection_t, process, status_t,
                        if (data.len < bs || data.len % bs)
                        {
                                DBG1(DBG_TLS, "encrypted TLS record length invalid");
-                               return FAILED;
+                               this->alert->add(this->alert, TLS_FATAL, TLS_BAD_RECORD_MAC);
+                               return NEED_MORE;
                        }
                }
                this->crypter_in->decrypt(this->crypter_in, data, iv, NULL);
@@ -145,7 +157,8 @@ METHOD(tls_protection_t, process, status_t,
                if (padding_length >= data.len)
                {
                        DBG1(DBG_TLS, "invalid TLS record padding");
-                       return FAILED;
+                       this->alert->add(this->alert, TLS_FATAL, TLS_BAD_RECORD_MAC);
+                       return NEED_MORE;
                }
                data.len -= padding_length + 1;
        }
@@ -158,19 +171,20 @@ METHOD(tls_protection_t, process, status_t,
                if (data.len <= bs)
                {
                        DBG1(DBG_TLS, "TLS record too short to verify MAC");
-                       return FAILED;
+                       this->alert->add(this->alert, TLS_FATAL, TLS_BAD_RECORD_MAC);
+                       return NEED_MORE;
                }
                mac = chunk_skip(data, data.len - bs);
                data.len -= bs;
 
-               header = sigheader(this->seq_in, type,
-                                                  this->tls->get_version(this->tls), data.len);
+               header = sigheader(this->seq_in, type, this->version, data.len);
                macdata = chunk_cat("mc", header, data);
                if (!this->signer_in->verify_signature(this->signer_in, macdata, mac))
                {
                        DBG1(DBG_TLS, "TLS record MAC verification failed");
                        free(macdata.ptr);
-                       return FAILED;
+                       this->alert->add(this->alert, TLS_FATAL, TLS_BAD_RECORD_MAC);
+                       return NEED_MORE;
                }
                free(macdata.ptr);
        }
@@ -204,8 +218,7 @@ METHOD(tls_protection_t, build, status_t,
                {
                        chunk_t mac, header;
 
-                       header = sigheader(this->seq_out, *type,
-                                                          this->tls->get_version(this->tls), data->len);
+                       header = sigheader(this->seq_out, *type, this->version, data->len);
                        this->signer_out->get_signature(this->signer_out, header, NULL);
                        free(header.ptr);
                        this->signer_out->allocate_signature(this->signer_out, *data, &mac);
@@ -283,6 +296,12 @@ METHOD(tls_protection_t, set_cipher, void,
        }
 }
 
+METHOD(tls_protection_t, set_version, void,
+       private_tls_protection_t *this, tls_version_t version)
+{
+       this->version = version;
+}
+
 METHOD(tls_protection_t, destroy, void,
        private_tls_protection_t *this)
 {
@@ -293,8 +312,8 @@ METHOD(tls_protection_t, destroy, void,
 /**
  * See header
  */
-tls_protection_t *tls_protection_create(tls_t *tls,
-                                                                               tls_compression_t *compression)
+tls_protection_t *tls_protection_create(tls_compression_t *compression,
+                                                                               tls_alert_t *alert)
 {
        private_tls_protection_t *this;
 
@@ -303,9 +322,10 @@ tls_protection_t *tls_protection_create(tls_t *tls,
                        .process = _process,
                        .build = _build,
                        .set_cipher = _set_cipher,
+                       .set_version = _set_version,
                        .destroy = _destroy,
                },
-               .tls = tls,
+               .alert = alert,
                .compression = compression,
        );