/* rsa_gen.c */
/* Copyright (C) 1995 Eric Young (eay@mincom.oz.au).
 * All rights reserved.
 * Copyright remains Eric Young's, and as such any Copyright notices in
 * the code are not to be removed.
 * See the COPYRIGHT file in the SSLeay distribution for more details.
 */

#include <stdio.h>
#include <time.h>
#include "crypto.h"
#include "bn.h"
#include "md5.h"
#include "X509.h"

/* The quick seive algorithm approach to weeding out primes is
 * Philip Zimmermann's, as implemented in PGP.  I have had a read of
 * his comments and implemented my own version.
 */
#include "bn_prime.h"

#ifdef PROTO
static int witness(BIGNUM *a, BIGNUM *n);
static int probable_prime(BIGNUM *rand, int bits);
static int bn_extended_euclid(BIGNUM **rd, BIGNUM **rx, BIGNUM **ry, BIGNUM *a, BIGNUM *b);
static BIGNUM *bn_inverse_modn(BIGNUM *a, BIGNUM *n);
#else
static int witness();
static int probable_prime();
static int bn_extended_euclid();
static BIGNUM *bn_inverse_modn();
#endif

static int RSA_prime_checks=5;
static void (*callback)()=NULL;

void RSA_set_generate_prime_callback(a)
void (*a)();
	{
	callback=a;
	}

BIGNUM *RSA_generate_prime(bits)
int bits;
	{
	BIGNUM *rand;
	static BIGNUM *check=NULL;
	int i,j,c1=0,c2=0;

	rand=bn_new();
	if (rand == NULL) return(NULL);
	if (check == NULL)
		{
		check=bn_new();
		if (check == NULL) return(NULL);
		}
	loop: 
		{
		/* make a random number and set the top and bottom bits */
		if (!probable_prime(rand,bits)) return(NULL);
		if (callback != NULL) callback(0,c1++);

		/* we now have a random number 'rand' to test. */
		for (i=0; i<RSA_prime_checks; i++)
			{
			if (!RSA_bn_rand(check,bits-1,0)) return(NULL);
			j=witness(check,rand);
			if (j == -1) return(NULL);
			if (j) goto loop; /* try again :-( */
			if (callback != NULL) callback(1,c2++);
			}
		/* number is prime so finish up */
		}
	/* we have a prime :-) */
	if (callback != NULL) callback(2,0);
	return(rand);
	}

static int witness(a, n)
BIGNUM *a;
BIGNUM *n;
	{
	int k,i,nb;
	BIGNUM *d,*dd,*tmp;
	BIGNUM *d1,*d2,*x,*n1,*inv;
	int tos=bn_get_tos();

	d1=bn_get_reg();
	d2=bn_get_reg();
	x=bn_get_reg();
	n1=bn_get_reg();
	inv=bn_get_reg();

	if (	(d1 == NULL) || (d2 == NULL) ||
		(x == NULL)  || (n1 == NULL) ||
		(inv == NULL))
		goto err;

	d=d1;
	dd=d2;
	if (!bn_one(d)) goto err;
	if (!bn_sub(n1,n,d)) goto err; /* n1=n-1; */
	k=bn_num_bits(n1);

	i=bn_num_bits(n);
#ifdef RECP_MUL_MOD
	nb=bn_reciprical(inv,n); /**/
	if (nb == -1) goto err;
#endif

	for (i=k-1; i>=0; i--)
		{
		if (bn_copy(x,d) == NULL) goto err;
#ifndef RECP_MUL_MOD
		if (!bn_mul_mod(dd,d,d,n)) goto err; /**/
#else
		if (!bn_modmul_recip(dd,d,d,n,inv,nb)) goto err; /**/
#endif
		if (	bn_is_one(dd) &&
			!bn_is_one(x) &&
			(bn_cmp(x,n1) != 0))
			return(1);
		if (bn_is_bit_set(n1,i))
			{
#ifndef RECP_MUL_MOD
			if (!bn_mul_mod(d,dd,a,n)) goto err; /**/
#else
			if (!bn_modmul_recip(d,dd,a,n,inv,nb)) goto err; /**/
#endif
			}
		else
			{
			tmp=d;
			d=dd;
			dd=tmp;
			}
		}
	if (bn_is_one(d))
		i=0;
	else	i=1;
	bn_set_tos(tos);
	return(i);
err:
	bn_set_tos(tos);
	return(-1);
	}

/* solves ax == 1 (mod n) */
static BIGNUM *bn_inverse_modn(a, n)
BIGNUM *a;
BIGNUM *n;
	{
	BIGNUM *t,*d,*x1,*y1;
	int tos;

	t=bn_new();

	tos=bn_get_tos();
	d=bn_get_reg();
	x1=bn_get_reg();
	y1=bn_get_reg();
	if ((d == NULL) || (x1 == NULL) || (y1 == NULL))
		goto err;

	if (!bn_extended_euclid(&d,&x1,&y1,n,a)) goto err;

	if (y1->neg)
		{
		if (!bn_add(y1,y1,n)) goto err;
		}

	if (bn_is_one(d))
		{ if (!bn_mod(t,y1,n)) goto err; }
	else
		{
		RSA_errno=RSA_ERR_INVERSE_MODN_NO_SOLUTION;
		goto err;
		}
	bn_set_tos(tos);
	return(t);
err:
	bn_set_tos(tos);
	return(NULL);
	}

static int bn_extended_euclid(rd, rx, ry, a, b)
BIGNUM **rd;
BIGNUM **rx;
BIGNUM **ry;
BIGNUM *a;
BIGNUM *b;
	{
	BIGNUM *A,*B,*tmp;
	int tos=bn_get_tos();

	if (bn_is_zero(b))
		{
		if (bn_copy(*rd,a) == NULL) goto err;
		if (!bn_one(*rx)) goto err;
		bn_zero(*ry);
		return(1);
		}

	A=bn_get_reg();
	if (A == NULL) goto err;
	if (!bn_mod(A,a,b)) goto err;

	if (!bn_extended_euclid(rd,rx,ry,b,A))
		goto err;
	tmp= *rx;
	*rx= *ry;
	if (!bn_div(A,NULL,a,b)) goto err;

	B=bn_get_reg();
	if (B == NULL) goto err;

	if (!bn_mul(B,*ry,A)) goto err;
	if (!bn_sub(A,tmp,B)) goto err;
	if (bn_copy(tmp,A) == NULL) goto err;
	*ry=tmp;

	bn_set_tos(tos);
	return(1);
err:
	bn_set_tos(tos);
	return(0);
	}

int RSA_bn_rand(rand, bits, prime)
BIGNUM *rand;
int bits;
int prime;
	{
	static int buf_length=0;
	static unsigned char *buf=NULL;
	int bit,bytes,mask;

	bytes=(bits+7)/8;
	bit=(bits-1)%8;
	mask=0xff<<bit;
	if ((buf_length < bytes) && (buf != NULL))
		{
		free(buf);
		buf=NULL;
		}
	if (buf == NULL)
		{
		buf=(unsigned char *)malloc(bytes);
		if (buf == NULL)
			{
			RSA_errno=RSA_ERR_OUT_OF_MEM;
			return(0);
			}
		buf_length=bytes;
		}

	/* make a random number and set the top and bottom bits */
	{ long tim; time(&tim); MD5_rand_seed(4,(unsigned char *)&tim); }
	MD5_rand(bytes,buf);
	buf[0]|=(1<<bit);
	buf[0]&= ~(mask<<1);
	if (prime) /* set bottom bit */
		buf[bytes-1]|=1;
	if (!bn_bin2bn(bytes,buf,rand)) return(0);
	return(1);
	}

static int probable_prime(rand, bits)
BIGNUM *rand;
int bits;
	{
	int i;
	BN_ULONG mods[NUMPRIMES];
	BN_ULONG delta;

	if (!RSA_bn_rand(rand,bits,1)) return(0);
	/* we now have a random number 'rand' to test. */
	for (i=1; i<NUMPRIMES; i++)
		mods[i]=bn_mod_word(rand,(BN_ULONG)primes[i]);
	delta=0;
	loop: for (i=1; i<NUMPRIMES; i++)
		{
		/* check if not a prime */
		if (((mods[i]+delta)%primes[i]) == 0)
			{
			delta+=2;
/* NEED TO CHECK FOR OVERFLOW OF DELTA */
			goto loop;
			}
		}
	if (!bn_add_word(rand,delta)) return(0);
	return(1);
	}

RSA *RSA_generate_key(bits, use_f4)
int bits;
int use_f4;
	{
	RSA *rsa;
	BIGNUM *r0,*r1,*r2,*tmp;
	int tos;

	tos=bn_get_tos();
	r0=bn_get_reg();
	r1=bn_get_reg();
	r2=bn_get_reg();
	if ((r0 == NULL) || (r1 == NULL) || (r2 == NULL))
		goto err;

	bits/=2;
	rsa=RSA_new();
	if (rsa == NULL) goto err;

	/* generate p and q */
	rsa->p=RSA_generate_prime(bits);
	if (rsa->p == NULL) goto err;
	do	{
		rsa->q=RSA_generate_prime(bits);
		if (rsa->q == NULL) goto err;
		} while (bn_cmp(rsa->p,rsa->q) == 0);
	if (bn_cmp(rsa->p,rsa->q) < 0)
		{
		tmp=rsa->p;
		rsa->p=rsa->q;
		rsa->q=tmp;
		}

	/* calculate n */
	rsa->n=bn_new();
	if (rsa->n == NULL) goto err;
	if (!bn_mul(rsa->n,rsa->p,rsa->q)) goto err;

	/* set e */ 
	rsa->e=bn_new();
	if (rsa->e == NULL) goto err;
	if (!bn_one(r0)) goto err;
	if (use_f4)
		{ if (!bn_lshift(r1,r0,16)) goto err; } /* 0x10001 */
	else
		{ if (!bn_lshift(r1,r0,1)) goto err; } /* 0x3 */
	if (!bn_add(rsa->e,r1,r0)) goto err;

	/* calculate d */
	if (!bn_sub(r1,rsa->p,r0)) goto err;	/* p-1 */
	if (!bn_sub(r2,rsa->q,r0)) goto err;	/* q-1 */
	if (!bn_mul(r0,r1,r2)) goto err;	/* (p-1)(q-1) */
	rsa->d=(BIGNUM *)bn_inverse_modn(rsa->e,r0);	/* d */
	if (rsa->d == NULL) goto err;

	/* calculate d mod (p-1) */
	rsa->dmp1=bn_new();
	if (rsa->dmp1 == NULL) goto err;
	if (!bn_mod(rsa->dmp1,rsa->d,r1)) goto err;

	/* calculate d mod (q-1) */
	rsa->dmq1=bn_new();
	if (rsa->dmq1 == NULL) goto err;
	if (!bn_mod(rsa->dmq1,rsa->d,r2)) goto err;

	/* calculate inverse of q mod p */
	rsa->iqmp=bn_inverse_modn(rsa->q,rsa->p);
	if (rsa->iqmp == NULL) goto err;

	bn_set_tos(tos);
	bn_clean_up();
	return(rsa);
err:
	bn_set_tos(tos);
	bn_clean_up();
	if (rsa != NULL) RSA_free(rsa);
	return(NULL);
	}

