#include "u.h"
#include "libc.h"
#include "dat.h"
#include "fns.h"

char Eintr[] = "interrupted";
char Ehungup[] = "i/o on hungup channel";
char Emountrpc[] = "mount rpc error";

void
__abort(void)
{
	abort();
}

void*
mallocz(int len, int zero)
{
	void *p;

	if((p = malloc(len)) == nil)
		__abort();
	if(zero)
		memset(p, 0, len);
	return p;
}

void
initlock(Lock *l)
{
	InitializeCriticalSection(&l->cs);
}
void
freelock(Lock *l)
{
	DeleteCriticalSection(&l->cs);
}

void
lock(Lock *l)
{
	EnterCriticalSection(&l->cs);
}

void
unlock(Lock *l)
{
	LeaveCriticalSection(&l->cs);
}


static int tls_thread;
static Thread *threads;
static Lock threadslock;

static Thread*
getthread(void)
{
	return TlsGetValue(tls_thread);
}

int
getpid(void)
{
	return getthread()->pid;
}

HANDLE
getevent(void)
{
	return getthread()->ev;
}

static DWORD WINAPI
startthread(LPVOID aux)
{
	Thread t, **tt;
	void (*proc)(void *);
	void *arg;

	proc = ((void**)aux)[0];
	arg = ((void**)aux)[1];

	memset(&t, 0, sizeof(t));
	initlock(&t.l);
	t.pid = GetCurrentThreadId();
	t.errp = t.err;
	t.ev = CreateEvent(nil, FALSE, FALSE, nil);
	TlsSetValue(tls_thread, &t);

	lock(&threadslock);
	t.next = threads;
	threads = &t;
	unlock(&threadslock);

	if(((void**)aux)[2])
		SetEvent(((void**)aux)[2]);

	if(!waserror()){
		(*proc)(arg);
		poperror();
	}

	lock(&threadslock);
	for(tt = &threads; *tt; tt = &((*tt)->next)){
		if(*tt == &t){
			*tt = t.next;
			t.next = nil;
			break;
		}
	}
	unlock(&threadslock);

	freelock(&t.l);
	CloseHandle(t.ev);

	return 0;
}

int
threadcreate(void (*proc)(void *), void *arg)
{
	HANDLE h;
	DWORD pid;
	void *aux[3];

	aux[0] = proc;
	aux[1] = arg;
	aux[2] = getevent();
	if((h = CreateThread(nil, STACKSIZE, startthread, aux, 0, &pid)) == nil)
		return -1;
	CloseHandle(h);
	WaitForSingleObject(aux[2], INFINITE);
	return pid;
}

int
threadnotify(int pid, char *note)
{
	Thread *t;
	int ret;

	ret = -1;
	lock(&threadslock);
	for(t = threads; t; t = t->next)
		if(t->pid == pid)
			break;
	if(t){
		lock(&t->l);
		t->note = note;
		ret = SetEvent(t->ev);
		unlock(&t->l);
	}
	unlock(&threadslock);
	return ret;
}

void
checkinterrupt(void)
{
	Thread *t;
	Rendez *r;
	char *note;

	t = getthread();
	lock(&t->l);
	if(note = t->note){
		t->note = nil;
		if(r = t->r){
			assert(r->t == t);
			t->r = nil;
			r->t = nil;
		}
	}
	unlock(&t->l);
	if(note)
		error(Eintr);
}

int
return0(void *arg)
{
	return 0;
}

void
initrendez(Rendez *r)
{
	r->t = nil;
	initlock(&r->l);
}
void
freerendez(Rendez *r)
{
	assert(r->t == nil);
	freelock(&r->l);
}

void
rsleep(Rendez *r, int (*test)(void *), void *a)
{
	Thread *t;

	t = getthread();
	lock(&r->l);
	lock(&t->l);
	assert(t->r == nil);
	assert(r->t == nil);
	while(!(*test)(a) && !t->note){
		r->t = t;
		t->r = r;
		unlock(&t->l);
		unlock(&r->l);
		WaitForSingleObject(t->ev, INFINITE);
		lock(&r->l);
		lock(&t->l);
		t->r = nil;
		r->t = nil;
	}
	unlock(&t->l);
	unlock(&r->l);
	checkinterrupt();
}

int
rwakeup(Rendez *r)
{
	Thread *t;
	int ret;

	ret = 0;
	lock(&r->l);
	if(t = r->t){
		lock(&t->l);
		assert(t->r == r);
		r->t = nil;
		t->r = nil;
		ret = SetEvent(t->ev);
		unlock(&t->l);
	}
	unlock(&r->l);
	return ret;
}

Errjmp*
pusherror(void)
{
	Thread *t;
	Errjmp *j;

	t = getthread();
	j = ++t->errp;
	assert(j < (t->err + nelem(t->err)));
	return j;
}

Errjmp*
poperror(void)
{
	Thread *t;
	Errjmp *j;

	t = getthread();
	j = t->errp--;
	assert(j > t->err);
	return j;
}

char*
errorstr(void)
{
	return getthread()->errstr;
}

void
error(char *fmt, ...)
{
	va_list a;
	char tmp[ERRMAX];

	va_start(a, fmt);
	vsnprintf(tmp, sizeof(tmp), fmt, a);
	strncpy(errorstr(), tmp, ERRMAX);
	nexterror();
}

static void
startthreadmain(void *aux)
{
	threadmain(*((int*)((void**)aux)[0]), (char**)((void**)aux)[1]);
}

void
main(int argc, char *argv[])
{
	void *arg[2], *aux[3];

	tls_thread = TlsAlloc();
	initlock(&threadslock);
	threads = nil;

	arg[0] = &argc;
	arg[1] = argv;

	aux[0] = startthreadmain;
	aux[1] = arg;
	aux[2] = nil;
	startthread(aux);
}
