#include "dat.h"
#include "fns.h"

typedef struct Mux Mux;

struct Mux {
	Conn	*c;
	Channel	*chan[2];
};

enum {
	Alrm		= 0,
	NetR,
	NetW,
	AppR,
	AppW
};

static	void		netread(void*);
static	void		netwrite(void*);
static	void		xportthread(void*);
static	Mux		*mkmux(Conn*, Channel*, Channel*);

void
xportinit(Conn *c, Channel *rc, Channel *wc)
{
	Mux *mx;

	mx = mkmux(c, rc, wc);
	threadcreate(xportthread, mx, STKSZ);
	return;
}

static void
xportthread(void *arg)
{
	int fwd;
	int n, rv;
	char *msg;
	Msg *m;
	Mux *mx;
	Conn *c;
	
	Alt a[] = {
		/* c	v	op	*/
	[Alrm]	{ nil,	&msg,	CHANRCV },	/* alarm */
	[NetR]	{ nil,	&m,	CHANRCV },	/* wire read */
	[NetW]	{ nil,	&m,	CHANNOP },	/* wire write */
	[AppR]	{ nil,  &m,	CHANNOP },	/* reads from upper layers */
	[AppW]	{ nil,  &m,	CHANNOP },	/* writes from upper layers */
		{ nil,	nil,	CHANEND }
	};

	threadsetname("xportthread");

	/* enable multiplexer */
	mx = (Mux*)arg;
	c = mx->c;
	a[Alrm].c = chancreate(sizeof msg, 0);
	a[NetR].c = chancreate(sizeof(void*), 0);
	a[NetW].c = chancreate(sizeof(void*), 0);
	a[AppR].c = mx->chan[0];
	a[AppW].c = mx->chan[1];
	free(mx);

	/* negotiate SSH v2 and start transport layer */
	if(sendversion(c) < 0)
		sysfatal("protocol negotiation fails: %r");
	if(recvversion(c) < 0)
		sysfatal("protocol negotiation fails: %r");

	mx = mkmux(c, a[NetR].c, nil);
	threadcreate(netread, mx, STKSZ);
	mx = mkmux(c, a[NetW].c, nil);
	threadcreate(netwrite, mx, STKSZ);

	alarminit(a[Alrm].c);
	kexinit(c, a[NetR].c, a[NetW].c);	/* start first key exchange */

	for(;;){
	//	if(dbglevel != 0)
	//		poolcheck(mainmem);
		n = alt(a);
		switch(n){
		default:
			panic("bad alt");
			break;

		case Alrm:	/* alarm */
			if(strcmp(msg, "auth") == 0){
				if(c->wirestate < WSauth)
					disconnect(c, a[NetW].c, SSH_DISCONNECT_AUTH_CANCELLED_BY_USER);
			}else if(strcmp(msg, "nop") == 0){
				/* XXX: bug */
				sshnop(c, a[NetW].c, nrand(256));
			}else if(strcmp(msg, "debug") == 0)
				sshdebug(c, a[NetW].c, msg);
			else if(strcmp(msg, "kex") == 0){
				if(c->bugflags & BugNoReKey)
					break;
				error(c, "key re-exchange request issued");
				break;

				panic(Enotimpl);

				/*
				 * XXX: need to queue messages until
				 * KEXINIT reply is heard
				 */

				/*
				 * key re-exchange not permitted
				 * until after authentication
				 */
				if(c->wirestate >= WSauth){
					a[AppW].op = CHANNOP;
					kexinit(c, a[NetR].c, a[NetW].c);
				}
			}else
				panic("bad alarm: [%s]", msg);
			break;

		case NetR: /* netread */
			if(m == nil)
				sysfatal("premature EOF: %r");
			fwd = 1;
			if(m->type == SSH_MSG_KEXINIT){
				fwd = 0;
				rv = kexrun(c, a[NetR].c, a[NetW].c, m);
				assert(rv >= 0 || c->wirestate == WShup);
				if(rv < 0)
					goto Dead;
			}
			if(c->wirestate >= WSkex && a[NetW].op == CHANNOP)
				a[AppW].op = CHANRCV;
			if(fwd){
				a[AppR].v = a[NetR].v;
				a[AppR].op = CHANSND;
				a[NetR].op = CHANNOP;
			}
			break;

		case NetW:	/* netwrite */
			a[NetW].op = CHANNOP;
			a[AppW].op = CHANRCV;
			break;

		case AppW:	/* higher layer write */
			a[NetW].v = &m;
			a[NetW].op = CHANSND;
			a[AppW].op = CHANNOP;
			break;

		case AppR:
			a[AppR].op = CHANNOP;
			a[NetR].op = CHANRCV;
			break;
		}
	}

 Dead:
	error(c, "transport layer dies");
	threadexitsall(nil);

	return;
}

static void
netread(void *arg)
{
	int t;
	Msg *m;
	Mux *mx;
	Conn *c;
	Channel *chan;

	mx = (Mux*)arg;
	c = mx->c;
	chan = mx->chan[0];
	threadsetname("netread");

	for(;;){
		m = readmsg(c, MSGIGN);
		assert(m == nil || m->c == c);
		if(c->wirestate == WShup){
			freemsg(m);
			goto Dead;
		}
		if(m == nil){
			debug(DbgPacket, "readmsg: %r");
			sendp(chan, nil);
			continue;
		}
		t = m->type;
		sendp(chan, m);
		if(t == SSH_MSG_NEWKEYS)
			recvp(c->keylock[0]);
	}

 Dead:
	threadexits(nil);
	panic(Ebotch);
	return;
}

static void
netwrite(void *arg)
{
	int t;
	Msg *m;
	Mux *mx;
	Conn *c;
	Channel *chan;

	threadsetname("netwrite");

	mx = (Mux*)arg;
	c = mx->c;
	chan = mx->chan[0];

	for(;;){
		m = recvp(chan);
		t = m->type;
		assert(m->c == c);
		writemsg(m);
		switch(t){
		case SSH_MSG_NEWKEYS:
			sendp(c->keylock[1], nil);
			break;

		case SSH_MSG_DISCONNECT:
			sendp(c->dislock, nil);
			break;
		}
		
	}

	panic(Ebotch);
	return;
}

static Mux*
mkmux(Conn *c, Channel *rc, Channel *wc)
{
	Mux *mx;

	mx = emalloc(sizeof *mx);
	mx->c = c;
	mx->chan[0] = rc;
	mx->chan[1] = wc;

	return mx;
}
