#include "u.h"
#include "libc.h"
#include "dat.h"
#include "fns.h"
#include "fcall.h"

/* 9p driver, stolen from plan9 kernel */

#define MAXRPC (IOHDRSZ+8192)

typedef struct Mntrpc Mntrpc;
struct Mntrpc
{
	Chan*	c;		/* Channel for whom we are working */
	Mntrpc*	list;		/* Free/pending list */
	Fcall	request;	/* Outgoing file system protocol message */
	uint	reqlen;
	Fcall 	reply;		/* Incoming reply */
	Rendez	r;	/* wait for response */
	Mnt*	m;		/* Mount device during rpc */
	uchar*	rpc;		/* I/O Data buffer */
	uint	rpclen;		/* len of buffer */
	Mntrpc*	flushed;	/* message this one flushes */
	uchar	*h, *b;	/* rmessage and data for Rread */
	char	done;		/* Rpc completed */
};

int
alloctag(void)
{
	static long tag = 1;
	return incref(&tag);
}

Mnt*
mkmount(Chan *ch)
{
	Mnt *m;

	m = mallocz(sizeof(*m), 1);
	m->ref = 1;
	m->c = ch;
	initlock(&m->l);
	return m;
}

void
freemount(Mnt *m)
{
	if(m && decref(&m->ref) == 0){
		freechan(m->c);
		freelock(&m->l);
		free(m->version);
		free(m);
	}
}

static int mntwrite(Chan *c, void *buf, int n, vlong off);
static int mntread(Chan *c, void *buf, int n, vlong off);
static void mntclose(Chan *c);

Chan*
mkmountchan(Mnt *m)
{
	Chan *ch;

	ch = mkchan();
	incref(&m->ref);
	ch->mnt = m;
	ch->mchan = m->c;
	ch->read = mntread;
	ch->write = mntwrite;
	ch->close = mntclose;
	return ch;
}

static Mntrpc*
mntralloc(Chan *c, ulong msize)
{
	Mntrpc *r;

	r = mallocz(sizeof(*r), 1);
	initrendez(&r->r);
	r->rpc = mallocz(msize, 0);
	r->rpclen = msize;
	r->request.tag = alloctag();
	r->c = c;
	r->done = 0;
	r->flushed = nil;
	return r;
}

static void
mntfree(Mntrpc *r)
{
	freerendez(&r->r);
	free(r->rpc);
	free(r->h);
	free(r);
}

static void
mntgate(Mnt *m)
{
	Mntrpc *q;

	lock(&m->l);
	m->pid = 0;
	for(q = m->queue; q; q = q->list) {
		if(q->done == 0)
		if(rwakeup(&q->r))
			break;
	}
	unlock(&m->l);
}

static void
mountmux(Mnt *m, Mntrpc *r)
{
	Mntrpc **l, *q;

	lock(&m->l);
	l = (Mntrpc**)&m->queue;
	for(q = *l; q; q = q->list) {
		/* look for a reply to a message */
		if(q->request.tag == r->reply.tag) {
			*l = q->list;
			if(q != r) {
				/*
				 * Completed someone else.
				 * Trade pointers to receive buffer.
				 */
				q->reply = r->reply;
				q->b = r->b;
				q->h = r->h;
				r->b = nil;
				r->h = nil;
			}
			q->done = 1;
			unlock(&m->l);
			if(q != r)
				rwakeup(&q->r);
			return;
		}
		l = &q->list;
	}
	unlock(&m->l);
}

static void
mntqrm(Mnt *m, Mntrpc *r)
{
	Mntrpc **l, *f;

	lock(&m->l);
	r->done = 1;
	l = (Mntrpc**)&m->queue;
	for(f = *l; f; f = f->list) {
		if(f == r) {
			*l = r->list;
			break;
		}
		l = &f->list;
	}
	unlock(&m->l);
}

static Mntrpc*
mntflushalloc(Mntrpc *r, ulong iounit)
{
	Mntrpc *fr;

	fr = mntralloc(0, iounit);
	fr->request.type = Tflush;
	if(r->request.type == Tflush)
		fr->request.oldtag = r->request.oldtag;
	else
		fr->request.oldtag = r->request.tag;
	fr->flushed = r;

	return fr;
}

static void
mntflushfree(Mnt *m, Mntrpc *r)
{
	Mntrpc *fr;

	while(r){
		fr = r->flushed;
		if(!r->done){
			r->reply.type = Rflush;
			mntqrm(m, r);
		}
		if(fr)
			mntfree(r);
		r = fr;
	}
}
static void
mntrpcread(Mnt *m, Mntrpc *r)
{
	int n, t;
	ulong len, hlen;
	uchar *msg, *rp, *ep;

	r->reply.type = 0;
	r->reply.tag = 0;

	msg = mallocz(IOHDRSZ+m->msize, 0);
	if(waserror()){
		free(msg);
		nexterror();
	}
	/* read at least length, type, and tag */
	rp = msg;
	ep = msg + BIT32SZ+BIT8SZ+BIT16SZ;
	while(rp < ep){
		if((n = readchan(m->c, rp, ep - rp, m->c->off)) <= 0)
			error(Emountrpc);
		rp += n;
	}
	len = GBIT32(msg);
	if(len > m->msize)
		error("rpc too big");
	/* read the rest of the message */
	ep = msg + len;
	while(rp < ep){
		if((n = readchan(m->c, rp, ep - rp, m->c->off)) <= 0)
			error(Emountrpc);
		rp += n;
	}
	t = msg[BIT32SZ];
	switch(t){
	case Rread:
		hlen = BIT32SZ+BIT8SZ+BIT16SZ+BIT32SZ;
		break;
	default:
		hlen = len;
		break;
	}
	if(convM2S(msg, len, &r->reply) <= 0)
		error(Emountrpc);
	poperror();
	r->h = msg;
	r->b = msg + hlen;
}

static int
rpcattn(void *v)
{
	Mntrpc *r;

	r = v;
	return r->done || r->m->pid == 0;
}

static void
mountio(Mnt *m, Mntrpc *r)
{
	int n;

	while(waserror()){
		char *err;

		if(m->pid == getpid())
			mntgate(m);

		err = errorstr();
		if(strcmp(err, Eintr) != 0){
			mntflushfree(m, r);
			nexterror();
		}
		r = mntflushalloc(r, m->msize);
	}

	lock(&m->l);
	r->m = m;
	r->list = m->queue;
	m->queue = r;
	unlock(&m->l);

	/* Transmit a file system rpc */
	if(m->msize == 0)
		error("msize");
	n = convS2M(&r->request, r->rpc, m->msize);
	if(n < 0)
		error(Emountrpc);
	if(writechan(m->c, r->rpc, n, m->c->off) != n)
		error(Emountrpc);
	r->reqlen = n;

	/* Gate readers onto the mount point one at a time */
	for(;;) {
		lock(&m->l);
		if(m->pid == 0)
			break;
		unlock(&m->l);
		rsleep(&r->r, rpcattn, r);
		if(r->done){
			poperror();
			mntflushfree(m, r);
			return;
		}
	}
	m->pid = getpid();
	unlock(&m->l);
	while(r->done == 0) {
		mntrpcread(m, r);
		mountmux(m, r);
	}
	mntgate(m);
	poperror();
	mntflushfree(m, r);
}

static void
mountrpc(Mnt *m, Mntrpc *r)
{
	int t;

	r->reply.tag = 0;
	r->reply.type = Tmax;	/* can't ever be a valid message type */
	mountio(m, r);
	t = r->reply.type;
	switch(t) {
	case Rerror:
		error("%s", r->reply.ename);
	case Rflush:
		error(Eintr);
	default:
		if(t == r->request.type+1)
			break;
		error(Emountrpc);
	}
}

void
mntversion(Mnt *m, char *version, ulong msize)
{
	Fcall f;
	Chan *c;
	uchar *msg;
	long k, l;

	c = m->c;
	if(msize == 0)
		msize = MAXRPC;
	if(msize > c->iounit && c->iounit != 0)
		msize = c->iounit;
	if(version == nil || version[0] == 0)
		version = VERSION9P;
	if(msize < 0)
		error("bad iounit in version call");
	if(strncmp(version, VERSION9P, strlen(VERSION9P)) != 0)
		error("bad 9P version specification");
	f.type = Tversion;
	f.tag = NOTAG;
	f.msize = msize;
	f.version = version;
	msg = mallocz(8192+IOHDRSZ, 0);
	if(waserror()){
		free(msg);
		nexterror();
	}
	k = convS2M(&f, msg, 8192+IOHDRSZ);
	if(k == 0)
		error("bad fversion conversion on send");
	if((l = writechan(c, msg, k, c->off)) < k)
		error("short write in fversion");
	if((k = readchan(c, msg, 8192+IOHDRSZ, c->off)) <= 0)
		error("EOF receiving fversion reply");
	l = convM2S(msg, k, &f);
	if(l != k)
		error("bad fversion conversion on reply");
	if(f.type != Rversion){
		if(f.type == Rerror)
			error(f.ename);
		error("unexpected reply type in fversion");
	}
	if(f.msize > (u32int)msize)
		error("server tries to increase msize in fversion");
	if(f.msize<256 || f.msize>1024*1024)
		error("nonsense value of msize in fversion");
	if(strncmp(f.version, version, strlen(f.version)) != 0)
		error("bad 9P version returned from server");
	m->version = strdup(f.version);
	m->msize = f.msize;
	poperror();
	free(msg);
}

Chan*
mntattach(Mnt *m, char *uname, char *spec)
{
	Chan *c;
	Mntrpc *r;

	if(m->msize==0 || m->version == nil)
		mntversion(m, nil, 0);
	c = mkmountchan(m);
	if(waserror()) {
		freechan(c);
		nexterror();
	}
	r = mntralloc(0, m->msize);
	if(waserror()) {
		mntfree(r);
		nexterror();
	}
	r->request.type = Tattach;
	r->request.fid = c->fid;
	r->request.afid = NOFID;
	r->request.uname = uname;
	r->request.aname = spec;
	mountrpc(m, r);
	c->qid = r->reply.qid;
	poperror();	/* r */
	mntfree(r);
	poperror();	/* c */
	return c;
}

static Mnt*
mntchk(Chan *c)
{
	if(c->mnt == nil || c->mnt->c != c->mchan)
		error("not a mnt chan");
	return c->mnt;
}

Chan*
mntwalk(Chan *c, char **name, int nname)
{
	Mnt *m;
	Chan *nc;
	Mntrpc *r;

	if(nname > MAXWELEM)
		error("too many name elements");
	m = mntchk(c);
	nc = mkmountchan(m);
	if(waserror()){
		freechan(nc);
		nexterror();
	}
	r = mntralloc(c, m->msize);
	if(waserror()) {
		mntfree(r);
		nexterror();
	}
	r->request.type = Twalk;
	r->request.fid = c->fid;
	r->request.newfid = nc->fid;
	r->request.nwname = nname;
	memmove(r->request.wname, name, nname*sizeof(char*));
	mountrpc(m, r);
	if(r->reply.nwqid > nname)
		error("too many QIDs returned by walk");
	if(r->reply.nwqid < nname)
		error("file not found");
	nc->qid = r->reply.wqid[r->reply.nwqid-1];
	poperror(); /* r */
	mntfree(r);
	poperror(); /* nc */
	return nc;
}

static void
mntclunk(Chan *c, int t)
{
	Mnt *m;
	Mntrpc *r;

	m = mntchk(c);
	r = mntralloc(c, m->msize);
	if(waserror()){
		mntfree(r);
		nexterror();
	}
	r->request.type = t;
	r->request.fid = c->fid;
	mountrpc(m, r);
	poperror();
	mntfree(r);
}

static void
mntclose(Chan *c)
{
	/* ignore error */
	if(!waserror()){
		mntclunk(c, Tclunk);
		poperror();
	}
}

void
mntremove(Chan *c)
{
	mntclunk(c, Tremove);
}

static Chan*
mntopencreate(int type, Chan *c, char *name, int omode, ulong perm)
{
	Mnt *m;
	Mntrpc *r;

	m = mntchk(c);
	r = mntralloc(c, m->msize);
	if(waserror()) {
		mntfree(r);
		nexterror();
	}
	r->request.type = type;
	r->request.fid = c->fid;
	r->request.mode = omode;
	if(type == Tcreate){
		r->request.perm = perm;
		r->request.name = name;
	}
	mountrpc(m, r);
	c->qid = r->reply.qid;
	c->off = 0;
	c->iounit = r->reply.iounit;
	if(c->iounit == 0 || c->iounit > m->msize-IOHDRSZ)
		c->iounit = m->msize-IOHDRSZ;
	poperror();
	mntfree(r);
	return c;
}

Chan*
mntopen(Chan *c, int omode)
{
	return mntopencreate(Topen, c, nil, omode, 0);
}

long
mntrdwr(int type, Chan *c, void *buf, long n, vlong off)
{
	Mnt *m;
 	Mntrpc *r;
	char *uba;
	ulong cnt, nr, nreq;

	m = mntchk(c);
	uba = buf;
	cnt = 0;
	for(;;) {
		r = mntralloc(c, m->msize);
		if(waserror()) {
			mntfree(r);
			nexterror();
		}
		r->request.type = type;
		r->request.fid = c->fid;
		r->request.offset = off;
		r->request.data = uba;
		nr = n;
		if(nr > m->msize-IOHDRSZ)
			nr = m->msize-IOHDRSZ;
		r->request.count = nr;
		mountrpc(m, r);
		nreq = r->request.count;
		nr = r->reply.count;
		if(nr > nreq)
			nr = nreq;
		if(type == Tread)
			memmove(uba, r->b, nr);
		poperror();
		mntfree(r);
		off += nr;
		uba += nr;
		cnt += nr;
		n -= nr;
		if(nr != nreq || n == 0)
			break;
	}
	return cnt;
}

static int
mntwrite(Chan *c, void *buf, int n, vlong off)
{
	return mntrdwr(Twrite, c, buf, n, off);
}
static int
mntread(Chan *c, void *buf, int n, vlong off)
{
	return mntrdwr(Tread, c, buf, n, off);
}
