Implemented ntru_trits class
[strongswan.git] / src / libstrongswan / plugins / ntru / ntru_crypto / ntru_crypto_ntru_encrypt.c
index 5271d7c..0654609 100644 (file)
@@ -40,7 +40,8 @@
 #include "ntru_crypto_ntru_encrypt_key.h"
 #include "ntru_crypto_ntru_convert.h"
 #include "ntru_crypto_ntru_poly.h"
-#include "ntru_crypto_ntru_mgftp1.h"
+#
+#include "ntru_trits.h"
 
 /* ntru_crypto_ntru_encrypt
  *
@@ -106,6 +107,9 @@ ntru_crypto_ntru_encrypt(
     uint16_t                mprime_len = 0;
     uint16_t                mod_q_mask;
     uint32_t                result = NTRU_OK;
+       ntru_trits_t           *mask;
+       uint8_t                *mask_trits;
+       chunk_t                 seed;
 
     /* check for bad parameters */
 
@@ -198,7 +202,7 @@ ntru_crypto_ntru_encrypt(
         uint8_t *ptr = tmp_buf;
 
         /* get b */
-        if (drbg->generate(drbg, params->sec_strength_len << 3,
+        if (drbg->generate(drbg, params->sec_strength_len * BITS_PER_BYTE,
                                  params->sec_strength_len, b_buf))
                {
                        result = NTRU_OK;
@@ -255,14 +259,18 @@ ntru_crypto_ntru_encrypt(
                                        r_buf, params->N, params->q,
                                        scratch_buf, ringel_buf);
 
-            /* form R mod 4 */
-            ntru_coeffs_mod4_2_octets(params->N, ringel_buf, tmp_buf);
+                       /* form R mod 4 */
+                       ntru_coeffs_mod4_2_octets(params->N, ringel_buf, tmp_buf);
+
+                       /* form mask */
+                       seed = chunk_create(tmp_buf, (params->N + 3)/4);
+                       mask = ntru_trits_create(params->N, hash_algid, seed);
+                       if (!mask)
+                       {
+                               result = NTRU_MGF1_FAIL;
+                       }
+               }
 
-            /* form mask */
-            result = ntru_mgftp1(hash_algid, params->min_MGF_hash_calls,
-                                 (params->N + 3) / 4, tmp_buf,
-                                 tmp_buf + params->N, params->N, tmp_buf);
-        }
                if (result == NTRU_OK)
                {
             uint8_t  *Mtrin_buf = tmp_buf + params->N;
@@ -296,26 +304,40 @@ ntru_crypto_ntru_encrypt(
                        }
 
             ntru_bits_2_trits(M_buf, mprime_len, Mtrin_buf);
+                       mask_trits = mask->get_trits(mask);
 
-            /* form the msg representative m' by adding Mtrin to mask, mod p */
-
-            if (params->is_product_form) {
-                for (i = 0; i < mprime_len; i++) {
-                    tmp_buf[i] = tmp_buf[i] + Mtrin_buf[i];
-                    if (tmp_buf[i] >= 3)
-                        tmp_buf[i] -= 3;
-                    if (tmp_buf[i] == 1)
-                        ++m1;
-                    else if (tmp_buf[i] == 2)
-                        --m1;
-                }
-            } else {
-                for (i = 0; i < mprime_len; i++) {
-                    tmp_buf[i] = tmp_buf[i] + Mtrin_buf[i];
-                    if (tmp_buf[i] >= 3)
-                        tmp_buf[i] -= 3;
-                }
-            }
+                       /* form the msg representative m' by adding Mtrin to mask, mod p */
+                       if (params->is_product_form)
+                       {
+                               for (i = 0; i < mprime_len; i++)
+                               {
+                                       tmp_buf[i] = mask_trits[i] + Mtrin_buf[i];
+                                       if (tmp_buf[i] >= 3)
+                                       {
+                                               tmp_buf[i] -= 3;
+                                       }
+                                       if (tmp_buf[i] == 1)
+                                       {
+                                               ++m1;
+                                       }
+                                       else if (tmp_buf[i] == 2)
+                                       {
+                                               --m1;
+                                       }
+                               }
+                       }
+                       else
+                       {
+                               for (i = 0; i < mprime_len; i++)
+                               {
+                                       tmp_buf[i] = mask_trits[i] + Mtrin_buf[i];
+                                       if (tmp_buf[i] >= 3)
+                                       {
+                                               tmp_buf[i] -= 3;
+                                       }
+                               }
+                       }
+                       mask->destroy(mask);
 
             /* check that message representative meets minimum weight
              * requirements
@@ -426,9 +448,11 @@ ntru_crypto_ntru_decrypt(
     uint16_t                i;
     bool                    decryption_ok = TRUE;
     uint32_t                result = NTRU_OK;
+       ntru_trits_t           *mask;
+       uint8_t                *mask_trits;
+       chunk_t                 seed;
 
-    /* check for bad parameters */
-
+       /* check for bad parameters */
        if (!privkey_blob || !ct || !pt_len)
        {
                return NTRU_BAD_PARAMETER;
@@ -588,38 +612,48 @@ ntru_crypto_ntru_decrypt(
                                                                                                   params->min_msg_rep_wt);
        }
 
-    /* form cR = e - cm' mod q */
-
-    for (i = 0; i < cmprime_len; i++) {
-        if (Mtrin_buf[i] == 1)
-            ringel_buf2[i] = (ringel_buf2[i] - 1) & mod_q_mask;
-        else if (Mtrin_buf[i] == 2)
-            ringel_buf2[i] = (ringel_buf2[i] + 1) & mod_q_mask;
-    }
-    if (params->is_product_form)
-        ringel_buf2[i] = (ringel_buf2[i] + m1) & mod_q_mask;
-
-
-    /* form cR mod 4 */
-
-    ntru_coeffs_mod4_2_octets(params->N, ringel_buf2, tmp_buf);
-
-    /* form mask */
+       /* form cR = e - cm' mod q */
+       for (i = 0; i < cmprime_len; i++)
+       {
+               if (Mtrin_buf[i] == 1)
+               {
+                       ringel_buf2[i] = (ringel_buf2[i] - 1) & mod_q_mask;
+               }
+               else if (Mtrin_buf[i] == 2)
+               {
+                       ringel_buf2[i] = (ringel_buf2[i] + 1) & mod_q_mask;
+               }
+       }
+       if (params->is_product_form)
+       {
+               ringel_buf2[i] = (ringel_buf2[i] + m1) & mod_q_mask;
+       }
 
-    result = ntru_mgftp1(hash_algid, params->min_MGF_hash_calls,
-                         (params->N + 3) / 4, tmp_buf,
-                         tmp_buf + params->N, params->N, tmp_buf);
+       /* form cR mod 4 */
+       ntru_coeffs_mod4_2_octets(params->N, ringel_buf2, tmp_buf);
 
-       if (result == NTRU_OK)
+       /* form mask */
+       seed = chunk_create(tmp_buf, (params->N + 3)/4);
+       mask = ntru_trits_create(params->N, hash_algid, seed);
+       if (!mask)
        {
+               result = NTRU_MGF1_FAIL;
+       }
+       else
+       {
+               mask_trits = mask->get_trits(mask);
 
-        /* form cMtrin by subtracting mask from cm', mod p */
+               /* form cMtrin by subtracting mask from cm', mod p */
+               for (i = 0; i < cmprime_len; i++)
+               {
+                       Mtrin_buf[i] = Mtrin_buf[i] - mask_trits[i];
+                       if (Mtrin_buf[i] >= 3)
+                       {
+                               Mtrin_buf[i] += 3;
+                       }
+               }
+               mask->destroy(mask);
 
-        for (i = 0; i < cmprime_len; i++) {
-            Mtrin_buf[i] = Mtrin_buf[i] - tmp_buf[i];
-            if (Mtrin_buf[i] >= 3)
-                Mtrin_buf[i] += 3;
-        }
         if (params->is_product_form)
 
             /* set the last trit to zero since that's what it was, and
@@ -897,7 +931,8 @@ ntru_crypto_ntru_encrypt_keygen(
      * as a list of indices
      */
 
-    if (drbg->generate(drbg, params->sec_strength_len << 3, seed_len, tmp_buf))
+    if (drbg->generate(drbg, params->sec_strength_len * BITS_PER_BYTE,
+                                                        seed_len, tmp_buf))
        {
                result = NTRU_OK;
        }
@@ -985,8 +1020,8 @@ ntru_crypto_ntru_encrypt_keygen(
         /* get random bytes for seed for generating trinary g
          * as a list of indices
          */
-        if (!drbg->generate(drbg, params->sec_strength_len << 3, seed_len,
-                                                                 tmp_buf))
+        if (!drbg->generate(drbg, params->sec_strength_len * BITS_PER_BYTE,
+                                                                 seed_len, tmp_buf))
                {
                        result = NTRU_DRBG_FAIL;
                }