#include "dat.h"
#include "fns.h"

/*
 * XXX: disconnect
 */

typedef struct Mux Mux;
typedef struct Entry Entry;

enum {
	ACtl		= 0,			/* control messages */
	ANetR,					/* network read */
	ANetW,					/* network write */
	AMux,					/* I/O thread to mux */
	ABase,

	Calloc		= 1,			/* allocated, but not open */
	Copen,					/* channel states */
	Cclosing,
	Cclosed
};

struct Entry {
	char		*type;			/* channel type */
	int		state;
	Chan		chan;
	Msg		*m;			/* packet from network */
	MsgQueue	q;			/* queued from network */
};

struct Mux {
	Conn		*c;
	Channel		*disw;

	u32int		ndesc;
	Msg		*r;
	Msg		*w;
	Msg		*m;
	char 		*ctl;

	Alt		*a;
 	Entry		*tab;
};

static	void	chanthread(void*);
static	void	cglobal(Mux*, Msg*);
static	void	cchanop(Mux*, Msg*);
static	void	cdispatch(Mux*, Msg*);
static	void	cclose(Mux*, int);
static	u32int	cchanalloc(Mux*);

/* XXX: currently assumes we are a client */
void
connectthread(void *arg)
{
	int n;
	u32int i;

	Mux mx;
	Connection *cn;

	threadsetname("connectthread");

	cn = (Connection*)arg;
	mx.c = cn->c;
	mx.ndesc = 0;
	mx.tab = emalloc(mx.ndesc*sizeof *mx.tab);
	mx.a = emalloc((4+mx.ndesc+1)*sizeof mx.a[0]);
	mx.a[ACtl].c = cn->ctl;
	mx.a[ACtl].v = &mx.ctl;
	mx.a[ACtl].op = CHANRCV;
	mx.a[ANetR].c = cn->netr;
	mx.a[ANetR].v = &mx.r;
	mx.a[ANetR].op = CHANRCV;
	mx.a[ANetW].c = cn->netw;
	mx.a[ANetW].v = &mx.w;
	mx.a[ANetW].op = CHANNOP;
	mx.a[AMux].c = chancreate(sizeof(Msg*), 0);
	mx.a[AMux].v = &mx.m;
	mx.a[AMux].op = CHANRCV;
	mx.a[ABase+mx.ndesc].op = CHANEND;

	free(cn);

	for(;;){
		n = alt(mx.a);
		if(n > ABase+mx.ndesc)
			panic("bad alt");
		switch(n){
		case ACtl:
			fprint(2, "WARNING: ignored ctl message: %s\n",mx.ctl);
			free(mx.ctl);
			break;

		case ANetR:
			if(mx.r->type < SSH_MSG_CHANNEL_OPEN ||
			   mx.r->type > SSH_MSG_CHANNEL_FAILURE)
				goto Error;

			switch(mx.r->type){
			case SSH_MSG_GLOBAL_REQUEST:
			case SSH_MSG_REQUEST_SUCCESS:
			case SSH_MSG_REQUEST_FAILURE:
				cglobal(&mx, mx.r);
				break;

			case SSH_MSG_CHANNEL_OPEN:
			case SSH_MSG_CHANNEL_OPEN_CONFIRMATION:
			case SSH_MSG_CHANNEL_OPEN_FAILURE:
				cchanop(&mx, mx.r);
				break;

			case SSH_MSG_CHANNEL_CLOSE:
			case SSH_MSG_CHANNEL_WINDOW_ADJUST:
			case SSH_MSG_CHANNEL_DATA:
			case SSH_MSG_CHANNEL_EXTENDED_DATA:
			case SSH_MSG_CHANNEL_EOF:
			case SSH_MSG_CHANNEL_REQUEST:
			case SSH_MSG_CHANNEL_SUCCESS:
			case SSH_MSG_CHANNEL_FAILURE:
				cdispatch(&mx, mx.r);
				break;
			}
			break;

		case ANetW:
			mx.a[ANetW].op = CHANNOP;
			mx.a[AMux].op = CHANRCV;
			break;

		case AMux:
			i = mx.m->chan;
			assert(i <= mx.ndesc);
			cclose(&mx, i);
			mx.w = mx.m;
			mx.a[AMux].op = CHANNOP;
			mx.a[ANetW].op = CHANSND;
			break;

		default:
			assert(n >= ABase);
			i = n - ABase;
			assert(i <= mx.ndesc);
			mx.tab[i].m = msgdequeue(&mx.tab[i].q);
			if(mx.tab[i].m != nil){
				cclose(&mx, i);
				mx.a[ABase+i].op = CHANSND;
			}
			break;
		}
	}
 Error:
	disconnect(mx.c, mx.disw, SSH_DISCONNECT_PROTOCOL_ERROR);
	/* XXX: fix me */

	panic(Ebotch);
}

static void
chanthread(void *arg)
{
	Alt a[3];
	Msg *r, *w;
	Ioproc *io[2];
	Channel **argv;

	argv = (Channel**)arg;
	memset(a, 0, sizeof a);
	a[0].c = argv[0];
	a[0].v = &r;
	a[0].op = CHANRCV;
	a[1].c = argv[1];
	a[1].v = &w;
	a[1].op = CHANNOP;

	USED(io);

	switch(alt(a)){
	case 0:
	case 1:
		break;

	default:
		panic("bad alt");
	}

	panic(Ebotch);
}

static void
cchanop(Mux *mx, Msg *m)
{
	u32int i;

	i = getlong(&m->b);
	m->chan = i;
	if(i >= mx->ndesc)
		goto Error;

	switch(m->type){
	case SSH_MSG_CHANNEL_OPEN:
		if(mx->c->role == RClient)
			goto Error;
		panic(Enotimpl);
		break;

	case SSH_MSG_CHANNEL_OPEN_FAILURE:
		mx->tab[i].state = Cclosed;
		freemsg(m);
		break;

	case SSH_MSG_CHANNEL_OPEN_CONFIRMATION:
		mx->tab[i].state = Copen;
		mx->tab[i].chan.id = i;
		mx->tab[i].chan.remid = getlong(&m->b);
		mx->tab[i].chan.winsize = getlong(&m->b);
		mx->tab[i].chan.winused = 0;
		mx->tab[i].chan.maxpacket = getlong(&m->b);
		break;
	}

	return;

 Error:
	freemsg(m);
	disconnect(mx->c, mx->disw, SSH_DISCONNECT_PROTOCOL_ERROR);
	return;
}

static void
cdispatch(Mux *mx, Msg *m)
{
	u32int i;

	i = getlong(&m->b);
	m->chan = i;
	if(i >= mx->ndesc)
		goto Error;
	msgqueue(&mx->tab[i].q, m);
	if(mx->a[ABase+i].op == CHANNOP){
		cclose(mx, i);
		mx->tab[i].m = msgdequeue(&mx->tab[i].q);
		mx->a[ABase+i].op = CHANSND;
	}

	return;

 Error:
	disconnect(mx->c, mx->disw, SSH_DISCONNECT_PROTOCOL_ERROR);
	return;
}

static void
cglobal(Mux *mx, Msg *m)
{
	/* XXX: currently only TCP/IP forwarding, which we don't do */
	panic("global request: %p %p %s", mx, m, msgname(m));
	return;
}

static void
cclose(Mux *mx, int i)
{
	switch(mx->tab[i].state){
	case Copen:
		mx->tab[i].state = Cclosing;
		break;

	case Cclosing:
		mx->tab[i].state = Cclosed;
		break;

	case Cclosed:
		/* XXX: disconnect() */
		panic(Ebotch);
		break;
	}

	return;
}

static u32int
cchanalloc(Mux *mx)
{
	u32int i;

	for(i=0; i<mx->ndesc; i++)
		if(mx->tab[i].state == Cclosed)
			break;
	if(i == mx->ndesc){
		mx->ndesc++;
		mx->tab = erealloc(mx->tab, mx->ndesc*sizeof(mx->tab[0]));
		for(i=0; i<mx->ndesc; i++)
			mx->a[ABase+i].v = &mx->tab[i].m;
	}

	memset(&mx->tab[i], 0, sizeof mx->tab[i]);
	mx->tab[i].state = Calloc;

	return i;
}
