#include "libc.h"
#include "auth.h"

enum {
	ARgiveup = 100,
};

static struct {
	char *verb;
	int val;
} tab[] = {
	"ok",			ARok,
	"done",		ARdone,
	"error",		ARerror,
	"needkey",	ARneedkey,
	"badkey",		ARbadkey,
	"phase",		ARphase,
	"toosmall",	ARtoosmall,
	"error",		ARerror,
};

static char qsep[] = " \t\r\n";

int
utfnlen(char *s, long m)
{
	int c;
	long n;
	Rune rune;
	char *es;

	es = s + m;
	for(n = 0; s < es; n++) {
		c = *(uchar*)s;
		if(c < Runeself){
			if(c == '\0')
				break;
			s++;
			continue;
		}
		if(!fullrune(s, es-s))
			break;
		s += chartorune(&rune, s);
	}
	return n;
}

int
utflen(char *s)
{
	int c;
	long n;
	Rune rune;

	n = 0;
	for(;;) {
		c = *(uchar*)s;
		if(c < Runeself) {
			if(c == 0)
				return n;
			s++;
		} else
			s += chartorune(&rune, s);
		n++;
	}
}

static char*
qtoken(char *s, char *sep)
{
	int quoting;
	char *t;

	quoting = 0;
	t = s;	/* s is output string, t is input string */
	while(*t!='\0' && (quoting || utfrune(sep, *t)==nil)){
		if(*t != '\''){
			*s++ = *t++;
			continue;
		}
		/* *t is a quote */
		if(!quoting){
			quoting = 1;
			t++;
			continue;
		}
		/* quoting and we're on a quote */
		if(t[1] != '\''){
			/* end of quoted section; absorb closing quote */
			t++;
			quoting = 0;
			continue;
		}
		/* doubled quote; fold one quote into two */
		t++;
		*s++ = *t++;
	}
	if(*s != '\0'){
		*s = '\0';
		if(t == s)
			t++;
	}
	return t;
}

static char*
etoken(char *t, char *sep)
{
	int quoting;

	/* move to end of next token */
	quoting = 0;
	while(*t!='\0' && (quoting || utfrune(sep, *t)==nil)){
		if(*t != '\''){
			t++;
			continue;
		}
		/* *t is a quote */
		if(!quoting){
			quoting = 1;
			t++;
			continue;
		}
		/* quoting and we're on a quote */
		if(t[1] != '\''){
			/* end of quoted section; absorb closing quote */
			t++;
			quoting = 0;
			continue;
		}
		/* doubled quote; fold one quote into two */
		t += 2;
	}
	return t;
}

int
gettokens(char *s, char **args, int maxargs, char *sep)
{
	int nargs;

	for(nargs=0; nargs<maxargs; nargs++){
		while(*s!='\0' && utfrune(sep, *s)!=nil)
			*s++ = '\0';
		if(*s == '\0')
			break;
		args[nargs] = s;
		s = etoken(s, sep);
	}

	return nargs;
}

int
tokenize(char *s, char **args, int maxargs)
{
	int nargs;

	for(nargs=0; nargs<maxargs; nargs++){
		while(*s!='\0' && utfrune(qsep, *s)!=nil)
			s++;
		if(*s == '\0')
			break;
		args[nargs] = s;
		s = qtoken(s, qsep);
	}

	return nargs;
}

int
getfields(char *str, char **args, int max, int mflag, char *set)
{
	Rune r;
	int nr, intok, narg;

	if(max <= 0)
		return 0;

	narg = 0;
	args[narg] = str;
	if(!mflag)
		narg++;
	intok = 0;
	for(;; str += nr) {
		nr = chartorune(&r, str);
		if(r == 0)
			break;
		if(utfrune(set, r)) {
			if(narg >= max)
				break;
			*str = 0;
			intok = 0;
			args[narg] = str + nr;
			if(!mflag)
				narg++;
		} else {
			if(!intok && mflag)
				narg++;
			intok = 1;
		}
	}
	return narg;
}

static int
classify(char *buf, uint n, AuthRpc *rpc)
{
	int i, len;

	for(i=0; i<nelem(tab); i++){
		len = strlen(tab[i].verb);
		if(n >= len && memcmp(buf, tab[i].verb, len) == 0 && (n==len || buf[len]==' ')){
			if(n==len){
				rpc->narg = 0;
				rpc->arg = "";
			}else{
				rpc->narg = n - (len+1);
				rpc->arg = (char*)buf+len+1;
			}
			return tab[i].val;
		}
	}
	werrstr("malformed rpc response: %s", buf);
	return ARrpcfailure;
}

AuthRpc*
auth_allocrpc(int afd)
{
	AuthRpc *rpc;

	rpc = mallocz(sizeof(*rpc), 1);
	if(rpc == nil)
		return nil;
	rpc->afd = afd;
	return rpc;
}

void
auth_freerpc(AuthRpc *rpc)
{
	free(rpc);
}

uint
auth_rpc(AuthRpc *rpc, char *verb, void *a, int na)
{
	int l, n, type;
	char *f[4];

	l = strlen(verb);
	if(na+l+1 > AuthRpcMax){
		werrstr("rpc too big");
		return ARtoobig;
	}

	memmove(rpc->obuf, verb, l);
	rpc->obuf[l] = ' ';
	memmove(rpc->obuf+l+1, a, na);
	if((n=write(rpc->afd, rpc->obuf, l+1+na)) != l+1+na){
		if(n >= 0)
			werrstr("auth_rpc short write");
		return ARrpcfailure;
	}

	if((n=read(rpc->afd, rpc->ibuf, AuthRpcMax)) < 0)
		return ARrpcfailure;
	rpc->ibuf[n] = '\0';

	/*
	 * Set error string for good default behavior.
	 */
	switch(type = classify(rpc->ibuf, n, rpc)){
	default:
		werrstr("unknown rpc type %d (bug in auth_rpc.c)", type);
		break;
	case ARok:
		break;
	case ARrpcfailure:
		break;
	case ARerror:
		if(rpc->narg == 0)
			werrstr("unspecified rpc error");
		else
			werrstr("%s", rpc->arg);
		break;
	case ARneedkey:
		werrstr("needkey %s", rpc->arg);
		break;
	case ARbadkey:
		if(getfields(rpc->arg, f, nelem(f), 0, "\n") < 2)
			werrstr("badkey %s", rpc->arg);
		else
			werrstr("badkey %s", f[1]);
		break;
	case ARphase:
		werrstr("phase error %s", rpc->arg);
		break;
	}
	return type;
}

static int
dorpc(AuthRpc *rpc, char *verb, char *val, int len, AuthGetkey *getkey)
{
	int ret;

	for(;;){
		if((ret = auth_rpc(rpc, verb, val, len)) != ARneedkey && ret != ARbadkey)
			return ret;
		if(getkey == nil)
			return ARgiveup;	/* don't know how */
		if((*getkey)(rpc->arg) < 0)
			return ARgiveup;	/* user punted */
	}
}

UserPasswd*
auth_getuserpasswd(AuthGetkey *getkey, char *fmt, ...)
{
	AuthRpc *rpc;
	char *f[3], *p, *params;
	int fd;
	va_list arg;
	UserPasswd *up;

	up = nil;
	rpc = nil;
	params = nil;

	fd = open("/mnt/factotum/rpc", ORDWR);
	if(fd < 0)
		goto out;
	rpc = auth_allocrpc(fd);
	if(rpc == nil)
		goto out;
	quotefmtinstall();	/* just in case */
	va_start(arg, fmt);
	params = vsmprint(fmt, arg);
	va_end(arg);
	if(params == nil)
		goto out;

	if(dorpc(rpc, "start", params, strlen(params), getkey) != ARok
	|| dorpc(rpc, "read", nil, 0, getkey) != ARok)
		goto out;

	rpc->arg[rpc->narg] = '\0';
	if(tokenize(rpc->arg, f, 2) != 2){
		werrstr("bad answer from factotum");
		goto out;
	}
	up = malloc(sizeof(*up)+rpc->narg+1);
	if(up == nil)
		goto out;
	p = (char*)&up[1];
	strcpy(p, f[0]);
	up->user = p;
	p += strlen(p)+1;
	strcpy(p, f[1]);
	up->passwd = p;

out:
	free(params);
	auth_freerpc(rpc);
	close(fd);
	return up;
}

int
auth_getkey(char *params)
{
	char *name;
	Dir *d;
	int pid;
	int w, status;

	/* start /factotum to query for a key */
	name = "/factotum";
	d = dirstat(name);
	if(d == nil){
		name = "/boot/factotum";
		d = dirstat(name);
	}
	if(d == nil){
		werrstr("auth_getkey: no /factotum or /boot/factotum: didn't get key %s", params);
		return -1;
	}
	switch(pid = fork()){
	case -1:
		werrstr("can't fork for %s: %r", name);
		return -1;
	case 0:
		execl(name, "getkey", "-g", params, nil);
		exits(0);
	default:
		for(;;){
			w = wait(&status);
			if(w == pid){
				if(status != 0){
					return -1;
				}
				return 0;
			}
		}
	}
	return 0;
}

char *
getpass(char *host, char *user)
{
	UserPasswd *up;

	up = auth_getuserpasswd(auth_getkey, "proto=pass service=ssh2 server=%q user=%q", host, user);
	if(up == nil)
		return nil;
	return up->passwd;
}
