#include "dat.h"
#include "fns.h"

struct KexState {
	char	*cliver;
	char	*srvver;
	Blob	*clikex;
	Blob	*srvkex;

	Digest	*d;

	u32int	min;
	u32int	desired;
	u32int	max;

	mpint	*p;
	mpint	*q;		/* (p-1)/2 */
	mpint	*g;

	mpint	*K;
	mpint	*e;
	mpint	*f;
	mpint	*x;
	mpint 	*y;
};

enum {
	DHMin	= 1536,
	DHDef	= 1536,
	DHMax	= 8192
};

/*
 * Nota Bene: must fall back to old GEX for datafellows;
 * bug-for-bug compatibility and old GEX not implemented
 */

enum {
	SSH_MSG_KEX_DH_GEX_REQUEST_OLD		= 30,
	SSH_MSG_KEX_DH_GEX_GROUP,
	SSH_MSG_KEX_DH_GEX_INIT,
	SSH_MSG_KEX_DH_GEX_REPLY,
	SSH_MSG_KEX_DH_GEX_REQUEST
};

static char *msgtab[] = {
 	"SSH_MSG_KEX_DH_GEX_REQUEST_OLD",
	"SSH_MSG_KEX_DH_GEX_GROUP",
	"SSH_MSG_KEX_DH_GEX_INIT",
	"SSH_MSG_KEX_DH_GEX_REPLY",
	"SSH_MSG_KEX_DH_GEX_REQUEST"
};

static char*
msgnamedhgex(int t)
{
	if(t >= 30 && t-30 < nelem(msgtab))
		return msgtab[t-30];
	return "<unknown>";
}

static KexState*
kexdhgexinit(Conn *c, Blob *cli, Blob *srv)
{
	KexState *s;

	s = emalloc(sizeof *s);
	s->clikex = cli;
	s->srvkex = srv;
	s->cliver = estrdup(c->cversion);
	s->srvver = estrdup(c->sversion);
	s->d = &digestsha1;
	c->msgnamekex = msgnamedhgex;

	return s;
}

static void
kexdhgexfree(KexState *s)
{
	if(s == nil)
		return;
	free(s->cliver);
	free(s->srvver);
	clearblob(s->clikex);
	clearblob(s->srvkex);

	/* digest: nothing to do */

	mpclear(s->p);
	mpclear(s->q);
	mpclear(s->g);

	mpclear(s->K);
	mpclear(s->e);
	mpclear(s->f);
	mpclear(s->x);
	mpclear(s->y);

	free(s);
}

static	uchar	*kexdhgexhash(Conn*, KexState*, Blob*);

static int
kexdhgexclient(Conn *c, Channel *rc, Channel *wc)
{
	int n;
	uchar *H;
	u32int min, max, def;

	Msg *m;
	Sig *sig;
	Cert *cert;
	KexState *s;
	Blob *certblob, *sigblob, *tmpblob;

	certblob = sigblob = tmpblob = nil;

	s = c->kexstate;
	s->K = mpnew(0);
	s->e = mpnew(0);
	s->f = mpnew(0);
	s->q = mpnew(0);
	s->x = nil;
	s->y = nil;

	/* send server group size parameters */
	min = cfgint(c, "dhgpmin");
	max = cfgint(c, "dhgpmax");
	def = cfgint(c, "dhgpdef");
	if(min == 0)
		min = DHMin;
	if(max == 0)
		max = DHMax;
	if(def == 0)
		def = DHDef;
	s->min = min;
	s->max = max;
	s->desired = def;

	if(c->bugflags & BugOldGpExchange){
		m = allocmsg(c, SSH_MSG_KEX_DH_GEX_REQUEST_OLD, 0);
		putlong(&m->b, def);
	}else{
		m = allocmsg(c, SSH_MSG_KEX_DH_GEX_REQUEST, 0);
		putlong(&m->b, min);
		putlong(&m->b, def);
		putlong(&m->b, max);
	}
	sendmsg(wc, m);

	/* read group and sub-group generator */
	m = recvmsg(rc, SSH_MSG_KEX_DH_GEX_GROUP);
	s->p = getmpint(&m->b);
	s->g = getmpint(&m->b);
	freemsg(m);

	if((n = mpsignif(s->p)) < min)
		panic("kexdhgexclient: too few bits in GF(p): %d", n);

	/* q <- (p-1)/2 */
	mpsub(s->p, mpone, s->q);
	mpright(s->q, 1, s->q);

	/* x <-R- (1, q) */
	do{
		s->x = mpnrand(s->q, s->x);	/* random x in [1, q) */
	}while(mpcmp(s->x, mpone) == 0);	/* toss x == 1 */

	mpexp(s->g, s->x, s->p, s->e);		/* e <- g^x (p) */

	m = allocmsg(c, SSH_MSG_KEX_DH_GEX_INIT, 0);
	putmpint(&m->b, s->e);
	sendmsg(wc, m);

	/*
	 * string: K_S, server host key and certificates
	 * mpint: f
	 * string: signature of H
	 */
	m = recvmsg(rc, SSH_MSG_KEX_DH_GEX_REPLY);
	if((certblob = getblobstring(&m->b)) == nil){
		werrstr("kexdhgex: no host certificate: %r");
		goto Error;
	}

	/* copy as getcert() is destructive and we still need the original */
	tmpblob = copyblob(certblob->bp, certblob->wp - certblob->bp);
	if((cert = getcert(c, tmpblob)) == nil){
		werrstr("kexdhgex: bad host certificate: %r");
		goto Error;
	}
	freeblob(tmpblob);
	tmpblob = nil;

	/* look up (certificate, host) pair in local database */
	if(kexcertcheck(c, cert) != 0)
		goto Error;

	s->f = getmpint(&m->b);
	if(mpcmp(mpone, s->f) > 0 || mpcmp(s->p, s->f) <= 0){
		werrstr("kexdhgex: bogus server challenge");
		goto Error;
	}
	if((sigblob = getblobstring(&m->b)) == nil){
		werrstr("kexdhgex: no host signature on session id: %r");
		goto Error;
	}
	if((sig = getsig(c, sigblob)) == nil){
		werrstr("kexdhgex: bad host signature on session id: %r");
		goto Error;
	}

	mpexp(s->f, s->x, s->p, s->K);
	H = kexdhgexhash(c, s, certblob);

	/* We verified certificate itself above; now verify signature */
	if(cert->verify(cert, sig, H, s->d->digestlen) != 0){
		werrstr("kexdhgex: invalid server signature on session id");
		goto Error;
	}
	memset(H, 0, s->d->digestlen);
	free(H);

	/* save certificate and signature for future reference */
	if(c->hostcert != nil)
		c->hostcert->free(c->hostcert);
	c->hostcert = cert;
	/* we already said it was ok in getcert, so we better find it */
	c->hostcertimpl = (CertImpl*)findimpl(c->okhostcert, cert->name);
	assert(c->hostcertimpl != nil);

	if(c->hostsig != nil)
		c->hostsig->free(c->hostsig);
	c->hostsig = sig;
	c->hostsigimpl = (SigImpl*)findimpl(c->okhostsig, sig->name);
	/* we already said it was ok in getsig, so we better find it */
	assert(c->hostsigimpl != nil);

	clearblob(certblob);
	clearblob(sigblob);
	return 0;

 Error:
	clearblob(certblob);
	clearblob(sigblob);
	clearblob(tmpblob);
	return -1;
}

static int
kexdhgexserver(Conn *c, Channel *rc, Channel *wc)
{
	int n;
	uchar *H;
	u32int min, max, def;

	Msg *m;
	Sig *sig;
	KexState *s;
	DHGroup *dhg;
	Blob *tmpblob;

	s = c->kexstate;
	s->K = mpnew(0);
	s->f = mpnew(0);

	if(c->bugflags & BugOldGpExchange){
		m = recvmsg(rc, SSH_MSG_KEX_DH_GEX_REQUEST_OLD);
		def = getlong(&m->b);
		min = max = def;
	}else{
		m = recvmsg(rc, SSH_MSG_KEX_DH_GEX_REQUEST_OLD);
		min = getlong(&m->b);
		def = getlong(&m->b);
		max = getlong(&m->b);
	}
	freemsg(m);

	/* XXX: find appropriate group */
	dhg = nil;

	m = allocmsg(c, SSH_MSG_KEX_DH_GEX_GROUP, 0);
	putmpint(&m->b, s->p);
	putmpint(&m->b, s->g);
	sendmsg(wc, m);

	/* q <- (p-1)/2 */
	mpsub(s->p, mpone, s->q);
	mpright(s->q, 1, s->q);

	/* y <-R- (0, q) */
	s->y = mpnrand(s->q, nil);
	mpexp(s->g, s->y, s->p, s->f);

	m = recvmsg(rc, SSH_MSG_KEX_DH_GEX_INIT);
	s->e = getmpint(&m->b);
	freemsg(m);
	if(mpcmp(mpone, s->e) > 0 || mpcmp(s->p, s->e) <= 0){
		werrstr("kexdhgex: bogus client challenge");
		goto Error;
	}

	mpexp(s->e, s->y, s->p, s->K);
	n = sizecert(c->hostcert);
	tmpblob = mkblob(n);
	putcert(tmpblob, c->hostcert);
	H = kexdhgexhash(c, s, tmpblob);
	sig = c->hostcert->sign(c->hostcert, H, s->d->digestlen);
	memset(H, 0, s->d->digestlen);
	free(H);

	/* Why make it easy when we can make it hard? */
	m = allocmsg(c, SSH_MSG_KEX_DH_GEX_REPLY, 0);

	putblobstring(&m->b, tmpblob);
	clearblob(tmpblob);
	putmpint(&m->b, s->f);

	n = sizesig(sig);
	tmpblob = mkblob(n);
	putsig(tmpblob, sig);
	putblobstring(&m->b, tmpblob);
	clearblob(tmpblob);
	sig->free(sig);

	sendmsg(wc, m);

	return 0;

 Error:
	return -1;
}


static uchar*
kexdhgexhash(Conn *c, KexState *s, Blob *cert)
{
	ulong n;
	uchar *H;
	Blob *b;

	/*
	 * b = V_C || V_S || I_C || I_S || K_S || min || n ||
	 *     max || p || g || e || f || K
	 */
	n = sizestring(s->cliver) + sizestring(s->srvver);
	n += 4+sizeblob(s->clikex);
	n += 4+sizeblob(s->srvkex);
	n += sizeblobstring(cert);
	n += 4+4+4;
	n += sizempint(s->p) + sizempint(s->g);
	n += sizempint(s->e) + sizempint(s->f) + sizempint(s->K);
	b = mkblob(n);

	putstring(b, s->cliver);
	putstring(b, s->srvver);
	putlong(b, sizeblob(s->clikex));
	putblobraw(b, s->clikex);
	putlong(b, sizeblob(s->srvkex));
	putblobraw(b, s->srvkex);
	putblobstring(b, cert);
	if(c->bugflags & BugOldGpExchange)
		putlong(b, s->desired);
	else{
		putlong(b, s->min);
		putlong(b, s->desired);
		putlong(b, s->max);
	}
	putmpint(b, s->p);
	putmpint(b, s->g);
	putmpint(b, s->e);
	putmpint(b, s->f);
	putmpint(b, s->K);

	/* hash b and sign it with the host key */
	H = emalloc(s->d->digestlen);
	s->d->digest(b->bp, sizeblob(b), H, nil);
	clearblob(b);

	debug(DbgCrypto, "kexdhgex: H = %.*H", s->d->digestlen, H);

	n = sizempint(s->K);
	c->kexkey = mkblob(n);
	putmpint(c->kexkey, s->K);
	c->kexhash = copyblob(H, s->d->digestlen);

	return H;
}

Kex kexdhgex = {
	"diffie-hellman-group-exchange-sha1",
	Sign,
	&digestsha1,
	kexdhgexinit,
	kexdhgexfree,
	kexdhgexclient,
	kexdhgexserver
};
