gmp: Fix compatibility with older libgmp releases
[strongswan.git] / src / libstrongswan / plugins / gmp / gmp_rsa_private_key.c
1 /*
2 * Copyright (C) 2017 Tobias Brunner
3 * Copyright (C) 2005 Jan Hutter
4 * Copyright (C) 2005-2009 Martin Willi
5 * Copyright (C) 2012 Andreas Steffen
6 * HSR Hochschule fuer Technik Rapperswil
7 *
8 * This program is free software; you can redistribute it and/or modify it
9 * under the terms of the GNU General Public License as published by the
10 * Free Software Foundation; either version 2 of the License, or (at your
11 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
12 *
13 * This program is distributed in the hope that it will be useful, but
14 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
15 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
16 * for more details.
17 */
18
19 #include <gmp.h>
20 #include <sys/stat.h>
21 #include <unistd.h>
22 #include <string.h>
23
24 #include "gmp_rsa_private_key.h"
25 #include "gmp_rsa_public_key.h"
26
27 #include <utils/debug.h>
28 #include <asn1/oid.h>
29 #include <asn1/asn1.h>
30 #include <asn1/asn1_parser.h>
31 #include <credentials/keys/signature_params.h>
32
33 #ifdef HAVE_MPZ_POWM_SEC
34 # undef mpz_powm
35 # define mpz_powm mpz_powm_sec
36 #endif
37
38 /**
39 * Public exponent to use for key generation.
40 */
41 #define PUBLIC_EXPONENT 0x10001
42
43 typedef struct private_gmp_rsa_private_key_t private_gmp_rsa_private_key_t;
44
45 /**
46 * Private data of a gmp_rsa_private_key_t object.
47 */
48 struct private_gmp_rsa_private_key_t {
49 /**
50 * Public interface for this signer.
51 */
52 gmp_rsa_private_key_t public;
53
54 /**
55 * Public modulus.
56 */
57 mpz_t n;
58
59 /**
60 * Public exponent.
61 */
62 mpz_t e;
63
64 /**
65 * Private prime 1.
66 */
67 mpz_t p;
68
69 /**
70 * Private Prime 2.
71 */
72 mpz_t q;
73
74 /**
75 * Carmichael function m = lambda(n) = lcm(p-1,q-1).
76 */
77 mpz_t m;
78
79 /**
80 * Private exponent and optional secret sharing polynomial coefficients.
81 */
82 mpz_t *d;
83
84 /**
85 * Private exponent 1.
86 */
87 mpz_t exp1;
88
89 /**
90 * Private exponent 2.
91 */
92 mpz_t exp2;
93
94 /**
95 * Private coefficient.
96 */
97 mpz_t coeff;
98
99 /**
100 * Total number of private key shares
101 */
102 u_int shares;
103
104 /**
105 * Secret sharing threshold
106 */
107 u_int threshold;
108
109 /**
110 * Optional verification key (threshold > 1).
111 */
112 mpz_t v;
113
114 /**
115 * Keysize in bytes.
116 */
117 size_t k;
118
119 /**
120 * reference count
121 */
122 refcount_t ref;
123 };
124
125 /**
126 * Convert a MP integer into a chunk_t
127 */
128 chunk_t gmp_mpz_to_chunk(const mpz_t value)
129 {
130 chunk_t n;
131
132 n.len = 1 + mpz_sizeinbase(value, 2) / BITS_PER_BYTE;
133 n.ptr = mpz_export(NULL, NULL, 1, n.len, 1, 0, value);
134 if (n.ptr == NULL)
135 { /* if we have zero in "value", gmp returns NULL */
136 n.len = 0;
137 }
138 return n;
139 }
140
141 /**
142 * Auxiliary function overwriting private key material with zero bytes
143 */
144 static void mpz_clear_sensitive(mpz_t z)
145 {
146 size_t len = mpz_size(z) * GMP_LIMB_BITS / BITS_PER_BYTE;
147 uint8_t *zeros = alloca(len);
148
149 memset(zeros, 0, len);
150 /* overwrite mpz_t with zero bytes before clearing it */
151 mpz_import(z, len, 1, 1, 1, 0, zeros);
152 mpz_clear(z);
153 }
154
155 /**
156 * Create a mpz prime of at least prime_size
157 */
158 static status_t compute_prime(size_t prime_size, bool safe, mpz_t *p, mpz_t *q)
159 {
160 rng_t *rng;
161 chunk_t random_bytes;
162 int count = 0;
163
164 rng = lib->crypto->create_rng(lib->crypto, RNG_TRUE);
165 if (!rng)
166 {
167 DBG1(DBG_LIB, "no RNG of quality %N found", rng_quality_names,
168 RNG_TRUE);
169 return FAILED;
170 }
171
172 mpz_init(*p);
173 mpz_init(*q);
174
175 do
176 {
177 if (!rng->allocate_bytes(rng, prime_size, &random_bytes))
178 {
179 DBG1(DBG_LIB, "failed to allocate random prime");
180 mpz_clear(*p);
181 mpz_clear(*q);
182 rng->destroy(rng);
183 return FAILED;
184 }
185
186 /* make sure the two most significant bits are set */
187 if (safe)
188 {
189 random_bytes.ptr[0] &= 0x7F;
190 random_bytes.ptr[0] |= 0x60;
191 mpz_import(*q, random_bytes.len, 1, 1, 1, 0, random_bytes.ptr);
192 do
193 {
194 count++;
195 mpz_nextprime (*q, *q);
196 mpz_mul_ui(*p, *q, 2);
197 mpz_add_ui(*p, *p, 1);
198 }
199 while (mpz_probab_prime_p(*p, 10) == 0);
200 DBG2(DBG_LIB, "safe prime found after %d iterations", count);
201 }
202 else
203 {
204 random_bytes.ptr[0] |= 0xC0;
205 mpz_import(*p, random_bytes.len, 1, 1, 1, 0, random_bytes.ptr);
206 mpz_nextprime (*p, *p);
207 }
208 chunk_clear(&random_bytes);
209 }
210
211 /* check if the prime isn't too large */
212 while (((mpz_sizeinbase(*p, 2) + 7) / 8) > prime_size);
213
214 rng->destroy(rng);
215
216 /* additionally return p-1 */
217 mpz_sub_ui(*q, *p, 1);
218
219 return SUCCESS;
220 }
221
222 /**
223 * PKCS#1 RSADP function
224 */
225 static chunk_t rsadp(private_gmp_rsa_private_key_t *this, chunk_t data)
226 {
227 mpz_t t1, t2;
228 chunk_t decrypted;
229
230 mpz_init(t1);
231 mpz_init(t2);
232
233 mpz_import(t1, data.len, 1, 1, 1, 0, data.ptr);
234
235 mpz_powm(t2, t1, this->exp1, this->p); /* m1 = c^dP mod p */
236 mpz_powm(t1, t1, this->exp2, this->q); /* m2 = c^dQ mod Q */
237 mpz_sub(t2, t2, t1); /* h = qInv (m1 - m2) mod p */
238 mpz_mod(t2, t2, this->p);
239 mpz_mul(t2, t2, this->coeff);
240 mpz_mod(t2, t2, this->p);
241
242 mpz_mul(t2, t2, this->q); /* m = m2 + h q */
243 mpz_add(t1, t1, t2);
244
245 decrypted.len = this->k;
246 decrypted.ptr = mpz_export(NULL, NULL, 1, decrypted.len, 1, 0, t1);
247 if (decrypted.ptr == NULL)
248 {
249 decrypted.len = 0;
250 }
251
252 mpz_clear_sensitive(t1);
253 mpz_clear_sensitive(t2);
254
255 return decrypted;
256 }
257
258 /**
259 * PKCS#1 RSASP1 function
260 */
261 static chunk_t rsasp1(private_gmp_rsa_private_key_t *this, chunk_t data)
262 {
263 return rsadp(this, data);
264 }
265
266 /**
267 * Build a signature using the PKCS#1 EMSA scheme
268 */
269 static bool build_emsa_pkcs1_signature(private_gmp_rsa_private_key_t *this,
270 hash_algorithm_t hash_algorithm,
271 chunk_t data, chunk_t *signature)
272 {
273 chunk_t digestInfo = chunk_empty;
274 chunk_t em;
275
276 if (hash_algorithm != HASH_UNKNOWN)
277 {
278 hasher_t *hasher;
279 chunk_t hash;
280 int hash_oid = hasher_algorithm_to_oid(hash_algorithm);
281
282 if (hash_oid == OID_UNKNOWN)
283 {
284 return FALSE;
285 }
286
287 hasher = lib->crypto->create_hasher(lib->crypto, hash_algorithm);
288 if (!hasher || !hasher->allocate_hash(hasher, data, &hash))
289 {
290 DESTROY_IF(hasher);
291 return FALSE;
292 }
293 hasher->destroy(hasher);
294
295 /* build DER-encoded digestInfo */
296 digestInfo = asn1_wrap(ASN1_SEQUENCE, "mm",
297 asn1_algorithmIdentifier(hash_oid),
298 asn1_simple_object(ASN1_OCTET_STRING, hash)
299 );
300 chunk_free(&hash);
301 data = digestInfo;
302 }
303
304 if (data.len > this->k - 3)
305 {
306 free(digestInfo.ptr);
307 DBG1(DBG_LIB, "unable to sign %d bytes using a %dbit key", data.len,
308 mpz_sizeinbase(this->n, 2));
309 return FALSE;
310 }
311
312 /* build chunk to rsa-decrypt:
313 * EM = 0x00 || 0x01 || PS || 0x00 || T.
314 * PS = 0xFF padding, with length to fill em
315 * T = encoded_hash
316 */
317 em.len = this->k;
318 em.ptr = malloc(em.len);
319
320 /* fill em with padding */
321 memset(em.ptr, 0xFF, em.len);
322 /* set magic bytes */
323 *(em.ptr) = 0x00;
324 *(em.ptr+1) = 0x01;
325 *(em.ptr + em.len - data.len - 1) = 0x00;
326 /* set DER-encoded hash */
327 memcpy(em.ptr + em.len - data.len, data.ptr, data.len);
328
329 /* build signature */
330 *signature = rsasp1(this, em);
331
332 free(digestInfo.ptr);
333 free(em.ptr);
334
335 return TRUE;
336 }
337
338 /**
339 * Build a signature using the PKCS#1 EMSA PSS scheme
340 */
341 static bool build_emsa_pss_signature(private_gmp_rsa_private_key_t *this,
342 rsa_pss_params_t *params, chunk_t data,
343 chunk_t *signature)
344 {
345 ext_out_function_t xof;
346 hasher_t *hasher = NULL;
347 rng_t *rng = NULL;
348 xof_t *mgf = NULL;
349 chunk_t hash, salt = chunk_empty, m, ps, db, dbmask, em;
350 size_t embits, emlen, maskbits;
351 bool success = FALSE;
352
353 if (!params)
354 {
355 return FALSE;
356 }
357 xof = xof_mgf1_from_hash_algorithm(params->mgf1_hash);
358 if (xof == XOF_UNDEFINED)
359 {
360 DBG1(DBG_LIB, "%N is not supported for MGF1", hash_algorithm_names,
361 params->mgf1_hash);
362 return FALSE;
363 }
364 /* emBits = modBits - 1 */
365 embits = mpz_sizeinbase(this->n, 2) - 1;
366 /* emLen = ceil(emBits/8) */
367 emlen = (embits + 7) / BITS_PER_BYTE;
368 /* mHash = Hash(M) */
369 hasher = lib->crypto->create_hasher(lib->crypto, params->hash);
370 if (!hasher)
371 {
372 DBG1(DBG_LIB, "hash algorithm %N not supported",
373 hash_algorithm_names, params->hash);
374 return FALSE;
375 }
376 hash = chunk_alloca(hasher->get_hash_size(hasher));
377 if (!hasher->get_hash(hasher, data, hash.ptr))
378 {
379 goto error;
380 }
381
382 salt.len = hash.len;
383 if (params->salt.len)
384 {
385 salt = params->salt;
386 }
387 else if (params->salt_len > RSA_PSS_SALT_LEN_DEFAULT)
388 {
389 salt.len = params->salt_len;
390 }
391 if (emlen < (hash.len + salt.len + 2))
392 { /* too long */
393 goto error;
394 }
395 if (salt.len && !params->salt.len)
396 {
397 salt = chunk_alloca(salt.len);
398 rng = lib->crypto->create_rng(lib->crypto, RNG_STRONG);
399 if (!rng || !rng->get_bytes(rng, salt.len, salt.ptr))
400 {
401 goto error;
402 }
403 }
404 /* M' = 0x0000000000000000 | mHash | salt */
405 m = chunk_cata("ccc",
406 chunk_from_chars(0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00),
407 hash, salt);
408 /* H = Hash(M') */
409 if (!hasher->get_hash(hasher, m, hash.ptr))
410 {
411 goto error;
412 }
413 /* PS = 00...<padding depending on hash and salt length> */
414 ps = chunk_alloca(emlen - salt.len - hash.len - 2);
415 memset(ps.ptr, 0, ps.len);
416 /* DB = PS | 0x01 | salt */
417 db = chunk_cata("ccc", ps, chunk_from_chars(0x01), salt);
418 /* dbMask = MGF(H, emLen - hLen - 1) */
419 mgf = lib->crypto->create_xof(lib->crypto, xof);
420 dbmask = chunk_alloca(db.len);
421 if (!mgf)
422 {
423 DBG1(DBG_LIB, "%N not supported", ext_out_function_names, xof);
424 goto error;
425 }
426 if (!mgf->set_seed(mgf, hash) ||
427 !mgf->get_bytes(mgf, dbmask.len, dbmask.ptr))
428 {
429 goto error;
430 }
431 /* maskedDB = DB xor dbMask */
432 memxor(db.ptr, dbmask.ptr, db.len);
433 /* zero out unused bits */
434 maskbits = (8 * emlen) - embits;
435 if (maskbits)
436 {
437 db.ptr[0] &= (0xff >> maskbits);
438 }
439 /* EM = maskedDB | H | 0xbc */
440 em = chunk_cata("ccc", db, hash, chunk_from_chars(0xbc));
441 /* S = RSASP1(K, EM) */
442 *signature = rsasp1(this, em);
443 success = TRUE;
444
445 error:
446 DESTROY_IF(hasher);
447 DESTROY_IF(rng);
448 DESTROY_IF(mgf);
449 return success;
450 }
451
452 METHOD(private_key_t, get_type, key_type_t,
453 private_gmp_rsa_private_key_t *this)
454 {
455 return KEY_RSA;
456 }
457
458 METHOD(private_key_t, sign, bool,
459 private_gmp_rsa_private_key_t *this, signature_scheme_t scheme,
460 void *params, chunk_t data, chunk_t *signature)
461 {
462 switch (scheme)
463 {
464 case SIGN_RSA_EMSA_PKCS1_NULL:
465 return build_emsa_pkcs1_signature(this, HASH_UNKNOWN, data, signature);
466 case SIGN_RSA_EMSA_PKCS1_SHA2_224:
467 return build_emsa_pkcs1_signature(this, HASH_SHA224, data, signature);
468 case SIGN_RSA_EMSA_PKCS1_SHA2_256:
469 return build_emsa_pkcs1_signature(this, HASH_SHA256, data, signature);
470 case SIGN_RSA_EMSA_PKCS1_SHA2_384:
471 return build_emsa_pkcs1_signature(this, HASH_SHA384, data, signature);
472 case SIGN_RSA_EMSA_PKCS1_SHA2_512:
473 return build_emsa_pkcs1_signature(this, HASH_SHA512, data, signature);
474 case SIGN_RSA_EMSA_PKCS1_SHA3_224:
475 return build_emsa_pkcs1_signature(this, HASH_SHA3_224, data, signature);
476 case SIGN_RSA_EMSA_PKCS1_SHA3_256:
477 return build_emsa_pkcs1_signature(this, HASH_SHA3_256, data, signature);
478 case SIGN_RSA_EMSA_PKCS1_SHA3_384:
479 return build_emsa_pkcs1_signature(this, HASH_SHA3_384, data, signature);
480 case SIGN_RSA_EMSA_PKCS1_SHA3_512:
481 return build_emsa_pkcs1_signature(this, HASH_SHA3_512, data, signature);
482 case SIGN_RSA_EMSA_PKCS1_SHA1:
483 return build_emsa_pkcs1_signature(this, HASH_SHA1, data, signature);
484 case SIGN_RSA_EMSA_PKCS1_MD5:
485 return build_emsa_pkcs1_signature(this, HASH_MD5, data, signature);
486 case SIGN_RSA_EMSA_PSS:
487 return build_emsa_pss_signature(this, params, data, signature);
488 default:
489 DBG1(DBG_LIB, "signature scheme %N not supported in RSA",
490 signature_scheme_names, scheme);
491 return FALSE;
492 }
493 }
494
495 METHOD(private_key_t, decrypt, bool,
496 private_gmp_rsa_private_key_t *this, encryption_scheme_t scheme,
497 chunk_t crypto, chunk_t *plain)
498 {
499 chunk_t em, stripped;
500 bool success = FALSE;
501
502 if (scheme != ENCRYPT_RSA_PKCS1)
503 {
504 DBG1(DBG_LIB, "encryption scheme %N not supported",
505 encryption_scheme_names, scheme);
506 return FALSE;
507 }
508 /* rsa decryption using PKCS#1 RSADP */
509 stripped = em = rsadp(this, crypto);
510
511 /* PKCS#1 v1.5 8.1 encryption-block formatting (EB = 00 || 02 || PS || 00 || D) */
512
513 /* check for hex pattern 00 02 in decrypted message */
514 if ((*stripped.ptr++ != 0x00) || (*(stripped.ptr++) != 0x02))
515 {
516 DBG1(DBG_LIB, "incorrect padding - probably wrong rsa key");
517 goto end;
518 }
519 stripped.len -= 2;
520
521 /* the plaintext data starts after first 0x00 byte */
522 while (stripped.len-- > 0 && *stripped.ptr++ != 0x00)
523
524 if (stripped.len == 0)
525 {
526 DBG1(DBG_LIB, "no plaintext data");
527 goto end;
528 }
529
530 *plain = chunk_clone(stripped);
531 success = TRUE;
532
533 end:
534 chunk_clear(&em);
535 return success;
536 }
537
538 METHOD(private_key_t, get_keysize, int,
539 private_gmp_rsa_private_key_t *this)
540 {
541 return mpz_sizeinbase(this->n, 2);
542 }
543
544 METHOD(private_key_t, get_public_key, public_key_t*,
545 private_gmp_rsa_private_key_t *this)
546 {
547 chunk_t n, e;
548 public_key_t *public;
549
550 n = gmp_mpz_to_chunk(this->n);
551 e = gmp_mpz_to_chunk(this->e);
552
553 public = lib->creds->create(lib->creds, CRED_PUBLIC_KEY, KEY_RSA,
554 BUILD_RSA_MODULUS, n, BUILD_RSA_PUB_EXP, e, BUILD_END);
555 chunk_free(&n);
556 chunk_free(&e);
557
558 return public;
559 }
560
561 METHOD(private_key_t, get_encoding, bool,
562 private_gmp_rsa_private_key_t *this, cred_encoding_type_t type,
563 chunk_t *encoding)
564 {
565 chunk_t n, e, d, p, q, exp1, exp2, coeff;
566 bool success;
567
568 n = gmp_mpz_to_chunk(this->n);
569 e = gmp_mpz_to_chunk(this->e);
570 d = gmp_mpz_to_chunk(*this->d);
571 p = gmp_mpz_to_chunk(this->p);
572 q = gmp_mpz_to_chunk(this->q);
573 exp1 = gmp_mpz_to_chunk(this->exp1);
574 exp2 = gmp_mpz_to_chunk(this->exp2);
575 coeff = gmp_mpz_to_chunk(this->coeff);
576
577 success = lib->encoding->encode(lib->encoding,
578 type, NULL, encoding, CRED_PART_RSA_MODULUS, n,
579 CRED_PART_RSA_PUB_EXP, e, CRED_PART_RSA_PRIV_EXP, d,
580 CRED_PART_RSA_PRIME1, p, CRED_PART_RSA_PRIME2, q,
581 CRED_PART_RSA_EXP1, exp1, CRED_PART_RSA_EXP2, exp2,
582 CRED_PART_RSA_COEFF, coeff, CRED_PART_END);
583 chunk_free(&n);
584 chunk_free(&e);
585 chunk_clear(&d);
586 chunk_clear(&p);
587 chunk_clear(&q);
588 chunk_clear(&exp1);
589 chunk_clear(&exp2);
590 chunk_clear(&coeff);
591
592 return success;
593 }
594
595 METHOD(private_key_t, get_fingerprint, bool,
596 private_gmp_rsa_private_key_t *this, cred_encoding_type_t type, chunk_t *fp)
597 {
598 chunk_t n, e;
599 bool success;
600
601 if (lib->encoding->get_cache(lib->encoding, type, this, fp))
602 {
603 return TRUE;
604 }
605 n = gmp_mpz_to_chunk(this->n);
606 e = gmp_mpz_to_chunk(this->e);
607
608 success = lib->encoding->encode(lib->encoding, type, this, fp,
609 CRED_PART_RSA_MODULUS, n, CRED_PART_RSA_PUB_EXP, e, CRED_PART_END);
610 chunk_free(&n);
611 chunk_free(&e);
612
613 return success;
614 }
615
616 METHOD(private_key_t, get_ref, private_key_t*,
617 private_gmp_rsa_private_key_t *this)
618 {
619 ref_get(&this->ref);
620 return &this->public.key;
621 }
622
623 METHOD(private_key_t, destroy, void,
624 private_gmp_rsa_private_key_t *this)
625 {
626 if (ref_put(&this->ref))
627 {
628 int i;
629
630 mpz_clear(this->n);
631 mpz_clear(this->e);
632 mpz_clear(this->v);
633 mpz_clear_sensitive(this->p);
634 mpz_clear_sensitive(this->q);
635 mpz_clear_sensitive(this->m);
636 mpz_clear_sensitive(this->exp1);
637 mpz_clear_sensitive(this->exp2);
638 mpz_clear_sensitive(this->coeff);
639
640 for (i = 0; i < this->threshold; i++)
641 {
642 mpz_clear_sensitive(*this->d + i);
643 }
644 free(this->d);
645
646 lib->encoding->clear_cache(lib->encoding, this);
647 free(this);
648 }
649 }
650
651 /**
652 * Check the loaded key if it is valid and usable
653 */
654 static status_t check(private_gmp_rsa_private_key_t *this)
655 {
656 mpz_t u, p1, q1;
657 status_t status = SUCCESS;
658
659 /* PKCS#1 1.5 section 6 requires modulus to have at least 12 octets.
660 * We actually require more (for security).
661 */
662 if (this->k < 512 / BITS_PER_BYTE)
663 {
664 DBG1(DBG_LIB, "key shorter than 512 bits");
665 return FAILED;
666 }
667
668 /* we picked a max modulus size to simplify buffer allocation */
669 if (this->k > 8192 / BITS_PER_BYTE)
670 {
671 DBG1(DBG_LIB, "key larger than 8192 bits");
672 return FAILED;
673 }
674
675 mpz_init(u);
676 mpz_init(p1);
677 mpz_init(q1);
678
679 /* precompute p1 = p-1 and q1 = q-1 */
680 mpz_sub_ui(p1, this->p, 1);
681 mpz_sub_ui(q1, this->q, 1);
682
683 /* check that n == p * q */
684 mpz_mul(u, this->p, this->q);
685 if (mpz_cmp(u, this->n) != 0)
686 {
687 status = FAILED;
688 }
689
690 /* check that e divides neither p-1 nor q-1 */
691 mpz_mod(u, p1, this->e);
692 if (mpz_cmp_ui(u, 0) == 0)
693 {
694 status = FAILED;
695 }
696
697 mpz_mod(u, q1, this->e);
698 if (mpz_cmp_ui(u, 0) == 0)
699 {
700 status = FAILED;
701 }
702
703 /* check that d is e^-1 (mod lcm(p-1, q-1)) */
704 /* see PKCS#1v2, aka RFC 2437, for the "lcm" */
705 mpz_lcm(this->m, p1, q1);
706 mpz_mul(u, *this->d, this->e);
707 mpz_mod(u, u, this->m);
708 if (mpz_cmp_ui(u, 1) != 0)
709 {
710 status = FAILED;
711 }
712
713 /* check that exp1 is d mod (p-1) */
714 mpz_mod(u, *this->d, p1);
715 if (mpz_cmp(u, this->exp1) != 0)
716 {
717 status = FAILED;
718 }
719
720 /* check that exp2 is d mod (q-1) */
721 mpz_mod(u, *this->d, q1);
722 if (mpz_cmp(u, this->exp2) != 0)
723 {
724 status = FAILED;
725 }
726
727 /* check that coeff is (q^-1) mod p */
728 mpz_mul(u, this->coeff, this->q);
729 mpz_mod(u, u, this->p);
730 if (mpz_cmp_ui(u, 1) != 0)
731 {
732 status = FAILED;
733 }
734
735 mpz_clear_sensitive(u);
736 mpz_clear_sensitive(p1);
737 mpz_clear_sensitive(q1);
738
739 if (status != SUCCESS)
740 {
741 DBG1(DBG_LIB, "key integrity tests failed");
742 }
743 return status;
744 }
745
746 /**
747 * Internal generic constructor
748 */
749 static private_gmp_rsa_private_key_t *gmp_rsa_private_key_create_empty(void)
750 {
751 private_gmp_rsa_private_key_t *this;
752
753 INIT(this,
754 .public = {
755 .key = {
756 .get_type = _get_type,
757 .sign = _sign,
758 .decrypt = _decrypt,
759 .get_keysize = _get_keysize,
760 .get_public_key = _get_public_key,
761 .equals = private_key_equals,
762 .belongs_to = private_key_belongs_to,
763 .get_fingerprint = _get_fingerprint,
764 .has_fingerprint = private_key_has_fingerprint,
765 .get_encoding = _get_encoding,
766 .get_ref = _get_ref,
767 .destroy = _destroy,
768 },
769 },
770 .threshold = 1,
771 .ref = 1,
772 );
773 return this;
774 }
775
776 /**
777 * See header.
778 */
779 gmp_rsa_private_key_t *gmp_rsa_private_key_gen(key_type_t type, va_list args)
780 {
781 private_gmp_rsa_private_key_t *this;
782 u_int key_size = 0, shares = 0, threshold = 1;
783 bool safe_prime = FALSE, rng_failed = FALSE, invert_failed = FALSE;
784 mpz_t p, q, p1, q1, d;
785 ;
786
787 while (TRUE)
788 {
789 switch (va_arg(args, builder_part_t))
790 {
791 case BUILD_KEY_SIZE:
792 key_size = va_arg(args, u_int);
793 continue;
794 case BUILD_SAFE_PRIMES:
795 safe_prime = TRUE;
796 continue;
797 case BUILD_SHARES:
798 shares = va_arg(args, u_int);
799 continue;
800 case BUILD_THRESHOLD:
801 threshold = va_arg(args, u_int);
802 continue;
803 case BUILD_END:
804 break;
805 default:
806 return NULL;
807 }
808 break;
809 }
810 if (!key_size)
811 {
812 return NULL;
813 }
814 key_size = key_size / BITS_PER_BYTE;
815
816 /* Get values of primes p and q */
817 if (compute_prime(key_size/2, safe_prime, &p, &p1) != SUCCESS)
818 {
819 return NULL;
820 }
821 if (compute_prime(key_size/2, safe_prime, &q, &q1) != SUCCESS)
822 {
823 mpz_clear(p);
824 mpz_clear(p1);
825 return NULL;
826 }
827
828 /* Swapping Primes so p is larger then q */
829 if (mpz_cmp(p, q) < 0)
830 {
831 mpz_swap(p, q);
832 mpz_swap(p1, q1);
833 }
834
835 /* Create and initialize RSA private key object */
836 this = gmp_rsa_private_key_create_empty();
837 this->shares = shares;
838 this->threshold = threshold;
839 this->d = malloc(threshold * sizeof(mpz_t));
840 *this->p = *p;
841 *this->q = *q;
842
843 mpz_init_set_ui(this->e, PUBLIC_EXPONENT);
844 mpz_init(this->n);
845 mpz_init(this->m);
846 mpz_init(this->exp1);
847 mpz_init(this->exp2);
848 mpz_init(this->coeff);
849 mpz_init(this->v);
850 mpz_init(d);
851
852 mpz_mul(this->n, p, q); /* n = p*q */
853 mpz_lcm(this->m, p1, q1); /* m = lcm(p-1,q-1) */
854 mpz_invert(d, this->e, this->m); /* e has an inverse mod m */
855 mpz_mod(this->exp1, d, p1); /* exp1 = d mod p-1 */
856 mpz_mod(this->exp2, d, q1); /* exp2 = d mod q-1 */
857 mpz_invert(this->coeff, q, p); /* coeff = q^-1 mod p */
858
859 invert_failed = mpz_cmp_ui(this->m, 0) == 0 ||
860 mpz_cmp_ui(this->coeff, 0) == 0;
861
862 /* store secret exponent d */
863 (*this->d)[0] = *d;
864
865 /* generate and store random coefficients of secret sharing polynomial */
866 if (threshold > 1)
867 {
868 rng_t *rng;
869 chunk_t random_bytes;
870 mpz_t u;
871 int i;
872
873 rng = lib->crypto->create_rng(lib->crypto, RNG_TRUE);
874 mpz_init(u);
875
876 for (i = 1; i < threshold; i++)
877 {
878 mpz_init(d);
879
880 if (!rng->allocate_bytes(rng, key_size, &random_bytes))
881 {
882 rng_failed = TRUE;
883 continue;
884 }
885 mpz_import(d, random_bytes.len, 1, 1, 1, 0, random_bytes.ptr);
886 mpz_mod(d, d, this->m);
887 (*this->d)[i] = *d;
888 chunk_clear(&random_bytes);
889 }
890
891 /* generate verification key v as a square number */
892 do
893 {
894 if (!rng->allocate_bytes(rng, key_size, &random_bytes))
895 {
896 rng_failed = TRUE;
897 break;
898 }
899 mpz_import(this->v, random_bytes.len, 1, 1, 1, 0, random_bytes.ptr);
900 mpz_mul(this->v, this->v, this->v);
901 mpz_mod(this->v, this->v, this->n);
902 mpz_gcd(u, this->v, this->n);
903 chunk_free(&random_bytes);
904 }
905 while (mpz_cmp_ui(u, 1) != 0);
906
907 mpz_clear(u);
908 rng->destroy(rng);
909 }
910
911 mpz_clear_sensitive(p1);
912 mpz_clear_sensitive(q1);
913
914 if (rng_failed || invert_failed)
915 {
916 DBG1(DBG_LIB, "rsa key generation failed");
917 destroy(this);
918 return NULL;
919 }
920
921 /* set key size in bytes */
922 this->k = key_size;
923
924 return &this->public;
925 }
926
927 /**
928 * Recover the primes from n, e and d using the algorithm described in
929 * Appendix C of NIST SP 800-56B.
930 */
931 static bool calculate_pq(private_gmp_rsa_private_key_t *this)
932 {
933 gmp_randstate_t rstate;
934 mpz_t k, r, g, y, n1, x;
935 int i, t, j;
936 bool success = FALSE;
937
938 gmp_randinit_default(rstate);
939 mpz_init(k);
940 mpz_init(r);
941 mpz_init(g);
942 mpz_init(y);
943 mpz_init(n1);
944 mpz_init(x);
945 /* k = (d * e) - 1 */
946 mpz_mul(k, *this->d, this->e);
947 mpz_sub_ui(k, k, 1);
948 if (mpz_odd_p(k))
949 {
950 goto error;
951 }
952 /* k = 2^t * r, where r is the largest odd integer dividing k, and t >= 1 */
953 mpz_set(r, k);
954 for (t = 0; !mpz_odd_p(r); t++)
955 { /* r = r/2 */
956 mpz_divexact_ui(r, r, 2);
957 }
958 /* we need n-1 below */
959 mpz_sub_ui(n1, this->n, 1);
960 for (i = 0; i < 100; i++)
961 { /* generate random integer g in [0, n-1] */
962 mpz_urandomm(g, rstate, this->n);
963 /* y = g^r mod n */
964 mpz_powm(y, g, r, this->n);
965 /* try again if y == 1 or y == n-1 */
966 if (mpz_cmp_ui(y, 1) == 0 || mpz_cmp(y, n1) == 0)
967 {
968 continue;
969 }
970 for (j = 0; j < t; j++)
971 { /* x = y^2 mod n */
972 mpz_powm_ui(x, y, 2, this->n);
973 /* stop if x == 1 */
974 if (mpz_cmp_ui(x, 1) == 0)
975 {
976 goto done;
977 }
978 /* retry with new g if x = n-1 */
979 if (mpz_cmp(x, n1) == 0)
980 {
981 break;
982 }
983 /* y = x */
984 mpz_set(y, x);
985 }
986 }
987 goto error;
988
989 done:
990 /* p = gcd(y-1, n) */
991 mpz_sub_ui(y, y, 1);
992 mpz_gcd(this->p, y, this->n);
993 /* q = n/p */
994 mpz_divexact(this->q, this->n, this->p);
995 success = TRUE;
996
997 error:
998 mpz_clear_sensitive(k);
999 mpz_clear_sensitive(r);
1000 mpz_clear_sensitive(g);
1001 mpz_clear_sensitive(y);
1002 mpz_clear_sensitive(x);
1003 mpz_clear(n1);
1004 gmp_randclear(rstate);
1005 return success;
1006 }
1007
1008 /**
1009 * See header.
1010 */
1011 gmp_rsa_private_key_t *gmp_rsa_private_key_load(key_type_t type, va_list args)
1012 {
1013 private_gmp_rsa_private_key_t *this;
1014 chunk_t n, e, d, p, q, exp1, exp2, coeff;
1015
1016 n = e = d = p = q = exp1 = exp2 = coeff = chunk_empty;
1017 while (TRUE)
1018 {
1019 switch (va_arg(args, builder_part_t))
1020 {
1021 case BUILD_RSA_MODULUS:
1022 n = va_arg(args, chunk_t);
1023 continue;
1024 case BUILD_RSA_PUB_EXP:
1025 e = va_arg(args, chunk_t);
1026 continue;
1027 case BUILD_RSA_PRIV_EXP:
1028 d = va_arg(args, chunk_t);
1029 continue;
1030 case BUILD_RSA_PRIME1:
1031 p = va_arg(args, chunk_t);
1032 continue;
1033 case BUILD_RSA_PRIME2:
1034 q = va_arg(args, chunk_t);
1035 continue;
1036 case BUILD_RSA_EXP1:
1037 exp1 = va_arg(args, chunk_t);
1038 continue;
1039 case BUILD_RSA_EXP2:
1040 exp2 = va_arg(args, chunk_t);
1041 continue;
1042 case BUILD_RSA_COEFF:
1043 coeff = va_arg(args, chunk_t);
1044 continue;
1045 case BUILD_END:
1046 break;
1047 default:
1048 return NULL;
1049 }
1050 break;
1051 }
1052
1053 this = gmp_rsa_private_key_create_empty();
1054
1055 this->d = malloc(sizeof(mpz_t));
1056 mpz_init(this->n);
1057 mpz_init(this->e);
1058 mpz_init(*this->d);
1059 mpz_init(this->p);
1060 mpz_init(this->q);
1061 mpz_init(this->m);
1062 mpz_init(this->exp1);
1063 mpz_init(this->exp2);
1064 mpz_init(this->coeff);
1065 mpz_init(this->v);
1066
1067 mpz_import(this->n, n.len, 1, 1, 1, 0, n.ptr);
1068 mpz_import(this->e, e.len, 1, 1, 1, 0, e.ptr);
1069 mpz_import(*this->d, d.len, 1, 1, 1, 0, d.ptr);
1070 if (p.len)
1071 {
1072 mpz_import(this->p, p.len, 1, 1, 1, 0, p.ptr);
1073 }
1074 if (q.len)
1075 {
1076 mpz_import(this->q, q.len, 1, 1, 1, 0, q.ptr);
1077 }
1078 if (!p.len && !q.len)
1079 { /* p and q missing in key, recalculate from n, e and d */
1080 if (!calculate_pq(this))
1081 {
1082 destroy(this);
1083 return NULL;
1084 }
1085 }
1086 else if (!p.len)
1087 { /* p missing in key, recalculate: p = n / q */
1088 mpz_divexact(this->p, this->n, this->q);
1089 }
1090 else if (!q.len)
1091 { /* q missing in key, recalculate: q = n / p */
1092 mpz_divexact(this->q, this->n, this->p);
1093 }
1094 if (!exp1.len)
1095 { /* exp1 missing in key, recalculate: exp1 = d mod (p-1) */
1096 mpz_sub_ui(this->exp1, this->p, 1);
1097 mpz_mod(this->exp1, *this->d, this->exp1);
1098 }
1099 else
1100 {
1101 mpz_import(this->exp1, exp1.len, 1, 1, 1, 0, exp1.ptr);
1102 }
1103 if (!exp2.len)
1104 { /* exp2 missing in key, recalculate: exp2 = d mod (q-1) */
1105 mpz_sub_ui(this->exp2, this->q, 1);
1106 mpz_mod(this->exp2, *this->d, this->exp2);
1107 }
1108 else
1109 {
1110 mpz_import(this->exp2, exp2.len, 1, 1, 1, 0, exp2.ptr);
1111 }
1112 if (!coeff.len)
1113 { /* coeff missing in key, recalculate: coeff = q^-1 mod p */
1114 mpz_invert(this->coeff, this->q, this->p);
1115 }
1116 else
1117 {
1118 mpz_import(this->coeff, coeff.len, 1, 1, 1, 0, coeff.ptr);
1119 }
1120 this->k = (mpz_sizeinbase(this->n, 2) + 7) / BITS_PER_BYTE;
1121 if (check(this) != SUCCESS)
1122 {
1123 destroy(this);
1124 return NULL;
1125 }
1126 return &this->public;
1127 }