#include "dat.h"
#include "fns.h"

/*
 * SSH insists on a baroque algorithm negotiation and most of
 * the resulting cruft ends up in the handling of key exchanges.
 * There is substantial room for improvement here.
 */

static	Hmac	*okhmac(Conn*, Blob*, Blob*);
static	Cipher	*okencrypt(Conn*, Blob*, Blob*);
static	int	okcompress(Conn*, Blob*, Blob*);
static	int	okkex(Conn*, char*, char*, Namelist*, Namelist*);

static	int	newkeys(Conn*, int);
static	void	genkey(Conn*, Blob*, Blob*, char, int, Blob**);

static	int	dokex(Conn*, Channel*, Channel*, Msg*);
static	int	kexnegotiate(Conn*, Channel*, Channel*);

static	Msg	*ssh_msg_newkeys(Conn*);
static	Msg	*ssh_msg_kexinit(Conn*, int);

void
kexinit(Conn *c, Channel *rc, Channel *wc)
{
	Msg *m;

	USED(rc);

	assert(c->kexblob[0] == nil);
	assert(c->kexblob[1] == nil);
	assert(c->wirestate >= WSraw);

	c->kexinit = 1;
	m = ssh_msg_kexinit(c, 0);	/* We don't try to guess for now */
	c->kexblob[0] = msgpayload(m);
	sendmsg(wc, m);			/* Queue message */

	return;
}

int
kexrun(Conn *c, Channel *rc, Channel *wc, Msg *m)
{
	Blob *b;
	int dbg;

	if(c->kexinit == 0)
		kexinit(c, rc, wc);
	if(dokex(c, rc, wc, m) < 0)
		goto Error;

	m = recvmsg(rc, SSH_MSG_NEWKEYS);
	freemsg(m);

	/* compute new write-side keys and start using them */
	if(newkeys(c, OREAD) != 0){
		/* XXX: too late; other side will be using new keys */
		debug(DbgCrypto|DbgAuth, "key exchange fails");
		sendp(c->keylock[0], nil);
		goto Error;
	}
	/* mark key exchange complete */
	c->kexinit = 0;
	if(c->wirestate < WSkex)
		c->wirestate = WSkex;
	sendp(c->keylock[0], nil);

	b = c->sessionid;
	dbg = DbgCrypto|DbgAuth;

	debug(dbg, "SSH_MSG_NEWKEYS");
	debug(dbg, "\tsession ID: %.*H", b->wp - b->bp, b->bp);
	debug(dbg, "\tkey exchange: %s", c->kex->name);
	debug(dbg, "\tclient cipher: %s", c->cipher[0]->name);
	debug(dbg, "\tserver cipher: %s", c->cipher[1]->name);
	debug(dbg, "\tclient hmac: %s", c->hmac[0]->name);
	debug(dbg, "\tserver hmac: %s", c->hmac[1]->name);

	return 0;

 Error:
	debug(DbgCrypto|DbgAuth, "key exchange fails: %r");
	disconnect(c, wc, SSH_DISCONNECT_KEY_EXCHANGE_FAILED);
	return -1;
}

int
kexcertcheck(Conn *c, Cert *cert)
{
	int i, rv;
	char *s, *id;
	Cert *saved;
	Namelist *nl;

	rv = -1;
	id = nil;
	nl = nil;
	saved = nil;

	s = cfgstr(c, "hostcertdb");
	assert(s != nil);
	nl = parsenamelist(s);
	for(i=0; i<nl->nstr; i++)
		if(certpublookup(nl->strtab[i], c, cert->name, &id, &saved)==0)
			break;
	if(i == nl->nstr){
		error(c, "Warning: couldn't find host certificate: %r");
		werrstr("unknown certificate");
		Bwrcertpub(&c->bstderr, cert, c->host);
		Bflush(&c->bstderr);
		if(strict == 0)
			rv = 0;
		goto Done;
	}
	free(id);
	if(cert->compare(cert, saved) == 0){
		rv = 0;
		debug(DbgAuth, "host certificate matches");
	}else{
		error(c, "WARNING: certificate does not match the saved one for '%s' in %s", c->host, nl->strtab[i]);
		werrstr("certificate mismatch");
	}

	saved->free(saved);

 Done:
	freenamelist(nl);
	return rv;
}

static int
dokex(Conn *c, Channel *rc, Channel *wc, Msg *m)
{
	int rv;
	Blob *kh;

	assert(m->type == SSH_MSG_KEXINIT);
	assert(c->kexblob[0] != nil);
	assert(c->kexblob[1] == nil);

	c->kexblob[1] = msgpayload(m);
	freemsg(m);

	/* run the key exchange protocol(s) */
	rv = kexnegotiate(c, rc, wc);

	/* clean up key exchange state unconditionally */
	c->kex->free(c->kexstate);		/* -> clikex, srvkex */
	c->kexstate = nil;

	clearblob(c->kexblob[0]);
	clearblob(c->kexblob[1]);
	c->kexblob[0] = nil;
	c->kexblob[1] = nil;

	if(rv < 0)
		goto Error;

	kh = c->kexhash;

	/* session identifier is the first exchange hash */
	if(c->sessionid == nil)
		c->sessionid = copyblob(kh->bp, kh->wp - kh->bp);

	/*
	 * Key exchange OK:
	 * 	make write side real; wait for NEWKEYS to make read side real
	 */
	m = ssh_msg_newkeys(c);
	sendmsg(wc, m);
	recvp(c->keylock[1]);		/* synchronize NEWKEYS */
	newkeys(c, OWRITE);

	return 0;

 Error:
	return -1;
}

/*
 * At this point we have both sent and received the SSH_MSG_KEXINIT packet.
 * We now run the negotiation algorithms to determine what cryptographic
 * protocols we have to run next.  Rather than determine our role in this
 * mess, we parse both client and server packets and use the connection to
 * look up algorithm attributes.  The result is rather clumsy.
 */
static int
kexnegotiate(Conn *c, Channel *rc, Channel *wc)
{
	char *s, *t;
	int optimistic;
	int kexguessed;
	int i, k, rv, found;

	Msg *m;
	Blob *cli, *srv;
	Blob *clikex, *srvkex;
	Namelist *ckexl, *skexl, *csrvl, *ssrvl;

	kexguessed = 0;

	/*
	 * squirrel away copies of the SSH_MSG_KEXINIT packets'
	 * payloads for later use by the kex exchange algorithm
	 */
	cli = c->kexblob[0];
	srv = c->kexblob[1];
	clikex = copyblob(cli->bp, cli->wp - cli->bp);
	srvkex = copyblob(srv->bp, srv->wp - srv->bp);

	getbyte(cli);					/* dump message type */
	getbyte(srv);

	getbytes(cli, 16);				/* cookie */
	getbytes(srv, 16);

	/* 
	 * figure out key exchange and host authentication algorithms.
	 * we can set corresponding entries in Conn immediately
	 */
	found = 0;
	ckexl = getnamelist(cli);			/* kex list */
	csrvl = getnamelist(cli);			/* host key list */
	skexl = getnamelist(srv);
	ssrvl = getnamelist(srv);

	debug(DbgPacket, "server okkex: %N", skexl);
	debug(DbgPacket, "server okhostcert: %N", ssrvl);

	/* If "guessed" (i.e. first) algorithm is the same, we must use it */
	if(ckexl->nstr < 1 || skexl->nstr < 1){
		werrstr("bad key exchange algorithm list");
		goto Kex;
	}

	s = ckexl->strtab[0];
	t = skexl->strtab[0];
	if(*s && *t && strcmp(s, t) == 0){
		kexguessed = 1;
		if(okkex(c, ckexl->strtab[0], skexl->strtab[0], csrvl, ssrvl))
			found = 1;
		else
			werrstr("kex algorithms guessed, but not ok (bug)");
		goto Kex;
	}

	/* otherwise, we have a prescribed algorithm for the negotiation */
	for(i=0; i<ckexl->nstr; i++){
		s = ckexl->strtab[i];
		for(k=0; k<skexl->nstr; k++){
			if(okkex(c, s, skexl->strtab[k], csrvl, ssrvl)){
				found = 1;
				goto Kex;
			}
		}
	}

	werrstr("can't agree on key exchange algorithm");
 Kex:
	freenamelist(ckexl);	ckexl = nil;
	freenamelist(csrvl);	csrvl = nil;
	freenamelist(skexl);	skexl = nil;
	freenamelist(ssrvl);	ssrvl = nil;
	if(found == 0)
		goto Error;

	/* We now agree on a key exchange algorithm */

	/* encryption: client --> server; then server --> client */
	if((c->kexcipher[1 - c->role] = okencrypt(c, cli, srv)) == nil)
		goto Error;
	if((c->kexcipher[c->role] = okencrypt(c, cli, srv)) == nil)
		goto Error;

	/* HMAC: client --> server; then server --> client */
	if((c->kexhmac[1 - c->role] = okhmac(c, cli, srv)) == nil)
		goto Error;
	if((c->kexhmac[c->role] = okhmac(c, cli, srv)) == nil)
		goto Error;
	
	/* XXX: compression: client --> server; then server --> client */
	if(!okcompress(c, cli, srv))
		goto Error;
	if(!okcompress(c, cli, srv))
		goto Error;

	/* language lists; we can ignore them (thank god) */
	getstring(cli);
	getstring(srv);

	/* Discard erroneous optimistic packet, if any */
	if(c->role == RClient)
		optimistic = getbool(srv);
	else
		optimistic = getbool(cli);
	if(optimistic && !kexguessed && (c->bugflags & BugFirstKex) == 0){
		m = recvmsg(rc, MSGIGN);
		freemsg(m);
	}

	/* clikex and srvkex responsibility of c->kexstate upon success */
	if((c->kexstate = c->kex->init(c, clikex, srvkex)) == nil)
		goto Error;
	clikex = nil;
	srvkex = nil;

	/* now run the key exchange */
	if(c->role == RClient)
		rv = c->kex->client(c, rc, wc);
	else
		rv = c->kex->server(c, rc, wc);
	if(rv < 0)
		goto Error;
	return 0;

 Error:
	freeblob(clikex);
	freeblob(srvkex);
	freenamelist(ckexl);
	freenamelist(skexl);
	freenamelist(csrvl);
	freenamelist(ssrvl);
	return -1;
}

static int
newkeys(Conn *c, int mode)
{
	int i, n;
	Blob *h, *k;
	Blob *iv[2], *key[2], *hmackey[2];

	assert(mode == OREAD || mode == OWRITE);
	assert(c->role == RClient || c->role == RServer);

	debug(DbgCrypto, "newkeys %s", mode==OREAD?"READ":"WRITE");

	k = c->kexkey;
	h = c->kexhash;

	i = (c->role == RClient);

	/*
	 * XXX: botch: only compute necessary keys
	 */
	genkey(c, h, k, 'A', c->kexcipher[i]->ivlen, &iv[i]);
	genkey(c, h, k, 'B', c->kexcipher[1-i]->ivlen, &iv[1-i]);
	genkey(c, h, k, 'C', c->kexcipher[i]->keylen, &key[i]);
	genkey(c, h, k, 'D', c->kexcipher[1-i]->keylen, &key[1-i]);
	genkey(c, h, k, 'E', c->kexhmac[i]->keylen, &hmackey[i]);
	genkey(c, h, k, 'F', c->kexhmac[1-i]->keylen, &hmackey[1-i]);

	n = (mode == OREAD) ? 0 : 1;
	if(c->cipherstate[n] != nil)
		c->cipher[n]->free(c->cipherstate[n]);
	c->cipher[n] = c->kexcipher[n];
	c->cipherstate[n] = c->cipher[n]->init(c, key[n], iv[n]);
	if(c->hmacstate[n] != nil)
		c->hmac[n]->free(c->hmacstate[n]);
	c->hmac[n] = c->kexhmac[n];
	c->hmacstate[n] = c->hmac[n]->init(c, hmackey[n]);
	free(c->digest[n]);
	c->digest[n] = emalloc(c->hmac[n]->hmaclen);

	for(i=0; i<2; i++){
		clearblob(iv[i]);
		clearblob(key[i]);
		clearblob(hmackey[i]);
	}
	
	return 0;
}

static void
genkey(Conn *c, Blob *h, Blob *k, char ch, int n, Blob **out)
{
	int m;
	uchar *buf;
	Blob *b, *sid;
	Digest *digest;
	DigestState *ds;

	digest = c->kex->digest;
	buf = emalloc(digest->digestlen);
	for(m=0; n > 0; m++)
		n -= digest->digestlen;
	b = mkblob((m+1)*digest->digestlen);		/* alloc header+body */

	/*
	 * K1 = HASH(K || H || ch || session_id)
	 * Kn = HASH(K || H || K1 || ... || Kn-1)
	 */

	sid = c->sessionid;
	ds = digest->digest(k->bp, k->wp - k->bp, nil, nil);
	digest->digest(h->bp, h->wp - h->bp, nil, ds);
	digest->digest(&ch, 1, nil, ds);
	digest->digest(sid->bp, sid->wp - sid->bp, buf, ds);
	putbytes(b, buf, digest->digestlen);

	while(--m > 0){
		ds = digest->digest(k->bp, k->wp - k->bp, nil, nil);
		digest->digest(h->bp, h->wp - h->bp, nil, ds);
		digest->digest(b->bp, b->wp - b->bp, buf, ds);
		putbytes(b, buf, digest->digestlen);
	}

	*out = b;
	free(buf);
	return;
}

/* 
 * Given client and server suggested key exchanged algorithms, determine
 * if the pair is an acceptable match per RFC's negotiation algorithm
 */

static int
okkex(Conn *c, char *ckex, char *skex, Namelist *csrvl, Namelist *ssrvl)
{
	int i, k;
	ulong mask;
	char *s;
	Kex *kp;
	CertImpl *cp;

	/* see if server and client agree on key exchange algorithm by name */
	if(strcmp(ckex, skex) != 0)
		return 0;

	/* make sure we know about the agreed upon algorithm and retrieve it */
	kp = (Kex*)findimpl(c->okkex, ckex);
	assert(kp != nil);

	/* do we need sign and/or encrypt capable host key algorithm? */
	mask = Sign|Encrypt;
	mask &= kp->flags;

	/*
	 * XXX: draft standard does not (currently) specify order
	 * for choosing host key algorithm.  We just make sure
	 * there is a valid choice and deal with it later in the
	 * individual key exchange modules
	 */
	for(i=0; i<ssrvl->nstr; i++){
		s = ssrvl->strtab[i];
		for(k=0; k<csrvl->nstr; k++){
			if(strcmp(s, csrvl->strtab[k]) == 0){
				cp = (CertImpl*)findimpl(c->okhostcert, s);
				assert(cp != nil);	/* we sent it */
				if((cp->flags & mask) == mask)
					goto Hostauth;
			}
		}
	}
	werrstr("negotiated key exchange, but not host authentication");
	return 0;

 Hostauth:
	assert(c->kex == nil);
	assert(c->kexstate == nil);

	/* c->hostcertimpl = cp;	*//* statically allocated */
	c->kex = kp;
	
	return 1;
}

static Cipher*
okencrypt(Conn *c, Blob *cb, Blob *sb)
{
	int i, k;
	char *nm;
	Cipher *p;
	Namelist *cli, *srv;

	cli = getnamelist(cb);
	srv = getnamelist(sb);
	debug(DbgPacket, "server okcipher: %N", srv);
	debug(DbgPacket, "client okcipher: %N", cli);
	for(i=0; i<cli->nstr; i++)
		for(k=0; k<srv->nstr; k++)
			if(strcmp(cli->strtab[i], srv->strtab[k]) == 0){
				nm = cli->strtab[i];
				goto Found;
			}
	werrstr("can't agree on encryption algorithm");
	freenamelist(cli);
	freenamelist(srv);
	return nil;

 Found:
	p = (Cipher*)findimpl(c->okcipher, nm);
	return p;
}

static Hmac*
okhmac(Conn *c, Blob *cb, Blob *sb)
{
	int i, k;
	char *nm;
	Hmac *p;
	Namelist *cli, *srv;

	cli = getnamelist(cb);
	srv = getnamelist(sb);
	debug(DbgPacket, "server okhmac: %N", srv);
	debug(DbgPacket, "client okhmac: %N", cli);
	for(i=0; i<cli->nstr; i++)
		for(k=0; k<srv->nstr; k++)
			if(strcmp(cli->strtab[i], srv->strtab[k]) == 0){
				nm = cli->strtab[i];
				goto Found;
			}
	werrstr("can't agree on HMAC algorithm");
	freenamelist(cli);
	freenamelist(srv);
	return nil;

 Found:
	p = (Hmac*)findimpl(c->okhmac, nm);
	return p;
}

static int
okcompress(Conn *c, Blob *cb, Blob *sb)
{
	int i, k;
	Namelist *cli, *srv;

	cli = getnamelist(cb);
	srv = getnamelist(sb);
	debug(DbgPacket, "server okcompress: %N", srv);
	debug(DbgPacket, "client okcompress: %N", cli);
	for(i=0; i<cli->nstr; i++)
		for(k=0; k<srv->nstr; k++)
			if(strcmp(cli->strtab[i], srv->strtab[k]) == 0)
				goto Found;
	werrstr("can't agree on compression algorithm");
	freenamelist(cli);
	freenamelist(srv);
	return 0;

 Found:
	/* XXX: compression */
	assert(strcmp(cli->strtab[i], "none") == 0);
	freenamelist(cli);
	freenamelist(srv);
	return 1;
}

static Msg*
ssh_msg_newkeys(Conn *c)
{
	Msg *m;

	m = allocmsg(c, SSH_MSG_NEWKEYS, 0);
	return m;
}

static void
setalgorithm(Blob *b, Conn *c, Impllist *il, char *key, char *nm)
{
	char *s;
	Namelist *nl;

	nl = impl2name(il);
	s = esmprint("%N", nl);
	debug(DbgPacket, "%s: %s", nm, s);
	putstring(b, s);
	freenamelist(nl);
	free(s);

	return;
}

static Msg*
ssh_msg_kexinit(Conn *c, int g)
{
	int i;
	char *s;
	uchar cookie[16];
	Msg *m;

	m = allocmsg(c, SSH_MSG_KEXINIT, 0);
	for(i=0; i<nelem(cookie); i++)
		cookie[i] = fastrand();
	putbytes(&m->b, cookie, sizeof(cookie));

	debug(DbgPacket, "sending kexinit: cookie = %.*H",
	      sizeof(cookie), cookie);

	setalgorithm(&m->b, c, c->okkex, "kex", "okkex");
	setalgorithm(&m->b, c, c->okhostcert, "hostcert", "okhostcert");

	/*
	 * Signature algorithm list doesn't go on the wire
	 * since the relevant part can, in theory, be determined
	 * from the okhostcert list, which is marshalled
	 */
	s = consimpllist(c->okhostsig);
	debug(DbgPacket, "okhostsig: %s", s);
	free(s);

	setalgorithm(&m->b, c, c->okcipher, "cipher", "okcipher [c->s]");
	setalgorithm(&m->b, c, c->okcipher, "cipher", "okcipher [s->c]");
	setalgorithm(&m->b, c, c->okhmac, "hmac", "okhmac [c->s]");
	setalgorithm(&m->b, c, c->okhmac, "hmac", "okhmac [s->c]");

	/* XXX: compression */
	putstring(&m->b, "none");		/* client -> server */
	putstring(&m->b, "none");		/* server -> client */

	/* language list (wtf?) */
	putstring(&m->b, "");			/* client -> server */
	putstring(&m->b, "");			/* client -> server */
	
	/* first key exchange packet follows? */
	putbool(&m->b, g);

	/* reserved uint32 */
	putlong(&m->b, 0);

	return m;
}
