Refactored common used operations into TLS crypto helper
[strongswan.git] / src / charon / plugins / eap_tls / tls / tls_crypto.c
index 829d29a..719003e 100644 (file)
@@ -50,6 +50,11 @@ struct private_tls_crypto_t {
        tls_t *tls;
 
        /**
+        * All handshake data concatentated
+        */
+       chunk_t handshake;
+
+       /**
         * Connection state TLS PRF
         */
        tls_prf_t *prf;
@@ -83,6 +88,11 @@ struct private_tls_crypto_t {
         * IV for output decryption, if < TLSv1.2
         */
        chunk_t iv_out;
+
+       /**
+        * EAP-TLS MSK
+        */
+       chunk_t msk;
 };
 
 typedef struct {
@@ -341,7 +351,128 @@ METHOD(tls_crypto_t, select_cipher_suite, tls_cipher_suite_t,
        return 0;
 }
 
-METHOD(tls_crypto_t, derive_master_secret, void,
+METHOD(tls_crypto_t, append_handshake, void,
+       private_tls_crypto_t *this, tls_handshake_type_t type, chunk_t data)
+{
+       u_int32_t header;
+
+       /* reconstruct handshake header */
+       header = htonl(data.len | (type << 24));
+       this->handshake = chunk_cat("mcc", this->handshake,
+                                                               chunk_from_thing(header), data);
+}
+
+/**
+ * Create a hash of the stored handshake data
+ */
+static bool hash_handshake(private_tls_crypto_t *this, chunk_t *hash)
+{
+       if (this->tls->get_version(this->tls) >= TLS_1_2)
+       {
+               hasher_t *hasher;
+               suite_algs_t *alg;
+
+               alg = find_suite(this->suite);
+               if (!alg)
+               {
+                       return FALSE;
+               }
+               hasher = lib->crypto->create_hasher(lib->crypto, alg->hash);
+               if (!hasher)
+               {
+                       DBG1(DBG_IKE, "%N not supported", hash_algorithm_names, alg->hash);
+                       return FALSE;
+               }
+               hasher->allocate_hash(hasher, this->handshake, hash);
+               hasher->destroy(hasher);
+       }
+       else
+       {
+               hasher_t *md5, *sha1;
+               char buf[HASH_SIZE_MD5 + HASH_SIZE_SHA1];
+
+               md5 = lib->crypto->create_hasher(lib->crypto, HASH_MD5);
+               if (!md5)
+               {
+                       DBG1(DBG_IKE, "%N not supported", hash_algorithm_names, HASH_MD5);
+                       return FALSE;
+               }
+               md5->get_hash(md5, this->handshake, buf);
+               md5->destroy(md5);
+               sha1 = lib->crypto->create_hasher(lib->crypto, HASH_SHA1);
+               if (!sha1)
+               {
+                       DBG1(DBG_IKE, "%N not supported", hash_algorithm_names, HASH_SHA1);
+                       return FALSE;
+               }
+               sha1->get_hash(sha1, this->handshake, buf + HASH_SIZE_MD5);
+               sha1->destroy(sha1);
+
+               *hash = chunk_clone(chunk_from_thing(buf));
+       }
+       return TRUE;
+}
+
+METHOD(tls_crypto_t, sign_handshake, bool,
+       private_tls_crypto_t *this, private_key_t *key, chunk_t *sig)
+{
+       if (this->tls->get_version(this->tls) >= TLS_1_2)
+       {
+               u_int16_t length;
+               u_int8_t hash_alg;
+               u_int8_t sig_alg;
+
+               if (!key->sign(key, SIGN_RSA_EMSA_PKCS1_SHA1, this->handshake, sig))
+               {
+                       return FALSE;
+               }
+               /* TODO: signature scheme to hashsign algorithm mapping */
+               hash_alg = 2; /* sha1 */
+               sig_alg = 1; /* RSA */
+               length = htons(sig->len);
+               *sig = chunk_cat("cccm", chunk_from_thing(hash_alg),
+                                       chunk_from_thing(sig_alg), chunk_from_thing(length), *sig);
+       }
+       else
+       {
+               u_int16_t length;
+               chunk_t hash;
+
+               if (!hash_handshake(this, &hash))
+               {
+                       return FALSE;
+               }
+               if (!key->sign(key, SIGN_RSA_EMSA_PKCS1_NULL, hash, sig))
+               {
+                       free(hash.ptr);
+                       return FALSE;
+               }
+               free(hash.ptr);
+               length = htons(sig->len);
+               *sig = chunk_cat("cm", chunk_from_thing(length), *sig);
+       }
+       return TRUE;
+}
+
+METHOD(tls_crypto_t, calculate_finished, bool,
+       private_tls_crypto_t *this, char *label, char out[12])
+{
+       chunk_t seed;
+
+       if (!this->prf)
+       {
+               return FALSE;
+       }
+       if (!hash_handshake(this, &seed))
+       {
+               return FALSE;
+       }
+       this->prf->get_bytes(this->prf, label, seed, 12, out);
+       free(seed.ptr);
+       return TRUE;
+}
+
+METHOD(tls_crypto_t, derive_secrets, void,
        private_tls_crypto_t *this, chunk_t premaster,
        chunk_t client_random, chunk_t server_random)
 {
@@ -363,7 +494,7 @@ METHOD(tls_crypto_t, derive_master_secret, void,
        if (this->crypter_out)
        {
                eks = this->crypter_out->get_key_size(this->crypter_out);
-               if (this->tls->get_version(this->tls) < TLS_1_2)
+               if (this->tls->get_version(this->tls) < TLS_1_1)
                {
                        ivs = this->crypter_out->get_block_size(this->crypter_out);
                }
@@ -442,10 +573,22 @@ METHOD(tls_crypto_t, change_cipher, void,
        }
 }
 
-METHOD(tls_crypto_t, get_prf, tls_prf_t*,
+METHOD(tls_crypto_t, derive_eap_msk, void,
+       private_tls_crypto_t *this, chunk_t client_random, chunk_t server_random)
+{
+       chunk_t seed;
+
+       seed = chunk_cata("cc", client_random, server_random);
+       free(this->msk.ptr);
+       this->msk = chunk_alloc(64);
+       this->prf->get_bytes(this->prf, "client EAP encryption", seed,
+                                                this->msk.len, this->msk.ptr);
+}
+
+METHOD(tls_crypto_t, get_eap_msk, chunk_t,
        private_tls_crypto_t *this)
 {
-       return this->prf;
+       return this->msk;
 }
 
 METHOD(tls_crypto_t, destroy, void,
@@ -457,6 +600,8 @@ METHOD(tls_crypto_t, destroy, void,
        DESTROY_IF(this->crypter_out);
        free(this->iv_in.ptr);
        free(this->iv_out.ptr);
+       free(this->handshake.ptr);
+       free(this->msk.ptr);
        DESTROY_IF(this->prf);
        free(this->suites);
        free(this);
@@ -473,9 +618,13 @@ tls_crypto_t *tls_crypto_create(tls_t *tls)
                .public = {
                        .get_cipher_suites = _get_cipher_suites,
                        .select_cipher_suite = _select_cipher_suite,
-                       .derive_master_secret = _derive_master_secret,
+                       .append_handshake = _append_handshake,
+                       .sign_handshake = _sign_handshake,
+                       .calculate_finished = _calculate_finished,
+                       .derive_secrets = _derive_secrets,
                        .change_cipher = _change_cipher,
-                       .get_prf = _get_prf,
+                       .derive_eap_msk = _derive_eap_msk,
+                       .get_eap_msk = _get_eap_msk,
                        .destroy = _destroy,
                },
                .tls = tls,