#include "u.h"
#include "libc.h"
#include "dat.h"
#include "fns.h"

typedef struct SockAux SockAux;
struct SockAux
{
	SOCKET s;
	HANDLE e;
	char *err;
};

static int
sockread(Chan *ch, void *data, int len, vlong off)
{
	SockAux *a;
	int err;

	a = ch->aux;
	for(;;){
		HANDLE e[2];
		WSANETWORKEVENTS v;

		if(a->err)
			error(a->err);
		e[0] = a->e;
		e[1] = getevent();
		checkinterrupt();
		switch(WSAWaitForMultipleEvents(nelem(e), e, FALSE, WSA_INFINITE, FALSE)){
		case WSA_WAIT_EVENT_0:
			memset(&v, 0, sizeof(v));
			if(WSAEnumNetworkEvents(a->s, a->e, &v) == SOCKET_ERROR)
				error("WSAEnumNetworkEvents: %x", WSAGetLastError());
			if(v.lNetworkEvents & FD_CLOSE){
				a->err = "hangup";
				return 0;
			}
			if(v.lNetworkEvents & FD_READ){
				if((err = recv(a->s, data, len, 0)) == SOCKET_ERROR){
					if((err = WSAGetLastError()) == WSAEWOULDBLOCK)
						break;
					a->err = "hangup: recv";
					error("recv: %x", err);
				}
				return err;
			}
		case WSA_WAIT_EVENT_0+1:
			break;
		default:
			error("WSAWaitForMultipleEvents: %x", WSAGetLastError());
		}
	}
}
static int
sockwrite(Chan *ch, void *data, int len, vlong off)
{
	SockAux *a;
	int err;

	a = ch->aux;
	if(a->err)
		error(a->err);
	if((err = send(a->s, data, len, 0)) == SOCKET_ERROR){
		err = WSAGetLastError();
		a->err = "hangup: send";
		WSASetEvent(a->e);
		error("send: %x", err);
	}
	return err;
}
static void
sockclose(Chan *ch)
{
	SockAux *a;

	a = ch->aux;
	CloseHandle(a->e);
	closesocket(a->s);
	free(a);
}

Chan*
mksockchan(SOCKET s)
{
	Chan *ch;
	SockAux *a;

	a = mallocz(sizeof(*a), 1);
	a->e = CreateEvent(nil, TRUE, FALSE, nil);
	a->s = s;
	if(WSAEventSelect(a->s, a->e, FD_READ | FD_CLOSE) == SOCKET_ERROR)
		error("WSAEventSelect: %x", WSAGetLastError());
	ch = mkchan();
	ch->aux = a;
	ch->read = sockread;
	ch->write = sockwrite;
	ch->close = sockclose;
	return ch;
}

void
initsock(void)
{
	WSADATA w;

	if(WSAStartup(MAKEWORD(2,0), &w))
		error("WSAStartup");
}

void
tcpserver(int port, void (*client)(void *))
{
	SOCKET s;
	SOCKADDR_IN a;

	if((s = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET)
		error("socket: %x", WSAGetLastError());
	if(waserror()){
		closesocket(s);
		nexterror();
	}
	memset(&a, 0, sizeof(a));
	a.sin_family=AF_INET;
	a.sin_port=htons(port);
	a.sin_addr.s_addr=ADDR_ANY;
	if(bind(s, (SOCKADDR*)&a, sizeof(a)) == SOCKET_ERROR)
		error("bind: %x", WSAGetLastError());
	if(listen(s, 10) == SOCKET_ERROR)
		error("listen: %x", WSAGetLastError());
	for(;;){
		SOCKET cs;

		if((cs = accept(s, nil, nil)) == INVALID_SOCKET)
			error("accept: %x", WSAGetLastError());
		if(threadcreate(client, mksockchan(cs)) < 0)
			closesocket(cs);
	}
}
