#include "u.h"
#include "libc.h"
#include "dat.h"
#include "fns.h"

#include <wincrypt.h>
#include <wintrust.h>
#include <schannel.h>

#define SECURITY_WIN32
#include <security.h>
#include <sspi.h>

#define DLL_NAME TEXT("Secur32.dll")
#define NT4_DLL_NAME TEXT("Security.dll")

typedef struct TlsAux TlsAux;
struct TlsAux
{
	Chan *slave;
	CtxtHandle ctx;
	CredHandle srvcred;
	SecPkgContext_StreamSizes sizes;
	Lock wl, rl;
	struct {
		int n;
		uchar *b;
	} backoff;	/* protected by wl */
	struct {
		uchar *b, *p, *e;
		SecBuffer bufs[4];
	    SecBufferDesc desc;
		uchar buf[0x1000];
	} i, o;
};

static PSecurityFunctionTable sspi;

void
inittls(void)
{
    INIT_SECURITY_INTERFACE pInitSecurityInterface;
    OSVERSIONINFO VerInfo;
    UCHAR lpszDLL[MAX_PATH];
	HMODULE hdll;

	memset(&VerInfo, 0, sizeof(VerInfo));
    VerInfo.dwOSVersionInfoSize = sizeof(OSVERSIONINFO);
    if (!GetVersionEx(&VerInfo))
		error("GetVersionEx: %x", GetLastError());
    if (VerInfo.dwPlatformId == VER_PLATFORM_WIN32_NT && VerInfo.dwMajorVersion == 4)
        strcpy(lpszDLL, NT4_DLL_NAME);
    else
        strcpy(lpszDLL, DLL_NAME);
	if((hdll = LoadLibrary(lpszDLL)) == nil)
		error("LoadLibrary %s: %x", lpszDLL, GetLastError());
    if((pInitSecurityInterface = (INIT_SECURITY_INTERFACE)GetProcAddress(hdll, "InitSecurityInterfaceA")) == nil)
		error("GetProcAddress InitSecurityInterfaceA: %x", GetLastError());
    if((sspi = pInitSecurityInterface()) == nil)
		error("InitSecurityInterface: %x", GetLastError());
}

static void
getcreds(char *user, CredHandle *creds)
{
	static HCERTSTORE certstore = nil;
    SCHANNEL_CRED   SchannelCred;
    CERT_CONTEXT    *pcert;
    TimeStamp       expiry;
    SECURITY_STATUS status;

    if(user == nil || strlen(user) == 0)
		error("no cert subject given");
	if(certstore == nil)
		if((certstore = CertOpenStore(CERT_STORE_PROV_SYSTEM, X509_ASN_ENCODING, 0, CERT_SYSTEM_STORE_LOCAL_MACHINE, L"MY")) == nil)
			error("CertOpenStore MY: %x", GetLastError());
	if((pcert = CertFindCertificateInStore(certstore, X509_ASN_ENCODING, 0, CERT_FIND_SUBJECT_STR_A, user, nil)) == nil)
		error("CertFindCertificateInStore: %x", GetLastError());
	if(waserror()){
		CertFreeCertificateContext(pcert);
		nexterror();
	}
	memset(&SchannelCred, 0, sizeof(SchannelCred));
    SchannelCred.dwVersion = SCHANNEL_CRED_VERSION;
    SchannelCred.cCreds = 1;
    SchannelCred.paCred = &pcert;
	SchannelCred.dwFlags = SCH_CRED_NO_SYSTEM_MAPPER;
	SchannelCred.grbitEnabledProtocols = 0;
    if((status = sspi->AcquireCredentialsHandle(
		nil,                    // Name of principal
		UNISP_NAME_A,			// Name of package
		SECPKG_CRED_INBOUND,    // Flags indicating use
		nil,                    // Pointer to logon ID
		&SchannelCred,          // Package specific data
		nil,                    // Pointer to GetKey() func
		nil,                    // Value to pass to GetKey()
		creds,                  // (out) Cred Handle
		&expiry)) != SEC_E_OK)	// (out) Lifetime (optional)
		error("AcquireCredentialsHandle: %x", status);
	poperror();
	CertFreeCertificateContext(pcert);
}

static void
negotiate(TlsAux *a)
{
	CtxtHandle *pctx, *pcctx;
    SECURITY_STATUS status;
    TimeStamp expiry;
    DWORD oflags;
	int n;

	pcctx = nil;
	pctx = &a->ctx;

    status = SEC_I_CONTINUE_NEEDED;
    while(status == SEC_I_CONTINUE_NEEDED || status == SEC_E_INCOMPLETE_MESSAGE || 
		status == SEC_I_INCOMPLETE_CREDENTIALS){
        if((a->i.p == a->i.b) || status == SEC_E_INCOMPLETE_MESSAGE){
			if((n = readchan(a->slave, a->i.p, a->i.e - a->i.p, a->slave->off)) <= 0)
				error(Ehungup);
			a->i.p += n;
        }

		a->i.bufs[0].pvBuffer = a->i.b;
		a->i.bufs[0].cbBuffer = a->i.p - a->i.b;
		a->i.bufs[0].BufferType = SECBUFFER_TOKEN;
        a->i.desc.cBuffers = 2;

		a->o.bufs[0].pvBuffer = nil;
		a->o.bufs[0].cbBuffer = 0;
		a->o.bufs[0].BufferType = SECBUFFER_TOKEN;
		a->o.desc.cBuffers = 1;

		status = sspi->AcceptSecurityContext(
			&a->srvcred,
            pcctx,
			&a->i.desc,
			ASC_REQ_SEQUENCE_DETECT |
			ASC_REQ_REPLAY_DETECT |
			ASC_REQ_CONFIDENTIALITY |
			ASC_REQ_EXTENDED_ERROR |
			ASC_REQ_ALLOCATE_MEMORY |
			ASC_REQ_STREAM,
			SECURITY_NATIVE_DREP,
            pctx,
			&a->o.desc,
            &oflags,
            &expiry);
		if(pctx){
			pcctx = pctx;
			pctx = nil;
		}
		if(status == SEC_E_OK || status == SEC_I_CONTINUE_NEEDED || (FAILED(status) && (oflags & ISC_RET_EXTENDED_ERROR))){
			if(a->o.bufs[0].cbBuffer && a->o.bufs[0].pvBuffer){
				writechan(a->slave, a->o.bufs[0].pvBuffer, a->o.bufs[0].cbBuffer, a->slave->off);
                sspi->FreeContextBuffer(a->o.bufs[0].pvBuffer);
				a->o.bufs[0].BufferType = SECBUFFER_EMPTY;
				a->o.bufs[0].pvBuffer = nil;
				a->o.bufs[0].cbBuffer = 0;
            }
        }
        if(status == SEC_E_OK){
            if(a->i.bufs[1].BufferType == SECBUFFER_EXTRA){
				memmove(a->i.b, a->i.bufs[1].pvBuffer, a->i.bufs[1].cbBuffer);
				a->i.p = a->i.b + a->i.bufs[0].cbBuffer;
            } else
				a->i.p = a->i.b;
			if((status = sspi->QueryContextAttributes(&a->ctx, SECPKG_ATTR_STREAM_SIZES, &a->sizes)) != SEC_E_OK)
				error("QueryContextAttributes: %x", status);
			return;
        }
		if(FAILED(status) && (status != SEC_E_INCOMPLETE_MESSAGE))
			break;
		if(status != SEC_E_INCOMPLETE_MESSAGE && status != SEC_I_INCOMPLETE_CREDENTIALS){
            if(a->i.bufs[1].BufferType == SECBUFFER_EXTRA){
				memmove(a->i.b, a->i.bufs[1].pvBuffer, a->i.bufs[1].cbBuffer);
				a->i.p = a->i.b + a->i.bufs[0].cbBuffer;
            } else
				a->i.p = a->i.b;
        }
    }
	error("AcceptSecurityContext: %x", status);
}

static void
closedown(TlsAux *a)
{
    SECURITY_STATUS status;
	TimeStamp expiry;
	DWORD dw, oflags;

	lock(&a->wl);
	dw = SCHANNEL_SHUTDOWN;
	a->o.bufs[0].pvBuffer = &dw;
	a->o.bufs[0].cbBuffer = sizeof(dw);
	a->o.bufs[0].BufferType = SECBUFFER_TOKEN;
	a->o.desc.cBuffers = 1;
	status = sspi->ApplyControlToken(&a->ctx, &a->o.desc);
	if(FAILED(status))
		goto out;
    a->o.bufs[0].pvBuffer   = NULL;
    a->o.bufs[0].BufferType = SECBUFFER_TOKEN;
    a->o.bufs[0].cbBuffer   = 0;
	a->o.desc.cBuffers  = 1;
    status = sspi->AcceptSecurityContext(
		&a->srvcred,
		&a->ctx,
		nil,
		ASC_REQ_SEQUENCE_DETECT |
		ASC_REQ_REPLAY_DETECT |
		ASC_REQ_CONFIDENTIALITY |
		ASC_REQ_EXTENDED_ERROR |
		ASC_REQ_ALLOCATE_MEMORY |
		ASC_REQ_STREAM,
		SECURITY_NATIVE_DREP,
		nil,
		&a->o.desc,
		&oflags,
		&expiry);
	if(!FAILED(status) && a->o.bufs[0].cbBuffer && a->o.bufs[0].pvBuffer){
		if(!waserror()){
			writechan(a->slave, a->o.bufs[0].pvBuffer, a->o.bufs[0].cbBuffer, a->slave->off);
			poperror();
		}
		sspi->FreeContextBuffer(a->o.bufs[0].pvBuffer);
		a->o.bufs[0].BufferType = SECBUFFER_EMPTY;
		a->o.bufs[0].pvBuffer = nil;
		a->o.bufs[0].cbBuffer = 0;
	}
out:
	unlock(&a->wl);
}

static void
closetls(Chan *ch)
{
	TlsAux *a;

	a = ch->aux;
	closedown(a);
	freechan(a->slave);
	sspi->DeleteSecurityContext(&a->ctx);
	sspi->FreeCredentialsHandle(&a->srvcred);
	free(a->backoff.b);
	freelock(&a->rl);
	freelock(&a->wl);
	free(a);
}

static int
readtls(Chan *ch, void *data, int len, vlong off)
{
    SECURITY_STATUS status;
	TlsAux *a;
	int i, n;

	a = ch->aux;
	lock(&a->rl);
	if(waserror()){
		unlock(&a->rl);
		nexterror();
	}
	if(a->backoff.b && (n = a->backoff.n)){
		if(n > len)
			n = len;
		memmove(data, a->backoff.b, n);
		if(a->backoff.n == n){
			free(a->backoff.b);
			a->backoff.n = 0;
			a->backoff.b = nil;
		} else {
			memmove(a->backoff.b, a->backoff.b + n, a->backoff.n - n);
			a->backoff.n -= n;
		}
		goto out;
	}

	status = SEC_E_OK;
	do{
		a->i.bufs[0].BufferType = SECBUFFER_DATA;
		a->i.bufs[0].pvBuffer = a->i.b;
		a->i.bufs[0].cbBuffer = a->i.p - a->i.b;

		a->i.bufs[1].BufferType = SECBUFFER_EMPTY;
		a->i.bufs[1].pvBuffer = nil;
		a->i.bufs[1].cbBuffer = 0;

		a->i.bufs[2].BufferType = SECBUFFER_EMPTY;
		a->i.bufs[2].pvBuffer = nil;
		a->i.bufs[2].cbBuffer = 0;

		a->i.bufs[3].BufferType = SECBUFFER_EMPTY;
		a->i.bufs[3].pvBuffer = nil;
		a->i.bufs[3].cbBuffer = 0;

		a->i.desc.cBuffers = 4;

		status = sspi->DecryptMessage(&a->ctx, &a->i.desc, 0, nil);
		if(status == SEC_E_INCOMPLETE_MESSAGE){
			if((n = readchan(a->slave, a->i.p, a->i.e - a->i.p, a->slave->off)) <= 0){
				n = 0;
				goto out;
			}
			a->i.p += n;
		}
	} while(status == SEC_E_INCOMPLETE_MESSAGE);
	n = 0;
	if(status == SEC_I_CONTEXT_EXPIRED)
		goto out;
	if(status != SEC_E_OK)
		error("DecryptMessage: %x", status);
	for(i=1; i < (int)a->i.desc.cBuffers; i++){
		if(a->i.bufs[i].BufferType == SECBUFFER_DATA){
			n = a->i.bufs[i].cbBuffer;
			break;
		}
	}
	if(n == 0)
		goto out;
	if(n > len){
		a->backoff.n = n - len;
		a->backoff.b = mallocz(a->backoff.n, 0);
		memmove(a->backoff.b, (uchar*)a->i.bufs[i].pvBuffer + len, a->backoff.n);
		n = len;
	}
	memmove(data, a->i.bufs[i].pvBuffer, n);

	a->i.p = a->i.b;
	for(i=1; i < (int)a->i.desc.cBuffers; i++){
		if(a->i.bufs[i].BufferType == SECBUFFER_EXTRA){
			memmove(a->i.p, a->i.bufs[i].pvBuffer, a->i.bufs[i].cbBuffer);
			a->i.p += a->i.bufs[i].cbBuffer;
		}
	}

out:
	poperror();
	unlock(&a->rl);
	return n;
}

static int
writetls(Chan *ch, void *data, int len, vlong off)
{
    SECURITY_STATUS status;
	uchar *b, *p, *e;
	TlsAux *a;
	int w, n;

	a = ch->aux;
	lock(&a->wl);
	if(waserror()){
		unlock(&a->wl);
		nexterror();
	}
	p = b = data;
	e = b + len;
	while(p < e){
		n = a->sizes.cbMaximumMessage;
		if((e - p) < n)
			n = e - p;
		if(n + a->sizes.cbHeader + a->sizes.cbTrailer > sizeof(a->i.buf))
			n = sizeof(a->i.buf) - (a->sizes.cbHeader + a->sizes.cbTrailer);

		a->o.p = a->o.b = a->o.buf;
		a->o.e = a->o.b + a->sizes.cbHeader + n + a->sizes.cbTrailer;

		a->o.bufs[0].BufferType = SECBUFFER_STREAM_HEADER;
		a->o.bufs[0].pvBuffer = a->o.b;
		a->o.bufs[0].cbBuffer = a->sizes.cbHeader;
		memset(a->o.bufs[0].pvBuffer, 0, a->o.bufs[0].cbBuffer);

		a->o.bufs[1].BufferType = SECBUFFER_DATA;
		a->o.bufs[1].pvBuffer = a->o.b + a->sizes.cbHeader;
		a->o.bufs[1].cbBuffer = n;
		memmove(a->o.bufs[1].pvBuffer, p, a->o.bufs[1].cbBuffer);

		a->o.bufs[2].BufferType = SECBUFFER_STREAM_TRAILER;
		a->o.bufs[2].pvBuffer = a->o.b + a->sizes.cbHeader + n;
		a->o.bufs[2].cbBuffer = a->sizes.cbTrailer;
		memset(a->o.bufs[2].pvBuffer, 0, a->o.bufs[2].cbBuffer);

		a->o.desc.cBuffers = 3;

		status = sspi->EncryptMessage(&a->ctx, 0, &a->o.desc, 0);
		if(FAILED(status))
			error("EncryptMessage: %x", status);
		w = a->o.bufs[0].cbBuffer + a->o.bufs[1].cbBuffer + a->o.bufs[2].cbBuffer;
		if(writechan(a->slave, a->o.b, w, a->slave->off) != w)
			error(Ehungup);
		p += n;
	}
	n = p - b;
	poperror();
	unlock(&a->wl);
	return n;
}

static void
initbufdescr(TlsAux *a)
{
	int i;

	a->i.p = a->i.b = a->i.buf;
	a->i.e = a->i.b + sizeof(a->i.buf);
	for(i=0; i<nelem(a->i.bufs); i++){
		a->i.bufs[i].BufferType = SECBUFFER_EMPTY;
		a->i.bufs[i].pvBuffer = nil;
		a->i.bufs[i].cbBuffer = 0;
	}
	a->i.desc.ulVersion = SECBUFFER_VERSION;
	a->i.desc.pBuffers = a->i.bufs;
	a->i.desc.cBuffers = nelem(a->i.bufs);

	a->o.p = a->o.b = a->i.buf;
	a->o.e = a->o.b + sizeof(a->i.buf);
	for(i=0; i<nelem(a->i.bufs); i++){
		a->o.bufs[i].BufferType = SECBUFFER_EMPTY;
		a->o.bufs[i].pvBuffer = nil;
		a->o.bufs[i].cbBuffer = 0;
	}
	a->o.desc.ulVersion = SECBUFFER_VERSION;
	a->o.desc.pBuffers = a->o.bufs;
	a->o.desc.cBuffers = nelem(a->o.bufs);
}

Chan*
tlsserver(Chan *slave, char *cert)
{
	TlsAux *a;
	Chan *ch;

	if(waserror())
		error("tlsserver: %s", errorstr());
	a = mallocz(sizeof(*a), 1);
	initlock(&a->rl);
	initlock(&a->wl);
	a->slave = slave;
	a->backoff.n = 0;
	a->backoff.b = nil;
	initbufdescr(a);
	if(waserror()){
		freelock(&a->rl);
		freelock(&a->wl);
		free(a);
		nexterror();
	}
	getcreds(cert, &a->srvcred);
	if(waserror()){
		sspi->FreeCredentialsHandle(&a->srvcred);
		nexterror();
	}
	negotiate(a);
	poperror();
	poperror();
	poperror();

	ch = mkchan();
	ch->aux = a;
	ch->read = readtls;
	ch->write = writetls;
	ch->close = closetls;
	return ch;
}
