/* ssl_clnt.c */
/* Copyright (C) 1995 Eric Young (eay@mincom.oz.au).
 * All rights reserved.
 * Copyright remains Eric Young's, and as such any Copyright notices in
 * the code are not to be removed.
 * See the COPYRIGHT file in the SSLeay distribution for more details.
 */

#include <stdio.h>
#include "md5.h"
#include "ssl_locl.h"

#ifdef PROTO
static int get_server_finish(SSL *s);
static int get_server_verify(SSL *s);
static int get_server_hello(SSL *s);
static int client_hello(SSL *s); 
static int client_master_key(SSL *s);
static int client_finish(SSL *s);
static int client_certificate(SSL *s, unsigned char *buf);
static void choose_cipher(SSL *s);
#else
static int get_server_finish();
static int get_server_verify();
static int get_server_hello();
static int client_hello(); 
static int client_master_key();
static int client_finish();
static int client_certificate();
static void choose_cipher();
#endif

int SSL_connect(s)
SSL *s;
	{
	CIPHER **c;
	unsigned long l=time(NULL);

	MD5_rand_seed(sizeof(l),(unsigned char *)&l);

	/* init things to blank */
	ssl_clear(s);

#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"start Client Hello\n");
#endif

	if (client_hello(s) <= 0) return(0);

#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"send Client Hello\n");
	SSL_TRACE(SSL_ERR,"Version: %d\n",s->version);
   SSL_TRACE(SSL_ERR,"Challenge:");
   ssl_print_bytes(SSL_ERR,s->challenge_length,s->challenge);
	SSL_TRACE(SSL_ERR,"\n");
#endif

	if (get_server_hello(s) <= 0) return(0);

	if (!s->hit) /* new connection */
		{
#ifdef SSL_DEBUG
		SSL_TRACE(SSL_ERR,"Certificate:");
		if (SSL_ERR!=NULL)
		    i2f_X509(SSL_ERR,s->conn->cert->x509);
		SSL_TRACE(SSL_ERR,"\n");
		for (c= &(s->conn->ciphers[0]); *c; c++)
			SSL_TRACE(SSL_ERR,"Ciphers: %s\n",(*c)->name);
		SSL_TRACE(SSL_ERR,"\n");
		SSL_TRACE(SSL_ERR,"Connection-Id:");
		ssl_print_bytes(SSL_ERR,CONNECTION_ID_LENGTH,s->conn_id);
		SSL_TRACE(SSL_ERR,"\n");
#endif
		if (client_master_key(s) <= 0) return(0);
		}

#ifdef SSL_DEBUG
	/* master key etc is now chosen so we can print it */
	if (!s->hit)
		{
		SSL_TRACE(SSL_ERR,"Master Key:");
		ssl_print_bytes(SSL_ERR,s->conn->master_key_length,
			s->conn->master_key);
		SSL_TRACE(SSL_ERR,"\n");
		SSL_TRACE(SSL_ERR,"Master Key Arg:");
		ssl_print_bytes(SSL_ERR,s->conn->key_arg_length,
			s->conn->key_arg);
		SSL_TRACE(SSL_ERR,"\n");
		}
#endif

	/* Ok, we now have all the stuff needed to start encrypting, so
	 * lets fire it up :-) */
	if (!s->conn->cipher->crypt_init(s,1)) return(0);
	s->clear_text=0;

	if (client_finish(s) <= 0) return(0);

#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"Client_finish\nConnection-Id:");
	ssl_print_bytes(SSL_ERR,s->conn_id_length,s->conn_id);
	SSL_TRACE(SSL_ERR,"\n");
#endif

	if (get_server_verify(s) <= 0) return(0);

#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"Server_verify\nChallenge:");
	ssl_print_bytes(SSL_ERR,s->challenge_length,s->challenge);
	SSL_TRACE(SSL_ERR,"\n");
#endif

	if (get_server_finish(s) <= 0) return(0);

#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"Server_finish\nnew_session_id:");
	ssl_print_bytes(SSL_ERR,s->conn->session_id_length,s->conn->session_id);
	SSL_TRACE(SSL_ERR,"\n");
	SSL_TRACE(SSL_ERR,"SSL PROTOCOL FINSHED\n");
#endif
	return(1);
	}

static int get_server_hello(s)
SSL *s;
	{
	unsigned char buf[SSL_MAX_RECORD_LENGTH_3_BYTE_HEADER];
	unsigned char *p;
	CIPHER **cipher,*cp;
	int i,n,j,to;
	int session_id_hit,cert_type,cert_length,csl,conn_id_length;

	p=buf;
	i=SSL_read(s,(char *)buf,11);
	if (i < 11)
		{
		ssl_return_error(s);
		SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_HELLO,SSL_R_SHORT_READ);
		return(0);
		}

	if (*(p++) != SSL_MT_SERVER_HELLO)
           {
		if (p[-1] != SSL_MT_ERROR)
			{
			ssl_return_error(s);
			SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_HELLO,
				SSL_R_READ_WRONG_PACKET_TYPE);
			}
		else
			SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_HELLO,
				SSL_R_PEER_ERROR);
		return(0);
		}
	s->hit=session_id_hit=(*(p++))?1:0;
	cert_type= *(p++);
	n2s(p,i);
	if (i < s->version) s->version=i;
	n2s(p,cert_length);
	n2s(p,csl);
	n2s(p,conn_id_length);

	i=SSL_read(s,(char *)buf,cert_length+csl+conn_id_length);
	if (i != cert_length+csl+conn_id_length)
		{
		ssl_return_error(s);
		SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_HELLO,
			SSL_R_SHORT_READ);
		return(0);
		}

	p=buf;
	if (session_id_hit)
		{
		if ((cert_length != 0) || /*(cert_type != 0) || */ (csl != 0))
			{
			/*return(0);*/  /* ignore these errors :-( */
			}
		}
	else
		{
		if (s->conn->session_id != NULL) free(s->conn->session_id);
		s->conn->session_id=NULL;
		s->conn->session_id_length=0;

#ifdef SSL_DEBUG
		SSL_TRACE(SSL_ERR,"cert_length=%d\n",cert_length);
#endif

		if (ssl_set_certificate(s,cert_type,cert_length,p) <= 0)
			{
			ssl_return_error(s);
			return(0);
			}
		p+=cert_length;

		if (csl == 0)
			{
			ssl_return_error(s);
			SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_HELLO,
				SSL_R_NO_CIPHER_LIST);
			return(0);
			}

		cipher= &(s->conn->ciphers[0]);
		to=0;
		for (i=0; i<csl; i+=3)
			{
			n=p[i];
			
			for (j=to; cipher[j] != NULL; j++)
				{
				if (n == cipher[j]->num)
					{
					cp=cipher[to];
					cipher[to]=cipher[j];
					cipher[j]=cp;
					to++;
					}
				}
			}
		cipher[to]=NULL;
		if (to == 0)
			{
			ssl_return_error(s);
			SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_HELLO,
				SSL_R_NO_CIPHER_MATCH);
			return(0);
			}
		p+=csl;
		/* pick a cipher - which one? */
		choose_cipher(s);
		}
	s->peer=(X509 *)s->conn->cert->x509;
	s->free_peer=0;
	/* get conn_id */
	if (s->conn_id_length != conn_id_length)
		{
		if (s->conn_id) free(s->conn_id);
		s->conn_id=(unsigned char *)malloc(conn_id_length);
		if (s->conn_id == NULL)
			{
			ssl_return_error(s);
			SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_HELLO,
				SSL_R_OUT_OF_MEMORY);
			return(0);
			}
		}
	s->conn_id_length=conn_id_length;
	memcpy(s->conn_id,p,conn_id_length);
	return(1);
	}

static int client_hello(s)
SSL *s;
	{
	unsigned char buf[SSL_MAX_RECORD_LENGTH_3_BYTE_HEADER];
	unsigned char *p,*d;
	CIPHER **cipher;
	int i,n;

	if (s->conn == NULL)
		if (!ssl_new_conn(s,0))
			{
			ssl_return_error(s);
			return(0);
			}
	/* else use the pre-loaded session */

	cipher= &(s->conn->ciphers[0]);
	p=buf;					/* header */
	d=p+9;					/* data section */
	*(p++)=SSL_MT_CLIENT_HELLO;		/* type */
	s2n(SSL_CLIENT_VERSION,p);		/* version */
	i=n=0;
	for (;;)/* cipher spec data */
		{
		if (ssl_ciphers[i].valid == -1)
			{
			i++;
			break;
			}
		if (ssl_ciphers[i].valid)
			{
			*(cipher++)= &(ssl_ciphers[i]);
			*(d++)=ssl_ciphers[i].num;
			*(d++)=ssl_ciphers[i].noidea;
			*(d++)=ssl_ciphers[i].keybits;
#ifdef SSL_DEBUG
			SSL_TRACE(SSL_ERR,"Ciphers: %s\n",ssl_ciphers[i].name);
#endif
			n+=3;
			}
		*cipher=NULL;
		i++;
		}
	s2n(n,p);				/* cipher spec num bytes */

	if (s->conn->session_id != NULL)
		{
#ifdef SSL_DEBUG
		SSL_TRACE(SSL_ERR,"EXISTING SESSION_ID:");
		ssl_print_bytes(SSL_ERR,s->conn->session_id_length,
			s->conn->session_id);
		SSL_TRACE(SSL_ERR,"\n");
#endif
		i=s->conn->session_id_length;
		s2n(i,p);		/* session id length */
		memcpy(d,s->conn->session_id,i);
		d+=i;
		}
	else
		{
#ifdef SSL_DEBUG
		SSL_TRACE(SSL_ERR,"new session\n");
#endif
		s2n(0,p);
		}

	s->challenge_length=CHALLENGE_LENGTH;
	s2n(CHALLENGE_LENGTH,p);		/* challenge length */
	s->challenge=(unsigned char *)malloc(CHALLENGE_LENGTH);
	if (s->challenge == NULL)
		{
		ssl_return_error(s);
		SSL_errno=SSL_ERRCODE(SSL_F_CLIENT_HELLO,SSL_R_OUT_OF_MEMORY);
		return(0);
		}
	MD5_rand(CHALLENGE_LENGTH,s->challenge);/* challenge id data */
	memcpy(d,s->challenge,CHALLENGE_LENGTH);
	d+=CHALLENGE_LENGTH;

	n=d-buf;
	return(SSL_write(s,(char *)buf,n));
	}

static int client_master_key(s)
SSL *s;
	{
	unsigned char buf[SSL_MAX_RECORD_LENGTH_3_BYTE_HEADER];/**/
	unsigned char *p,*d;
	int clear,enc,karg,n,i;

	p=buf;
	d=p+10;
	*(p++)=SSL_MT_CLIENT_MASTER_KEY;	/* type */
	*(p++)=s->conn->cipher->num;		/* cipher type - byte 1 */
	*(p++)=s->conn->cipher->noidea;		/* cipher type - byte 2 */
	*(p++)=s->conn->cipher->keybits;	/* cipher type - byte 3 */

	/* make a master key */
	i=s->conn->master_key_length=(unsigned char)s->conn->cipher->keybits/8;
	s->conn->master_key=(unsigned char *)malloc(i);
	if (s->conn->master_key == NULL)
		{
		ssl_return_error(s);
		SSL_errno=SSL_ERRCODE(SSL_F_CLIENT_MASTER_KEY,
			SSL_R_OUT_OF_MEMORY);
		return(0);
		}
	MD5_rand(i,s->conn->master_key);

	/* make key_arg data */
	i=s->conn->key_arg_length=s->conn->cipher->key_arg_size;
	if (i == 0)
		s->conn->key_arg=NULL;
	else
		{
		s->conn->key_arg=(unsigned char *)malloc(i);
		if (s->conn->key_arg == NULL)
			{
			ssl_return_error(s);
			SSL_errno=SSL_ERRCODE(SSL_F_CLIENT_MASTER_KEY,
				SSL_R_OUT_OF_MEMORY);
			return(0);
			}
		MD5_rand(i,s->conn->key_arg);
		}

	i=s->conn->cipher->keybits;
	enc=s->conn->cipher->enc_bits;
#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"bits=%d enc=%d\n",i,enc);
#endif
	if (enc == 0) enc=i;
	if (i < enc)
		{
		ssl_return_error(s);
		SSL_errno=SSL_ERRCODE(SSL_F_CLIENT_MASTER_KEY,
			SSL_R_BAD_INTERNEL_ERROR);
		return(0);
		}
	clear=i-enc;
#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"enc=%d clear=%d\n",enc,clear);
#endif
	/* bytes or bits? */
	clear/=8; /* clear */
	s2n(clear,p);
	memcpy(d,s->conn->master_key,clear);
	d+=clear;

	enc/=8;
#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"clear=%d p=%08x d=%08x\n",clear,p,d);
#endif
	enc=s->conn->cert->public_encrypt(s->conn->cert,enc,
		&(s->conn->master_key[clear]),d);
	if (enc <= 0)
		{
		ssl_return_error(s);
		SSL_errno=SSL_ERRCODE(SSL_F_CLIENT_MASTER_KEY,
			SSL_R_PUBLIC_KEY_DECRYPT_ERROR);
#ifdef SSL_DEBUG
		SSL_TRACE(SSL_ERR,"public key encrypt error\n");
#endif
		return(0);
		}
#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"p=%08x d=%08x end=%d\n",p,d,enc);
#endif
	s2n(enc,p);
	d+=enc;
	karg=s->conn->cipher->key_arg_size;	
	s2n(karg,p); /* key arg size */
	memcpy(d,s->conn->key_arg,karg);
	d+=karg;

	n=d-buf;
	return(SSL_write(s,(char *)buf,n));
	}

static int client_finish(s)
SSL *s;
	{
	unsigned char buf[SSL_MAX_RECORD_LENGTH_3_BYTE_HEADER];
	unsigned char *p;

	p=buf;
	*(p++)=SSL_MT_CLIENT_FINISHED;
#ifdef SSL_DEBUG
SSL_TRACE(SSL_ERR,"conn_id_length=%d\n",s->conn_id_length);
#endif
	memcpy(p,s->conn_id,s->conn_id_length);
	return(SSL_write(s,(char *)buf,s->conn_id_length+1));
	}

/* read the data and then respond */
static int client_certificate(s, buf)
SSL *s;
unsigned char *buf;
	{
	unsigned char *p,*d;
	int n/*,type*/,i;
	MD5_CTX md5s;
	unsigned char  md[16];
	int cert_ch_len;
	unsigned char cert_ch[MAX_CERT_CHALLENGE_LENGTH];

	i=SSL_read(s,(char *)buf,MAX_CERT_CHALLENGE_LENGTH+1);
	if (i<MIN_CERT_CHALLENGE_LENGTH+1)
		{
		ssl_return_error(s);
		SSL_errno=SSL_ERRCODE(SSL_F_CLIENT_CERTIFICATE,
			SSL_R_SHORT_READ);
		return(0);
		}
	cert_ch_len=i-1;

/*	type=buf[0]; */
	/* type eq x509 */

	memcpy(cert_ch,&(buf[1]),cert_ch_len);

	if ((s->cert == NULL) || (s->cert->x509 == NULL))
		{
		p=buf;
		*(p++)=SSL_MT_ERROR;
		s2n(SSL_PE_NO_CERTIFICATE,p);
		SSL_write(s,(char *)buf,3);
		return(1);
		}

	/* ok, now we calculate the checksum
	 * do it first so we can resue buf :-) */
	MD5Init(&md5s);
	MD5Update(&md5s,s->conn->key_material,
		s->conn->key_material_length);
	MD5Update(&md5s,cert_ch,cert_ch_len);
	p=buf;
	n=i2D_X509(s->conn->cert->x509,&p);
	MD5Update(&md5s,buf,n);
	MD5Final(&(md[0]),&md5s);

	p=buf;
	d=p+6;
	*(p++)=SSL_MT_CLIENT_CERTIFICATE;
	*(p++)=SSL_CT_X509_CERTIFICATE;
	n=i2D_X509(s->cert->x509,&d);
	s2n(n,p);
#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"clen=%d\n",n);
#endif

	n=s->cert->private_encrypt(s->cert,MD5_DIGEST_LENGTH,md,d);
	s2n(n,p);
	d+=n;
#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"rlen=%d\n",n);
#endif

	n=d-buf;
	return(SSL_write(s,(char *)buf,n));
	}

static int get_server_verify(s)
SSL *s;
	{
	unsigned char buf[SSL_MAX_RECORD_LENGTH_3_BYTE_HEADER];
	unsigned char *p;
	int i;

	p=buf;
	i=SSL_read(s,(char *)buf,1);
	if (i < 1)
		{
		ssl_return_error(s);
		SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_VERIFY,
			SSL_R_SHORT_READ);
		return(0);
		}
	if (*p != SSL_MT_SERVER_VERIFY)
		{
		if (p[0] != SSL_MT_ERROR)
			{
			ssl_return_error(s);
			SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_VERIFY,
				SSL_R_READ_WRONG_PACKET_TYPE);
			}
		else
			SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_VERIFY,
				SSL_R_PEER_ERROR);
		return(0);
		}
#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"cl=%d\n",s->challenge_length);
#endif
	i=SSL_read(s,(char *)buf,s->challenge_length);
	if (i < s->challenge_length)
		{
		ssl_return_error(s);
		SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_VERIFY,
			SSL_R_SHORT_READ);
		return(0);
		}
	if (memcmp(buf,s->challenge,s->challenge_length) != 0)
		{
#ifdef ERR_DEBUG
		SSL_TRACE(SSL_ERR,"bad challenge\ngot: ");
		ssl_print_bytes(SSL_ERR,buf,s->challenge_length);
		SSL_TRACE(SSL_ERR,"\nwant:");
		ssl_print_bytes(SSL_ERR,s->challenge,s->challenge_length);
		SSL_TRACE(SSL_ERR,"\n");
#endif
		ssl_return_error(s);
		SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_VERIFY,
			SSL_R_CHALLENGE_IS_DIFFERENT);
		return(0);
		}
	return(1);
	}

static int get_server_finish(s)
SSL *s;
	{
	unsigned char buf[SSL_MAX_RECORD_LENGTH_3_BYTE_HEADER];
	unsigned char *p;
	int i;

	p=buf;
	for (;;)
		{
		i=SSL_read(s,(char *)buf,1);
		if (i < 1)
			{
			ssl_return_error(s);
			SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_FINISHED,
				SSL_R_SHORT_READ);
			return(0);
			}
		if (*p == SSL_MT_REQUEST_CERTIFICATE)
			{
#ifdef SSL_DEBUG
			SSL_TRACE(SSL_ERR,"SSL_MT_REQUEST_CERTIFICATE\n");
#endif
			if (client_certificate(s,buf) <= 0)
				return(0);
			}
		else if (*p != SSL_MT_SERVER_FINISHED)
			{
			if (p[0] != SSL_MT_ERROR)
				{
				ssl_return_error(s);
				SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_FINISHED,
					SSL_R_READ_WRONG_PACKET_TYPE);
				}
			else
				SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_FINISHED,
					SSL_R_PEER_ERROR);
			return(0);
			}
		else
			break;
		}

#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"sidl=%d\n",SESSION_ID_LENGTH);
#endif
	i=SSL_read(s,(char *)buf,SESSION_ID_LENGTH);
	if (i < SESSION_ID_LENGTH)
		{
		ssl_return_error(s);
		SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_FINISHED,
			SSL_R_SHORT_READ);
		return(0);
		}
	if (!s->hit) /* new connection */
		{
		if ((s->conn->session_id == NULL) ||
			(s->conn->session_id_length < SESSION_ID_LENGTH))
			{
			if (s->conn->session_id) free(s->conn->session_id);
			s->conn->session_id=(unsigned char *)malloc(SESSION_ID_LENGTH);
			if (s->conn->session_id == NULL)
				{
				ssl_return_error(s);
				SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_FINISHED,
					SSL_R_OUT_OF_MEMORY);
				return(0);
				}
			}
		s->conn->session_id_length=SESSION_ID_LENGTH;
		memcpy(s->conn->session_id,p,SESSION_ID_LENGTH);
		ssl_add_hash_conn(s->conn); /* should return error?
		*/
		}
	else
		{
		if (memcmp(buf,s->conn->session_id,s->conn->session_id_length) != 0)
			{
#ifdef ERR_DEBUG
			SSL_TRACE(SSL_ERR,"bad session_id\ngot:\t");
			ssl_print_bytes(SSL_ERR,SESSION_ID_LENGTH,buf);
			SSL_TRACE(SSL_ERR,"\nwant:");
			ssl_print_bytes(SSL_ERR,SESSION_ID_LENGTH,s->conn->session_id);
			SSL_TRACE(SSL_ERR,"\n");
#endif
			ssl_return_error(s);
			SSL_errno=SSL_ERRCODE(SSL_F_GET_SERVER_FINISHED,
				SSL_R_SESSION_ID_IS_DIFFERENT);
			return(0);
			}
		}
	return(1);
	}

static void choose_cipher(s)
SSL *s;
	{
	int j,i=0;
	char *p;

	for (;;)
		{
		p=SSL_get_pref_cipher(s,i++);
		if (p == NULL)
			{
			s->conn->cipher=s->conn->ciphers[0];
			return;
			}
		for (j=0; s->conn->ciphers[j] != NULL; j++)
			{
			if (strcmp(p,s->conn->ciphers[j]->name) == 0)
				{
				s->conn->cipher=s->conn->ciphers[j];
				return;
				}
			}
		}
	}
