openssl: Add functions to determine missing RSA private key parameters
authorTobias Brunner <tobias@strongswan.org>
Fri, 22 Sep 2017 13:47:11 +0000 (15:47 +0200)
committerTobias Brunner <tobias@strongswan.org>
Wed, 8 Nov 2017 15:48:10 +0000 (16:48 +0100)
We only need n, e, and d.  The parameters for the Chinese remainder
algorithm and even p and q can be determined from these.

src/libstrongswan/plugins/openssl/openssl_rsa_private_key.c

index f2c320f..234f559 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2008-2016 Tobias Brunner
+ * Copyright (C) 2008-2017 Tobias Brunner
  * Copyright (C) 2009 Martin Willi
  * HSR Hochschule fuer Technik Rapperswil
  *
@@ -37,6 +37,7 @@
 OPENSSL_KEY_FALLBACK(RSA, key, n, e, d)
 OPENSSL_KEY_FALLBACK(RSA, factors, p, q)
 OPENSSL_KEY_FALLBACK(RSA, crt_params, dmp1, dmq1, iqmp)
+#define BN_secure_new() BN_new()
 #endif
 
 typedef struct private_openssl_rsa_private_key_t private_openssl_rsa_private_key_t;
@@ -400,6 +401,195 @@ private_key_t *openssl_rsa_private_key_create(EVP_PKEY *key, bool engine)
        return &this->public.key;
 }
 
+/**
+ * Recover the primes from n, e and d using the algorithm described in
+ * Appendix C of NIST SP 800-56B.
+ */
+static bool calculate_pq(BIGNUM *n, BIGNUM *e, BIGNUM *d,
+                                                BIGNUM **p, BIGNUM **q)
+{
+       BN_CTX *ctx;
+       BIGNUM *k, *r, *g, *y, *n1, *x;
+       int i, t, j;
+       bool success = FALSE;
+
+       ctx = BN_CTX_new();
+       if (!ctx)
+       {
+               return FALSE;
+       }
+       BN_CTX_start(ctx);
+       k = BN_CTX_get(ctx);
+       r = BN_CTX_get(ctx);
+       g = BN_CTX_get(ctx);
+       y = BN_CTX_get(ctx);
+       n1 = BN_CTX_get(ctx);
+       x = BN_CTX_get(ctx);
+       if (!x)
+       {
+               goto error;
+       }
+       /* k = (d * e) - 1 */
+       if (!BN_mul(k, d, e, ctx) || !BN_sub(k, k, BN_value_one()))
+       {
+               goto error;
+       }
+       /* k must be even */
+       if (BN_is_odd(k))
+       {
+               goto error;
+       }
+       /* k = 2^t * r, where r is the largest odd integer dividing k, and t >= 1 */
+       if (!BN_copy(r, k))
+       {
+               goto error;
+       }
+       for (t = 0; !BN_is_odd(r); t++)
+       {       /* r = r/2 */
+               if (!BN_rshift(r, r, 1))
+               {
+                       goto error;
+               }
+       }
+       /* we need n-1 below */
+       if (!BN_sub(n1, n, BN_value_one()))
+       {
+               goto error;
+       }
+       for (i = 0; i < 100; i++)
+       {       /* generate random integer g in [0, n-1] */
+               if (!BN_pseudo_rand_range(g, n))
+               {
+                       goto error;
+               }
+               /* y = g^r mod n */
+               if (!BN_mod_exp(y, g, r, n, ctx))
+               {
+                       goto error;
+               }
+               /* try again if y == 1 or y == n-1 */
+               if (BN_is_one(y) || BN_cmp(y, n1) == 0)
+               {
+                       continue;
+               }
+               for (j = 0; j < t; j++)
+               {       /* x = y^2 mod n */
+                       if (!BN_mod_sqr(x, y, n, ctx))
+                       {
+                               goto error;
+                       }
+                       /* stop if x == 1 */
+                       if (BN_is_one(x))
+                       {
+                               goto done;
+                       }
+                       /* retry with new g if x = n-1 */
+                       if (BN_cmp(x, n1) == 0)
+                       {
+                               break;
+                       }
+                       /* y = x */
+                       if (!BN_copy(y, x))
+                       {
+                               goto error;
+                       }
+               }
+       }
+       goto error;
+
+done:
+       /* p = gcd(y-1, n) */
+       if (!BN_sub(y, y, BN_value_one()))
+       {
+               goto error;
+       }
+       *p = BN_secure_new();
+       if (!BN_gcd(*p, y, n, ctx))
+       {
+               BN_clear_free(*p);
+               goto error;
+       }
+       /* q = n/p */
+       *q = BN_secure_new();
+       if (!BN_div(*q, NULL, n, *p, ctx))
+       {
+               BN_clear_free(*p);
+               BN_clear_free(*q);
+               goto error;
+       }
+       success = TRUE;
+
+error:
+       BN_CTX_end(ctx);
+       BN_CTX_free(ctx);
+       return success;
+}
+
+/**
+ * Calculates dp = d (mod p-1) or dq = d (mod q-1) for the Chinese remainder
+ * algorithm.
+ */
+static BIGNUM *dmodpq1(BIGNUM *d, BIGNUM *pq)
+{
+       BN_CTX *ctx;
+       BIGNUM *res = NULL, *pq1;
+
+       ctx = BN_CTX_new();
+       if (!ctx)
+       {
+               return NULL;
+       }
+       BN_CTX_start(ctx);
+       pq1 = BN_CTX_get(ctx);
+       /* p|q - 1 */
+       if (!BN_sub(pq1, pq, BN_value_one()))
+       {
+               goto error;
+       }
+       /* d (mod p|q -1) */
+       res = BN_secure_new();
+       if (!BN_mod(res, d, pq1, ctx))
+       {
+               BN_clear_free(res);
+               res = NULL;
+               goto error;
+       }
+
+error:
+       BN_CTX_end(ctx);
+       BN_CTX_free(ctx);
+       return res;
+}
+
+/**
+ * Calculates qinv = q^-1 (mod p) for the Chinese remainder algorithm.
+ */
+static BIGNUM *qinv(BIGNUM *q, BIGNUM *p)
+{
+       BN_CTX *ctx;
+       BIGNUM *res = NULL;
+
+       ctx = BN_CTX_new();
+       if (!ctx)
+       {
+               return NULL;
+       }
+       BN_CTX_start(ctx);
+       /* q^-1 (mod p) */
+       res = BN_secure_new();
+       if (!BN_mod_inverse(res, q, p, ctx))
+       {
+               BN_clear_free(res);
+               res = NULL;
+               goto error;
+       }
+
+error:
+       BN_CTX_end(ctx);
+       BN_CTX_free(ctx);
+       return res;
+}
+
 /*
  * See header
  */
@@ -458,7 +648,7 @@ openssl_rsa_private_key_t *openssl_rsa_private_key_load(key_type_t type,
                        return &this->public;
                }
        }
-       else if (n.ptr && e.ptr && d.ptr && p.ptr && q.ptr && coeff.ptr)
+       else if (n.ptr && e.ptr && d.ptr)
        {
                BIGNUM *bn_n, *bn_e, *bn_d, *bn_p, *bn_q;
                BIGNUM *dmp1 = NULL, *dmq1 = NULL, *iqmp = NULL;
@@ -470,32 +660,56 @@ openssl_rsa_private_key_t *openssl_rsa_private_key_load(key_type_t type,
                bn_d = BN_bin2bn((const u_char*)d.ptr, d.len, NULL);
                if (!RSA_set0_key(this->rsa, bn_n, bn_e, bn_d))
                {
-                       destroy(this);
-                       return NULL;
+                       goto error;
 
                }
-               bn_p = BN_bin2bn((const u_char*)p.ptr, p.len, NULL);
-               bn_q = BN_bin2bn((const u_char*)q.ptr, q.len, NULL);
+               if (p.ptr && q.ptr)
+               {
+                       bn_p = BN_bin2bn((const u_char*)p.ptr, p.len, NULL);
+                       bn_q = BN_bin2bn((const u_char*)q.ptr, q.len, NULL);
+               }
+               else
+               {
+                       if (!calculate_pq(bn_n, bn_e, bn_d, &bn_p, &bn_q))
+                       {
+                               goto error;
+                       }
+               }
                if (!RSA_set0_factors(this->rsa, bn_p, bn_q))
                {
-                       destroy(this);
-                       return NULL;
+                       goto error;
                }
                if (exp1.ptr)
                {
                        dmp1 = BN_bin2bn((const u_char*)exp1.ptr, exp1.len, NULL);
                }
+               else
+               {
+                       dmp1 = dmodpq1(bn_d, bn_p);
+               }
                if (exp2.ptr)
                {
                        dmq1 = BN_bin2bn((const u_char*)exp2.ptr, exp2.len, NULL);
                }
-               iqmp = BN_bin2bn((const u_char*)coeff.ptr, coeff.len, NULL);
+               else
+               {
+                       dmq1 = dmodpq1(bn_d, bn_q);
+               }
+               if (coeff.ptr)
+               {
+                       iqmp = BN_bin2bn((const u_char*)coeff.ptr, coeff.len, NULL);
+               }
+               else
+               {
+                       iqmp = qinv(bn_q, bn_p);
+               }
                if (RSA_set0_crt_params(this->rsa, dmp1, dmq1, iqmp) &&
                        RSA_check_key(this->rsa) == 1)
                {
                        return &this->public;
                }
        }
+error:
        destroy(this);
        return NULL;
 }