/*
  File: rsa-generate.c

  Authors:
        Mika Kojo <mkojo@ssh.fi>
        Tatu Ylonen <ylo@cs.hut.fi>
        Tero T Mononen <tmo@ssh.fi>

  Description:

        Take on the RSA key generation, modified after Tatu Ylonen's
        original SSH implementation.

        Description of the RSA algorithm can be found e.g. from the
        following sources:

  - Bruce Schneier: Applied Cryptography.  John Wiley & Sons, 1994.
  - Jennifer Seberry and Josed Pieprzyk: Cryptography: An Introduction to
    Computer Security.  Prentice-Hall, 1989.
  - Man Young Rhee: Cryptography and Secure Data Communications.  McGraw-Hill,
    1994.
  - R. Rivest, A. Shamir, and L. M. Adleman: Cryptographic Communications
    System and Method.  US Patent 4,405,829, 1983.
  - Hans Riesel: Prime Numbers and Computer Methods for Factorization.
    Birkhauser, 1994.

  Copyright:
        Copyright (c) 1995-2001 SSH Communications Security Corp, Finland.
        All rights reserved.
*/


#include "sshincludes.h"
#include "sshmp.h"
#include "sshgenmp.h"
#include "sshcrypt.h"
#include "sshpk.h"
#include "rsa.h"
#include "sshasn1.h"

#ifdef WITH_RSA

#define SSH_DEBUG_MODULE "SshCryptoRSA"






#define SSH_RSA_MINIMUM_PADDING 10
#define SSH_RSA_MAX_BYTES       65535








/* Given mutual primes p and q, derives RSA key components n, d, e,
   and u.  The exponent e will be at least ebits bits in size. p must
   be smaller than q. */

static Boolean
derive_rsa_keys(SshMPInt n, SshMPInt e, SshMPInt d, SshMPInt u,
                SshMPInt p, SshMPInt q,
                unsigned int ebits)
{
  SshMPIntStruct p_minus_1, q_minus_1, aux, phi, G, F;
  Boolean rv = TRUE;

  /* Initialize. */
  ssh_mp_init(&p_minus_1);
  ssh_mp_init(&q_minus_1);
  ssh_mp_init(&aux);
  ssh_mp_init(&phi);
  ssh_mp_init(&G);
  ssh_mp_init(&F);

  /* Compute p-1 and q-1. */
  ssh_mp_sub_ui(&p_minus_1, p, 1);
  ssh_mp_sub_ui(&q_minus_1, q, 1);

  /* phi = (p - 1) * (q - 1); the number of positive integers less than p*q
     that are relatively prime to p*q. */
  ssh_mp_mul(&phi, &p_minus_1, &q_minus_1);

  /* G is the number of "spare key sets" for a given modulus n.  The
     smaller G is, the better.  The smallest G can get is 2. This
     tells in practice nothing about the safety of primes p and q. */
  ssh_mp_gcd(&G, &p_minus_1, &q_minus_1);

  /* F = phi / G; the number of relative prime numbers per spare key
     set. */
  ssh_mp_div_q(&F, &phi, &G);

  /* Find a suitable e (the public exponent). */
  ssh_mp_set_ui(e, 1);
  ssh_mp_mul_2exp(e, e, ebits);
  ssh_mp_sub_ui(e, e, 1); /* make lowest bit 1, and substract 2. */

  /* Keep adding 2 until it is relatively prime to (p-1)(q-1). */
  do
    {
      ssh_mp_add_ui(e, e, 2);
      ssh_mp_gcd(&aux, e, &phi);
    }
  while (!ssh_mprz_isnan(&aux) && ssh_mp_cmp_ui(&aux, 1) != 0);

  /* d is the multiplicative inverse of e, mod F.  Could also be mod
     (p-1)(q-1); however, we try to choose the smallest possible d. */
  ssh_mp_mod_invert(d, e, &F);

  /* u is the multiplicative inverse of p, mod q, if p < q.  It is used
     when doing private key RSA operations using the chinese remainder
     theorem method. */
  ssh_mp_mod_invert(u, p, q);

  /* n = p * q (the public modulus). */
  ssh_mp_mul(n, p, q);

  /* Check modulus (n) inv(p) (u) and inv(e) (d) */
  if (ssh_mprz_isnan(n) || ssh_mprz_isnan(u) || ssh_mprz_isnan(d))
    rv = FALSE;

  /* Clear auxiliary variables. */
  ssh_mp_clear(&p_minus_1);
  ssh_mp_clear(&q_minus_1);
  ssh_mp_clear(&aux);
  ssh_mp_clear(&phi);
  ssh_mp_clear(&G);
  ssh_mp_clear(&F);

  return rv;
}

/* Generate RSA keys with e set to a fixed value. */
static int
derive_rsa_keys_with_e(SshMPInt n, SshMPInt e, SshMPInt d,
                       SshMPInt u, SshMPInt p, SshMPInt q,
                       SshMPInt given_e)
{
  SshMPIntStruct p_minus_1, q_minus_1, aux, phi, G, F;
  int rv;

  /* Initialize. */
  ssh_mp_init(&p_minus_1);
  ssh_mp_init(&q_minus_1);
  ssh_mp_init(&aux);
  ssh_mp_init(&phi);
  ssh_mp_init(&G);
  ssh_mp_init(&F);

  /* Compute p-1 and q-1. */
  ssh_mp_sub_ui(&p_minus_1, p, 1);
  ssh_mp_sub_ui(&q_minus_1, q, 1);

  /* phi = (p - 1) * (q - 1); the number of positive integers less than p*q
     that are relatively prime to p*q. */
  ssh_mp_mul(&phi, &p_minus_1, &q_minus_1);

  /* G is the number of "spare key sets" for a given modulus n.  The smaller
     G is, the better.  The smallest G can get is 2. This tells
     in practice nothing about the safety of primes p and q. */
  ssh_mp_gcd(&G, &p_minus_1, &q_minus_1);

  /* F = phi / G; the number of relative prime numbers per spare key set. */
  ssh_mp_div_q(&F, &phi, &G);

  /* Find a suitable e (the public exponent). */
  ssh_mp_set(e, given_e);
  if (ssh_mp_cmp_ui(e, 3) < 0)
    {
      rv = 0;
      goto failed;
    }

  /* Transform the e into something that is has some probability of
     being correct. */
  if ((ssh_mp_get_ui(e) & 0x1) == 0)
    ssh_mp_add_ui(e, e, 1);
  ssh_mp_sub_ui(e, e, 2);
  /* Keep adding 2 until it is relatively prime to (p-1)(q-1). */
  do
    {
      ssh_mp_add_ui(e, e, 2);
      ssh_mp_gcd(&aux, e, &phi);
    }
  while (!ssh_mprz_isnan(&aux) && ssh_mp_cmp_ui(&aux, 1) != 0);

  /* Verify that the e is correct still! */
  if (ssh_mp_cmp(e, given_e) != 0)
    rv = 1;
  else
    rv = 2;

  /* d is the multiplicative inverse of e, mod F.  Could also be mod
     (p-1)(q-1); however, we try to choose the smallest possible d. */
  ssh_mp_mod_invert(d, e, &F);

  /* u is the multiplicative inverse of p, mod q, if p < q.  It is used
     when doing private key RSA operations using the chinese remainder
     theorem method. */
  ssh_mp_mod_invert(u, p, q);

  /* n = p * q (the public modulus). */
  ssh_mp_mul(n, p, q);

failed:

  /* Clear auxiliary variables. */
  ssh_mp_clear(&p_minus_1);
  ssh_mp_clear(&q_minus_1);
  ssh_mp_clear(&aux);
  ssh_mp_clear(&phi);
  ssh_mp_clear(&G);
  ssh_mp_clear(&F);

  return rv;
}

/* Almost same as above but is given d also. Creates the valid
   SshRSAPrivateKey. Is used from action make routines. */
void *ssh_rsa_make_private_key_of_pqd(SshMPInt p, SshMPInt q, SshMPInt d)
{
  SshMPIntStruct p_minus_1, q_minus_1, aux, phi, G, F;
  SshRSAPrivateKey *private_key;

  if ((private_key = ssh_malloc(sizeof(*private_key))) == NULL)
    return NULL;

  /* Initialize. */
  ssh_mp_init(&p_minus_1);
  ssh_mp_init(&q_minus_1);
  ssh_mp_init(&aux);
  ssh_mp_init(&phi);
  ssh_mp_init(&G);
  ssh_mp_init(&F);

  /* Initialize the private key. */
  ssh_mp_init(&private_key->e);
  ssh_mp_init(&private_key->d);
  ssh_mp_init(&private_key->u);
  ssh_mp_init(&private_key->n);
  ssh_mp_init(&private_key->p);
  ssh_mp_init(&private_key->q);

  /* Compute p-1 and q-1. */
  ssh_mp_sub_ui(&p_minus_1, p, 1);
  ssh_mp_sub_ui(&q_minus_1, q, 1);

  /* Set the p and q. */
  ssh_mp_set(&private_key->p, p);
  ssh_mp_set(&private_key->q, q);

  /* phi = (p - 1) * (q - 1); the number of positive integers less than p*q
     that are relatively prime to p*q. */
  ssh_mp_mul(&phi, &p_minus_1, &q_minus_1);

  /* G is the number of "spare key sets" for a given modulus n.  The smaller
     G is, the better.  The smallest G can get is 2. This tells
     in practice nothing about the safety of primes p and q. */
  ssh_mp_gcd(&G, &p_minus_1, &q_minus_1);

  /* F = phi / G; the number of relative prime numbers per spare key set. */
  ssh_mp_div_q(&F, &phi, &G);

  /* Find a suitable e (the public exponent). */
  ssh_mp_mod_invert(&private_key->e, d, &phi);
  ssh_mp_set(&private_key->d, d);

  /* u is the multiplicative inverse of p, mod q, if p < q.  It is used
     when doing private key RSA operations using the chinese remainder
     theorem method. */
  ssh_mp_mod_invert(&private_key->u, p, q);

  /* n = p * q (the public modulus). */
  ssh_mp_mul(&private_key->n, p, q);

  /* Compute the bit size of the key. */
  private_key->bits = ssh_mp_bit_size(&private_key->n);

  if (ssh_mprz_isnan(&private_key->p) ||
      ssh_mprz_isnan(&private_key->q) ||
      ssh_mprz_isnan(&private_key->u) ||
      ssh_mprz_isnan(&private_key->d) ||
      ssh_mprz_isnan(&private_key->u) ||
      ssh_mprz_isnan(&private_key->n))
    {
      ssh_free(private_key);
      ssh_mp_clear(&private_key->n);
      ssh_mp_clear(&private_key->e);
      ssh_mp_clear(&private_key->d);
      ssh_mp_clear(&private_key->u);
      ssh_mp_clear(&private_key->p);
      ssh_mp_clear(&private_key->q);
      private_key = NULL;
    }


  /* Clear auxiliary variables. */
  ssh_mp_clear(&p_minus_1);
  ssh_mp_clear(&q_minus_1);
  ssh_mp_clear(&aux);
  ssh_mp_clear(&phi);
  ssh_mp_clear(&G);
  ssh_mp_clear(&F);


  /* Return the private key object. */
  return (void *)private_key;
}

/* Generates RSA public and private keys.  This initializes the data
   structures; they should be freed with rsa_clear_private_key and
   rsa_clear_public_key. */

void *ssh_rsa_generate_private_key(unsigned int bits, SshMPInt e)
{
  SshMPIntStruct test, aux, min, max;
  unsigned int pbits;
#if SSH_USE_OLD_RSA_MODULUS_GENERATION
  unsigned int qbits;
#endif
  int ret;
  SshRSAPrivateKey *prv = ssh_malloc(sizeof(*prv));

  if (prv == NULL)
    return NULL;

  /* Initialize our key. */
  ssh_mp_init(&prv->q);
  ssh_mp_init(&prv->p);
  ssh_mp_init(&prv->e);
  ssh_mp_init(&prv->d);
  ssh_mp_init(&prv->u);
  ssh_mp_init(&prv->n);

  /* Auxliary variables. */
  ssh_mp_init(&test);
  ssh_mp_init(&aux);
  ssh_mp_init(&min);
  ssh_mp_init(&max);

  /* Compute the number of bits in each prime. */
  pbits = bits / 2;

  /* Generate random number p. */
  ssh_mp_random_prime(&prv->p, pbits);
  if (ssh_mprz_isnan(&prv->p))
    goto failure;

  /* Repeat until one finds primes that are distinct. */
retry:

#if SSH_USE_OLD_RSA_MODULUS_GENERATION
  qbits = bits - pbits;
  /* Generate random number q. */
  ssh_mp_random_prime(&prv->q, qbits);
#else /* SSH_USE_OLD_RSA_MODULUS_GENERATION */
  /* Form 2^(bits). */
  ssh_mp_set_ui(&aux, 0);
  ssh_mp_set_bit(&aux, bits);
  /* Divide to get the maximum. */
  ssh_mp_div_q(&max, &aux, &prv->p);
  /* Form 2^(bits-1) - 1 (actually 2^(bits-1) for the same effort). */
  ssh_mp_set_ui(&aux, 0);
  ssh_mp_set_bit(&aux, bits-1);
  /* Divide to get the minimum. */
  ssh_mp_div_q(&min, &aux, &prv->p);

  /* Generate a prime number between (min, max). That is,
     we get min < q < max, and then

       min * p < 2^(bits-1) < q * p < max * p < 2^(bits),

     as we desire. The higher bound is trivial, but the lower bound
     follows as min * p < 2^(bits-1) < (min + 1) * p.
     */

  ssh_mp_random_prime_within_interval(&prv->q, &min, &max);
#endif /* SSH_USE_OLD_RSA_MODULUS_GENERATION */

  if (ssh_mprz_isnan(&prv->q))
    goto failure;

  /* Sort them so that p < q. */
  ret = ssh_mp_cmp(&prv->p, &prv->q);
  if (ret == 0)
    goto retry;

  if (ret > 0)
    {
      ssh_mp_set(&aux, &prv->p);
      ssh_mp_set(&prv->p, &prv->q);
      ssh_mp_set(&prv->q, &aux);
    }

  /* Make certain p and q are relatively prime (in case one or both were false
     positives...  Though this is quite impossible). */
  ssh_mp_gcd(&aux, &prv->p, &prv->q);
  if (ssh_mp_cmp_ui(&aux, 1) != 0)
    goto retry;

  if (e == NULL)
    {
      /* Derive the RSA private key from the primes. */
      if (!derive_rsa_keys(&prv->n, &prv->e, &prv->d, &prv->u,
                           &prv->p, &prv->q,
                           16))
        goto failure;
    }
  else
    switch (derive_rsa_keys_with_e(&prv->n, &prv->e, &prv->d, &prv->u,
                                   &prv->p, &prv->q,
                                   e))
      {
      case 0:
      failure:
        ssh_mp_clear(&prv->n);
        ssh_mp_clear(&prv->e);
        ssh_mp_clear(&prv->d);
        ssh_mp_clear(&prv->u);
        ssh_mp_clear(&prv->p);
        ssh_mp_clear(&prv->q);
        ssh_free(prv);

        ssh_mp_clear(&aux);
        ssh_mp_clear(&test);
        ssh_mp_clear(&min);
        ssh_mp_clear(&max);
        return NULL;
      case 1:
      case 2:
        /* Do nothing special, accept possible changes to the given value. */
        break;
      }

  ssh_mp_clear(&aux);
  ssh_mp_clear(&test);
  ssh_mp_clear(&min);
  ssh_mp_clear(&max);

  /* Compute the bit size of the key. */
  prv->bits = ssh_mp_bit_size(&prv->n);
  return (void *)prv;
}


/* Try to handle the given data in a reasonable manner. This can
   generate and define key. */
void *ssh_rsa_private_key_generate_action(void *context)
{
  SshRSAInitCtx *ctx = context;

  if (ssh_mp_cmp_ui(&ctx->d, 0) == 0 ||
      ssh_mp_cmp_ui(&ctx->p, 0) == 0 ||
      ssh_mp_cmp_ui(&ctx->q, 0) == 0)
    {
      /* Generate with e, p and q set. */
      if (ssh_mp_cmp_ui(&ctx->e, 0) != 0 &&
          ssh_mp_cmp_ui(&ctx->p, 0) != 0 &&
          ssh_mp_cmp_ui(&ctx->q, 0) != 0)
        {
          SshRSAPrivateKey *prv ;
          int rv;

          if ((prv = ssh_malloc(sizeof(*prv))) == NULL)
            return NULL;

          ssh_rsa_private_key_init(prv);
          ssh_mp_set(&prv->q, &ctx->q);
          ssh_mp_set(&prv->p, &ctx->p);
          rv = derive_rsa_keys_with_e(&prv->n, &prv->e, &prv->d, &prv->u,
                                      &prv->p, &prv->q, &ctx->e);
          prv->bits = ssh_mp_bit_size(&prv->n);
          if (rv != 0)
            return prv;
          else
            {
              ssh_rsa_private_key_free(prv);
              return NULL;
            }
        }

      /* Cannot generate because no predefined size exists. */
      if (ctx->bits == 0)
        return NULL;

      /* Generate with e set. */
      if (ssh_mp_cmp_ui(&ctx->e, 0) != 0)
        return (void *)ssh_rsa_generate_private_key(ctx->bits,
                                                    &ctx->e);

      /* Just generate from assigned values. */
      return (void *)ssh_rsa_generate_private_key(ctx->bits, NULL);
    }
  else
    {
      if (ssh_mp_cmp_ui(&ctx->d, 0) != 0 &&
          ssh_mp_cmp_ui(&ctx->p, 0) != 0 &&
          ssh_mp_cmp_ui(&ctx->q, 0) != 0)
        {
          if (ssh_mp_cmp_ui(&ctx->e, 0) != 0 &&
              ssh_mp_cmp_ui(&ctx->n, 0) != 0 &&
              ssh_mp_cmp_ui(&ctx->u, 0) != 0)
            {
              return
                ssh_rsa_make_private_key_of_all(&ctx->p, &ctx->q,
                                                &ctx->n, &ctx->e,
                                                &ctx->d, &ctx->u);
            }

          return ssh_rsa_make_private_key_of_pqd(&ctx->p, &ctx->q, &ctx->d);
        }
    }
  return NULL;
}


#endif /* WITH_RSA */
