#include "dat.h"
#include "fns.h"

static char *msgtab[] = {
[1]	"SSH_MSG_DISCONNECT",
	"SSH_MSG_IGNORE",
	"SSH_MSG_UNIMPLEMENTED",
	"SSH_MSG_DEBUG",
	"SSH_MSG_SERVICE_REQUEST",
	"SSH_MSG_SERVICE_ACCEPT",

[20]	"SSH_MSG_KEXINIT",
	"SSH_MSG_NEWKEYS",

[50]	"SSH_MSG_USERAUTH_REQUEST",
	"SSH_MSG_USERAUTH_FAILURE",
	"SSH_MSG_USERAUTH_SUCCESS",
	"SSH_MSG_USERAUTH_BANNER",

[80]	"SSH_MSG_GLOBAL_REQUEST",
	"SSH_MSG_REQUEST_SUCCESS",
	"SSH_MSG_REQUEST_FAILURE",

[90]	"SSH_MSG_CHANNEL_OPEN",
	"SSH_MSG_CHANNEL_OPEN_CONFIRMATION",
	"SSH_MSG_CHANNEL_OPEN_FAILURE",
	"SSH_MSG_CHANNEL_WINDOW_ADJUST",
	"SSH_MSG_CHANNEL_DATA",
	"SSH_MSG_CHANNEL_EXTENDED_DATA",
	"SSH_MSG_CHANNEL_EOF",
	"SSH_MSG_CHANNEL_CLOSE",
	"SSH_MSG_CHANNEL_REQUEST",
	"SSH_MSG_CHANNEL_SUCCESS",
	"SSH_MSG_CHANNEL_FAILURE",
};

static char *disconnecttab[] = {
	"",
	"host not allowed to connect",
	"protocol error",
	"key exchange failed",
	"<reserved>",
	"mac error",
	"compression error",
	"service not available",
	"protocol version not supported",
	"host key not verifiable",
	"connection lost",
	"application",
	"too many connections"
	"authentication cancelled by user",
	"no more authentication methods available",
	"illegal username",
	nil
};

char*
msgname(Msg *m)
{
	Conn *c;
	static char buf[128];

	if(m == nil)
		return "nil message";
	assert(m->c != nil);

	c = m->c;
	if(30 <= m->type && m->type <= 49 && c->msgnamekex != nil)
		return m->c->msgnamekex(m->type);
	else if(60 <= m->type && m->type <= 79 && c->msgnameauth != nil)
		return m->c->msgnameauth(m->type);
	if(0 <= m->type && m->type < nelem(msgtab))
		return msgtab[m->type];
	snprint(buf, sizeof buf, "#%d", m->type);
	return buf;
}

char*
disconnectmsg(int n)
{
	if(n >= 0 && n < nelem(disconnecttab))
		return disconnecttab[n];

	return "unknown reason";
}

void
disconnect(Conn *c, Channel *wc, int reason)
{
	Msg *m;

	if(c->wirestate == WShup)
		return;

	debug(DbgProto, "disconnect: 0x%lux %s",
	      getcallerpc(&c), disconnectmsg(reason));

	/* XXX: message queued; should preempt */
	c->wirestate = WShup;
	m = allocmsg(c, SSH_MSG_DISCONNECT, 0);
	putlong(&m->b, reason);
	putstring(&m->b, disconnectmsg(reason));
	putstring(&m->b, "");
	sendmsg(wc, m);
	recvp(c->dislock);

	return;
}

void
msgqueue(MsgQueue *q, Msg *m)
{
	if(q->head == nil){
		q->head = m;
		q->tail = &q->head;
	}else{
		*q->tail = m;
		q->tail = &m->link;
	}

	return;
}

Msg*
msgdequeue(MsgQueue *q)
{
	Msg *m;

	m = q->head;
	if(q->head != nil)
		q->head = q->head->link;

	return m;
}

/* XXX: arbitrary */
enum {
	MINPKTSZ	= 16*1024,
	MAXPKTSZ	= 256*1024
};

Msg*
allocmsg(Conn *c, int type, int len)
{
	Msg *m;
	uchar *p;

	if(len > MAXPKTSZ)
		panic("allocmsg: %d too big", len);
	else if (len < MINPKTSZ)
		len = MINPKTSZ;

	m = emalloc(sizeof(Msg));
	p = emalloc(len);
	m->c = c;
	m->type = type;
	m->b.bp = p;
	m->b.bp[5] = type;		/* so payload is complete for kex */
	m->b.wp = p+4+1+1;		/* b.wp after len, pad len, and type */
	m->b.ep = p+len;
	m->b.ebp = p+len;

	return m;
}

void
reallocmsg(Msg *m, int len)
{
	uchar *p;

	if(len > MAXPKTSZ)
		panic("allocmsg: %d too big", len);

	if(m->b.ebp - m->b.bp >= len){
		m->b.ep = m->b.bp+len;
		return;
	}

	p = erealloc(m->b.bp, len);
	m->b.wp = p + (m->b.wp - m->b.bp);
	m->b.rp = p + (m->b.rp - m->b.bp);
	m->b.ep = p + len;
	m->b.ebp = m->b.ep;
	m->b.bp = p;

	return;
}

void
freemsg(Msg *m)
{
	if(m == nil)
		return;
	free(m->b.bp);
	free(m);

	return;
}

void
badmsg(Msg *m, char *msg)
{
	char *s, buf[20+ERRMAX];

	if(m == nil){
		snprint(buf, sizeof buf, "premature eof");
		s = buf;
	}else
		s = msgname(m);
	freemsg(m);
	if(msg != nil)
		panic("unexpected message: %s; expected %s", s, msg);
	panic("unexpected message: %s", s);
}

Blob*
msgpayload(Msg *m)
{
	Blob *b;
	uchar *p;

	p = m->b.bp+4+1;
	b = copyblob(p, m->b.wp - p);

	return b;
}

void
sendmsg(Channel *wc, Msg *m)
{
	sendp(wc, m);
}

Msg*
recvmsg(Channel *rc, int type)
{
	Msg *m;

	m = recvp(rc);
	if(m == nil)
		panic("premature EOF: %r");
	if(type != MSGIGN && m->type != type)
		badmsg(m, nil);
	return m;
}

static Msg*
readmsg0(Conn *c)
{
	int pad, n;
	uchar seqno[4];
	ulong len, mlen, rempktlen, remtotlen;

	Msg *m;

	m = allocmsg(c, MSGIGN, 0);

	/* read and decrypt blksz or 8 bytes, whichever is larger */
	if(c->cipherstate[0] != nil && c->cipher[0]->blksz > 8)
		n = c->cipher[0]->blksz;
	else
		n = 8;

	if(ioreadn(c->io[0], c->fd[0], m->b.bp, n) != n){
		werrstr("short net read");
		return nil;
	}

	if(c->cipherstate[0] != nil)
		c->cipher[0]->decrypt(c->cipherstate[0], m->b.bp, n);

	/* the length field doesn't include itself */
	if((len = LONG(m->b.bp)) < 12)	/* length must be >=16; thus 12 here */
		panic(Edecode);
	rempktlen = len - (n - 4);	/* remaining packet length */
	remtotlen = rempktlen;
	if(c->hmacstate[0] != nil)	/* include HMAC, if any */
		remtotlen += c->hmac[0]->hmaclen;
	reallocmsg(m, remtotlen+n);	/* keep what we've read */
	if(ioreadn(c->io[0], c->fd[0], m->b.bp+n, remtotlen) != remtotlen){
		werrstr("short net read");
		freemsg(m);
		return nil;
	}
	/* update write pointer to end of payload+padding (ignoring HMAC) */
	m->b.wp = m->b.bp+n+rempktlen;

	if(c->cipherstate[0] != nil)	/* decrypt remainder, but not HMAC */
		c->cipher[0]->decrypt(c->cipherstate[0], m->b.bp+n, rempktlen);
	m->b.rp = m->b.bp + 4;		/* skip length field */
	if((pad = m->b.rp[0]) >= len)	/* payload length must be >= 0 */
		panic(Edecode);
	m->b.rp++;			/* skip pad length */
	m->b.wp -= pad;			/* ignore padding */
	m->b.ep = m->b.rp+len-pad-1;	/* end of payload */
	if(m->b.ep > m->b.ebp || m->b.ep < m->b.bp)
		panic(Edecode);
	if(m->b.wp > m->b.ep || m->b.wp < m->b.bp)
		panic(Edecode);
	if(c->hmacstate[0] != nil){	/* HMAC(key, seqno || entire packet) */
		/*
		 * Unencrypted HMAC starts at m->b.ep+pad
		 * m->b.ep points to the end of the payload per above
		 */
		if(m->b.ep+pad+c->hmac[0]->hmaclen > m->b.ebp)
			panic(Edecode);
		mlen = len+4;		/* HMAC includes initial length */
		PLONG(seqno, c->seqno[0]);
		c->hmac[0]->hmac(c->hmacstate[0], seqno, sizeof(seqno), nil);
		c->hmac[0]->hmac(c->hmacstate[0], m->b.bp, mlen, c->digest[0]);
		if(memcmp(c->digest[0], m->b.ep+pad, c->hmac[0]->hmaclen) != 0)
			panic(Ebadmac);
	}
	c->seqno[0]++;
	m->type = *m->b.rp++;

	return m;
}

Msg*
readmsg(Conn *c, int type)
{
	Msg *m;
	ulong n;
	char *s, *t;

	if(c->wirestate == WShup){
		werrstr(Ehungup);
		return nil;
	}

	while((m = readmsg0(c)) != nil){
		debug(DbgProto, "received %s len %d",
		      msgname(m), m->b.ep - m->b.rp + 1); /* +1: type */
		switch(m->type){
		default:
			goto Done;
			break;

		case SSH_MSG_DEBUG:
		case SSH_MSG_IGNORE:
		case SSH_MSG_DISCONNECT:
			if(m->type == SSH_MSG_DEBUG){
				getbool(&m->b);
				s = getstring(&m->b);
				filterstring(s);
				debug(DbgProto, "remote DEBUG: %s", s);
			}
			if(m->type == SSH_MSG_DISCONNECT){
				n = getlong(&m->b);
				s = getstring(&m->b);
				filterstring(s);
				t = disconnectmsg(n);
				error(c, "ssh disconnect: %s: %s", t, s);
				c->wirestate = WShup;
				freemsg(m);
				m = nil;
				sysfatal(Ehungup);
			}
			freemsg(m);
			break;
		}
	}

 Done:
	/* XXX: what to do with SSH_MSG_UNIMPLEMENTED */
	if(type != MSGIGN && (m == nil || m->type != type))
		badmsg(m, nil);
	return m;
}

int
writemsg(Msg *m)
{
	ulong len;
	int i, n, pad;
	uchar *p, seqno[4];

	Conn *c = m->c;

	if(c->cipherstate[1] != nil && c->cipher[1]->blksz > 8)
		n = c->cipher[1]->blksz;
	else
		n = 8;				/* XXX: multiple of blksz? */
	len = m->b.wp - m->b.bp;		/* b.bp: pkt, not payload */
	pad = n - len % n;			/* pad to max(8, blksz) */
	if(pad < 4)				/* with min. padding of 4 */
		pad += n;
	assert(pad < 256);			/* max. padding is 255 */
	len += pad;

	/* strip length and padding length fields from reported length */
	debug(DbgProto, "sending %s len %d", msgname(m), len-pad-4-1);

	p = m->b.bp;
	PLONG(p, len-4);
	p[4] = pad;
	p[5] = m->type;
	p += len;				/* skip length, pad length */
	reallocmsg(m, len);			/* make room for padding */
	if(c->cipherstate[1] != nil){
		for(i=0; i<pad; i++)
			*p++ = fastrand();
	}else{
		memset(p, 0, pad);
		p += pad;
	}

	if(c->hmacstate[1] != nil){
		PLONG(seqno, c->seqno[1]);
		c->hmac[1]->hmac(c->hmacstate[1], seqno, sizeof(seqno), nil);
		c->hmac[1]->hmac(c->hmacstate[1], m->b.bp, len, c->digest[1]);
	}

	if(c->cipherstate[1] != nil)
		c->cipher[1]->encrypt(c->cipherstate[1], m->b.bp, len);

	if(iowriten(c->io[1], c->fd[1], m->b.bp, len) != len){
		werrstr("short net write");
		freemsg(m);
		return -1;
	}

	if(c->hmacstate[1] != nil){
		n = c->hmac[1]->hmaclen;
		if(iowriten(c->io[1], c->fd[1], c->digest[1], n) != n){
			werrstr("short net write");
			freemsg(m);
			return -1;
		}
	}

	c->seqno[1]++;
	freemsg(m);

	return 0;
}

/* 
 * Binary packet parsing and writing; could use some consolidation
 * Why the SSH protocol "designers" felt the need for so many types
 * is a legitimate question.
 */

uchar
getbyte(Blob *b)
{
	if(b->rp >= b->ep)
		panic(Edecode);

	return *b->rp++;
}

ushort
getshort(Blob *b)
{
	ushort x;

	if(b->rp+2 > b->ep)
		panic(Edecode);
	x = SHORT(b->rp);
	b->rp += 2;

	return x;
}

int
getbool(Blob *b)
{
	int x;

	x = (getbyte(b) != 0);

	return x;
}

ulong
getlong(Blob *b)
{
	ulong x;

	if(b->rp+4 > b->ep)
		panic(Edecode);
	x = LONG(b->rp);
	b->rp += 4;

	return x;
}

void*
getbytes(Blob *b, int n)
{
	uchar *p;

	if(b->rp+n > b->ep)
		panic(Edecode);
	p = b->rp;
	b->rp += n;

	return p;
}

char*
getstring(Blob *b)
{
	char *p;
	ulong len;

	/* overwrites length to make room for terminating NUL */
	len = getlong(b);
	if(b->rp+len > b->ep)
		panic(Edecode);
	p = (char*)b->rp-1;
	memmove(p, b->rp, len);
	p[len] = '\0';
	b->rp += len;

	return p;
}

uchar*
getbstring(Blob *b, int *len)
{
	uchar *p;

	*len = getlong(b);
	if(b->rp + *len > b->ep)
		panic(Edecode);
	p = b->rp;
	b->rp += *len;

	return p;
}

Blob*
getblobraw(Blob *b, int len)
{
	Blob *p;

	p = copyblob(b->rp, len);
	b->rp += len;

	return p;
}

Blob*
getblobstring(Blob *b)
{
	int len;
	uchar *buf;
	Blob *p;

	buf = getbstring(b, &len);
	p = copyblob(buf, len);

	return p;
}

Namelist*
getnamelist(Blob *b)
{
	char *s;
	Namelist *p;

	s = getstring(b);
	p = parsenamelist(s);

	return p;
}

mpint*
getmpint(Blob *b)
{
	ulong n;
	int sign;
	uchar *s;
	mpint *p;

	n = getlong(b);
	s = getbytes(b, n);
	if(n > 0 && (s[0] & 0x80) != 0){	/* XXX: sign botch */
		sign = -1;
		s[0] &= 0x7F;
	}else
		sign = 1;
	p = betomp(s, n, nil);
	p->sign = sign;

	return p;
}

Cert*
getcert(Conn *c, Blob *b)
{
	Cert *p;
	CertImpl *ip;

	if((ip = certwiresniff(c, b)) == nil){
		werrstr("unknown or prohibited certificate type");
		return nil;
	}
	p = ip->pubdecode(b);

	return p;
}

Sig*
getsig(Conn *c, Blob *b)
{
	Sig *p;
	SigImpl *ip;

	if((ip = sigwiresniff(c, b)) == nil){
		werrstr("unknown or prohibited signature type");
		return nil;
	}
	p = ip->decode(b);

	return p;
}

void
putbyte(Blob *b, uchar x)
{
	if(b->wp >= b->ep)
		panic(Eencode);
	b->wp[0] = x;
	b->wp++;
}

void
putshort(Blob *b, ushort x)
{
	if(b->wp+2 > b->ep)
		panic(Eencode);
	PSHORT(b->wp, x);
	b->wp += 2;
}

void
putbool(Blob *b, int x)
{
	putbyte(b, x != 0);
}

void
putlong(Blob *b, ulong x)
{
	if(b->wp+4 > b->ep)
		panic(Eencode);
	PLONG(b->wp, x);
	b->wp += 4;
}

void
putbytes(Blob *b, void *v, int n)
{
	if(b->wp+n > b->ep)
		panic(Eencode);
	memmove(b->wp, v, n);
	b->wp += n;
}

void
putstring(Blob *b, char *s)
{
	int len;

	len = strlen(s);
	putlong(b, len);
	putbytes(b, s, len);
}

void
putbstring(Blob *b, uchar *s, int len)
{
	putlong(b, len);
	putbytes(b, s, len);
}

void
putblobraw(Blob *b, Blob *p)
{
	putbytes(b, p->bp, p->wp - p->bp);
}

void
putblobstring(Blob *b, Blob *p)
{
	putbstring(b, p->bp, p->wp - p->bp);
}

void
putnamelist(Blob *b, Namelist *p)
{
	int i, len;

	len = p->nstr-1;			/* for comma separation */
	for(i=0; i<p->nstr; i++)
		len += strlen(p->strtab[i]);
	putlong(b, len);
	for(i=0; i<p->nstr; i++){
		if(i != 0)
			putbyte(b, ',');
		putbytes(b, p->strtab[i], strlen(p->strtab[i]));
	}
}

void
putmpint(Blob *b, mpint *p)
{
	int n;
	uchar *bufp;

	if(mpcmp(mpzero, p) == 0){
		putlong(b, 0);
		return;
	}
	if((n = mptobe(p, nil, 0, &bufp)) < 0)
		panic(Eencode);
	/* XXX: sign botch */
	assert(p->sign >= 0);
	if(n > 0 && p->sign >= 0 && (bufp[0] & 0x80) != 0){
		putlong(b, n+1);
		putbyte(b, 0);
	} else
		putlong(b, n);
	putbytes(b, bufp, n);
	free(bufp);
}

void
putcert(Blob *b, Cert *p)
{
	if(p->pubencode(p, b) != 0)
		panic(Eencode);
	return;
}

void
putsig(Blob *b, Sig *p)
{
	if(p->encode(p, b) != 0)
		panic(Eencode);
	return;
}

/*
 * sizefoo(x) gives the number of bytes needed for the
 * wire representation of type foo.  It is allowed to
 * overestimate.
 */

int
sizestring(char *s)
{
	return strlen(s)+4;
}

int
sizeblobstring(Blob *b)
{
	return sizeblob(b)+4;
}

int
sizempint(mpint *p)
{
	int n;

#ifdef NOTDEF
	if(mpcmp(mpzero, p) == 0)
		return 4;
#endif

	n = (p->top+1)*Dbytes + 4 + 1;		/* 1 for sign slop */

	return n;
}

int
sizecert(Cert *p)
{
	return p->pubsize(p);
}

int
sizesig(Sig *p)
{
	return p->size(p);
}
