#include <u.h>
#include <libc.h>
#include <auth.h>
#include <fcall.h>
#include "dat.h"
#include "fns.h"

ulong currmsize;

/*
 * Initial connection setup.  Assumes this is the only traffic
 * on the connection, so no need to multiplex or worry about
 * threads.
 */

/*
 * run 9P transaction; errok controls whether to return -1 on Rerror.
 */
int
transaction(Fcall *r, int errok)
{
	static uchar buf[MAXPKT];
	int n;
	Fcall f;

	n = convS2M(r, buf, MAXPKT);
	if(n <= BIT32SZ)
		sysfatal("convS2M: %r");

	chat("net <- %F...\n", r);
	write(netfd, buf, n);

	if((n = netread9pmsg(netfd, buf, MAXPKT)) <= 0){
		if(n == 0)
			werrstr("network eof");
		return -1;
	}
	if(convM2S(buf, n, &f) != n)
		return -1;
	chat("net -> %F...\n", &f);

	if(f.tag != r->tag){
		werrstr("unexpected tag %ud != %ud", f.tag, r->tag);
		return -1;
	}
	if(f.type == Rerror && !errok){
		werrstr("server: %s", f.ename);
		return -1;
	}
	if(f.type != Rerror && f.type != r->type+1){
		werrstr("unexpected type %ud != %ud+1", f.type, r->type);
		return -1;
	}
	*r = f;
	return 0;
}

int
xversion(void)
{
	Fcall f;

	memset(&f, 0, sizeof f);
	f.type = Tversion;
	f.tag = NOTAG;
	f.msize = MAXPKT;
	f.version = "9P2000";

	if(transaction(&f, 0) < 0)
		return -1;

	if(f.msize > MAXPKT)
		sysfatal("server msize %ud > requested msize %ud", f.msize, MAXPKT);
	if(currmsize == 0)
		currmsize = f.msize;
	if(currmsize > f.msize)
		sysfatal("server reduced msize on reconnect - was %lud now %ud", currmsize, f.msize);
	if(strcmp(f.version, "9P2000") != 0)
		sysfatal("server wants to speak %s", f.version);
	return 0;
}

static int
xclunk(Fid *fid)
{
	Fcall f;
	
	memset(&f, 0, sizeof f);
	f.type = Tclunk;
	f.fid = fid->fid.remote;
	
	return transaction(&f, 0);
}

static int
xread(Fid *fid, void *buf, int n)
{
	Fcall f;
	
	memset(&f, 0, sizeof f);
	f.type = Tread;
	f.fid = fid->fid.remote;
	f.count = n;

	if(transaction(&f, 0) < 0)
		return -1;
	
	memmove(buf, f.data, f.count);
	return f.count;
}

static int
xwrite(Fid *fid, void *buf, int n)
{
	Fcall f;
	
	memset(&f, 0, sizeof f);
	f.type = Twrite;
	f.fid = fid->fid.remote;
	f.count = n;
	f.data = buf;
	
	if(transaction(&f, 0) < 0)
		return -1;
	
	return f.count;
}

static int
xauth(Fid *fid, char *uname, char *aname)
{
	Fcall f;
	
	memset(&f, 0, sizeof f);
	f.type = Tauth;
	f.afid = fid->fid.remote;
	f.uname = uname;
	f.aname = aname;
	
	if(transaction(&f, 1) < 0)
		return -1;
	
	if(f.type == Rerror)
		return 0;

	return 1;
}

enum { 
	ARgiveup = 100,
};

static int
dorpc(AuthRpc *rpc, char *verb, char *val, int len, AuthGetkey *getkey)
{
	int ret;

	for(;;){
		if((ret = auth_rpc(rpc, verb, val, len)) != ARneedkey && ret != ARbadkey)
			return ret;
		if(getkey == nil)
			return ARgiveup;	/* don't know how */
		if((*getkey)(rpc->arg) < 0)
			return ARgiveup;	/* user punted */
	}
}

/*
 *  this just proxies what the factotum tells it to.
 */
static AuthInfo*
xfauthproxy(Fid *afid, AuthRpc *rpc, AuthGetkey *getkey, char *params)
{
	char *buf;
	int m, n, ret;
	AuthInfo *a;
	char oerr[ERRMAX];

	rerrstr(oerr, sizeof oerr);
	werrstr("UNKNOWN AUTH ERROR");

	if(dorpc(rpc, "start", params, strlen(params), getkey) != ARok){
		werrstr("xfauth_proxy start: %r");
		return nil;
	}

	buf = malloc(AuthRpcMax);
	if(buf == nil)
		return nil;
	for(;;){
		switch(dorpc(rpc, "read", nil, 0, getkey)){
		case ARdone:
			free(buf);
			a = auth_getinfo(rpc);
			errstr(oerr, sizeof oerr);	/* no error, restore whatever was there */
			return a;
		case ARok:
			if(xwrite(afid, rpc->arg, rpc->narg) != rpc->narg){
				werrstr("auth_proxy write fd: %r");
				goto Error;
			}
			break;
		case ARphase:
			n = 0;
			memset(buf, 0, AuthRpcMax);
			while((ret = dorpc(rpc, "write", buf, n, getkey)) == ARtoosmall){
				if(atoi(rpc->arg) > AuthRpcMax)
					break;
				m = xread(afid, buf+n, atoi(rpc->arg)-n);
				if(m <= 0){
					if(m == 0)
						werrstr("auth_proxy short read: %s", buf);
					goto Error;
				}
				n += m;
			}
			if(ret != ARok){
				werrstr("auth_proxy rpc write: %s: %r", buf);
				goto Error;
			}
			break;
		default:
			werrstr("auth_proxy rpc: %r");
			goto Error;
		}
	}
Error:
	free(buf);
	return nil;
}

static AuthInfo*
xauthproxy(Fid *authfid, AuthGetkey *getkey, char *fmt, ...)
{

	char *p;
	va_list arg;
	AuthInfo *ai;
	AuthRpc *rpc;

	va_start(arg, fmt);
	p = vsmprint(fmt, arg);
	va_end(arg);

	if((rpc = auth_allocrpc_wrap()) == nil){
		free(p);
		return nil;
	}
	
	ai = xfauthproxy(authfid, rpc, getkey, p);
	free(p);
	auth_freerpc(rpc);
	return ai;
}

static int
xattach(Fid *fid, Fid *afid, char *uname, char *aname, Qid *q)
{
	Fcall f;
	
	memset(&f, 0, sizeof f);
	f.type = Tattach;
	f.fid = fid->fid.remote;
	f.afid = afid ? afid->fid.remote : NOFID;
	f.uname = uname;
	f.aname = aname;

	if(transaction(&f, 0) < 0)
		return -1;
	
	*q = f.qid;
	return 0;
}

int
authattach(Attach *a)
{
	AuthInfo *ai;
	Fid *authfid;
	
	ai = nil;
	authfid = allocfid(fidgen++);
	
	switch(xauth(authfid, eve, a->aname)){
	case -1:
		return -1;
	case 0:
		freefid(authfid);
		authfid = nil;
		ai = nil;
		break;
	case 1:
		if((ai = xauthproxy(authfid, auth_getkey, "proto=p9any role=client")) == nil){
			if(authfid){
				xclunk(authfid);
				freefid(authfid);
			}
			return -1;
		}
		break;
	}
	
	if(ai)
		auth_freeAI(ai);
	if(xattach(a->rootfid, authfid, eve, a->aname, &a->rootqid) < 0){
		if(authfid){
			xclunk(authfid);
			freefid(authfid);
		}
		return -1;
	}

	if(authfid){
		xclunk(authfid);
		freefid(authfid);
	}
	return 0;
}

