Use mgf1_bitspender in ntru_poly_create_from_seed
[strongswan.git] / src / libstrongswan / plugins / ntru / ntru_poly.c
index a021ace..b5b3898 100644 (file)
  */
 
 #include "ntru_poly.h"
-#include "ntru_mgf1.h"
 
+#include <crypto/mgf1/mgf1_bitspender.h>
 #include <utils/debug.h>
 #include <utils/test.h>
 
 typedef struct private_ntru_poly_t private_ntru_poly_t;
+typedef struct indices_len_t indices_len_t;
+
+/**
+ * Stores number of +1 and -1 coefficients
+ */
+struct indices_len_t {
+       int p;
+       int m;
+};
 
 /**
  * Private data of an ntru_poly_t object.
@@ -34,21 +43,41 @@ struct private_ntru_poly_t {
        ntru_poly_t public;
 
        /**
+        * Ring dimension equal to the number of polynomial coefficients
+        */
+       uint16_t N;
+
+       /**
+        * Large modulus
+        */
+       uint16_t q;
+
+       /**
         * Array containing the indices of the non-zero coefficients
         */
        uint16_t *indices;
 
        /**
-        * Number of non-zero coefficients
+        * Number of indices of the non-zero coefficients
+        */
+       size_t num_indices;
+
+       /**
+        * Number of sparse polynomials
+        */
+       int num_polynomials;
+
+       /**
+        * Number of nonzero coefficients for up to 3 sparse polynomials
         */
-       uint32_t indices_len;
+       indices_len_t indices_len[3];
 
 };
 
 METHOD(ntru_poly_t, get_size, size_t,
        private_ntru_poly_t *this)
 {
-       return this->indices_len;
+       return this->num_indices;
 }
 
 METHOD(ntru_poly_t, get_indices, uint16_t*,
@@ -57,133 +86,292 @@ METHOD(ntru_poly_t, get_indices, uint16_t*,
        return this->indices;
 }
 
-METHOD(ntru_poly_t, destroy, void,
-       private_ntru_poly_t *this)
+/**
+  * Multiplication of polynomial a with a sparse polynomial b given by
+  * the indices of its +1 and -1 coefficients results in polynomial c.
+  * This is a convolution operation
+  */
+static void ring_mult_i(uint16_t *a, indices_len_t len, uint16_t *indices,
+                                                         uint16_t N, uint16_t mod_q_mask, uint16_t *t,
+                                                         uint16_t *c)
 {
-       memwipe(this->indices, this->indices_len);
-       free(this->indices);
-       free(this);
+       int i, j, k;
+
+       /* initialize temporary array t */
+       for (k = 0; k < N; k++)
+       {
+               t[k] = 0;
+       }
+
+       /* t[(i+k)%N] = sum i=0 through N-1 of a[i], for b[k] = -1 */
+       for (j = len.p; j < len.p + len.m; j++)
+       {
+               k = indices[j];
+               for (i = 0; k < N; ++i, ++k)
+               {
+                       t[k] += a[i];
+               }
+               for (k = 0; i < N; ++i, ++k)
+               {
+                       t[k] += a[i];
+               }
+       }
+
+       /* t[(i+k)%N] = -(sum i=0 through N-1 of a[i] for b[k] = -1) */
+       for (k = 0; k < N; k++)
+       {
+               t[k] = -t[k];
+       }
+
+       /* t[(i+k)%N] += sum i=0 through N-1 of a[i] for b[k] = +1 */
+       for (j = 0; j < len.p; j++)
+       {
+               k = indices[j];
+               for (i = 0; k < N; ++i, ++k)
+               {
+                       t[k] += a[i];
+               }
+               for (k = 0; i < N; ++i, ++k)
+               {
+                       t[k] += a[i];
+               }
+       }
+
+       /* c = (a * b) mod q */
+       for (k = 0; k < N; k++)
+       {
+               c[k] = t[k] & mod_q_mask;
+       }
 }
 
-/*
- * Described in header.
- */
-ntru_poly_t *ntru_poly_create(hash_algorithm_t alg, chunk_t seed,
-                                                         uint8_t c_bits, uint16_t limit, 
-                                                 uint16_t poly_len, uint32_t indices_count,
-                                                         bool is_product_form)
+METHOD(ntru_poly_t, get_array, void,
+       private_ntru_poly_t *this, uint16_t *array)
 {
-       private_ntru_poly_t *this;
-       size_t hash_len, octet_count = 0, i, num_polys, num_indices[3], indices_len;
-       uint8_t octets[HASH_SIZE_SHA512], *used, num_left = 0, num_needed;
-       uint16_t index, left = 0;
-       int poly_i = 0, index_i = 0;
-       ntru_mgf1_t *mgf1;
-
-       DBG2(DBG_LIB, "MGF1 is seeded with %u bytes", seed.len);
-       mgf1 = ntru_mgf1_create(alg, seed, TRUE);
-       if (!mgf1)
+       uint16_t *t, *bi;
+       uint16_t mod_q_mask = this->q - 1;
+       indices_len_t len;
+       int i;
+
+       /* form polynomial F or F1 */
+       memset(array, 0x00, this->N * sizeof(uint16_t));
+       bi = this->indices;
+       len = this->indices_len[0];
+       for (i = 0; i < len.p + len.m; i++)
        {
-           return NULL;
+               array[bi[i]] = (i < len.p) ? 1 : mod_q_mask;
        }
-       i = hash_len = mgf1->get_hash_size(mgf1);
 
-       if (is_product_form)
+       if (this->num_polynomials == 3)
+       {
+               /* allocate temporary array t */
+               t = malloc(this->N * sizeof(uint16_t));
+
+               /* form F1 * F2 */
+               bi += len.p + len.m;
+               len = this->indices_len[1];
+               ring_mult_i(array, len, bi, this->N, mod_q_mask, t, array);
+
+               /* form (F1 * F2) + F3 */
+               bi += len.p + len.m;
+               len = this->indices_len[2];
+               for (i = 0; i < len.p + len.m; i++)
+               {
+                       if (i < len.p)
+                       {
+                               array[bi[i]] += 1;
+                       }
+                       else
+                       {
+                               array[bi[i]] -= 1;
+                       }
+                       array[bi[i]] &= mod_q_mask;
+               }
+               free(t);
+       }
+}
+
+METHOD(ntru_poly_t, ring_mult, void,
+       private_ntru_poly_t *this, uint16_t *a, uint16_t *c)
+{
+       uint16_t *t1, *t2;
+       uint16_t *bi = this->indices;
+       uint16_t mod_q_mask = this->q - 1;
+       int i;
+
+       /* allocate temporary array t1 */
+       t1 = malloc(this->N * sizeof(uint16_t));
+
+       if (this->num_polynomials == 1)
        {
-               num_polys = 3;
-               num_indices[0] = 0xff &  indices_count;
-               num_indices[1] = 0xff & (indices_count >> 8);
-               num_indices[2] = 0xff & (indices_count >> 16);
-               indices_len = num_indices[0] + num_indices[1] + num_indices[2];
+               ring_mult_i(a, this->indices_len[0], bi, this->N, mod_q_mask, t1, c);
        }
        else
        {
-               num_polys = 1;
-               num_indices[0] = indices_count;
-               indices_len = indices_count;
+               /* allocate temporary array t2 */
+               t2 = malloc(this->N * sizeof(uint16_t));
+
+               /* t1 = a * b1 */
+               ring_mult_i(a, this->indices_len[0], bi, this->N, mod_q_mask, t1, t1);
+
+               /* t1 = (a * b1) * b2 */
+               bi += this->indices_len[0].p + this->indices_len[0].m;
+               ring_mult_i(t1, this->indices_len[1], bi, this->N, mod_q_mask, t2, t1);
+
+               /* t2 = a * b3 */
+               bi += this->indices_len[1].p + this->indices_len[1].m;
+               ring_mult_i(a, this->indices_len[2], bi, this->N, mod_q_mask, t2, t2);
+
+               /* c = (a * b1 * b2) + (a * b3) */
+               for (i = 0; i < this->N; i++)
+               {
+                       c[i] = (t1[i] + t2[i]) & mod_q_mask;
+               }
+               free(t2);
        }
-       used = malloc(poly_len);
+       free(t1);
+}
+
+METHOD(ntru_poly_t, destroy, void,
+       private_ntru_poly_t *this)
+{
+       memwipe(this->indices, sizeof(uint16_t) * get_size(this));
+       free(this->indices);
+       free(this);
+}
+
+/**
+ * Creates an empty ntru_poly_t object with space allocated for indices
+ */
+static private_ntru_poly_t* ntru_poly_create(uint16_t N, uint16_t q,
+                                                                                        uint32_t indices_len_p,
+                                                                                        uint32_t indices_len_m,
+                                                                                        bool is_product_form)
+{
+       private_ntru_poly_t *this;
+       int n;
 
        INIT(this,
                .public = {
                        .get_size = _get_size,
                        .get_indices = _get_indices,
+                       .get_array = _get_array,
+                       .ring_mult = _ring_mult,
                        .destroy = _destroy,
                },
-               .indices_len = indices_len,
-               .indices = malloc(indices_len * sizeof(uint16_t)),
+               .N = N,
+               .q = q,
        );
 
+       if (is_product_form)
+       {
+               this->num_polynomials = 3;
+               for (n = 0; n < 3; n++)
+               {
+                       this->indices_len[n].p = 0xff & indices_len_p;
+                       this->indices_len[n].m = 0xff & indices_len_m;
+                       this->num_indices += this->indices_len[n].p +
+                                                                this->indices_len[n].m;
+                       indices_len_p >>= 8;
+                       indices_len_m >>= 8;
+               }
+       }
+       else
+       {
+               this->num_polynomials = 1;
+               this->indices_len[0].p = indices_len_p;
+               this->indices_len[0].m = indices_len_m;
+               this->num_indices = indices_len_p + indices_len_m;
+       }
+       this->indices = malloc(sizeof(uint16_t) * this->num_indices);
+
+       return this;
+}
+
+/*
+ * Described in header.
+ */
+ntru_poly_t *ntru_poly_create_from_seed(hash_algorithm_t alg, chunk_t seed,
+                                                                               uint8_t c_bits, uint16_t N, uint16_t q,
+                                                                               uint32_t indices_len_p,
+                                                                               uint32_t indices_len_m,
+                                                                               bool is_product_form)
+{
+       private_ntru_poly_t *this;
+       int n, num_indices, index_i = 0;
+       uint32_t index, limit;
+       uint8_t *used;
+       mgf1_bitspender_t *bitspender;
+
+       bitspender = mgf1_bitspender_create(alg, seed, TRUE);
+       if (!bitspender)
+       {
+           return NULL;
+       }
+       this = ntru_poly_create(N, q, indices_len_p, indices_len_m, is_product_form);
+       used = malloc(N);
+       limit = N * ((1 << c_bits) / N);
+
        /* generate indices for all polynomials */
-       while (poly_i < num_polys)
+       for (n = 0; n < this->num_polynomials; n++)
        {
-               memset(used, 0, poly_len);
+               memset(used, 0, N);
+               num_indices = this->indices_len[n].p + this->indices_len[n].m;
 
                /* generate indices for a single polynomial */
-               while (num_indices[poly_i])
+               while (num_indices)
                {
                        /* generate a random candidate index with a size of c_bits */           
                        do
                        {
-                               /* use any leftover bits first */
-                               index = num_left ? left << (c_bits - num_left) : 0;
-
-                               /* get the rest of the bits needed from new octets */
-                               num_needed = c_bits - num_left;
-
-                               while (num_needed)
+                               index = bitspender->get_bits(bitspender, c_bits);
+                               if (index == MGF1_BITSPENDER_ERROR)
                                {
-                                       if (i == hash_len)
-                                       {
-                                               /* get another block from MGF1 */
-                                               if (!mgf1->get_mask(mgf1, hash_len, octets))
-                                               {
-                                                       mgf1->destroy(mgf1);
-                                                       destroy(this);
-                                                       free(used);
-                                                       return NULL;
-                                               }
-                                               octet_count += hash_len;
-                                               i = 0;
-                                       }
-                                       left = octets[i++];
-
-                                       if (num_needed <= 8)
-                                       {
-                                               /* all bits needed to fill the index are in this octet */
-                                               index |= left >> (8 - num_needed);
-                                               num_left = 8 - num_needed;
-                                               num_needed = 0;
-                                               left &= 0xff >> (8 - num_left);
-                                       }
-                                       else
-                                       {
-                                               /* more than one octet will be needed */
-                                               index |= left << (num_needed - 8);
-                                               num_needed -= 8;
-                                       }
+                                       bitspender->destroy(bitspender);
+                                       destroy(this);
+                                       free(used);
+                                       return NULL;
                                }
                        }
                        while (index >= limit);
 
                        /* form index and check if unique */
-                       index %= poly_len;
+                       index %= N;
                        if (!used[index])
                        {
                                used[index] = 1;
                                this->indices[index_i++] = index;
-                               num_indices[poly_i]--;
+                               num_indices--;
                        }
                }
-               poly_i++;
        }
 
-       DBG2(DBG_LIB, "MGF1 generates %u octets to derive %u indices",
-                                  octet_count, this->indices_len);
-       mgf1->destroy(mgf1);
+       bitspender->destroy(bitspender);
        free(used);
 
        return &this->public;
 }
 
-EXPORT_FUNCTION_FOR_TESTS(ntru, ntru_poly_create);
+/*
+ * Described in header.
+ */
+ntru_poly_t *ntru_poly_create_from_data(uint16_t *data, uint16_t N, uint16_t q,
+                                                                               uint32_t indices_len_p,
+                                                                               uint32_t indices_len_m,
+                                                                               bool is_product_form)
+{
+       private_ntru_poly_t *this;
+       int i;
+
+       this = ntru_poly_create(N, q, indices_len_p, indices_len_m, is_product_form);
+
+       for (i = 0; i < this->num_indices; i++)
+       {
+               this->indices[i] = data[i];
+       }
+
+       return &this->public;
+}
+
+EXPORT_FUNCTION_FOR_TESTS(ntru, ntru_poly_create_from_seed);
+
+EXPORT_FUNCTION_FOR_TESTS(ntru, ntru_poly_create_from_data);