wolfssl: Fixes, code style changes and some refactorings
[strongswan.git] / src / libstrongswan / plugins / wolfssl / wolfssl_aead.c
index 1ed0c71..2ea7c94 100644 (file)
 #define CCM_SALT_LEN   3
 #define CCM_NONCE_LEN  (CCM_SALT_LEN + IV_LEN)
 
-#if !defined(NO_AES) && defined(HAVE_AESGCM)
-#define MAX_NONCE_LEN  GCM_NONCE_LEN
-#define MAX_SALT_LEN   GCM_SALT_LEN
-#elif defined(HAVE_CHACHA) && defined(HAVE_POLY1305)
-#define MAX_NONCE_LEN  12
-#define MAX_SALT_LEN   4
-#elif !defined(NO_AES) && defined(HAVE_AESCCM)
-#define MAX_NONCE_LEN  CCM_NONCE_LEN
-#define MAX_SALT_LEN   GCM_SALT_LEN
-#endif
-
 typedef struct private_aead_t private_aead_t;
 
 /**
@@ -71,12 +60,7 @@ struct private_aead_t {
        /**
         * Salt value
         */
-       char salt[MAX_SALT_LEN];
-
-       /**
-        * Length of the salt
-        */
-       size_t salt_len;
+       chunk_t salt;
 
        /**
         * Size of the integrity check value
@@ -84,11 +68,6 @@ struct private_aead_t {
        size_t icv_size;
 
        /**
-        * Size of the IV
-        */
-       size_t iv_size;
-
-       /**
         * IV generator
         */
        iv_gen_t *iv_gen;
@@ -96,8 +75,7 @@ struct private_aead_t {
        /**
         * The cipher to use
         */
-       union
-       {
+       union {
 #if !defined(NO_AES) && (defined(HAVE_AESGCM) || defined(HAVE_AESCCM))
                Aes aes;
 #endif
@@ -109,15 +87,14 @@ struct private_aead_t {
        encryption_algorithm_t alg;
 };
 
-
 METHOD(aead_t, encrypt, bool,
        private_aead_t *this, chunk_t plain, chunk_t assoc, chunk_t iv,
        chunk_t *encrypted)
 {
-       bool success = FALSE;
-       int ret = 0;
+       chunk_t nonce;
        u_char *out;
-       u_char nonce[MAX_NONCE_LEN];
+       bool success = FALSE;
+       int ret;
 
        out = plain.ptr;
        if (encrypted)
@@ -126,8 +103,7 @@ METHOD(aead_t, encrypt, bool,
                out = encrypted->ptr;
        }
 
-       memcpy(nonce, this->salt, this->salt_len);
-       memcpy(nonce + this->salt_len, iv.ptr, IV_LEN);
+       nonce = chunk_cata("cc", this->salt, iv);
 
        switch (this->alg)
        {
@@ -140,8 +116,8 @@ METHOD(aead_t, encrypt, bool,
                        if (ret == 0)
                        {
                                ret = wc_AesGcmEncrypt(&this->cipher.aes, out, plain.ptr,
-                                       plain.len, nonce, GCM_NONCE_LEN, out + plain.len,
-                                       this->icv_size, assoc.ptr, assoc.len);
+                                               plain.len, nonce.ptr, GCM_NONCE_LEN, out + plain.len,
+                                               this->icv_size, assoc.ptr, assoc.len);
                        }
                        success = (ret == 0);
                        break;
@@ -150,23 +126,27 @@ METHOD(aead_t, encrypt, bool,
                case ENCR_AES_CCM_ICV8:
                case ENCR_AES_CCM_ICV12:
                case ENCR_AES_CCM_ICV16:
-                       if (plain.ptr == NULL && plain.len == 0)
-                               plain.ptr = nonce;
+                       /* wc_AesCcmEncrypt fails if the pointer is NULL */
+                       if (!plain.ptr && !plain.len)
+                       {
+                               plain.ptr = nonce.ptr;
+                       }
                        ret = wc_AesCcmSetKey(&this->cipher.aes, this->key.ptr,
                                                                  this->key.len);
                        if (ret == 0)
                        {
                                ret = wc_AesCcmEncrypt(&this->cipher.aes, out, plain.ptr,
-                                       plain.len, nonce, CCM_NONCE_LEN, out + plain.len,
-                                       this->icv_size, assoc.ptr, assoc.len);
+                                               plain.len, nonce.ptr, CCM_NONCE_LEN, out + plain.len,
+                                               this->icv_size, assoc.ptr, assoc.len);
                        }
                        success = (ret == 0);
                        break;
 #endif
 #if defined(HAVE_CHACHA) && defined(HAVE_POLY1305)
                case ENCR_CHACHA20_POLY1305:
-                       ret = wc_ChaCha20Poly1305_Encrypt(this->key.ptr, nonce, assoc.ptr,
-                                       assoc.len, plain.ptr, plain.len, out, out + plain.len);
+                       ret = wc_ChaCha20Poly1305_Encrypt(this->key.ptr, nonce.ptr,
+                                               assoc.ptr, assoc.len, plain.ptr, plain.len, out,
+                                               out + plain.len);
                        success = (ret == 0);
                        break;
 #endif
@@ -174,6 +154,7 @@ METHOD(aead_t, encrypt, bool,
                        break;
        }
 
+       memwipe(nonce.ptr, nonce.len);
        return success;
 }
 
@@ -181,10 +162,10 @@ METHOD(aead_t, decrypt, bool,
        private_aead_t *this, chunk_t encrypted, chunk_t assoc, chunk_t iv,
        chunk_t *plain)
 {
+       chunk_t nonce;
+       u_char *out;
        bool success = FALSE;
        int ret = 0;
-       u_char *out;
-       u_char nonce[MAX_NONCE_LEN];
 
        if (encrypted.len < this->icv_size)
        {
@@ -199,8 +180,7 @@ METHOD(aead_t, decrypt, bool,
                out = plain->ptr;
        }
 
-       memcpy(nonce, this->salt, this->salt_len);
-       memcpy(nonce + this->salt_len, iv.ptr, IV_LEN);
+       nonce = chunk_cata("cc", this->salt, iv);
 
        switch (this->alg)
        {
@@ -209,13 +189,13 @@ METHOD(aead_t, decrypt, bool,
                case ENCR_AES_GCM_ICV12:
                case ENCR_AES_GCM_ICV16:
                        ret = wc_AesGcmSetKey(&this->cipher.aes, this->key.ptr,
-                                 this->key.len);
+                                                                 this->key.len);
                        if (ret == 0)
                        {
                                ret = wc_AesGcmDecrypt(&this->cipher.aes, out, encrypted.ptr,
-                                       encrypted.len, nonce, GCM_NONCE_LEN,
-                                       encrypted.ptr + encrypted.len, this->icv_size, assoc.ptr,
-                                       assoc.len);
+                                                       encrypted.len, nonce.ptr, GCM_NONCE_LEN,
+                                                       encrypted.ptr + encrypted.len, this->icv_size,
+                                                       assoc.ptr, assoc.len);
                        }
                        success = (ret == 0);
                        break;
@@ -224,27 +204,32 @@ METHOD(aead_t, decrypt, bool,
                case ENCR_AES_CCM_ICV8:
                case ENCR_AES_CCM_ICV12:
                case ENCR_AES_CCM_ICV16:
-                       if (encrypted.ptr == NULL && encrypted.len == 0)
-                               encrypted.ptr = nonce;
-                       if (out == NULL && encrypted.len == 0)
-                               out = nonce;
+                       /* wc_AesCcmDecrypt() fails if the pointers are NULL */
+                       if (!encrypted.ptr && !encrypted.len)
+                       {
+                               encrypted.ptr = nonce.ptr;
+                       }
+                       if (!out && !encrypted.len)
+                       {
+                               out = nonce.ptr;
+                       }
                        ret = wc_AesCcmSetKey(&this->cipher.aes, this->key.ptr,
-                                 this->key.len);
+                                                                 this->key.len);
                        if (ret == 0)
                        {
                                ret = wc_AesCcmDecrypt(&this->cipher.aes, out, encrypted.ptr,
-                                       encrypted.len, nonce, CCM_NONCE_LEN,
-                                       encrypted.ptr + encrypted.len, this->icv_size, assoc.ptr,
-                                       assoc.len);
+                                                       encrypted.len, nonce.ptr, CCM_NONCE_LEN,
+                                                       encrypted.ptr + encrypted.len, this->icv_size,
+                                                       assoc.ptr, assoc.len);
                        }
                        success = (ret == 0);
                        break;
 #endif
 #if defined(HAVE_CHACHA) && defined(HAVE_POLY1305)
                case ENCR_CHACHA20_POLY1305:
-                       ret = wc_ChaCha20Poly1305_Decrypt(this->key.ptr, nonce, assoc.ptr,
-                                       assoc.len, encrypted.ptr, encrypted.len,
-                                       encrypted.ptr + encrypted.len, out);
+                       ret = wc_ChaCha20Poly1305_Decrypt(this->key.ptr, nonce.ptr,
+                                                       assoc.ptr, assoc.len, encrypted.ptr, encrypted.len,
+                                                       encrypted.ptr + encrypted.len, out);
                        success = (ret == 0);
                        break;
 #endif
@@ -252,13 +237,14 @@ METHOD(aead_t, decrypt, bool,
                        break;
        }
 
+       memwipe(nonce.ptr, nonce.len);
        return success;
 }
 
 METHOD(aead_t, get_block_size, size_t,
        private_aead_t *this)
 {
-       /* All AEAD algorithms are streaming. */
+       /* all AEAD algorithms are streaming */
        return 1;
 }
 
@@ -283,7 +269,7 @@ METHOD(aead_t, get_iv_gen, iv_gen_t*,
 METHOD(aead_t, get_key_size, size_t,
        private_aead_t *this)
 {
-       return this->key.len + this->salt_len;
+       return this->key.len + this->salt.len;
 }
 
 METHOD(aead_t, set_key, bool,
@@ -293,7 +279,7 @@ METHOD(aead_t, set_key, bool,
        {
                return FALSE;
        }
-       memcpy(this->salt, key.ptr + key.len - this->salt_len, this->salt_len);
+       memcpy(this->salt.ptr, key.ptr + key.len - this->salt.len, this->salt.len);
        memcpy(this->key.ptr, key.ptr, this->key.len);
        return TRUE;
 }
@@ -302,6 +288,7 @@ METHOD(aead_t, destroy, void,
        private_aead_t *this)
 {
        chunk_clear(&this->key);
+       chunk_clear(&this->salt);
        switch (this->alg)
        {
 #if !defined(NO_AES) && defined(HAVE_AESGCM)
@@ -318,10 +305,6 @@ METHOD(aead_t, destroy, void,
                        wc_AesFree(&this->cipher.aes);
                        break;
 #endif
-#if defined(HAVE_CHACHA) && defined(HAVE_POLY1305)
-               case ENCR_CHACHA20_POLY1305:
-                       break;
-#endif
                default:
                        break;
        }
@@ -336,6 +319,7 @@ aead_t *wolfssl_aead_create(encryption_algorithm_t algo,
                                                        size_t key_size, size_t salt_size)
 {
        private_aead_t *this;
+       size_t expected_salt_size;
 
        INIT(this,
                .public = {
@@ -404,8 +388,7 @@ aead_t *wolfssl_aead_create(encryption_algorithm_t algo,
                                case 16:
                                case 24:
                                case 32:
-                                       this->iv_size = GCM_NONCE_LEN;
-                                       this->salt_len = GCM_SALT_LEN;
+                                       expected_salt_size = GCM_SALT_LEN;
                                        if (wc_AesInit(&this->cipher.aes, NULL, INVALID_DEVID) != 0)
                                        {
                                                DBG1(DBG_LIB, "AES Init failed, aead create failed");
@@ -431,8 +414,7 @@ aead_t *wolfssl_aead_create(encryption_algorithm_t algo,
                                case 16:
                                case 24:
                                case 32:
-                                       this->iv_size = CCM_NONCE_LEN;
-                                       this->salt_len = CCM_SALT_LEN;
+                                       expected_salt_size = CCM_SALT_LEN;
                                        if (wc_AesInit(&this->cipher.aes, NULL, INVALID_DEVID) != 0)
                                        {
                                                DBG1(DBG_LIB, "AES Init failed, aead create failed");
@@ -454,8 +436,7 @@ aead_t *wolfssl_aead_create(encryption_algorithm_t algo,
                                        key_size = 32;
                                        /* FALL */
                                case 32:
-                                       this->iv_size = CHACHA_IV_BYTES;
-                                       this->salt_len = 4;
+                                       expected_salt_size = 4;
                                        break;
                                default:
                                        free(this);
@@ -468,7 +449,7 @@ aead_t *wolfssl_aead_create(encryption_algorithm_t algo,
                        return NULL;
        }
 
-       if (salt_size && salt_size != this->salt_len)
+       if (salt_size && salt_size != expected_salt_size)
        {
                /* currently not supported */
                free(this);
@@ -476,6 +457,7 @@ aead_t *wolfssl_aead_create(encryption_algorithm_t algo,
        }
 
        this->key = chunk_alloc(key_size);
+       this->salt = chunk_alloc(expected_salt_size);
        this->iv_gen = iv_gen_seq_create();
 
        return &this->public;