Implement ring multiplication method
authorAndreas Steffen <andreas.steffen@strongswan.org>
Wed, 26 Feb 2014 22:36:09 +0000 (23:36 +0100)
committerAndreas Steffen <andreas.steffen@strongswan.org>
Thu, 27 Feb 2014 14:22:58 +0000 (15:22 +0100)
src/libstrongswan/plugins/ntru/ntru_crypto/ntru_crypto_ntru_encrypt.c
src/libstrongswan/plugins/ntru/ntru_poly.c
src/libstrongswan/plugins/ntru/ntru_poly.h
src/libstrongswan/tests/suites/test_ntru.c

index 7fae373..af218d6 100644 (file)
@@ -111,7 +111,6 @@ ntru_crypto_ntru_encrypt(
        uint8_t                *mask_trits;
        chunk_t                 seed;
        ntru_poly_t                             *r_poly;
-       uint16_t                                *r_indices;
 
     /* check for bad parameters */
 
@@ -230,8 +229,8 @@ ntru_crypto_ntru_encrypt(
 
                        seed = chunk_create(tmp_buf, ptr - tmp_buf);
                        r_poly = ntru_poly_create(hash_algid, seed, params->c_bits,
-                                                                         params->N, 2 * params->dF_r,
-                                                                         params->is_product_form);
+                                                                         params->N, params->q, params->dF_r,
+                                                                         params->dF_r, params->is_product_form);
                        if (!r_poly)
                        {
                           result = NTRU_MGF1_FAIL;
@@ -249,21 +248,7 @@ ntru_crypto_ntru_encrypt(
                                                                   params->q_bits, ringel_buf);
 
                        /* form R = h * r */
-                       r_indices = r_poly->get_indices(r_poly);
-
-                       if (params->is_product_form)
-                       {
-                               ntru_ring_mult_product_indices(ringel_buf, (uint16_t)dr1,
-                                                                                          (uint16_t)dr2, (uint16_t)dr3,
-                                                                                          r_indices, params->N, params->q,
-                                                                                          scratch_buf, ringel_buf);
-                       }
-                       else
-                       {
-                               ntru_ring_mult_indices(ringel_buf, (uint16_t)dr, (uint16_t)dr,
-                                                                          r_indices, params->N, params->q,
-                                                                          scratch_buf, ringel_buf);
-                       }
+                       r_poly->ring_mult(r_poly, ringel_buf, ringel_buf);
                        r_poly->destroy(r_poly);
 
                        /* form R mod 4 */
@@ -459,7 +444,6 @@ ntru_crypto_ntru_decrypt(
        uint8_t                *mask_trits;
        chunk_t                 seed;
        ntru_poly_t                        *r_poly;
-       uint16_t                           *r_indices;
 
        /* check for bad parameters */
        if (!privkey_blob || !ct || !pt_len)
@@ -582,29 +566,41 @@ ntru_crypto_ntru_decrypt(
                                        (uint16_t)dF_r2, (uint16_t)dF_r3,
                                        i_buf, params->N, params->q,
                                        scratch_buf, ringel_buf1);
-        for (i = 0; i < cmprime_len; i++) {
-            ringel_buf1[i] = (ringel_buf2[i] + 3 * ringel_buf1[i]) & mod_q_mask;
-            if (ringel_buf1[i] >= (params->q >> 1))
-                    ringel_buf1[i] = ringel_buf1[i] - q_mod_p;
-            Mtrin_buf[i] = (uint8_t)(ringel_buf1[i] % 3);
-            if (Mtrin_buf[i] == 1)
-                ++m1;
-            else if (Mtrin_buf[i] == 2)
-                --m1;
-        }
-    }
+
+               for (i = 0; i < cmprime_len; i++)
+               {
+                       ringel_buf1[i] = (ringel_buf2[i] + 3 * ringel_buf1[i]) & mod_q_mask;
+                       if (ringel_buf1[i] >= (params->q >> 1))
+                       {
+                               ringel_buf1[i] = ringel_buf1[i] - q_mod_p;
+                       }
+                       Mtrin_buf[i] = (uint8_t)(ringel_buf1[i] % 3);
+                       if (Mtrin_buf[i] == 1)
+                       {
+                               ++m1;
+                       }
+                       else if (Mtrin_buf[i] == 2)
+                       {
+                               --m1;
+                       }
+               }
+       }
        else
        {
         ntru_ring_mult_indices(ringel_buf2, (uint16_t)dF_r, (uint16_t)dF_r,
                                i_buf, params->N, params->q,
                                scratch_buf, ringel_buf1);
-        for (i = 0; i < cmprime_len; i++) {
-            ringel_buf1[i] = (ringel_buf2[i] + 3 * ringel_buf1[i]) & mod_q_mask;
-            if (ringel_buf1[i] >= (params->q >> 1))
-                    ringel_buf1[i] = ringel_buf1[i] - q_mod_p;
-            Mtrin_buf[i] = (uint8_t)(ringel_buf1[i] % 3);
-        }
-    }
+
+               for (i = 0; i < cmprime_len; i++)
+               {
+                       ringel_buf1[i] = (ringel_buf2[i] + 3 * ringel_buf1[i]) & mod_q_mask;
+                       if (ringel_buf1[i] >= (params->q >> 1))
+                       {
+                               ringel_buf1[i] = ringel_buf1[i] - q_mod_p;
+                       }
+                       Mtrin_buf[i] = (uint8_t)(ringel_buf1[i] % 3);
+               }
+}
 
     /* check that the candidate message representative meets minimum weight
      * requirements
@@ -712,8 +708,8 @@ ntru_crypto_ntru_decrypt(
 
                seed = chunk_create(tmp_buf, ptr - tmp_buf);
                r_poly = ntru_poly_create(hash_algid, seed, params->c_bits,
-                                                                 params->N, 2 * params->dF_r,
-                                                                 params->is_product_form);
+                                                                 params->N, params->q, params->dF_r,
+                                                                 params->dF_r, params->is_product_form);
                if (!r_poly)
                {
                   result = NTRU_MGF1_FAIL;
@@ -733,20 +729,7 @@ ntru_crypto_ntru_decrypt(
                }
 
                /* form cR' = h * cr */
-               r_indices = r_poly->get_indices(r_poly);
-               if (params->is_product_form)
-               {
-                       ntru_ring_mult_product_indices(ringel_buf1, (uint16_t)dF_r1,
-                                                                                  (uint16_t)dF_r2, (uint16_t)dF_r3,
-                                                                                  r_indices, params->N, params->q,
-                                                                                  scratch_buf, ringel_buf1);
-               }
-               else
-               {
-                       ntru_ring_mult_indices(ringel_buf1, (uint16_t)dF_r, (uint16_t)dF_r,
-                                                                  r_indices, params->N, params->q,
-                                                                  scratch_buf, ringel_buf1);
-               }
+               r_poly->ring_mult(r_poly, ringel_buf1, ringel_buf1);
                r_poly->destroy(r_poly);
 
                /* compare cR' to cR */
@@ -857,7 +840,7 @@ ntru_crypto_ntru_encrypt_keygen(
     uint32_t                result = NTRU_OK;
        ntru_poly_t                        *F_poly = NULL;
        ntru_poly_t            *g_poly = NULL;
-       uint16_t                           *F_indices, *g_indices;
+       uint16_t                           *F_indices;
 
     /* get a pointer to the parameter-set parameters */
 
@@ -959,8 +942,8 @@ ntru_crypto_ntru_encrypt_keygen(
 
                seed = chunk_create(tmp_buf, seed_len);
                F_poly = ntru_poly_create(hash_algid, seed, params->c_bits,
-                                                                 params->N, 2 * params->dF_r,
-                                                                 params->is_product_form);
+                                                                 params->N, params->q, params->dF_r,
+                                                                 params->dF_r, params->is_product_form);
                if (!F_poly)
                {
                   result = NTRU_MGF1_FAIL;
@@ -1055,7 +1038,8 @@ ntru_crypto_ntru_encrypt_keygen(
 
                seed = chunk_create(tmp_buf, seed_len);
                g_poly = ntru_poly_create(hash_algid, seed, params->c_bits,
-                                                                 params->N, 2*params->dg + 1, FALSE);
+                                                                 params->N, params->q, params->dg + 1,
+                                                                 params->dg, FALSE);
                if (!g_poly)
                {
                   result = NTRU_MGF1_FAIL;
@@ -1067,10 +1051,7 @@ ntru_crypto_ntru_encrypt_keygen(
                uint16_t i;
 
                /* compute h = p * (f^-1 * g) mod q */
-               g_indices = g_poly->get_indices(g_poly);
-               ntru_ring_mult_indices(ringel_buf2, params->dg + 1, params->dg,
-                                                          g_indices, params->N, params->q, scratch_buf,
-                                                          ringel_buf2);
+               g_poly->ring_mult(g_poly, ringel_buf2, ringel_buf2);
                g_poly->destroy(g_poly);
 
                for (i = 0; i < params->N; i++)
index f893d4d..2081d03 100644 (file)
 #include <utils/test.h>
 
 typedef struct private_ntru_poly_t private_ntru_poly_t;
+typedef struct indices_len_t indices_len_t;
+
+/**
+ * Stores number of +1 and -1 coefficients
+ */
+struct indices_len_t {
+       int p;
+       int m;
+};
 
 /**
  * Private data of an ntru_poly_t object.
@@ -34,21 +43,43 @@ struct private_ntru_poly_t {
        ntru_poly_t public;
 
        /**
+        * Ring dimension equal to the number of polynomial coefficients
+        */
+       uint16_t N;
+
+       /**
+        * Large modulus
+        */
+       uint16_t q;
+
+       /**
         * Array containing the indices of the non-zero coefficients
         */
        uint16_t *indices;
 
        /**
-        * Number of non-zero coefficients
+        * Number of sparse polynomials
         */
-       uint32_t indices_len;
+       int num_polynomials;
+
+       /**
+        * Number of nonzero coefficients for up to 3 sparse polynomials
+        */
+       indices_len_t indices_len[3];
 
 };
 
 METHOD(ntru_poly_t, get_size, size_t,
        private_ntru_poly_t *this)
 {
-       return this->indices_len;
+       int n;
+       size_t size = 0;
+
+       for (n = 0; n < this->num_polynomials; n++)
+       {
+               size += this->indices_len[n].p + this->indices_len[n].m;
+       }
+       return size;
 }
 
 METHOD(ntru_poly_t, get_indices, uint16_t*,
@@ -56,11 +87,113 @@ METHOD(ntru_poly_t, get_indices, uint16_t*,
 {
        return this->indices;
 }
+/**
+  * Multiplication of polynomial a with a sparse polynomial b given by
+  * the indices of its +1 and -1 coefficients results in polynomial c.
+  * This is a convolution operation
+  */
+static void ring_mult_indices(uint16_t *a, indices_len_t len, uint16_t *indices,
+                                                         uint16_t N, uint16_t mod_q_mask, uint16_t *c)
+{
+       uint16_t *t;
+       int i, j, k;
+
+       /* allocate and initialize temporary array t */
+       t = malloc(N * sizeof(uint16_t));
+       for (k = 0; k < N; k++)
+       {
+               t[k] = 0;
+       }
+
+       /* t[(i+k)%N] = sum i=0 through N-1 of a[i], for b[k] = -1 */
+       for (j = len.p; j < len.p + len.m; j++)
+       {
+               k = indices[j];
+               for (i = 0; k < N; ++i, ++k)
+               {
+                       t[k] += a[i];
+               }
+               for (k = 0; i < N; ++i, ++k)
+               {
+                       t[k] += a[i];
+               }
+       }
+
+       /* t[(i+k)%N] = -(sum i=0 through N-1 of a[i] for b[k] = -1) */
+       for (k = 0; k < N; k++)
+       {
+               t[k] = -t[k];
+       }
+
+       /* t[(i+k)%N] += sum i=0 through N-1 of a[i] for b[k] = +1 */
+       for (j = 0; j < len.p; j++)
+       {
+               k = indices[j];
+               for (i = 0; k < N; ++i, ++k)
+               {
+                       t[k] += a[i];
+               }
+               for (k = 0; i < N; ++i, ++k)
+               {
+                       t[k] += a[i];
+               }
+       }
+
+       /* c = (a * b) mod q */
+       for (k = 0; k < N; k++)
+       {
+               c[k] = t[k] & mod_q_mask;
+       }
+
+       /* cleanup */
+       free(t);
+}
+
+METHOD(ntru_poly_t, ring_mult, void,
+       private_ntru_poly_t *this, uint16_t *a, uint16_t *c)
+{
+       uint16_t *bi = this->indices, mod_q_mask = this->q - 1;
+
+       if (this->num_polynomials == 1)
+       {
+               ring_mult_indices(a, this->indices_len[0], bi, this->N, mod_q_mask, c);
+       }
+       else
+       {
+               uint16_t *t1, *t2;
+               int i;
+
+               /* allocate temporary arrays */
+               t1 = malloc(this->N * sizeof(uint16_t));
+               t2 = malloc(this->N * sizeof(uint16_t));
+
+               /* t1 = a * b1 */
+               ring_mult_indices(a, this->indices_len[0], bi, this->N, mod_q_mask, t1);
+
+               /* t1 = (a * b1) * b2 */
+               bi += this->indices_len[0].p + this->indices_len[0].m;
+               ring_mult_indices(t1, this->indices_len[1], bi, this->N, mod_q_mask, t1);
+
+               /* t2 = a * b3 */
+               bi += this->indices_len[1].p + this->indices_len[1].m;
+               ring_mult_indices(a, this->indices_len[2], bi, this->N, mod_q_mask, t2);
+
+               /* c = (a * b1 * b2) + (a * b3) */
+               for (i = 0; i < this->N; i++)
+               {
+                       c[i] = (t1[i] + t2[i]) & mod_q_mask;
+               }
+
+               /* cleanup */
+               free(t1);
+               free(t2);
+       }
+}
 
 METHOD(ntru_poly_t, destroy, void,
        private_ntru_poly_t *this)
 {
-       memwipe(this->indices, this->indices_len);
+       memwipe(this->indices, get_size(this));
        free(this->indices);
        free(this);
 }
@@ -69,14 +202,15 @@ METHOD(ntru_poly_t, destroy, void,
  * Described in header.
  */
 ntru_poly_t *ntru_poly_create(hash_algorithm_t alg, chunk_t seed,
-                                                         uint8_t c_bits, uint16_t poly_len,
-                                                         uint32_t indices_count, bool is_product_form)
+                                                         uint8_t c_bits, uint16_t N, uint16_t q,
+                                                         uint32_t indices_len_p, uint32_t indices_len_m,
+                                                         bool is_product_form)
 {
        private_ntru_poly_t *this;
-       size_t hash_len, octet_count = 0, i, num_polys, num_indices[3], indices_len;
+       size_t hash_len, octet_count = 0, i;
        uint8_t octets[HASH_SIZE_SHA512], *used, num_left = 0, num_needed;
        uint16_t index, limit, left = 0;
-       int poly_i = 0, index_i = 0;
+       int n, num_indices, index_i = 0;
        ntru_mgf1_t *mgf1;
 
        DBG2(DBG_LIB, "MGF1 is seeded with %u bytes", seed.len);
@@ -87,40 +221,47 @@ ntru_poly_t *ntru_poly_create(hash_algorithm_t alg, chunk_t seed,
        }
        i = hash_len = mgf1->get_hash_size(mgf1);
 
-       if (is_product_form)
-       {
-               num_polys = 3;
-               num_indices[0] = 0xff &  indices_count;
-               num_indices[1] = 0xff & (indices_count >> 8);
-               num_indices[2] = 0xff & (indices_count >> 16);
-               indices_len = num_indices[0] + num_indices[1] + num_indices[2];
-       }
-       else
-       {
-               num_polys = 1;
-               num_indices[0] = indices_count;
-               indices_len = indices_count;
-       }
-       used = malloc(poly_len);
-       limit = poly_len * ((1 << c_bits) / poly_len);
-
        INIT(this,
                .public = {
                        .get_size = _get_size,
                        .get_indices = _get_indices,
+                       .ring_mult = _ring_mult,
                        .destroy = _destroy,
                },
-               .indices_len = indices_len,
-               .indices = malloc(indices_len * sizeof(uint16_t)),
+               .N = N,
+               .q = q,
        );
 
+       if (is_product_form)
+       {
+               this->num_polynomials = 3;
+               for (n = 0; n < 3; n++)
+               {
+                       this->indices_len[n].p = 0xff & indices_len_p;
+                       this->indices_len[n].m = 0xff & indices_len_m;
+                       indices_len_p >>= 8;
+                       indices_len_m >>= 8;
+               }
+       }
+       else
+       {
+               this->num_polynomials = 1;
+               this->indices_len[0].p = indices_len_p;
+               this->indices_len[0].m = indices_len_m;
+       }
+       this->indices = malloc(sizeof(uint16_t) * get_size(this)),
+
+       used = malloc(N);
+       limit = N * ((1 << c_bits) / N);
+
        /* generate indices for all polynomials */
-       while (poly_i < num_polys)
+       for (n = 0; n < this->num_polynomials; n++)
        {
-               memset(used, 0, poly_len);
+               memset(used, 0, N);
+               num_indices = this->indices_len[n].p + this->indices_len[n].m;
 
                /* generate indices for a single polynomial */
-               while (num_indices[poly_i])
+               while (num_indices)
                {
                        /* generate a random candidate index with a size of c_bits */           
                        do
@@ -167,19 +308,18 @@ ntru_poly_t *ntru_poly_create(hash_algorithm_t alg, chunk_t seed,
                        while (index >= limit);
 
                        /* form index and check if unique */
-                       index %= poly_len;
+                       index %= N;
                        if (!used[index])
                        {
                                used[index] = 1;
                                this->indices[index_i++] = index;
-                               num_indices[poly_i]--;
+                               num_indices--;
                        }
                }
-               poly_i++;
        }
 
        DBG2(DBG_LIB, "MGF1 generates %u octets to derive %u indices",
-                                  octet_count, this->indices_len);
+                                  octet_count, get_size(this));
        mgf1->destroy(mgf1);
        free(used);
 
index 92becb1..5367478 100644 (file)
@@ -43,6 +43,11 @@ struct ntru_poly_t {
        uint16_t* (*get_indices)(ntru_poly_t *this);
 
        /**
+        * @return              array containing the indices of the non-zero coefficients
+        */
+       void (*ring_mult)(ntru_poly_t *this, uint16_t *a, uint16_t *c);
+
+       /**
         * Destroy ntru_poly_t object
         */
        void (*destroy)(ntru_poly_t *this);
@@ -53,14 +58,17 @@ struct ntru_poly_t {
  *
  * @param alg                          hash algorithm to be used by MGF1
  * @param seed                         seed used by MGF1 to generate trits from
- * @param poly_len                     size of the trits polynomial
+ * @param N                                    ring dimension, number of polynomial coefficients
+ * @param q                                    large modulus
  * @param c_bits                       number of bits for candidate index
- * @param indices_count                number of non-zero indices
+ * @param indices_len_p                number of indices for +1 coefficients
+ * @param indices_len_m                number of indices for -1 coefficients
  * @param is_product_form      generate multiple polynomials
  */
 ntru_poly_t *ntru_poly_create(hash_algorithm_t alg, chunk_t seed,
-                                                         uint8_t c_bits, uint16_t poly_len,
-                                                         uint32_t indices_count, bool is_product_form);
+                                                         uint8_t c_bits, uint16_t N, uint16_t q,
+                                                         uint32_t indices_len_p, uint32_t indices_len_m,
+                                                         bool is_product_form);
 
 #endif /** NTRU_POLY_H_ @}*/
 
index e42bb8e..3bb7851 100644 (file)
@@ -33,8 +33,8 @@ IMPORT_FUNCTION_FOR_TESTS(ntru, ntru_trits_create, ntru_trits_t*,
 
 IMPORT_FUNCTION_FOR_TESTS(ntru, ntru_poly_create, ntru_poly_t*,
                                                  hash_algorithm_t alg, chunk_t seed, uint8_t c_bits,
-                                                 uint16_t poly_len, uint32_t indices_count,
-                                                 bool is_product_form)
+                                                 uint16_t N, uint16_t q, uint32_t indices_len_p,
+                                                 uint32_t indices_len_m, bool is_product_form)
 
 /**
  * NTRU parameter sets to test
@@ -302,10 +302,11 @@ END_TEST
 
 typedef struct {
        uint8_t c_bits;
-       uint16_t poly_len;
+       uint16_t N;
+       uint16_t q;
        bool is_product_form;
-       uint32_t indices_count;
        uint32_t indices_len;
+       uint32_t indices_size;
        uint16_t *indices;
 } poly_test_t;
 
@@ -427,10 +428,10 @@ mgf1_test_t mgf1_tests[] = {
                                0, 1, 1, 2, 0,  2, 2, 0, 0, 0,  1, 1, 0, 1, 0,  1, 1, 0, 1, 1,
                                0, 1, 2, 0, 1,  1, 0, 1, 2, 0,  0, 1, 2, 2, 0,  0, 2, 1, 2),
                {
-                       {       9, 439, TRUE, 2*(9 + (8 << 8) + (5 << 16)),
+                       {       9, 439, 2048, TRUE, 9 + (8 << 8) + (5 << 16),
                                countof(indices_ees439ep1), indices_ees439ep1
                        },
-                       {       11, 613, FALSE, 2*55,
+                       {       11, 613, 2048, FALSE, 55,
                                countof(indices_ees613ep1), indices_ees613ep1
                        }
                }
@@ -514,10 +515,10 @@ mgf1_test_t mgf1_tests[] = {
                                1, 0, 1, 0, 2,  2, 1, 0, 2, 2,  2, 2, 2, 1, 0,  2, 2, 2, 1, 2,
                                0, 2, 0, 0, 0,  0, 0, 1, 2, 0,  1, 0, 1),
                {
-                       {       13, 743, TRUE, 2*(11 + (11 << 8) + (15 << 16)),
+                       {       13, 743, 2048, TRUE, 11 + (11 << 8) + (15 << 16),
                                countof(indices_ees743ep1), indices_ees743ep1
                        },
-                       {       12, 1171, FALSE, 2*106,
+                       {       12, 1171, 2048, FALSE, 106,
                                countof(indices_ees1171ep1), indices_ees1171ep1
                        }
                }
@@ -632,19 +633,21 @@ START_TEST(test_ntru_poly)
        seed.len = mgf1_tests[_i].seed_len;
 
        p = &mgf1_tests[_i].poly_test[0];
-       poly = ntru_poly_create(HASH_UNKNOWN, seed, p->c_bits, p->poly_len,
-                                                       p->indices_count, p->is_product_form);
+       poly = ntru_poly_create(HASH_UNKNOWN, seed, p->c_bits, p->N, p->q,
+                                                       p->indices_len, p->indices_len,
+                                                       p->is_product_form);
        ck_assert(poly == NULL);
 
        for (n = 0; n < 2; n++)
        {
                p = &mgf1_tests[_i].poly_test[n];
-               poly = ntru_poly_create(mgf1_tests[_i].alg, seed, p->c_bits, p->poly_len,
-                                                               p->indices_count, p->is_product_form);
-               ck_assert(poly != NULL && poly->get_size(poly) == p->indices_len);
+               poly = ntru_poly_create(mgf1_tests[_i].alg, seed, p->c_bits, p->N, p->q,
+                                                               p->indices_len, p->indices_len,
+                                                               p->is_product_form);
+               ck_assert(poly != NULL && poly->get_size(poly) == p->indices_size);
 
                indices = poly->get_indices(poly);
-               for (j = 0; j < p->indices_len; j++)
+               for (j = 0; j < p->indices_size; j++)
                {
                        ck_assert(indices[j] == p->indices[j]);
                }