Use mgf1_bitspender in ntru_poly_create_from_seed
[strongswan.git] / src / libstrongswan / plugins / ntru / ntru_poly.c
index 3f754f2..b5b3898 100644 (file)
@@ -16,8 +16,8 @@
  */
 
 #include "ntru_poly.h"
-#include "ntru_mgf1.h"
 
+#include <crypto/mgf1/mgf1_bitspender.h>
 #include <utils/debug.h>
 #include <utils/test.h>
 
@@ -239,11 +239,29 @@ METHOD(ntru_poly_t, destroy, void,
        free(this);
 }
 
-static void init_indices(private_ntru_poly_t *this, bool is_product_form,
-                                                uint32_t indices_len_p, uint32_t indices_len_m)
+/**
+ * Creates an empty ntru_poly_t object with space allocated for indices
+ */
+static private_ntru_poly_t* ntru_poly_create(uint16_t N, uint16_t q,
+                                                                                        uint32_t indices_len_p,
+                                                                                        uint32_t indices_len_m,
+                                                                                        bool is_product_form)
 {
+       private_ntru_poly_t *this;
        int n;
 
+       INIT(this,
+               .public = {
+                       .get_size = _get_size,
+                       .get_indices = _get_indices,
+                       .get_array = _get_array,
+                       .ring_mult = _ring_mult,
+                       .destroy = _destroy,
+               },
+               .N = N,
+               .q = q,
+       );
+
        if (is_product_form)
        {
                this->num_polynomials = 3;
@@ -265,6 +283,8 @@ static void init_indices(private_ntru_poly_t *this, bool is_product_form,
                this->num_indices = indices_len_p + indices_len_m;
        }
        this->indices = malloc(sizeof(uint16_t) * this->num_indices);
+
+       return this;
 }
 
 /*
@@ -277,33 +297,17 @@ ntru_poly_t *ntru_poly_create_from_seed(hash_algorithm_t alg, chunk_t seed,
                                                                                bool is_product_form)
 {
        private_ntru_poly_t *this;
-       size_t hash_len, octet_count = 0, i;
-       uint8_t octets[HASH_SIZE_SHA512], *used, num_left = 0, num_needed;
-       uint16_t index, limit, left = 0;
        int n, num_indices, index_i = 0;
-       ntru_mgf1_t *mgf1;
+       uint32_t index, limit;
+       uint8_t *used;
+       mgf1_bitspender_t *bitspender;
 
-       DBG2(DBG_LIB, "MGF1 is seeded with %u bytes", seed.len);
-       mgf1 = ntru_mgf1_create(alg, seed, TRUE);
-       if (!mgf1)
+       bitspender = mgf1_bitspender_create(alg, seed, TRUE);
+       if (!bitspender)
        {
            return NULL;
        }
-       i = hash_len = mgf1->get_hash_size(mgf1);
-
-       INIT(this,
-               .public = {
-                       .get_size = _get_size,
-                       .get_indices = _get_indices,
-                       .get_array = _get_array,
-                       .ring_mult = _ring_mult,
-                       .destroy = _destroy,
-               },
-               .N = N,
-               .q = q,
-       );
-
-       init_indices(this, is_product_form, indices_len_p, indices_len_m);
+       this = ntru_poly_create(N, q, indices_len_p, indices_len_m, is_product_form);
        used = malloc(N);
        limit = N * ((1 << c_bits) / N);
 
@@ -319,43 +323,13 @@ ntru_poly_t *ntru_poly_create_from_seed(hash_algorithm_t alg, chunk_t seed,
                        /* generate a random candidate index with a size of c_bits */           
                        do
                        {
-                               /* use any leftover bits first */
-                               index = num_left ? left << (c_bits - num_left) : 0;
-
-                               /* get the rest of the bits needed from new octets */
-                               num_needed = c_bits - num_left;
-
-                               while (num_needed)
+                               index = bitspender->get_bits(bitspender, c_bits);
+                               if (index == MGF1_BITSPENDER_ERROR)
                                {
-                                       if (i == hash_len)
-                                       {
-                                               /* get another block from MGF1 */
-                                               if (!mgf1->get_mask(mgf1, hash_len, octets))
-                                               {
-                                                       mgf1->destroy(mgf1);
-                                                       destroy(this);
-                                                       free(used);
-                                                       return NULL;
-                                               }
-                                               octet_count += hash_len;
-                                               i = 0;
-                                       }
-                                       left = octets[i++];
-
-                                       if (num_needed <= 8)
-                                       {
-                                               /* all bits needed to fill the index are in this octet */
-                                               index |= left >> (8 - num_needed);
-                                               num_left = 8 - num_needed;
-                                               num_needed = 0;
-                                               left &= 0xff >> (8 - num_left);
-                                       }
-                                       else
-                                       {
-                                               /* more than one octet will be needed */
-                                               index |= left << (num_needed - 8);
-                                               num_needed -= 8;
-                                       }
+                                       bitspender->destroy(bitspender);
+                                       destroy(this);
+                                       free(used);
+                                       return NULL;
                                }
                        }
                        while (index >= limit);
@@ -371,9 +345,7 @@ ntru_poly_t *ntru_poly_create_from_seed(hash_algorithm_t alg, chunk_t seed,
                }
        }
 
-       DBG2(DBG_LIB, "MGF1 generates %u octets to derive %u indices",
-                                  octet_count, this->num_indices);
-       mgf1->destroy(mgf1);
+       bitspender->destroy(bitspender);
        free(used);
 
        return &this->public;
@@ -390,19 +362,8 @@ ntru_poly_t *ntru_poly_create_from_data(uint16_t *data, uint16_t N, uint16_t q,
        private_ntru_poly_t *this;
        int i;
 
-       INIT(this,
-               .public = {
-                       .get_size = _get_size,
-                       .get_indices = _get_indices,
-                       .get_array = _get_array,
-                       .ring_mult = _ring_mult,
-                       .destroy = _destroy,
-               },
-               .N = N,
-               .q = q,
-       );
+       this = ntru_poly_create(N, q, indices_len_p, indices_len_m, is_product_form);
 
-       init_indices(this, is_product_form, indices_len_p, indices_len_m);
        for (i = 0; i < this->num_indices; i++)
        {
                this->indices[i] = data[i];