Use mgf1_bitspender in ntru_poly_create_from_seed
[strongswan.git] / src / libstrongswan / plugins / ntru / ntru_poly.c
index 9cc537f..b5b3898 100644 (file)
@@ -16,8 +16,8 @@
  */
 
 #include "ntru_poly.h"
-#include "ntru_mgf1.h"
 
+#include <crypto/mgf1/mgf1_bitspender.h>
 #include <utils/debug.h>
 #include <utils/test.h>
 
@@ -58,6 +58,11 @@ struct private_ntru_poly_t {
        uint16_t *indices;
 
        /**
+        * Number of indices of the non-zero coefficients
+        */
+       size_t num_indices;
+
+       /**
         * Number of sparse polynomials
         */
        int num_polynomials;
@@ -72,14 +77,7 @@ struct private_ntru_poly_t {
 METHOD(ntru_poly_t, get_size, size_t,
        private_ntru_poly_t *this)
 {
-       int n;
-       size_t size = 0;
-
-       for (n = 0; n < this->num_polynomials; n++)
-       {
-               size += this->indices_len[n].p + this->indices_len[n].m;
-       }
-       return size;
+       return this->num_indices;
 }
 
 METHOD(ntru_poly_t, get_indices, uint16_t*,
@@ -241,29 +239,16 @@ METHOD(ntru_poly_t, destroy, void,
        free(this);
 }
 
-/*
- * Described in header.
+/**
+ * Creates an empty ntru_poly_t object with space allocated for indices
  */
-ntru_poly_t *ntru_poly_create_from_seed(hash_algorithm_t alg, chunk_t seed,
-                                                                               uint8_t c_bits, uint16_t N, uint16_t q,
-                                                                               uint32_t indices_len_p,
-                                                                               uint32_t indices_len_m,
-                                                                               bool is_product_form)
+static private_ntru_poly_t* ntru_poly_create(uint16_t N, uint16_t q,
+                                                                                        uint32_t indices_len_p,
+                                                                                        uint32_t indices_len_m,
+                                                                                        bool is_product_form)
 {
        private_ntru_poly_t *this;
-       size_t hash_len, octet_count = 0, i;
-       uint8_t octets[HASH_SIZE_SHA512], *used, num_left = 0, num_needed;
-       uint16_t index, limit, left = 0;
-       int n, num_indices, index_i = 0;
-       ntru_mgf1_t *mgf1;
-
-       DBG2(DBG_LIB, "MGF1 is seeded with %u bytes", seed.len);
-       mgf1 = ntru_mgf1_create(alg, seed, TRUE);
-       if (!mgf1)
-       {
-           return NULL;
-       }
-       i = hash_len = mgf1->get_hash_size(mgf1);
+       int n;
 
        INIT(this,
                .public = {
@@ -284,6 +269,8 @@ ntru_poly_t *ntru_poly_create_from_seed(hash_algorithm_t alg, chunk_t seed,
                {
                        this->indices_len[n].p = 0xff & indices_len_p;
                        this->indices_len[n].m = 0xff & indices_len_m;
+                       this->num_indices += this->indices_len[n].p +
+                                                                this->indices_len[n].m;
                        indices_len_p >>= 8;
                        indices_len_m >>= 8;
                }
@@ -293,9 +280,34 @@ ntru_poly_t *ntru_poly_create_from_seed(hash_algorithm_t alg, chunk_t seed,
                this->num_polynomials = 1;
                this->indices_len[0].p = indices_len_p;
                this->indices_len[0].m = indices_len_m;
+               this->num_indices = indices_len_p + indices_len_m;
        }
-       this->indices = malloc(sizeof(uint16_t) * get_size(this)),
+       this->indices = malloc(sizeof(uint16_t) * this->num_indices);
+
+       return this;
+}
+
+/*
+ * Described in header.
+ */
+ntru_poly_t *ntru_poly_create_from_seed(hash_algorithm_t alg, chunk_t seed,
+                                                                               uint8_t c_bits, uint16_t N, uint16_t q,
+                                                                               uint32_t indices_len_p,
+                                                                               uint32_t indices_len_m,
+                                                                               bool is_product_form)
+{
+       private_ntru_poly_t *this;
+       int n, num_indices, index_i = 0;
+       uint32_t index, limit;
+       uint8_t *used;
+       mgf1_bitspender_t *bitspender;
 
+       bitspender = mgf1_bitspender_create(alg, seed, TRUE);
+       if (!bitspender)
+       {
+           return NULL;
+       }
+       this = ntru_poly_create(N, q, indices_len_p, indices_len_m, is_product_form);
        used = malloc(N);
        limit = N * ((1 << c_bits) / N);
 
@@ -311,43 +323,13 @@ ntru_poly_t *ntru_poly_create_from_seed(hash_algorithm_t alg, chunk_t seed,
                        /* generate a random candidate index with a size of c_bits */           
                        do
                        {
-                               /* use any leftover bits first */
-                               index = num_left ? left << (c_bits - num_left) : 0;
-
-                               /* get the rest of the bits needed from new octets */
-                               num_needed = c_bits - num_left;
-
-                               while (num_needed)
+                               index = bitspender->get_bits(bitspender, c_bits);
+                               if (index == MGF1_BITSPENDER_ERROR)
                                {
-                                       if (i == hash_len)
-                                       {
-                                               /* get another block from MGF1 */
-                                               if (!mgf1->get_mask(mgf1, hash_len, octets))
-                                               {
-                                                       mgf1->destroy(mgf1);
-                                                       destroy(this);
-                                                       free(used);
-                                                       return NULL;
-                                               }
-                                               octet_count += hash_len;
-                                               i = 0;
-                                       }
-                                       left = octets[i++];
-
-                                       if (num_needed <= 8)
-                                       {
-                                               /* all bits needed to fill the index are in this octet */
-                                               index |= left >> (8 - num_needed);
-                                               num_left = 8 - num_needed;
-                                               num_needed = 0;
-                                               left &= 0xff >> (8 - num_left);
-                                       }
-                                       else
-                                       {
-                                               /* more than one octet will be needed */
-                                               index |= left << (num_needed - 8);
-                                               num_needed -= 8;
-                                       }
+                                       bitspender->destroy(bitspender);
+                                       destroy(this);
+                                       free(used);
+                                       return NULL;
                                }
                        }
                        while (index >= limit);
@@ -363,9 +345,7 @@ ntru_poly_t *ntru_poly_create_from_seed(hash_algorithm_t alg, chunk_t seed,
                }
        }
 
-       DBG2(DBG_LIB, "MGF1 generates %u octets to derive %u indices",
-                                  octet_count, get_size(this));
-       mgf1->destroy(mgf1);
+       bitspender->destroy(bitspender);
        free(used);
 
        return &this->public;
@@ -380,41 +360,11 @@ ntru_poly_t *ntru_poly_create_from_data(uint16_t *data, uint16_t N, uint16_t q,
                                                                                bool is_product_form)
 {
        private_ntru_poly_t *this;
-       int n, i, num_indices;
-
-       INIT(this,
-               .public = {
-                       .get_size = _get_size,
-                       .get_indices = _get_indices,
-                       .get_array = _get_array,
-                       .ring_mult = _ring_mult,
-                       .destroy = _destroy,
-               },
-               .N = N,
-               .q = q,
-       );
+       int i;
 
-       if (is_product_form)
-       {
-               this->num_polynomials = 3;
-               for (n = 0; n < 3; n++)
-               {
-                       this->indices_len[n].p = 0xff & indices_len_p;
-                       this->indices_len[n].m = 0xff & indices_len_m;
-                       indices_len_p >>= 8;
-                       indices_len_m >>= 8;
-               }
-       }
-       else
-       {
-               this->num_polynomials = 1;
-               this->indices_len[0].p = indices_len_p;
-               this->indices_len[0].m = indices_len_m;
-       }
-       num_indices = get_size(this);
+       this = ntru_poly_create(N, q, indices_len_p, indices_len_m, is_product_form);
 
-       this->indices = malloc(sizeof(uint16_t) * num_indices);
-       for (i = 0; i < num_indices; i++)
+       for (i = 0; i < this->num_indices; i++)
        {
                this->indices[i] = data[i];
        }