#include "dat.h"
#include "fns.h"

struct KexState {
	char	*cliver;
	char	*srvver;
	Blob	*clikex;
	Blob	*srvkex;

	Digest	*d;
	mpint	*p;
	mpint	*q;
	mpint	*g;

	mpint	*K;
	mpint	*e;
	mpint	*f;
	mpint	*y;
	mpint	*x;
};

/*
 * Prime modulus (ISAKMP/Oakley) generated by Richard Schroeppel at U. Arizona
 * p = 2^1024 - 2^960 - 1 + 2^64*floor(2^894*Pi+129093)
 */
static char p[] =	"FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1"
			"29024E088A67CC74020BBEA63B139B22514A08798E3404DD"
			"EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245"
			"E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED"
			"EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381"
			"FFFFFFFFFFFFFFFF";

static	char	*msgnamekexdh(int);
static	uchar	*kexdhhash(Conn*, KexState*, Blob*);

static KexState*
kexdhinit(Conn *c, Blob *cli, Blob *srv)
{
	KexState *s;

	s = emalloc(sizeof *s);
	s->clikex = cli;
	s->srvkex = srv;
	s->cliver = estrdup(c->cversion);
	s->srvver = estrdup(c->sversion);
	s->d = &digestsha1;
	s->p = strtomp(p, nil, 16, nil);
	if((s->q = mpnew(0)) == nil)
		panic(Ememory);
	s->g = mpcopy(mptwo);			/* group generated by 2 */
	mpsub(s->p, mpone, s->q);		/* q = (p-1)/2 */
	mpright(s->q, 1, s->q);
	c->msgnamekex = msgnamekexdh;

	return s;
}

static void
kexdhfree(KexState *s)
{
	if(s == nil)
		return;
	clearblob(s->clikex);
	clearblob(s->srvkex);
	free(s->cliver);
	free(s->srvver);

	mpclear(s->p);
	mpclear(s->q);
	mpclear(s->g);

	mpclear(s->K);
	mpclear(s->e);
	mpclear(s->f);
	mpclear(s->y);
	mpclear(s->x);

	free(s);
}

static int
kexdhserver(Conn *c, Channel *rc, Channel *wc)
{
	int n;
	uchar *H;

	Msg *m;
	Sig *sig;
	Blob *tmp;
	KexState *s;

	s = c->kexstate;
	s->K = mpnew(0);
	s->f = mpnew(0);
	s->y = mpnrand(s->q, nil);		/* Random y in (0, q) */
	mpexp(s->g, s->y, s->p, s->f);

	m = recvmsg(rc, SSH_MSG_KEXDH_INIT);
	s->e = getmpint(&m->b);
	freemsg(m);
	if(mpcmp(mpone, s->e) > 0 || mpcmp(s->p, s->e) <= 0){
		werrstr("kexdh: bogus client challenge");
		goto Error;
	}

	mpexp(s->e, s->y, s->p, s->K);
	n = sizecert(c->hostcert);
	tmp = mkblob(n);
	putcert(tmp, c->hostcert);
	H = kexdhhash(c, s, tmp);
	sig = c->hostcert->sign(c->hostcert, H, s->d->digestlen);
	memset(H, 0, s->d->digestlen);
	free(H);

	/* Why make it easy when we can make it hard? */
	m = allocmsg(c, SSH_MSG_KEXDH_REPLY, 0);
	putblobstring(&m->b, tmp);
	freeblob(tmp);
	putmpint(&m->b, s->f);
	n = sizesig(sig);
	tmp = mkblob(n);
	putsig(tmp, sig);
	putblobstring(&m->b, tmp);
	freeblob(tmp);
	sig->free(sig);
	sendmsg(wc, m);

	return 0;

 Error:
	return -1;
}

static int
kexdhclient(Conn *c, Channel *rc, Channel *wc)
{
	uchar *H;

	Msg *m;
	Sig *sig;
	Cert *cert;
	KexState *s;
	Blob *certblob, *sigblob, *tmpblob;

	certblob = sigblob = tmpblob = nil;

	s = c->kexstate;
	s->K = mpnew(0);
	s->e = mpnew(0);
	s->f = mpnew(0);
	s->x = nil;
	do{
		s->x = mpnrand(s->q, s->x);	/* random x in [1, q) */
	}while(mpcmp(s->x, mpone) == 0);	/* we want x in (1, q) */

	mpexp(s->g, s->x, s->p, s->e);

	m = allocmsg(c, SSH_MSG_KEXDH_INIT, 0);
	putmpint(&m->b, s->e);
	sendmsg(wc, m);

	m = recvmsg(rc, SSH_MSG_KEXDH_REPLY);

	/* Why make things easy when we can make them hard? */
	if((certblob = getblobstring(&m->b)) == nil){
		werrstr("kexdh: no host certificate: %r");
		goto Error;
	}

	/* copy as getcert() is destructive and we still need the original */
	tmpblob = copyblob(certblob->bp, certblob->wp - certblob->bp);
	if((cert = getcert(c, tmpblob)) == nil){
		werrstr("kexdh: bad host certificate: %r");
		goto Error;
	}
	freeblob(tmpblob);
	tmpblob = nil;

	/* look up (certificate, host) pair in local database */
	if(kexcertcheck(c, cert) != 0)
		goto Error;

	s->f = getmpint(&m->b);
	if(mpcmp(mpone, s->f) > 0 || mpcmp(s->p, s->f) <= 0){
		werrstr("kexdh: bogus server challenge");
		goto Error;
	}
	if((sigblob = getblobstring(&m->b)) == nil){
		werrstr("kexdh: no host signature on session id: %r");
		goto Error;
	}
	if((sig = getsig(c, sigblob)) == nil){
		werrstr("kexdh: bad host signature on session id: %r");
		goto Error;
	}

	mpexp(s->f, s->x, s->p, s->K);
	H = kexdhhash(c, s, certblob);

	/* We verified certificate itself above; now verify signature */
	if(cert->verify(cert, sig, H, s->d->digestlen) != 0){
		werrstr("kexdh: invalid server signature on session id");
		goto Error;
	}
	memset(H, 0, s->d->digestlen);
	free(H);

	/* save certificate and signature for future reference */
	if(c->hostcert != nil)
		c->hostcert->free(c->hostcert);
	c->hostcert = cert;
	/* we already said it was ok in getcert, so we better find it */
	c->hostcertimpl = (CertImpl*)findimpl(c->okhostcert, cert->name);
	assert(c->hostcertimpl != nil);

	if(c->hostsig != nil)
		c->hostsig->free(c->hostsig);
	c->hostsig = sig;
	c->hostsigimpl = (SigImpl*)findimpl(c->okhostsig, sig->name);
	/* we already said it was ok in getsig, so we better find it */
	assert(c->hostsigimpl != nil);

	clearblob(certblob);
	clearblob(sigblob);
	return 0;

 Error:
	clearblob(certblob);
	clearblob(sigblob);
	clearblob(tmpblob);
	return -1;
}

static uchar*
kexdhhash(Conn *c, KexState *s, Blob *cert)
{
	ulong n;
	uchar *H;
	Blob *b;

	/* b = V_C || V_S || I_C || I_S || K_S || e || f || K */
	n = sizestring(s->cliver) + sizestring(s->srvver);
	n += 4+sizeblob(s->clikex);
	n += 4+sizeblob(s->srvkex);
	n += sizeblobstring(cert);
	n += sizempint(s->e) + sizempint(s->f) + sizempint(s->K);
	b = mkblob(n);

	putstring(b, s->cliver);
	putstring(b, s->srvver);
	putlong(b, sizeblob(s->clikex));
	putblobraw(b, s->clikex);
	putlong(b, sizeblob(s->srvkex));
	putblobraw(b, s->srvkex);
	putblobstring(b, cert);
	putmpint(b, s->e);
	putmpint(b, s->f);
	putmpint(b, s->K);

	/* hash b and sign it with the host key */
	H = emalloc(s->d->digestlen);
	s->d->digest(b->bp, sizeblob(b), H, nil);
	clearblob(b);

	debug(DbgCrypto, "kexdh: H = %.*H", s->d->digestlen, H);

	n = sizempint(s->K);
	c->kexkey = mkblob(n);
	putmpint(c->kexkey, s->K);
	c->kexhash = copyblob(H, s->d->digestlen);

	return H;
}

static char*
msgnamekexdh(int t)
{
	switch(t){
	case SSH_MSG_KEXDH_INIT:
		return "SSH_MSG_KEXDH_INIT";

	case SSH_MSG_KEXDH_REPLY:
		return "SSH_MSG_KEXDH_REPLY";
	}

	return "<unknown>";
}

Kex kexdh = {
	"diffie-hellman-group1-sha1",
	Sign,
	&digestsha1,
	kexdhinit,
	kexdhfree,
	kexdhclient,
	kexdhserver
};
