#include "u.h"
#include "libc.h"
#include "dat.h"
#include "fns.h"

#define PROGNAME "wcpu"
#define CONFBASE "SYSTEM\\CurrentControlSet\\Services\\" PROGNAME "\\Config"

enum {
	PORT_NCPU = 17010,
};

int debug;

void
setconfstr(char *key, char *str)
{
	HKEY hk;

	if(RegOpenKey(HKEY_LOCAL_MACHINE, CONFBASE, &hk))
		if(RegCreateKey(HKEY_LOCAL_MACHINE, CONFBASE, &hk))
			return;
	if(key && str){
		RegSetValueEx(hk, key, 0, REG_SZ, str, strlen(str)+1);
	}
	RegCloseKey(hk);
}

char*
getconfstr(char *key, char *buf, int nbuf)
{
	HKEY hk;
	DWORD t;
	DWORD n;

	if(RegOpenKey(HKEY_LOCAL_MACHINE, CONFBASE, &hk))
		error("cant open registry: %x", GetLastError());
	if(waserror()){
		RegCloseKey(hk);
		nexterror();
	}
	n = nbuf-1;
	if(RegQueryValueEx(hk, key, nil, &t, buf, &n))
		error("cant read registry key %s: %x", key, GetLastError());
	if(t != REG_SZ)
		n = 0;
	buf[n] = 0;
	poperror();
	RegCloseKey(hk);
	return buf;
}

void
writestr(Chan *ch, char *str, char *thing, int ignore)
{
	if(waserror()){
		if(ignore)
			return;
		nexterror();
	}
	writechan(ch, str, strlen(str)+1, ch->off);
	poperror();
}

int
readstr(Chan *ch, char *str, int len)
{
	int n;

	while(len) {
		n = readchan(ch, str, 1, ch->off);
		if(n < 0) 
			return -1;
		if(*str == 0)
			return 0;
		str++;
		len--;
	}
	return -1;
}

static void
client(void *aux)
{
	char buf[MaxStr], xdir[MaxStr], cmd[MaxStr], user[MaxStr], *s, *name[8];
	HANDLE utok;
	Chan *ch, *cons, *root, *note;
	Mnt *m;

	m = nil;
	ch = aux;
	if(waserror()){
		syslog(1, PROGNAME, "client: %s", errorstr());
		freechan(ch);
		freemount(m);
		nexterror();
	}
	if(readstr(ch, cmd, sizeof(cmd)) < 0)
		error("readstr authmethod");
	if(s = strchr(cmd, ' '))
		*s++ = 0;
	utok = 0;
	ch = doauth(ch, cmd, s, user, sizeof(user), &utok);
	if(waserror()){
		CloseHandle(utok);
		nexterror();
	}
	syslog(1, PROGNAME, "client: auth ok user=%s token=%x", user, utok);

	cmd[0] = 0;
	if(readstr(ch, xdir, sizeof(xdir)) < 0)
		error("readstr dir/cmd");
	if(xdir[0] == '!') {
		for(s = &xdir[1]; *s==' '; s++)
				;
		strncpy(cmd, s, sizeof(cmd));
		if(readstr(ch, xdir, sizeof(xdir)) < 0)
			error("readstr dir");
	} else {
		strncpy(cmd, "cmd.exe", sizeof(cmd));
	}

	writestr(ch, "FS", "FS", 0);
	writestr(ch, "/", "exportfs dir", 0);
	if(readchan(ch, buf, sizeof(buf), 0) != 2 || buf[0] != 'O' || buf[1] != 'K')
		error("remote tree");

	m = mkmount(ch); ch = nil;
	root = mntattach(m, user, nil);
	if(waserror()){
		freechan(root);
		nexterror();
	}

	note = nil;
	name[0] = "dev";
	name[1] = "cons";
	cons = mntwalk(root, name, 2);
	if(waserror()){
		freechan(note);
		freechan(cons);
		nexterror();
	}
	mntopen(cons, 2);
	if(waserror()){
		freechan(note);
		note = nil;
	} else {
		name[0] = "dev";
		name[1] = "cpunote";
		note = mntwalk(root, name, 2);
		mntopen(note, 0);
		poperror();
	}

	syslog(1, PROGNAME, "client: shell %s", cmd);
	if(waserror()){
		if(!waserror()){
			chanprint(cons, PROGNAME ": %s\n", errorstr());
			poperror();
		}
		nexterror();
	}
	shell(cons, cons, cons, note, cmd, user, utok);
	poperror();
	syslog(1, PROGNAME, "client: shell %s terminated", cmd);

	poperror();
	freechan(note);
	freechan(cons);

	poperror();
	freechan(root);

	poperror();
	CloseHandle(utok);

	poperror();
	freemount(m);
}

static SERVICE_STATUS_HANDLE hsvcstate;
static SERVICE_STATUS svcstate;
static HANDLE hsvcshutdown;

static void WINAPI
svcctl(DWORD req)
{
	switch(req){
	case SERVICE_CONTROL_STOP:
	case SERVICE_CONTROL_SHUTDOWN:
		svcstate.dwCurrentState = SERVICE_STOP_PENDING;
		SetServiceStatus(hsvcstate, &svcstate);
		SetEvent(hsvcshutdown);
		break;
	case SERVICE_CONTROL_INTERROGATE:
		SetServiceStatus(hsvcstate, &svcstate);
		break;
	}
}
static VOID WINAPI
svcmain(DWORD argc, LPCSTR argv[]){
	hsvcstate = RegisterServiceCtrlHandler(PROGNAME, svcctl);

	memset(&svcstate, 0, sizeof(svcstate));
	svcstate.dwServiceType = SERVICE_WIN32_OWN_PROCESS;
	svcstate.dwWin32ExitCode = NO_ERROR;
	svcstate.dwCurrentState = SERVICE_START_PENDING;
	svcstate.dwControlsAccepted = 0;
	SetServiceStatus(hsvcstate, &svcstate);

	svcstate.dwControlsAccepted |= (SERVICE_ACCEPT_STOP | SERVICE_ACCEPT_SHUTDOWN);
	svcstate.dwCurrentState = SERVICE_RUNNING;
	SetServiceStatus(hsvcstate, &svcstate);

	WaitForSingleObject(hsvcshutdown, INFINITE);

	svcstate.dwControlsAccepted &= ~(SERVICE_ACCEPT_STOP | SERVICE_ACCEPT_SHUTDOWN);
	svcstate.dwCurrentState = SERVICE_STOPPED;
	SetServiceStatus(hsvcstate, &svcstate);
}
static void
svcinstall(int install)
{
	SC_HANDLE scm, svc;

	if((scm = OpenSCManager(nil, nil, SC_MANAGER_ALL_ACCESS)) == nil)
		error("OpenSCManager: %x", GetLastError());
	if(waserror()){
		CloseServiceHandle(scm);
		nexterror();
	}
	if(install){
		char depend[64], path[MAX_PATH], cmd[1024];

		GetModuleFileName(nil, path, sizeof(path));
		snprintf(cmd, sizeof(cmd), "\"%s\"", path);
		memset(depend, 0, sizeof(depend));
		snprintf(depend, sizeof(depend), "%c%s", SC_GROUP_IDENTIFIER, "TDI");
		if((svc = CreateService(scm,
			PROGNAME,		// service name
			PROGNAME,		// display name 
			SERVICE_ALL_ACCESS,
			SERVICE_WIN32_OWN_PROCESS|SERVICE_INTERACTIVE_PROCESS,
			SERVICE_AUTO_START,
			SERVICE_ERROR_NORMAL,
			cmd,
			nil,			// load order group
			nil,			// tag in load order 
			depend,			// load order dependency
			nil,			// username
			nil)) == nil)	// password
			error("CreateService: %x", GetLastError());
		/* create the config key */
		setconfstr(nil, nil);
	} else {
		if ((svc = OpenService(scm, PROGNAME, SERVICE_ALL_ACCESS)) == nil)
			error("OpenService: %x", GetLastError());
		if(!DeleteService(svc)){
			CloseServiceHandle(svc);
			error("DeleteService: %x", GetLastError());
		}
	}
	CloseServiceHandle(svc);
	poperror();
	CloseServiceHandle(scm);
}

static void
server(void *aux)
{
	if(waserror()){
		syslog(1, PROGNAME, "server: %s", errorstr());
	} else {
		inittls();
		initsock();

		tcpserver(PORT_NCPU, client);
		poperror();
	}
	syslog(1, PROGNAME, "shutting down");
	SetEvent(hsvcshutdown);
}


void
usage(void)
{
	chanprint(mkstdchan(2), "usage: " PROGNAME " [ -c tlscertsubject ] [ -i | -u ]\n");
	exit(1);
}

void 
threadmain(int argc, char *argv[])
{
	int startup;

	static SERVICE_TABLE_ENTRY svctab[] = {
		{ PROGNAME, svcmain },
		{ nil, nil },
	};

	if(waserror()){
		syslog(1, PROGNAME, "error: %s", errorstr());
		nexterror();
	}
	startup = 1;
	ARGBEGIN{
	case 'i':
	case 'u':
		svcinstall(argv[0][1] == 'i');
		startup = 0;
		break;
	case 'c':
		setconfstr("tlscertsubject", EARGF(usage()));
		startup = 0;
		break;
	default:
		usage();
	}ARGEND;
	if(argc)
		usage();
	if(startup){
		syslog(1, PROGNAME, "starting up");
		hsvcshutdown = CreateEvent(nil, FALSE, FALSE, nil);
		threadcreate(server, nil);
		StartServiceCtrlDispatcher(svctab);
	}
	poperror();
}
