#include "dat.h"
#include "fns.h"

static	int	findauth(Conn*, Auth**, AuthContext***);
static	char	*msgnameauth(int);

void
authinit(Conn *c)
{
	c->msgnameauth = msgnameauth;
	return;
}

int
authclient(Conn *c, Channel *rc, Channel *wc, char *service)
{
	int rv;
	char *lastauth;

	Msg *m;
	Auth *a;
	AuthContext **ctx;

	debug(DbgAuth, "authclient: requesting ssh-auth service");
	servicereq(c, rc, wc, "ssh-userauth");
	debug(DbgAuth, "authclient: ssh-auth service requested");

	/* initialize list of acceptable authentication methods */
	freenamelist(c->authlist);
	c->authlist = impl2name(c->okauth);

	lastauth = "none";
	m = ssh_msg_userauthreq(c, service, lastauth);
	sendmsg(wc, m);
	debug(DbgAuth, "authclient: sent userauthreq '%s'", lastauth);

	m = authrecvmsg(c, rc, MSGIGN);
	rv = authgeneric(c, rc, wc, m, 0);	/* don't retry on failure */
	debug(DbgAuth, "authgeneric %d", rv);
	while(rv <= UApartial){
		/* remove unconditionally failed methods from Conn->authlist */
		if(rv == UAfail)
			removename(c->authlist, lastauth);
		debug(DbgAuth, "auth. can continue: %N", c->authlist);
		if(findauth(c, &a, &ctx) < 0){
			werrstr("no authentication methods can continue");
			goto Error;
		}

		lastauth = a->name;
		debug(DbgAuth, "selected authentication method: %s", a->name);
		if(a->init(c, ctx) < 0){
			rv = UAfail;
			error(c, "authinit %s fails: %r", a->name);
		}else
			rv = a->run(c, rc, wc, *ctx, service);
		debug(DbgAuth, "%s --> %d", lastauth, rv);
	}

	assert(rv > 0);
	c->wirestate = WSauth;
	return 1;

 Error:
	return -1;
}

int
authserver(Conn *c, Channel *rc, Channel *wc)
{
	USED(c);
	USED(rc);
	USED(wc);

	panic("not implemented");
	return -1;
}

/*
 * XXX: authrecvmsg is a filthy hack
 */
Msg*
authrecvmsg(Conn *c, Channel *rc, int type)
{
	Msg *m;

 Again:
	m = recvmsg(rc, type);
	if(m->type == SSH_MSG_USERAUTH_BANNER){
		if(c->role == RServer)
			panic("client sent SSH_MSG_USERAUTH_BANNER");
		freemsg(m);
		goto Again;
	}

	return m;
}

int
authgeneric(Conn *c, Channel *rc, Channel *wc, Msg *m, int again)
{
	int n, rv;
	Namelist *nl;

	USED(rc);
	assert(m->c == c);

	rv = UAfail;
	switch(m->type){
	case SSH_MSG_USERAUTH_SUCCESS:
		freemsg(m);
		if(c->role == RClient)
			c->wirestate = WSauth;
		else{
			disconnect(c, wc, SSH_DISCONNECT_PROTOCOL_ERROR);
			panic("client sent SSH_MSG_USERAUTH_SUCCESS");
		}
		rv = UAsuccess;
		break;

	case SSH_MSG_USERAUTH_FAILURE:
		if(c->role == RServer){
			disconnect(c, wc, SSH_DISCONNECT_PROTOCOL_ERROR);
			panic("client sent SSH_MSG_USERAUTH_FAILURE");
		}
		/* restrict Conn->authlist to methods that can continue */
		nl = getnamelist(&m->b);
		namelistintersect(c->authlist, nl);
		n = getbool(&m->b);
		if(n)
			rv = UApartial;
		else if(again)
			rv = UAagain;
		freemsg(m);
		debug(DbgAuth, "authentication fails: partial failure: %d", n);
		werrstr("authentication failed");
		break;
		
	default:
		badmsg(m, "SSH_MSG_USERAUTH_SUCCESS");	/* XXX: I lie */
		break;
	}

	return rv;
}

/* XXX: see namelistintersect */
static int
findauth(Conn *c, Auth **a, AuthContext ***ctx)
{
	int i, j;

	Impllist *il;
	Namelist *al;

	*a = nil;
	*ctx = nil;
	il = c->okauth;
	al = c->authlist;

	for(i=0; i<il->nimpl; i++)
		for(j=0; j<al->nstr; j++)
			if(strcmp(il->impl[i]->name, al->strtab[j]) == 0){
				*a = (Auth*)il->impl[i];
				*ctx = (AuthContext**)&c->authcontext[i];
				goto Done;
			}

	debug(DbgAuth, "no shared authentication methods can continue");
	return -1;

 Done:
	return 0;
}

static char*
msgnameauth(int t)
{
	USED(t);

	return "<unknown>";
}

/*
 * XXX: userauth state machine not implemented
 * 1. Both should disable password authentication if no cipher or no HMAC
 * 2. 3 possible return values
 * 3. "Authentication mechanisms that can continue" list processing
 */

/* returns partially constructed SSH_MSG_USERAUTH_REQUEST */
Msg*
ssh_msg_userauthreq(Conn *c, char *service, char *method)
{
	Msg *m;

	m = allocmsg(c, SSH_MSG_USERAUTH_REQUEST, 0);
	putstring(&m->b, c->user);
	putstring(&m->b, service);
	putstring(&m->b, method);

	return m;
}

Msg*
ssh_msg_userauthfailure(Conn *c, Namelist *authlist, int psuccess)
{
	Msg *m;

	m = allocmsg(c, SSH_MSG_USERAUTH_FAILURE, 0);
	putnamelist(&m->b, authlist);
	putbool(&m->b, psuccess);			/* partial success */
	return m;
}

Msg*
ssh_msg_userauthsuccess(Conn *c)
{
	Msg *m;

	m = allocmsg(c, SSH_MSG_USERAUTH_SUCCESS, 0);
	return m;
}
