/*
 * Copyright 1999, Alexander Feldman <alex@varna.net>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. Neither the name of Alexander Feldman nor the names of its contributors
 *    may be used to endorse or promote products derived from this software
 *    without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY ALEXANDER FELDMAN AND CONTRIBUTORS ``AS IS''
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL ALEXANDER FELDMAN OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <syslog.h>

#include <errno.h>

#ifndef WIN32
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <netinet/in.h>

#else
#include "winsock2.h"
#endif

#include "passwddio.hpp"
#include "strlcpy.h"
#include "sockio.hpp"

#ifndef WIN32
#define READ read
#define WRITE write
#else
#define READ(fd, buf, cnt) recv((fd), (buf), (cnt), 0)
#define WRITE(fd, buf, cnt) send((fd), (buf), (cnt), 0)
#endif

CRSAKey *pLocalRSAKey = NULL;
CRSAKey *pRemoteRSAKey = NULL;

#ifdef WIN32
BOOL APIENTRY DllMain(HANDLE hModule, DWORD ul_reason_for_call, LPVOID lpReserved)
{
	switch (ul_reason_for_call) {
		case DLL_PROCESS_ATTACH:
		case DLL_THREAD_ATTACH:
		case DLL_THREAD_DETACH:
		case DLL_PROCESS_DETACH:
			break;
	}
	return TRUE;
}
#endif

int PASSWDDIO_API read_string(int sockfd, char *pszString, int iMaxLength)
{
	char *pcData = NULL;

	int iResult = read_string(sockfd, &pcData);
	if (-1 != iResult) {
		if (strlcpy(pszString, pcData, iMaxLength) > (size_t)iMaxLength)
			iResult = -1;
		free(pcData);
		return iResult;
	}
	return -1;
}

int PASSWDDIO_API read_string(int sockfd, char **ppcData)
{
	return read_data(sockfd, (void **)ppcData);
}

int PASSWDDIO_API write_string(int sockfd, char *pszString)
{
	return write_data(sockfd, pszString, strlen(pszString) + 1);
}

int PASSWDDIO_API read_integer(int sockfd, int *pi)
{
	bool fgResult = (READ(sockfd, (char *)pi, sizeof(int)) == sizeof(int));

	*pi = ntohl(*pi);

	return fgResult;
}

int PASSWDDIO_API write_integer(int sockfd, int i)
{
	int j = htonl(i);
	return (WRITE(sockfd, (char *)&j, sizeof(j)) == sizeof(j));
}

int PASSWDDIO_API read_byte(int sockfd, unsigned char *pc)
{
	return (READ(sockfd, (char *)pc, sizeof(unsigned char)) == sizeof(unsigned char));
}

int PASSWDDIO_API write_byte(int sockfd, unsigned char c)
{
	return (WRITE(sockfd, (char *)&c, sizeof(unsigned char)) == sizeof(unsigned char));
}

int PASSWDDIO_API read_data(int sockfd, void **ppvData)
{
	int iDataLength = -1;
	if (false == read_integer(sockfd, &iDataLength))
		return -1;

	if (iDataLength < MAX_DATA) {
		void *pvResult = calloc(iDataLength, 1);
		if (NULL == pvResult)
			return -1;
		if (iDataLength != READ(sockfd, (char *)pvResult, iDataLength)) {
			free(pvResult);
			return -1;
		}
		*ppvData = pvResult;
	}
	return iDataLength;
}

int PASSWDDIO_API write_data(int sockfd, void *pvData, int iDataLength)
{
	if (!write_integer(sockfd, iDataLength))
		return -1;

	int iResult = WRITE(sockfd, (const char *)pvData, iDataLength);

	if (iResult != iDataLength)
		return -1;

	return iResult;
}

int PASSWDDIO_API read_password(int sockfd, char *pszPassword, int iMaxLength)
{
	if (NULL != pLocalRSAKey) {
		try
		{
			CRSABlock cRSABlock(*pLocalRSAKey);
			cRSABlock.Read(sockfd);
			cRSABlock.Decrypt();
			strncpy(pszPassword,
					(char *)cRSABlock.GetData(),
					Min(cRSABlock.GetDataSize(), (unsigned)iMaxLength));
		}
		catch (...)
		{
			return 0;
		}
	}
	return 1;
}

int PASSWDDIO_API write_password(int sockfd, char *pszPassword)
{
	if (NULL != pRemoteRSAKey) {
		try
		{
			int iPasswordLength = strlen(pszPassword) + 1;
			int iBlockLength = (pRemoteRSAKey->GetModulusSize() - 1) / BITSINBYTE - BYTESINWORD * 2;
			if (iPasswordLength > iBlockLength)
				return 0;
			char *pbBuffer = (char *)calloc(iBlockLength, 1);
			if (NULL == pbBuffer)
				return 0;
			cPRNG.GetRandomData(pbBuffer + iPasswordLength, iBlockLength - iPasswordLength);
			memcpy(pbBuffer, pszPassword, iPasswordLength);
			CRSABlock cRSABlock(*pRemoteRSAKey, pbBuffer, iBlockLength);
			free(pbBuffer);
			cRSABlock.Encrypt();
			cRSABlock.Write(sockfd);
		}
		catch (char *p)
		{
			syslog(LOG_NOTICE, "%s", p);
			return 0;
		}
		catch (...)
		{
			return 0;
		}
	}
	return 1;
}

int PASSWDDIO_API read_local_key(char *pszKeyFile)
{
	CRSAKey *pNewRSAKey;
	if (NULL == (pNewRSAKey = new CRSAKey()))
		return 1;
	try 
	{
		pNewRSAKey->ReadPrivate(pszKeyFile, true);
	} 
	catch (...)
	{
		return 2;
	}
	pLocalRSAKey = pNewRSAKey;
	return 0;
}

int PASSWDDIO_API recv_remote_key(int sockfd, char *pszStoredKey)
{
	CRSAKey *pNewRSAKey;
	if (NULL == (pNewRSAKey = new CRSAKey()))
		return 1;
	try
	{
		pNewRSAKey->ReadPublicKey(sockfd, false);
	}
	catch (...)
	{
		return 2;
	}
	struct stat strStat;
	if ((-1 != stat(pszStoredKey, &strStat)) || (errno != ENOENT)) {
		CRSAKey cOldRSAKey;
		try
		{
			cOldRSAKey.ReadPublic(pszStoredKey);
		}
		catch (...)
		{
			return 3;
		}
		if (cOldRSAKey.GetModulus() != pNewRSAKey->GetModulus() ||
			cOldRSAKey.GetPublic() != pNewRSAKey->GetPublic())
			return 4;
	} else {
		try
		{
			pNewRSAKey->WritePublic(pszStoredKey);
		}
		catch (...)
		{
			return 5;
		}
	}
	pRemoteRSAKey = pNewRSAKey;
	return 0;
}

int PASSWDDIO_API send_local_key(int sockfd)
{
	try
	{
		pLocalRSAKey->WritePublicKey(sockfd, false);
	}
	catch (...)
	{
		return 1;
	}
	return 0;
}

void PASSWDDIO_API destroy_local_key()
{
	if (NULL != pLocalRSAKey)
		delete pLocalRSAKey;
	pLocalRSAKey = NULL;
}

void PASSWDDIO_API destroy_remote_key()
{
	if (NULL != pRemoteRSAKey)
		delete pRemoteRSAKey;
	pRemoteRSAKey = NULL;
}
