Optimized initialisation of indices
[strongswan.git] / src / libstrongswan / plugins / ntru / ntru_poly.c
1 /*
2 * Copyright (C) 2014 Andreas Steffen
3 * HSR Hochschule fuer Technik Rapperswil
4 *
5 * Copyright (C) 2009-2013 Security Innovation
6 *
7 * This program is free software; you can redistribute it and/or modify it
8 * under the terms of the GNU General Public License as published by the
9 * Free Software Foundation; either version 2 of the License, or (at your
10 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
11 *
12 * This program is distributed in the hope that it will be useful, but
13 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
14 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
15 * for more details.
16 */
17
18 #include "ntru_poly.h"
19 #include "ntru_mgf1.h"
20
21 #include <utils/debug.h>
22 #include <utils/test.h>
23
24 typedef struct private_ntru_poly_t private_ntru_poly_t;
25 typedef struct indices_len_t indices_len_t;
26
27 /**
28 * Stores number of +1 and -1 coefficients
29 */
30 struct indices_len_t {
31 int p;
32 int m;
33 };
34
35 /**
36 * Private data of an ntru_poly_t object.
37 */
38 struct private_ntru_poly_t {
39
40 /**
41 * Public ntru_poly_t interface.
42 */
43 ntru_poly_t public;
44
45 /**
46 * Ring dimension equal to the number of polynomial coefficients
47 */
48 uint16_t N;
49
50 /**
51 * Large modulus
52 */
53 uint16_t q;
54
55 /**
56 * Array containing the indices of the non-zero coefficients
57 */
58 uint16_t *indices;
59
60 /**
61 * Number of indices of the non-zero coefficients
62 */
63 size_t num_indices;
64
65 /**
66 * Number of sparse polynomials
67 */
68 int num_polynomials;
69
70 /**
71 * Number of nonzero coefficients for up to 3 sparse polynomials
72 */
73 indices_len_t indices_len[3];
74
75 };
76
77 METHOD(ntru_poly_t, get_size, size_t,
78 private_ntru_poly_t *this)
79 {
80 return this->num_indices;
81 }
82
83 METHOD(ntru_poly_t, get_indices, uint16_t*,
84 private_ntru_poly_t *this)
85 {
86 return this->indices;
87 }
88
89 /**
90 * Multiplication of polynomial a with a sparse polynomial b given by
91 * the indices of its +1 and -1 coefficients results in polynomial c.
92 * This is a convolution operation
93 */
94 static void ring_mult_i(uint16_t *a, indices_len_t len, uint16_t *indices,
95 uint16_t N, uint16_t mod_q_mask, uint16_t *t,
96 uint16_t *c)
97 {
98 int i, j, k;
99
100 /* initialize temporary array t */
101 for (k = 0; k < N; k++)
102 {
103 t[k] = 0;
104 }
105
106 /* t[(i+k)%N] = sum i=0 through N-1 of a[i], for b[k] = -1 */
107 for (j = len.p; j < len.p + len.m; j++)
108 {
109 k = indices[j];
110 for (i = 0; k < N; ++i, ++k)
111 {
112 t[k] += a[i];
113 }
114 for (k = 0; i < N; ++i, ++k)
115 {
116 t[k] += a[i];
117 }
118 }
119
120 /* t[(i+k)%N] = -(sum i=0 through N-1 of a[i] for b[k] = -1) */
121 for (k = 0; k < N; k++)
122 {
123 t[k] = -t[k];
124 }
125
126 /* t[(i+k)%N] += sum i=0 through N-1 of a[i] for b[k] = +1 */
127 for (j = 0; j < len.p; j++)
128 {
129 k = indices[j];
130 for (i = 0; k < N; ++i, ++k)
131 {
132 t[k] += a[i];
133 }
134 for (k = 0; i < N; ++i, ++k)
135 {
136 t[k] += a[i];
137 }
138 }
139
140 /* c = (a * b) mod q */
141 for (k = 0; k < N; k++)
142 {
143 c[k] = t[k] & mod_q_mask;
144 }
145 }
146
147 METHOD(ntru_poly_t, get_array, void,
148 private_ntru_poly_t *this, uint16_t *array)
149 {
150 uint16_t *t, *bi;
151 uint16_t mod_q_mask = this->q - 1;
152 indices_len_t len;
153 int i;
154
155 /* form polynomial F or F1 */
156 memset(array, 0x00, this->N * sizeof(uint16_t));
157 bi = this->indices;
158 len = this->indices_len[0];
159 for (i = 0; i < len.p + len.m; i++)
160 {
161 array[bi[i]] = (i < len.p) ? 1 : mod_q_mask;
162 }
163
164 if (this->num_polynomials == 3)
165 {
166 /* allocate temporary array t */
167 t = malloc(this->N * sizeof(uint16_t));
168
169 /* form F1 * F2 */
170 bi += len.p + len.m;
171 len = this->indices_len[1];
172 ring_mult_i(array, len, bi, this->N, mod_q_mask, t, array);
173
174 /* form (F1 * F2) + F3 */
175 bi += len.p + len.m;
176 len = this->indices_len[2];
177 for (i = 0; i < len.p + len.m; i++)
178 {
179 if (i < len.p)
180 {
181 array[bi[i]] += 1;
182 }
183 else
184 {
185 array[bi[i]] -= 1;
186 }
187 array[bi[i]] &= mod_q_mask;
188 }
189 free(t);
190 }
191 }
192
193 METHOD(ntru_poly_t, ring_mult, void,
194 private_ntru_poly_t *this, uint16_t *a, uint16_t *c)
195 {
196 uint16_t *t1, *t2;
197 uint16_t *bi = this->indices;
198 uint16_t mod_q_mask = this->q - 1;
199 int i;
200
201 /* allocate temporary array t1 */
202 t1 = malloc(this->N * sizeof(uint16_t));
203
204 if (this->num_polynomials == 1)
205 {
206 ring_mult_i(a, this->indices_len[0], bi, this->N, mod_q_mask, t1, c);
207 }
208 else
209 {
210 /* allocate temporary array t2 */
211 t2 = malloc(this->N * sizeof(uint16_t));
212
213 /* t1 = a * b1 */
214 ring_mult_i(a, this->indices_len[0], bi, this->N, mod_q_mask, t1, t1);
215
216 /* t1 = (a * b1) * b2 */
217 bi += this->indices_len[0].p + this->indices_len[0].m;
218 ring_mult_i(t1, this->indices_len[1], bi, this->N, mod_q_mask, t2, t1);
219
220 /* t2 = a * b3 */
221 bi += this->indices_len[1].p + this->indices_len[1].m;
222 ring_mult_i(a, this->indices_len[2], bi, this->N, mod_q_mask, t2, t2);
223
224 /* c = (a * b1 * b2) + (a * b3) */
225 for (i = 0; i < this->N; i++)
226 {
227 c[i] = (t1[i] + t2[i]) & mod_q_mask;
228 }
229 free(t2);
230 }
231 free(t1);
232 }
233
234 METHOD(ntru_poly_t, destroy, void,
235 private_ntru_poly_t *this)
236 {
237 memwipe(this->indices, sizeof(uint16_t) * get_size(this));
238 free(this->indices);
239 free(this);
240 }
241
242 static void init_indices(private_ntru_poly_t *this, bool is_product_form,
243 uint32_t indices_len_p, uint32_t indices_len_m)
244 {
245 int n;
246
247 if (is_product_form)
248 {
249 this->num_polynomials = 3;
250 for (n = 0; n < 3; n++)
251 {
252 this->indices_len[n].p = 0xff & indices_len_p;
253 this->indices_len[n].m = 0xff & indices_len_m;
254 this->num_indices += this->indices_len[n].p +
255 this->indices_len[n].m;
256 indices_len_p >>= 8;
257 indices_len_m >>= 8;
258 }
259 }
260 else
261 {
262 this->num_polynomials = 1;
263 this->indices_len[0].p = indices_len_p;
264 this->indices_len[0].m = indices_len_m;
265 this->num_indices = indices_len_p + indices_len_m;
266 }
267 this->indices = malloc(sizeof(uint16_t) * this->num_indices);
268 }
269
270 /*
271 * Described in header.
272 */
273 ntru_poly_t *ntru_poly_create_from_seed(hash_algorithm_t alg, chunk_t seed,
274 uint8_t c_bits, uint16_t N, uint16_t q,
275 uint32_t indices_len_p,
276 uint32_t indices_len_m,
277 bool is_product_form)
278 {
279 private_ntru_poly_t *this;
280 size_t hash_len, octet_count = 0, i;
281 uint8_t octets[HASH_SIZE_SHA512], *used, num_left = 0, num_needed;
282 uint16_t index, limit, left = 0;
283 int n, num_indices, index_i = 0;
284 ntru_mgf1_t *mgf1;
285
286 DBG2(DBG_LIB, "MGF1 is seeded with %u bytes", seed.len);
287 mgf1 = ntru_mgf1_create(alg, seed, TRUE);
288 if (!mgf1)
289 {
290 return NULL;
291 }
292 i = hash_len = mgf1->get_hash_size(mgf1);
293
294 INIT(this,
295 .public = {
296 .get_size = _get_size,
297 .get_indices = _get_indices,
298 .get_array = _get_array,
299 .ring_mult = _ring_mult,
300 .destroy = _destroy,
301 },
302 .N = N,
303 .q = q,
304 );
305
306 init_indices(this, is_product_form, indices_len_p, indices_len_m);
307 used = malloc(N);
308 limit = N * ((1 << c_bits) / N);
309
310 /* generate indices for all polynomials */
311 for (n = 0; n < this->num_polynomials; n++)
312 {
313 memset(used, 0, N);
314 num_indices = this->indices_len[n].p + this->indices_len[n].m;
315
316 /* generate indices for a single polynomial */
317 while (num_indices)
318 {
319 /* generate a random candidate index with a size of c_bits */
320 do
321 {
322 /* use any leftover bits first */
323 index = num_left ? left << (c_bits - num_left) : 0;
324
325 /* get the rest of the bits needed from new octets */
326 num_needed = c_bits - num_left;
327
328 while (num_needed)
329 {
330 if (i == hash_len)
331 {
332 /* get another block from MGF1 */
333 if (!mgf1->get_mask(mgf1, hash_len, octets))
334 {
335 mgf1->destroy(mgf1);
336 destroy(this);
337 free(used);
338 return NULL;
339 }
340 octet_count += hash_len;
341 i = 0;
342 }
343 left = octets[i++];
344
345 if (num_needed <= 8)
346 {
347 /* all bits needed to fill the index are in this octet */
348 index |= left >> (8 - num_needed);
349 num_left = 8 - num_needed;
350 num_needed = 0;
351 left &= 0xff >> (8 - num_left);
352 }
353 else
354 {
355 /* more than one octet will be needed */
356 index |= left << (num_needed - 8);
357 num_needed -= 8;
358 }
359 }
360 }
361 while (index >= limit);
362
363 /* form index and check if unique */
364 index %= N;
365 if (!used[index])
366 {
367 used[index] = 1;
368 this->indices[index_i++] = index;
369 num_indices--;
370 }
371 }
372 }
373
374 DBG2(DBG_LIB, "MGF1 generates %u octets to derive %u indices",
375 octet_count, this->num_indices);
376 mgf1->destroy(mgf1);
377 free(used);
378
379 return &this->public;
380 }
381
382 /*
383 * Described in header.
384 */
385 ntru_poly_t *ntru_poly_create_from_data(uint16_t *data, uint16_t N, uint16_t q,
386 uint32_t indices_len_p,
387 uint32_t indices_len_m,
388 bool is_product_form)
389 {
390 private_ntru_poly_t *this;
391 int i;
392
393 INIT(this,
394 .public = {
395 .get_size = _get_size,
396 .get_indices = _get_indices,
397 .get_array = _get_array,
398 .ring_mult = _ring_mult,
399 .destroy = _destroy,
400 },
401 .N = N,
402 .q = q,
403 );
404
405 init_indices(this, is_product_form, indices_len_p, indices_len_m);
406 for (i = 0; i < this->num_indices; i++)
407 {
408 this->indices[i] = data[i];
409 }
410
411 return &this->public;
412 }
413
414 EXPORT_FUNCTION_FOR_TESTS(ntru, ntru_poly_create_from_seed);
415
416 EXPORT_FUNCTION_FOR_TESTS(ntru, ntru_poly_create_from_data);