#include "u.h"
#include "libc.h"
#include "dat.h"
#include "fns.h"

#include <userenv.h>
#include <Tlhelp32.h>
#include <Winternl.h>

static DWORD
getprocessid(HANDLE h) 
{ 
	typedef DWORD (WINAPI* pfnGetProcID)(HANDLE h); 
	typedef NTSTATUS (WINAPI* pfnQueryInformationProcess)( 
	    HANDLE ProcessHandle, 
	    PROCESSINFOCLASS ProcessInformationClass, 
		PVOID ProcessInformation, 
		ULONG ProcessInformationLength, 
		PULONG ReturnLength); 

	static pfnQueryInformationProcess ntQIP;
    static pfnGetProcID getPID;
	if((getPID == nil) && (ntQIP == nil))
		if((getPID = (pfnGetProcID)GetProcAddress(GetModuleHandle("KERNEL32.DLL"),"GetProcessId")) == nil)
			if((ntQIP = (pfnQueryInformationProcess)GetProcAddress(GetModuleHandle("NTDLL.DLL"),"NtQueryInformationProcess")) == nil)
				return ~0;
	if(getPID != nil){
		return getPID(h); 
	} else { 
		PROCESS_BASIC_INFORMATION info; 
		ULONG size;

		memset(&info, 0, sizeof(info));
		ntQIP(h, ProcessBasicInformation, &info, sizeof(info), &size);
		return info.UniqueProcessId; 
    } 
} 

static void
TerminateProcessTree(HANDLE h, UINT code)
{
	HANDLE hsnap;
	DWORD pid;
	PROCESSENTRY32 pe;

	if(hsnap = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0)){
		memset(&pe, 0, sizeof(PROCESSENTRY32));
		pe.dwSize = sizeof(PROCESSENTRY32);

		if(Process32First(hsnap, &pe)){
			pid = getprocessid(h);
			do {
				if(pe.th32ParentProcessID == pid){
					HANDLE h2;
					if(h2 = OpenProcess(PROCESS_TERMINATE, FALSE, pe.th32ProcessID)){
						TerminateProcessTree(h2, code);
						CloseHandle(h2);
					}
				}
			} while(Process32Next(hsnap, &pe));
		}
		CloseHandle(hsnap);
	}
	TerminateProcess(h, code);
}

static int
transread(Chan *ch, void *data, int len, vlong off)
{
	return readchan((Chan*)ch->aux, data, len, off);
}
static void
transclose(Chan *ch)
{
	freechan(ch->aux);
}
Chan*
mktranschan(Chan *slave, int (*fwrite)(Chan *, void *, int, vlong))
{
	Chan *ch;

	ch = mkchan();
	ch->aux = slave;
	ch->read = transread;
	ch->write = fwrite;
	ch->close = transclose;
	return ch;
}

static int
transLF2CRLF(Chan *ch, void *data, int len, vlong off)
{
	char *b, *p, *e;

	for(b = p = data, e = p + len; p < e; p++){
		if(p < b)
			continue;
		switch(*p){
		case '\n':
			off += writechan((Chan*)ch->aux, b, p - b, off);
			off += writechan((Chan*)ch->aux, "\r\n", 2, off);
			b = p+1;
			break;
		}
	}
	if(p > b)
		writechan((Chan*)ch->aux, b, p - b, off);
	return len;
}
static int
transCRLF2LF(Chan *ch, void *data, int len, vlong off)
{
	char *b, *p, *e;

	for(b = p = data, e = p + len; p < e; p++){
		if(p < b)
			continue;
		switch(*p){
		case '\r':
			off += writechan((Chan*)ch->aux, b, p - b, off);
			b = p+1;
			break;
		}
	}
	if(p > b)
		writechan((Chan*)ch->aux, b, p - b, off);
	return len;
}

typedef struct Relay Relay;
struct Relay
{
	int pid, notepid;
	Chan *from, *to;
};

static void
relay(void *aux)
{
	uchar buf[0x1000];
	Relay *r;
	int n, notepid;

	r = aux;
	notepid = r->notepid;
	if(!waserror()){
		while((n = readchan(r->from, buf, sizeof(buf), r->from->off)) >= 0)
			writechan(r->to, buf, n, r->to->off);
		poperror();
	}
	freechan(r->from);
	freechan(r->to);
	r->pid = -1;
	threadnotify(notepid, "interrupt");
}

static void
dontinherit(HANDLE *ph)
{
	HANDLE old, tmp;

	old = *ph;
	if(!DuplicateHandle(
		GetCurrentProcess(), old, 
		GetCurrentProcess(), &tmp,
		0, FALSE, DUPLICATE_SAME_ACCESS))
		error("DuplicateHandle: %x", GetLastError());
	*ph = tmp;
	CloseHandle(old);
}

int
relayfinished(void *aux)
{
	Relay *rr = aux;
	return (rr[0].pid==-1) && (rr[1].pid==-1) && (rr[2].pid==-1);
}

void
shell(Chan *cin, Chan *cout, Chan *cerr, Chan *cnote, char *cmd, char *user, HANDLE utok)
{
	Chan *rpipe[2], *wpipe[2], *epipe[2];
	Relay rr[3];
	PROCESS_INFORMATION pi;
	STARTUPINFO si;
	DWORD err;
	Rendez r;

	initrendez(&r);
	if(waserror()){
		freerendez(&r);
		nexterror();
	}
	mkpipechan(rpipe);
	if(waserror()){
		freechan(rpipe[0]);
		freechan(rpipe[1]);
		nexterror();
	}
	mkpipechan(wpipe);
	if(waserror()){
		freechan(wpipe[0]);
		freechan(wpipe[1]);
		nexterror();
	}
	mkpipechan(epipe);
	if(waserror()){
		freechan(epipe[0]);
		freechan(epipe[1]);
		nexterror();
	}

	dontinherit(pfilehandle(rpipe[1]));
	dontinherit(pfilehandle(wpipe[0]));
	dontinherit(pfilehandle(epipe[0]));

	memset(&si, 0, sizeof(si));
	si.cb = sizeof(si);
	si.dwFlags = STARTF_USESTDHANDLES;
	si.hStdInput = *pfilehandle(rpipe[0]);
	si.hStdOutput = *pfilehandle(wpipe[1]);
	si.hStdError = *pfilehandle(epipe[1]);
	si.lpDesktop = "";
	memset(&pi, 0, sizeof(pi));
	if(utok){
		PROFILEINFO pri;
		char prof[MAX_PATH];
		DWORD nprof;
		void *env;

		nprof = sizeof(prof)-1;
		if(!GetUserProfileDirectory(utok, prof, &nprof))
			error("GetUserProfileDirectory: %x", GetLastError());
		memset(&pri, 0, sizeof(pri));
		pri.dwSize = sizeof(pri);
		pri.lpUserName = user;
		pri.lpProfilePath = prof;

		/* temporarily change back to LOCAL SYSTEM user to load the profile */
		if(!RevertToSelf())
			error("RevertToSelf: %x", GetLastError());
		if(!LoadUserProfile(utok, &pri)){
			err = GetLastError();
			ImpersonateLoggedOnUser(utok);
			error("LoadUserProfile: %x", err);
		}
		ImpersonateLoggedOnUser(utok);

		if(!CreateEnvironmentBlock(&env, utok, FALSE))
			error("CreateEnvironmentBlock: %x", GetLastError());
		if(!CreateProcessAsUser(utok, nil, cmd, nil, nil, TRUE, 
		    CREATE_UNICODE_ENVIRONMENT|CREATE_NEW_PROCESS_GROUP|CREATE_NO_WINDOW,
			env, prof, &si, &pi)){
			err = GetLastError();
			DestroyEnvironmentBlock(env);
			UnloadUserProfile(utok, pri.hProfile);
			error("CreateProcessAsUser: %x", err);
		}
		DestroyEnvironmentBlock(env);
		UnloadUserProfile(utok, pri.hProfile);
	} else {
		if(!CreateProcess(nil, cmd, nil, nil, TRUE, 
		    CREATE_NEW_PROCESS_GROUP|CREATE_NO_WINDOW,
			nil, nil, &si, &pi))
			error("CreateProcess: %x", GetLastError());
	}
	freechan(rpipe[0]); rpipe[0] = nil;
	freechan(wpipe[1]); wpipe[1] = nil;
	freechan(epipe[1]); epipe[1] = nil;

	memset(rr, 0, sizeof(rr));
	rr[0].pid = rr[1].pid = rr[2].pid = -1;
	if(waserror()){
		TerminateProcessTree(pi.hProcess, 1);
		CloseHandle(pi.hThread);
		CloseHandle(pi.hProcess);
		threadnotify(rr[0].pid, "interrupt");
		while(waserror())
			;
		rsleep(&r, relayfinished, rr);
		poperror();
		nexterror();
	}

	/* spawn relay threads */
	rr[0].notepid = getpid();
	rr[0].from = getchan(cin);
	rr[0].to = mktranschan(rpipe[1], transLF2CRLF); rpipe[1] = nil;
	if((rr[0].pid = threadcreate(relay, &rr[0])) < 0){
		freechan(rr[0].from);
		freechan(rr[0].to);
		error("relay stdin");
	}
	rr[1].notepid = getpid();
	rr[1].from = wpipe[0]; wpipe[0] = nil;
	rr[1].to = mktranschan(getchan(cout), transCRLF2LF);
	if((rr[1].pid = threadcreate(relay, &rr[1])) < 0){
		freechan(rr[1].from);
		freechan(rr[1].to);
		error("relay stdout");
	}
	rr[2].notepid = getpid();
	rr[2].from = epipe[0]; epipe[0] = nil;
	rr[2].to = mktranschan(getchan(cerr), transCRLF2LF);
	if((rr[2].pid = threadcreate(relay, &rr[2])) < 0){
		freechan(rr[2].from);
		freechan(rr[2].to);
		error("relay stderr");
	}

	poperror(); /* process */
	poperror(); /* epipe[] */
	poperror(); /* wpipe[] */
	poperror(); /* rpipe[] */

	if(cnote){
		char buf[ERRMAX];
		int n;

		while(!waserror()){
			if((n = readchan(cnote, buf, sizeof(buf)-1, cnote->off)) <= 0)
				error(Ehungup);
			poperror();

			buf[n] = 0;
			if(!strcmp(buf, "hangup"))
				break;
			if(!strcmp(buf, "interrupt"))
				break;
		}
	} else {
		while(waserror())
			;
		rsleep(&r, relayfinished, rr);
		poperror();
	}

	TerminateProcessTree(pi.hProcess, 1);
	CloseHandle(pi.hThread);
	CloseHandle(pi.hProcess);
	threadnotify(rr[0].pid, "interrupt");

	while(waserror())
		;
	rsleep(&r, relayfinished, rr);
	poperror();

	poperror(); /* r */
	freerendez(&r);
}
