Use mgf1_bitspender in ntru_poly_create_from_seed
[strongswan.git] / src / libstrongswan / plugins / ntru / ntru_poly.c
index 2081d03..b5b3898 100644 (file)
@@ -16,8 +16,8 @@
  */
 
 #include "ntru_poly.h"
-#include "ntru_mgf1.h"
 
+#include <crypto/mgf1/mgf1_bitspender.h>
 #include <utils/debug.h>
 #include <utils/test.h>
 
@@ -58,6 +58,11 @@ struct private_ntru_poly_t {
        uint16_t *indices;
 
        /**
+        * Number of indices of the non-zero coefficients
+        */
+       size_t num_indices;
+
+       /**
         * Number of sparse polynomials
         */
        int num_polynomials;
@@ -72,14 +77,7 @@ struct private_ntru_poly_t {
 METHOD(ntru_poly_t, get_size, size_t,
        private_ntru_poly_t *this)
 {
-       int n;
-       size_t size = 0;
-
-       for (n = 0; n < this->num_polynomials; n++)
-       {
-               size += this->indices_len[n].p + this->indices_len[n].m;
-       }
-       return size;
+       return this->num_indices;
 }
 
 METHOD(ntru_poly_t, get_indices, uint16_t*,
@@ -87,19 +85,19 @@ METHOD(ntru_poly_t, get_indices, uint16_t*,
 {
        return this->indices;
 }
+
 /**
   * Multiplication of polynomial a with a sparse polynomial b given by
   * the indices of its +1 and -1 coefficients results in polynomial c.
   * This is a convolution operation
   */
-static void ring_mult_indices(uint16_t *a, indices_len_t len, uint16_t *indices,
-                                                         uint16_t N, uint16_t mod_q_mask, uint16_t *c)
+static void ring_mult_i(uint16_t *a, indices_len_t len, uint16_t *indices,
+                                                         uint16_t N, uint16_t mod_q_mask, uint16_t *t,
+                                                         uint16_t *c)
 {
-       uint16_t *t;
        int i, j, k;
 
-       /* allocate and initialize temporary array t */
-       t = malloc(N * sizeof(uint16_t));
+       /* initialize temporary array t */
        for (k = 0; k < N; k++)
        {
                t[k] = 0;
@@ -144,87 +142,119 @@ static void ring_mult_indices(uint16_t *a, indices_len_t len, uint16_t *indices,
        {
                c[k] = t[k] & mod_q_mask;
        }
+}
+
+METHOD(ntru_poly_t, get_array, void,
+       private_ntru_poly_t *this, uint16_t *array)
+{
+       uint16_t *t, *bi;
+       uint16_t mod_q_mask = this->q - 1;
+       indices_len_t len;
+       int i;
+
+       /* form polynomial F or F1 */
+       memset(array, 0x00, this->N * sizeof(uint16_t));
+       bi = this->indices;
+       len = this->indices_len[0];
+       for (i = 0; i < len.p + len.m; i++)
+       {
+               array[bi[i]] = (i < len.p) ? 1 : mod_q_mask;
+       }
 
-       /* cleanup */
-       free(t);
+       if (this->num_polynomials == 3)
+       {
+               /* allocate temporary array t */
+               t = malloc(this->N * sizeof(uint16_t));
+
+               /* form F1 * F2 */
+               bi += len.p + len.m;
+               len = this->indices_len[1];
+               ring_mult_i(array, len, bi, this->N, mod_q_mask, t, array);
+
+               /* form (F1 * F2) + F3 */
+               bi += len.p + len.m;
+               len = this->indices_len[2];
+               for (i = 0; i < len.p + len.m; i++)
+               {
+                       if (i < len.p)
+                       {
+                               array[bi[i]] += 1;
+                       }
+                       else
+                       {
+                               array[bi[i]] -= 1;
+                       }
+                       array[bi[i]] &= mod_q_mask;
+               }
+               free(t);
+       }
 }
 
 METHOD(ntru_poly_t, ring_mult, void,
        private_ntru_poly_t *this, uint16_t *a, uint16_t *c)
 {
-       uint16_t *bi = this->indices, mod_q_mask = this->q - 1;
+       uint16_t *t1, *t2;
+       uint16_t *bi = this->indices;
+       uint16_t mod_q_mask = this->q - 1;
+       int i;
+
+       /* allocate temporary array t1 */
+       t1 = malloc(this->N * sizeof(uint16_t));
 
        if (this->num_polynomials == 1)
        {
-               ring_mult_indices(a, this->indices_len[0], bi, this->N, mod_q_mask, c);
+               ring_mult_i(a, this->indices_len[0], bi, this->N, mod_q_mask, t1, c);
        }
        else
        {
-               uint16_t *t1, *t2;
-               int i;
-
-               /* allocate temporary arrays */
-               t1 = malloc(this->N * sizeof(uint16_t));
+               /* allocate temporary array t2 */
                t2 = malloc(this->N * sizeof(uint16_t));
 
                /* t1 = a * b1 */
-               ring_mult_indices(a, this->indices_len[0], bi, this->N, mod_q_mask, t1);
+               ring_mult_i(a, this->indices_len[0], bi, this->N, mod_q_mask, t1, t1);
 
                /* t1 = (a * b1) * b2 */
                bi += this->indices_len[0].p + this->indices_len[0].m;
-               ring_mult_indices(t1, this->indices_len[1], bi, this->N, mod_q_mask, t1);
+               ring_mult_i(t1, this->indices_len[1], bi, this->N, mod_q_mask, t2, t1);
 
                /* t2 = a * b3 */
                bi += this->indices_len[1].p + this->indices_len[1].m;
-               ring_mult_indices(a, this->indices_len[2], bi, this->N, mod_q_mask, t2);
+               ring_mult_i(a, this->indices_len[2], bi, this->N, mod_q_mask, t2, t2);
 
                /* c = (a * b1 * b2) + (a * b3) */
                for (i = 0; i < this->N; i++)
                {
                        c[i] = (t1[i] + t2[i]) & mod_q_mask;
                }
-
-               /* cleanup */
-               free(t1);
                free(t2);
        }
+       free(t1);
 }
 
 METHOD(ntru_poly_t, destroy, void,
        private_ntru_poly_t *this)
 {
-       memwipe(this->indices, get_size(this));
+       memwipe(this->indices, sizeof(uint16_t) * get_size(this));
        free(this->indices);
        free(this);
 }
 
-/*
- * Described in header.
+/**
+ * Creates an empty ntru_poly_t object with space allocated for indices
  */
-ntru_poly_t *ntru_poly_create(hash_algorithm_t alg, chunk_t seed,
-                                                         uint8_t c_bits, uint16_t N, uint16_t q,
-                                                         uint32_t indices_len_p, uint32_t indices_len_m,
-                                                         bool is_product_form)
+static private_ntru_poly_t* ntru_poly_create(uint16_t N, uint16_t q,
+                                                                                        uint32_t indices_len_p,
+                                                                                        uint32_t indices_len_m,
+                                                                                        bool is_product_form)
 {
        private_ntru_poly_t *this;
-       size_t hash_len, octet_count = 0, i;
-       uint8_t octets[HASH_SIZE_SHA512], *used, num_left = 0, num_needed;
-       uint16_t index, limit, left = 0;
-       int n, num_indices, index_i = 0;
-       ntru_mgf1_t *mgf1;
-
-       DBG2(DBG_LIB, "MGF1 is seeded with %u bytes", seed.len);
-       mgf1 = ntru_mgf1_create(alg, seed, TRUE);
-       if (!mgf1)
-       {
-           return NULL;
-       }
-       i = hash_len = mgf1->get_hash_size(mgf1);
+       int n;
 
        INIT(this,
                .public = {
                        .get_size = _get_size,
                        .get_indices = _get_indices,
+                       .get_array = _get_array,
                        .ring_mult = _ring_mult,
                        .destroy = _destroy,
                },
@@ -239,6 +269,8 @@ ntru_poly_t *ntru_poly_create(hash_algorithm_t alg, chunk_t seed,
                {
                        this->indices_len[n].p = 0xff & indices_len_p;
                        this->indices_len[n].m = 0xff & indices_len_m;
+                       this->num_indices += this->indices_len[n].p +
+                                                                this->indices_len[n].m;
                        indices_len_p >>= 8;
                        indices_len_m >>= 8;
                }
@@ -248,9 +280,34 @@ ntru_poly_t *ntru_poly_create(hash_algorithm_t alg, chunk_t seed,
                this->num_polynomials = 1;
                this->indices_len[0].p = indices_len_p;
                this->indices_len[0].m = indices_len_m;
+               this->num_indices = indices_len_p + indices_len_m;
        }
-       this->indices = malloc(sizeof(uint16_t) * get_size(this)),
+       this->indices = malloc(sizeof(uint16_t) * this->num_indices);
+
+       return this;
+}
+
+/*
+ * Described in header.
+ */
+ntru_poly_t *ntru_poly_create_from_seed(hash_algorithm_t alg, chunk_t seed,
+                                                                               uint8_t c_bits, uint16_t N, uint16_t q,
+                                                                               uint32_t indices_len_p,
+                                                                               uint32_t indices_len_m,
+                                                                               bool is_product_form)
+{
+       private_ntru_poly_t *this;
+       int n, num_indices, index_i = 0;
+       uint32_t index, limit;
+       uint8_t *used;
+       mgf1_bitspender_t *bitspender;
 
+       bitspender = mgf1_bitspender_create(alg, seed, TRUE);
+       if (!bitspender)
+       {
+           return NULL;
+       }
+       this = ntru_poly_create(N, q, indices_len_p, indices_len_m, is_product_form);
        used = malloc(N);
        limit = N * ((1 << c_bits) / N);
 
@@ -266,43 +323,13 @@ ntru_poly_t *ntru_poly_create(hash_algorithm_t alg, chunk_t seed,
                        /* generate a random candidate index with a size of c_bits */           
                        do
                        {
-                               /* use any leftover bits first */
-                               index = num_left ? left << (c_bits - num_left) : 0;
-
-                               /* get the rest of the bits needed from new octets */
-                               num_needed = c_bits - num_left;
-
-                               while (num_needed)
+                               index = bitspender->get_bits(bitspender, c_bits);
+                               if (index == MGF1_BITSPENDER_ERROR)
                                {
-                                       if (i == hash_len)
-                                       {
-                                               /* get another block from MGF1 */
-                                               if (!mgf1->get_mask(mgf1, hash_len, octets))
-                                               {
-                                                       mgf1->destroy(mgf1);
-                                                       destroy(this);
-                                                       free(used);
-                                                       return NULL;
-                                               }
-                                               octet_count += hash_len;
-                                               i = 0;
-                                       }
-                                       left = octets[i++];
-
-                                       if (num_needed <= 8)
-                                       {
-                                               /* all bits needed to fill the index are in this octet */
-                                               index |= left >> (8 - num_needed);
-                                               num_left = 8 - num_needed;
-                                               num_needed = 0;
-                                               left &= 0xff >> (8 - num_left);
-                                       }
-                                       else
-                                       {
-                                               /* more than one octet will be needed */
-                                               index |= left << (num_needed - 8);
-                                               num_needed -= 8;
-                                       }
+                                       bitspender->destroy(bitspender);
+                                       destroy(this);
+                                       free(used);
+                                       return NULL;
                                }
                        }
                        while (index >= limit);
@@ -318,12 +345,33 @@ ntru_poly_t *ntru_poly_create(hash_algorithm_t alg, chunk_t seed,
                }
        }
 
-       DBG2(DBG_LIB, "MGF1 generates %u octets to derive %u indices",
-                                  octet_count, get_size(this));
-       mgf1->destroy(mgf1);
+       bitspender->destroy(bitspender);
        free(used);
 
        return &this->public;
 }
 
-EXPORT_FUNCTION_FOR_TESTS(ntru, ntru_poly_create);
+/*
+ * Described in header.
+ */
+ntru_poly_t *ntru_poly_create_from_data(uint16_t *data, uint16_t N, uint16_t q,
+                                                                               uint32_t indices_len_p,
+                                                                               uint32_t indices_len_m,
+                                                                               bool is_product_form)
+{
+       private_ntru_poly_t *this;
+       int i;
+
+       this = ntru_poly_create(N, q, indices_len_p, indices_len_m, is_product_form);
+
+       for (i = 0; i < this->num_indices; i++)
+       {
+               this->indices[i] = data[i];
+       }
+
+       return &this->public;
+}
+
+EXPORT_FUNCTION_FOR_TESTS(ntru, ntru_poly_create_from_seed);
+
+EXPORT_FUNCTION_FOR_TESTS(ntru, ntru_poly_create_from_data);