Add a return value to tls_prf_t.get_bytes()
authorMartin Willi <martin@revosec.ch>
Fri, 6 Jul 2012 07:49:16 +0000 (09:49 +0200)
committerMartin Willi <martin@revosec.ch>
Mon, 16 Jul 2012 12:53:33 +0000 (14:53 +0200)
src/libtls/tls_crypto.c
src/libtls/tls_prf.c
src/libtls/tls_prf.h

index b8df3de..5f7002c 100644 (file)
@@ -1462,7 +1462,11 @@ METHOD(tls_crypto_t, calculate_finished, bool,
        {
                return FALSE;
        }
-       this->prf->get_bytes(this->prf, label, seed, 12, out);
+       if (!this->prf->get_bytes(this->prf, label, seed, 12, out))
+       {
+               free(seed.ptr);
+               return FALSE;
+       }
        free(seed.ptr);
        return TRUE;
 }
@@ -1470,7 +1474,7 @@ METHOD(tls_crypto_t, calculate_finished, bool,
 /**
  * Derive master secret from premaster, optionally save session
  */
-static void derive_master(private_tls_crypto_t *this, chunk_t premaster,
+static bool derive_master(private_tls_crypto_t *this, chunk_t premaster,
                                                  chunk_t session, identification_t *id,
                                                  chunk_t client_random, chunk_t server_random)
 {
@@ -1480,16 +1484,20 @@ static void derive_master(private_tls_crypto_t *this, chunk_t premaster,
        /* derive master secret */
        seed = chunk_cata("cc", client_random, server_random);
        this->prf->set_key(this->prf, premaster);
-       this->prf->get_bytes(this->prf, "master secret", seed,
-                                                sizeof(master), master);
-
+       if (!this->prf->get_bytes(this->prf, "master secret", seed,
+                                                sizeof(master), master))
+       {
+               return FALSE;
+       }
        this->prf->set_key(this->prf, chunk_from_thing(master));
+
        if (this->cache && session.len)
        {
                this->cache->create(this->cache, session, id, chunk_from_thing(master),
                                                        this->suite);
        }
        memwipe(master, sizeof(master));
+       return TRUE;
 }
 
 /**
@@ -1513,7 +1521,11 @@ static bool expand_keys(private_tls_crypto_t *this,
        }
        seed = chunk_cata("cc", server_random, client_random);
        block = chunk_alloca((mks + eks + ivs) * 2);
-       this->prf->get_bytes(this->prf, "key expansion", seed, block.len, block.ptr);
+       if (!this->prf->get_bytes(this->prf, "key expansion", seed,
+                                                         block.len, block.ptr))
+       {
+               return FALSE;
+       }
 
        /* signer keys */
        client_write = chunk_create(block.ptr, mks);
@@ -1580,8 +1592,11 @@ static bool expand_keys(private_tls_crypto_t *this,
        {
                seed = chunk_cata("cc", client_random, server_random);
                this->msk = chunk_alloc(64);
-               this->prf->get_bytes(this->prf, this->msk_label, seed,
-                                                        this->msk.len, this->msk.ptr);
+               if (!this->prf->get_bytes(this->prf, this->msk_label, seed,
+                                                                 this->msk.len, this->msk.ptr))
+               {
+                       return FALSE;
+               }
        }
        return TRUE;
 }
@@ -1590,8 +1605,9 @@ METHOD(tls_crypto_t, derive_secrets, bool,
        private_tls_crypto_t *this, chunk_t premaster, chunk_t session,
        identification_t *id, chunk_t client_random, chunk_t server_random)
 {
-       derive_master(this, premaster, session, id, client_random, server_random);
-       return expand_keys(this, client_random, server_random);
+       return derive_master(this, premaster, session, id,
+                                                client_random, server_random) &&
+                  expand_keys(this, client_random, server_random);
 }
 
 METHOD(tls_crypto_t, resume_session, tls_cipher_suite_t,
index f181d01..0ef4418 100644 (file)
@@ -42,7 +42,7 @@ METHOD(tls_prf_t, set_key12, void,
 /**
  * The P_hash function as in TLS 1.0/1.2
  */
-static void p_hash(prf_t *prf, char *label, chunk_t seed, size_t block_size,
+static bool p_hash(prf_t *prf, char *label, chunk_t seed, size_t block_size,
                                   size_t bytes, char *out)
 {
        char buf[block_size], abuf[block_size];
@@ -71,14 +71,15 @@ static void p_hash(prf_t *prf, char *label, chunk_t seed, size_t block_size,
                out += block_size;
                bytes -= block_size;
        }
+       return TRUE;
 }
 
-METHOD(tls_prf_t, get_bytes12, void,
+METHOD(tls_prf_t, get_bytes12, bool,
        private_tls_prf12_t *this, char *label, chunk_t seed,
        size_t bytes, char *out)
 {
-       p_hash(this->prf, label, seed, this->prf->get_block_size(this->prf),
-                  bytes, out);
+       return p_hash(this->prf, label, seed, this->prf->get_block_size(this->prf),
+                                 bytes, out);
 }
 
 METHOD(tls_prf_t, destroy12, void,
@@ -144,17 +145,21 @@ METHOD(tls_prf_t, set_key10, void,
        this->sha1->set_key(this->sha1, chunk_create(key.ptr + key.len - len, len));
 }
 
-METHOD(tls_prf_t, get_bytes10, void,
+METHOD(tls_prf_t, get_bytes10, bool,
        private_tls_prf10_t *this, char *label, chunk_t seed,
        size_t bytes, char *out)
 {
        char buf[bytes];
 
-       p_hash(this->md5, label, seed, this->md5->get_block_size(this->md5),
-                  bytes, out);
-       p_hash(this->sha1, label, seed, this->sha1->get_block_size(this->sha1),
-                  bytes, buf);
+       if (!p_hash(this->md5, label, seed, this->md5->get_block_size(this->md5),
+                               bytes, out) ||
+               !p_hash(this->sha1, label, seed, this->sha1->get_block_size(this->sha1),
+                               bytes, buf))
+       {
+               return FALSE;
+       }
        memxor(out, buf, bytes);
+       return TRUE;
 }
 
 METHOD(tls_prf_t, destroy10, void,
index 9fb9bc2..c78842e 100644 (file)
@@ -44,8 +44,9 @@ struct tls_prf_t {
         * @param seed          seed input value
         * @param bytes         number of bytes to get
         * @param out           buffer receiving bytes
+        * @return                      TRUE if bytes generated successfully
         */
-       void (*get_bytes)(tls_prf_t *this, char *label, chunk_t seed,
+       bool (*get_bytes)(tls_prf_t *this, char *label, chunk_t seed,
                                          size_t bytes, char *out);
 
        /**