/* ssl_srvr.c */
/* Copyright (C) 1995 Eric Young (eay@mincom.oz.au).
 * All rights reserved.
 * Copyright remains Eric Young's, and as such any Copyright notices in
 * the code are not to be removed.
 * See the COPYRIGHT file in the SSLeay distribution for more details.
 */

#include <stdio.h>
#include "md5.h"
#include "ssl_locl.h"

#ifdef PROTO
static int get_client_master_key(SSL *s);
static int get_client_hello(SSL *s);
static int server_hello(SSL *s); 
static int get_client_finished(SSL *s);
static int server_verify(SSL *s);
static int server_finish(SSL *s);
static int request_certificate(SSL *s, unsigned char *buf);
#else
static int get_client_master_key();
static int get_client_hello();
static int server_hello(); 
static int get_client_finished();
static int server_verify();
static int server_finish();
static int request_certificate();
#endif

int SSL_accept(s)
SSL *s;
	{
	unsigned long l=time(NULL);

	MD5_rand_seed(sizeof(l),(unsigned char *)&l);

	ERR_clear_error();
	/* init things to blank */
	ssl_clear(s);

	if (get_client_hello(s) <= 0) return(0);
#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"got Client Hello\n");
	SSL_TRACE(SSL_ERR,"Version: %d\n",s->version);
	SSL_TRACE(SSL_ERR,"Challenge:");
	ssl_print_bytes(SSL_ERR,s->challenge_length,s->challenge);
	SSL_TRACE(SSL_ERR,"\n");
#endif
	
	if (server_hello(s) <= 0) return(0);
#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"sent Master Hello:\n");
	SSL_TRACE(SSL_ERR,"Connection_Id:");
	ssl_print_bytes(SSL_ERR,CONNECTION_ID_LENGTH,s->conn_id);
	SSL_TRACE(SSL_ERR,"\n");
#endif

	if (!s->hit)
		{
		if (get_client_master_key(s) <= 0) return(0);
#ifdef SSL_DEBUG
		SSL_TRACE(SSL_ERR,"sent Master Hello:\n");
		SSL_TRACE(SSL_ERR,"got Client Master Key\n");
		SSL_TRACE(SSL_ERR,"Master Key:");
		ssl_print_bytes(SSL_ERR,s->conn->master_key_length,
			s->conn->master_key);
		SSL_TRACE(SSL_ERR,"\n");
		SSL_TRACE(SSL_ERR,"Master Key Arg:");
		ssl_print_bytes(SSL_ERR,s->conn->key_arg_length,
			s->conn->key_arg);
		SSL_TRACE(SSL_ERR,"\n");
#endif
		}

	/* Ok we how have sent all the stuff needed to start encrypting, the
	 * next packet back will be encrypted, so lets go. */
	s->conn->cipher->crypt_init(s,0);
	s->clear_text=0;

	if (server_verify(s) <= 0) return(0);
#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"Server verify\nChallenge:");
	ssl_print_bytes(SSL_ERR,s->challenge_length,s->challenge);
	SSL_TRACE(SSL_ERR,"\n");
#endif

	if (get_client_finished(s) <= 0) return(0);
#ifdef SSL_DEBUG
   SSL_TRACE(SSL_ERR,"Client Finish\nConnection-id:");
   ssl_print_bytes(SSL_ERR,s->conn_id_length,s->conn_id);
   SSL_TRACE(SSL_ERR,"\n");
#endif

	if (server_finish(s) <= 0) return(0);
#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"Server finish\nNew-Session_id:");
	ssl_print_bytes(SSL_ERR,s->conn->session_id_length,
		s->conn->session_id);
	SSL_TRACE(SSL_ERR,"\n");
	SSL_TRACE(SSL_ERR,"SSL PROTOCOL FINISHED\n");
#endif
	return(1);
	}

static int get_client_master_key(s)
SSL *s;
	{
   unsigned char buf[SSL_MAX_RECORD_LENGTH_3_BYTE_HEADER];
	int i,n,clear,enc,keya;
	unsigned char *p;
	CIPHER **cipher;

	p=buf;
	i=SSL_read(s,(char *)buf,10);
	if (i < 10)
		{
		if ((i > 0) && (buf[0] == SSL_MT_ERROR))
			{
			SSLerr(SSL_F_GET_CLIENT_MASTER_KEY,SSL_R_PEER_ERROR);
			}
		else
			{
			ssl_return_error(s);
			SSLerr(SSL_F_GET_CLIENT_MASTER_KEY,SSL_R_SHORT_READ);
			}
		return(0);
		}
	
	if (*(p++) != SSL_MT_CLIENT_MASTER_KEY)
		{
		if (p[-1] != SSL_MT_ERROR)
			{
			ssl_return_error(s);
			SSLerr(SSL_F_GET_CLIENT_MASTER_KEY,
				SSL_R_READ_WRONG_PACKET_TYPE);
			}
		else
			SSLerr(SSL_F_GET_CLIENT_MASTER_KEY,SSL_R_PEER_ERROR);
		return(0);
		}
	cipher= &(s->conn->ciphers[0]);
	while (*cipher)
		{
		if ((*cipher)->num == *p) break;
		cipher++;
		}
	if (*cipher == NULL)
		{
		ssl_return_error(s);
		SSLerr(SSL_F_GET_CLIENT_MASTER_KEY,SSL_R_NO_CIPHER_MATCH);
		return(0);
		}
	s->conn->cipher= *cipher;
	p+=3;
	n2s(p,clear);
	n2s(p,enc);
	n2s(p,keya);
	n=clear+enc+keya;
#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"clear=%d\n",clear);
	SSL_TRACE(SSL_ERR,"enc  =%d\n",enc);
	SSL_TRACE(SSL_ERR,"keya =%d\n",keya);
#endif
	i=SSL_read(s,(char *)buf,(unsigned int)n);
	if (i != n)
		{
		ssl_return_error(s);
		SSLerr(SSL_F_GET_CLIENT_MASTER_KEY,SSL_R_SHORT_READ);
		return(0);
		}

	p=buf;
	/* do key_arg before we unpack the crypted key. */
	s->conn->key_arg_length=keya;
	if (keya > 0)
		{
		s->conn->key_arg=(unsigned char *)malloc((unsigned int)keya);
		if (s->conn->key_arg == NULL)
			{
			ssl_return_error(s);
			SSLerr(SSL_F_GET_CLIENT_MASTER_KEY,
				ERR_R_MALLOC_FAILURE);
			return(0);
			}
		}

	memcpy(s->conn->key_arg,&(buf[clear+enc]),(unsigned int)keya);

	if (s->conn->cert->privatekey == NULL)
		{
		ssl_return_error(s);
		SSLerr(SSL_F_GET_CLIENT_MASTER_KEY,SSL_R_NO_PRIVATEKEY);
		return(0);
		}
	enc=s->cert->private_decrypt(s->cert,enc,&(buf[clear]),&(buf[clear]));
	if (enc <= 0)
		{
		ssl_return_error(s);
		return(0);
		}
	s->conn->master_key_length=enc+clear;
	s->conn->master_key=(unsigned char *)malloc((unsigned int)enc+clear);
	memcpy(s->conn->master_key,buf,(unsigned int)enc+clear);
	return(1);
	}

static int get_client_hello(s)
SSL *s;
	{
   unsigned char buf[SSL_MAX_RECORD_LENGTH_3_BYTE_HEADER];
	int i,j,n;
	int cipher_spec_length;
	int challenge_length;
	int session_id_length;
	unsigned char *p;
	CIPHER **cipher;

	p=buf;
	i=SSL_read(s,(char *)buf,9);
	if (i < 9)
		{
		if ((i > 0) && (buf[0] == SSL_MT_ERROR))
			SSLerr(SSL_F_GET_CLIENT_HELLO,SSL_R_PEER_ERROR);
		else
			{
			ssl_return_error(s);
			SSLerr(SSL_F_GET_CLIENT_HELLO,SSL_R_SHORT_READ);
			}
		return(0);
		}
	
	if (*(p++) != SSL_MT_CLIENT_HELLO)
		{
		if (p[-1] != SSL_MT_ERROR)
			{
			ssl_return_error(s);
			SSLerr(SSL_F_GET_CLIENT_HELLO,
				SSL_R_READ_WRONG_PACKET_TYPE);
			}
		else
			SSLerr(SSL_F_GET_CLIENT_HELLO,SSL_R_PEER_ERROR);
		return(0);
		}
	n2s(p,i);
	if (i < s->version) s->version=i;
	n2s(p,cipher_spec_length);
	n2s(p,session_id_length);
	n2s(p,challenge_length);
	n=cipher_spec_length+challenge_length+session_id_length;
	i=SSL_read(s,(char *)buf,(unsigned int)n);
	if (i != n)
		{
		ssl_return_error(s);
		SSLerr(SSL_F_GET_CLIENT_HELLO,SSL_R_SHORT_READ);
		return(0);
		}

	p=buf;
	/* get session-id before cipher stuff so we can get out session
	 * structure if it is cached */
	/* session-id */
	if ((session_id_length != 0) && 
		(session_id_length != SESSION_ID_LENGTH))
		{
		ssl_return_error(s);
		SSLerr(SSL_F_GET_CLIENT_HELLO,SSL_R_BAD_SESSION_ID_LENGTH);
		return(0);
		}

#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"session_id_length=%d\n",session_id_length);
#endif
	if (session_id_length == 0)
		{
		if (!ssl_new_conn(s,1))
			{
			ssl_return_error(s);
			return(0);
			}
		}
	else
		{
#ifdef SSL_DEBUG
		SSL_TRACE(SSL_ERR,"session_id:");
		ssl_print_bytes(SSL_ERR,session_id_length,
			&(p[cipher_spec_length]));
		SSL_TRACE(SSL_ERR,"\n");
#endif
		i=ssl_get_prev_conn(s,session_id_length,
			&(p[cipher_spec_length]));
		if (i == 1)
			{
			/* previous conn */
			s->hit=1;
#ifdef SSL_DEBUG
			SSL_TRACE(SSL_ERR,"GOT PREVIOUS CONNECTION\n");

#endif
			}
		else if (i == -1)
			{
			ssl_return_error(s);
			return(0);
			}
		else
			{
			if (!ssl_new_conn(s,1))
				{
				ssl_return_error(s);
				return(0);
				}
			}
		}

#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"cipher spec length=%d\n",cipher_spec_length);
#endif
	cipher= &(s->conn->ciphers[0]);
	for (i=0; i<cipher_spec_length; i+=3)
		{
		n=p[i];
		j=0;
		while (ssl_ciphers[j].valid != -1)
			{
			if (ssl_ciphers[j].valid)
				{
				if (ssl_ciphers[j].num == n)
					{
					*(cipher++)= &(ssl_ciphers[j]);
					break;
					}
				}
#ifdef SSL_DEBUG
			SSL_TRACE(SSL_ERR,"Cipher:%02x %02x %02x %s\n",
				p[i],p[i+1],p[i+2],ssl_ciphers[j].name);
#endif
			j++;
			}
		}
	*cipher=NULL;
	p+=cipher_spec_length;
	/* done cipher selection */

	/* session id extracted already */
	p+=session_id_length;

	/* challenge */
	s->challenge_length=challenge_length;
	s->challenge=(unsigned char *)malloc((unsigned int)challenge_length);
	if (s->challenge == NULL)
		{
		ssl_return_error(s);
		SSLerr(SSL_F_GET_CLIENT_HELLO,ERR_R_MALLOC_FAILURE);
		return(0);
		}
	memcpy(s->challenge,p,(unsigned int)challenge_length);
	return(1);
	}

static int server_hello(s)
SSL *s;
	{
   unsigned char buf[SSL_MAX_RECORD_LENGTH_3_BYTE_HEADER];
   unsigned char *p,*d;
   int n,hit;
	CIPHER **cipher;

	p=buf;
	d=p+11;
	*(p++)=SSL_MT_SERVER_HELLO;		/* type */
	hit=s->hit;
	*(p++)=(unsigned char)hit;
	if (!hit)
		{			/* else add cert to session */
		if (s->conn->cert != NULL)
			{
			ssl_cert_free(s->conn->cert);
#ifdef SSL_DEBUG
			SSL_TRACE(SSL_ERR,"bad - defined certificate - perhaps a timeout\n");
#endif
			}
		s->conn->cert=s->cert;		
		s->cert->references++;
		}

	if (s->conn->cert == NULL)
		{
		ssl_return_error(s);
		SSLerr(SSL_F_SERVER_HELLO,SSL_R_NO_CERTIFICATE_SPECIFIED);
		return(0);
		}
	if (hit)
		{
#ifdef SSL_DEBUG
		SSL_TRACE(SSL_ERR,"EXISTING SESSION BEING USED\n");
#endif
		*(p++)=0;			/* no certificate type */
		s2n(s->version,p);		/* version */
		s2n(0,p);			/* cert len */
		s2n(0,p);			/* ciphers len */
		}
	else
		{
		*(p++)=s->cert->cert_type;	/* put certificate type */
		s2n(s->version,p);		/* version */
		n=i2D_X509(s->cert->x509,NULL);
		s2n(n,p);			/* certificate length */
		i2D_X509(s->cert->x509,&d);
		n=0;
		cipher= &(s->conn->ciphers[0]);
		while (*cipher)			/* add cipher data */
			{
#ifdef SSL_DEBUG
			SSL_TRACE(SSL_ERR,"Cipher: %s\n",(*cipher)->name);
#endif
			*(d++)=(*cipher)->num;
			*(d++)=(*cipher)->noidea;
			*(d++)=(*cipher)->keybits;
			cipher++;
			n+=3;
			}
		s2n(n,p);			/* add cipher length */
		}

	/* make and send conn_id */
	s2n(CONNECTION_ID_LENGTH,p);	/* add conn_id length */
	s->conn_id=(unsigned char *)malloc(CONNECTION_ID_LENGTH);
	if (s->conn_id == NULL)
		{
		ssl_return_error(s);
		SSLerr(SSL_F_SERVER_HELLO,ERR_R_MALLOC_FAILURE);
		return(-1);
		}
	s->conn_id_length=CONNECTION_ID_LENGTH;
	MD5_rand((int)s->conn_id_length,s->conn_id);
	memcpy(d,s->conn_id,CONNECTION_ID_LENGTH);
	d+=CONNECTION_ID_LENGTH;

	n=d-buf;
	return(SSL_write(s,(char *)buf,(unsigned int)n));
	}

static int get_client_finished(s)
SSL *s;
	{
	unsigned char buf[SSL_MAX_RECORD_LENGTH_3_BYTE_HEADER];
	unsigned char *p;
	int i;

	p=buf;
	i=SSL_read(s,(char *)buf,1);
	if (i < 1)
		{
		ssl_return_error(s);
		SSLerr(SSL_F_GET_CLIENT_FINISHED,SSL_R_SHORT_READ);
		return(0);
		}
	if (*p != SSL_MT_CLIENT_FINISHED)
		{
		if (*p != SSL_MT_ERROR)
			{
			ssl_return_error(s);
			SSLerr(SSL_F_GET_CLIENT_FINISHED,
				SSL_R_READ_WRONG_PACKET_TYPE);
			}
		else
			SSLerr(SSL_F_GET_CLIENT_FINISHED,SSL_R_PEER_ERROR);
		return(0);
		}
#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"conn_id_length=%d\n",s->conn_id_length);
#endif

	i=SSL_read(s,(char *)buf,(unsigned int)s->conn_id_length);
	if (i < s->conn_id_length)
		{
		ssl_return_error(s);
		SSLerr(SSL_F_GET_CLIENT_FINISHED,SSL_R_SHORT_READ);
		return(0);
		}
	if (memcmp(buf,s->conn_id,(unsigned int)s->conn_id_length) != 0)
		{
#ifdef ERR_DEBUG
		SSL_TRACE(SSL_ERR,"bad message id\n");
		SSL_TRACE(SSL_ERR,"got :");
		ssl_print_bytes(SSL_ERR,buf,s->conn_id_length);
		SSL_TRACE(SSL_ERR,"\nwant:");
		ssl_print_bytes(SSL_ERR,s->conn_id,s->conn_id_length);
		SSL_TRACE(SSL_ERR,"\n");
#endif
		ssl_return_error(s);
		SSLerr(SSL_F_GET_CLIENT_FINISHED,
			SSL_R_CONNECTION_ID_IS_DIFFERENT);
		return(0);
		}
	return(1);
	}

static int server_verify(s)
SSL *s;
	{
	unsigned char buf[SSL_MAX_RECORD_LENGTH_3_BYTE_HEADER];
	unsigned char *p;

	p=buf;
	*(p++)=SSL_MT_SERVER_VERIFY;
	memcpy(p,s->challenge,(unsigned int)s->challenge_length);
	p+=s->challenge_length;
	return(SSL_write(s,(char *)buf,(unsigned int)s->challenge_length+1));
	}

static int server_finish(s)
SSL *s;
	{
	unsigned char buf[SSL_MAX_RECORD_LENGTH_3_BYTE_HEADER];
	unsigned char *p;

	if (s->verify_mode & SSL_VERIFY_PEER)
		if (request_certificate(s,buf) <= 0)
			return(0);
	p=buf;
	*(p++)=SSL_MT_SERVER_FINISHED;

	ssl_add_hash_conn(s->conn);
	memcpy(p,s->conn->session_id,(unsigned int)s->conn->session_id_length);
	p+=s->conn->session_id_length;
	return(SSL_write(s,(char *)buf,
		(unsigned int)s->conn->session_id_length+1));
	}

/* send the request and check the response */
static int request_certificate(s, buf)
SSL *s;
unsigned char *buf;
	{
	unsigned char *p,*p2,*buf2;
	unsigned char ccd[MAX_CERT_CHALLENGE_LENGTH];
	int i,ctype,clen,rlen;
	X509 *x509;

#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"REQUEST_CERTIFICATE\n");
#endif
	MD5_rand(MIN_CERT_CHALLENGE_LENGTH,ccd);
	p=buf;
	*(p++)=SSL_MT_REQUEST_CERTIFICATE;
	*(p++)=SSL_AT_MD5_WITH_RSA_ENCRYPTION;
	memcpy(p,ccd,MIN_CERT_CHALLENGE_LENGTH);
	SSL_write(s,(char *)buf,MIN_CERT_CHALLENGE_LENGTH+2);

	i=SSL_read(s,(char *)buf,6);
	if (i <= 0)
		{
		ssl_return_error(s);
		SSLerr(SSL_F_REQUEST_CERTIFICATE,SSL_R_SHORT_READ);
		return(0);
		}
	p=buf;
	if ((*p == SSL_MT_ERROR) && (i >= 3))
		{
		n2s(p,i);
		if (s->verify_mode & SSL_VERIFY_FAIL_IF_NO_PEER_CERT)
			{
			SSLerr(SSL_F_REQUEST_CERTIFICATE,
				SSL_R_PEER_DID_NOT_RETURN_A_CERT);
			return(0);
			}
		return(1);
		}
	if ((*(p++) != SSL_MT_CLIENT_CERTIFICATE) || (i < 6)) /* very bad */
		{
		ssl_return_error(s);
		SSLerr(SSL_F_REQUEST_CERTIFICATE,SSL_R_SHORT_READ);
		return(0);
		}
	/* ok we have a response */
	ctype= *(p++); /* certificate type, there is only one right now. */
	n2s(p,clen);
	n2s(p,rlen);
#ifdef SSL_DEBUG
	SSL_TRACE(SSL_ERR,"clen=%d\nrlen=%d\n",clen,rlen);
#endif
	i=SSL_read(s,(char *)buf,(unsigned int)(clen+rlen));
	if (i < (clen+rlen))
		{
		ssl_return_error(s);
		SSLerr(SSL_F_REQUEST_CERTIFICATE,SSL_R_SHORT_READ);
		return(0);
		}

	p=buf;
	x509=(X509 *)X509_new_D2i_X509(clen,p);
	if (x509 == NULL)
		{
		SSLerr(SSL_F_REQUEST_CERTIFICATE,ERR_R_X509_LIB);
		return(0);
		}
	p+=clen;
	i=X509_verify(x509,s->verify_callback);

	if (i)	/* we like the packet, now check the chksum */
		{
		MD5_CTX md5s;
		unsigned char  md[16];
		CERT c;

		MD5Init(&md5s);
		MD5Update(&md5s,s->conn->key_material,
			(unsigned int)s->conn->key_material_length);
		MD5Update(&md5s,ccd,MIN_CERT_CHALLENGE_LENGTH);

		i=i2D_X509(s->cert->x509,NULL);
		buf2=(unsigned char *)malloc((unsigned int)i+10);
		if (buf2 == NULL)
			{
			SSLerr(SSL_F_REQUEST_CERTIFICATE,ERR_R_MALLOC_FAILURE);
			return(0);
			}
		p2=buf2;
		i=i2D_X509(s->cert->x509,&p2);
		MD5Update(&md5s,buf2,(unsigned int)i);
		free(buf2);
		MD5Final(&(md[0]),&md5s);

		c.publickey=ssl_rsa_extract_public_key(x509);
		if (c.publickey == NULL) return(0);
		i=ssl_rsa_public_decrypt(&c,rlen,p,p);

		if (i == 16) 
			{
			if (memcmp(p,md,16) == 0)
				{
				if ((s->peer_status == SSL_PEER_IN_SSL) &&
					(s->peer != NULL))
					X509_free(x509);
				s->peer=x509;
				s->peer_status=SSL_PEER_IN_SSL;
				return(1);
				}
			i=0;
			SSLerr(SSL_F_REQUEST_CERTIFICATE,SSL_R_BAD_CHECKSUM);
			}
		else
			SSLerr(SSL_F_REQUEST_CERTIFICATE,
				SSL_R_BAD_CHECKSUM_DECODE);
		}
	X509_free(x509);
	return(i);
	}
