#include <time.h>
#include <stdio.h>
#include <netdb.h>
#include <signal.h>
#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#include <sys/wait.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <errno.h>

#include "config.h"
#include "conf.hpp"
#include "pidfile.hpp"
#include "server.hpp"
#include "poll.hpp"
#include "resolv.hpp"
#include "logfile.hpp"
#include "firewall.hpp"
#include "account.hpp"
#include "tokenized_string.hpp"
#include "trim.hpp"
#include "version.hpp"
#include "rsa.hpp"

#define CONF_FILE					"/etc/alligator.conf"
#define PID_FILE 					"alligator.pid"
#define LOG_FILE					"alligator.log"

#define DEFAULT_SERVERPORT		1098

CConfigurationFile *pConfig = NULL;
CPidFile *pPidFile = NULL;
char szConfigFile[256] = CONF_FILE;
char *pszConfig = szConfigFile;

int iServerPort;

// Function prototypes
void signal_proc(int);
bool allow_host(struct in_addr);
bool is_trusted(struct in_addr);

void DisableEachBlocked(char *, char *);
void RotateStatic(char *, char *);
void CheckStatic(char *, char *);

void usage()
{
	printf("Usage: alligator [-c <config file>]\n"
			 "       alligator -v\n"
			 "Options:\n"
			 "  -c <file>   Use alternate config file instead the default\n"
			 "              '/etc/alligator.conf' system wide config file.\n"
			 "  -v          Print version and exit.\n"
			 "Written by: Alexander Feldman <alex@varna.net>\n\n");
}

#ifdef _SOCKLEN_T_UNDEFINED
	typedef int socklen_t;
#endif

int main(int argc, char **argv)
{
	struct sockaddr_in serv;
	struct sockaddr_in child;
	int listenfd;
	int sockfd;
	int pidstatus;
	int child_process;

	int w;
	while (-1 != (w = getopt(argc, argv, "c:v")))
		switch (w) {
			case 'c':
				strncpy(szConfigFile, optarg, sizeof(szConfigFile));
			break;
			case 'v':
				printf("%s\n", GetVersion());
				return 0;						// Successful exit
			break;
			default:
				usage();
				return 0;						// Successful exit
			break;
		}
	argc -= optind;							// optind is declared in <unistd.h>
	argv += optind;
	
	if (0 != argc) {
		usage();
		return 0;
	}

	tzset();
	
	pPidFile = new CPidFile(PID_FILE);
	if (NULL == pPidFile) {
		fprintf(stderr, "Error allocating memory...\n");
		return 1;
	}
	if (false == pPidFile->CheckPid()) {
		fprintf(stderr, "Daemon is already started...\n");
		return 1;
	}
	
	char *pszLogFile = new char[128];
	if (NULL == pszLogFile) {
		fprintf(stderr, "Error allocating memory...\n");
		return 1;
	}
	strcpy(pszLogFile, "alligator");
	strcat(pszLogFile, ".log");

#ifndef DEBUG
	int i = getdtablesize();
	for (int j = 0; j < i; j++)
		close(j);
	setsid();
#endif // DEBUG
	
	if (false == logopen(pszLogFile)) {
		delete pszLogFile;
		return 2;
	}
	delete pszLogFile;

#ifndef DEBUG
	int iPid;

	if ((iPid = fork()) > 0)
		return 0;				// Parrent process...
   
	if (iPid < 0) {
		lprintf("Initial fork error...\n");
		delete pPidFile;			
		return 1;
	}
#endif
	
	pPidFile->WritePid();

	signal(SIGTERM, signal_proc);
	signal(SIGHUP, signal_proc);
	signal(SIGINT, SIG_IGN);
	signal(SIGUSR1, signal_proc);
	signal(SIGALRM, signal_proc);
	signal(SIGCHLD, signal_proc);
	
	lprintf("Starting...\n");
	
	char *pszMessage;
	
	pConfig = new CConfigurationFile(pszConfig, &pszMessage);
	if (NULL == pConfig) {
		lprintf("Error allocating memory...\n");
		return 1;
	}
	if (NULL != pszMessage) {
		lprintf("%s\n", pszMessage);		
		return 1;
	}
	
	int iUpdateTime = pConfig->GetInteger("Global", "update_time", 60);
	if (iUpdateTime < 5 || iUpdateTime > 3600) {
		lprintf("Unvalid update time %d. Setting default value of 60 seconds...\n", iUpdateTime);
		iUpdateTime = 60;		
	}
	alarm(iUpdateTime);
	
	if ((listenfd = socket(PF_INET, SOCK_STREAM, 0)) < 0) {
		lprintf("socket error: %s\n", strerror(errno));
		return 1;
	}
	
	iServerPort = pConfig->GetInteger("Global", "port", DEFAULT_SERVERPORT);
	
	memset(&serv, 0, sizeof(serv));
	serv.sin_family = AF_INET;
	serv.sin_addr.s_addr = htonl(INADDR_ANY);
	serv.sin_port = htons(iServerPort);
	
	if (bind(listenfd, (struct sockaddr *)&serv, sizeof(serv)) < 0) {
		lprintf("bind error: %s\n", strerror(errno));
		return 1;
	}

	char *pszPrivateKey = pConfig->GetString("Security", "private_key", "/etc/passwdd.prk");
	char *pszPublicKey = pConfig->GetString("Security", "public_key", "/etc/passwdd.pbk");
	if (false == RSALoadPrivateKey(pszPrivateKey)) {
		lprintf("Error loading private key: %s\n", pszPrivateKey);
		delete pPidFile;
		delete pConfig;
		return 1;
	}
	if (false == RSALoadPublicKey(pszPublicKey)) {
		lprintf("Error loading public key: %s\n", pszPublicKey);
		delete pPidFile;
		delete pConfig;
		return 1;
	}

	if (listen(listenfd, SOMAXCONN) < 0) {
		lprintf("listen error: %s\n", strerror(errno));
		return 1;
	}

	CheckCurrent();
	pConfig->EnumSection("Blocked", DisableEachBlocked);
	pConfig->EnumSection("Static", CheckStatic);
	
	child_process = 0;
	while (0 == child_process) {			// Loop begins here
		memset(&child, 0, sizeof(child));
		socklen_t childlen = sizeof(child);
		if ((sockfd = accept(listenfd, (struct sockaddr *)&child, &childlen)) < 0) {
			if (EINTR != errno)
				lprintf("accept error (%s): %s\n", naddr2str(&child), strerror(errno));
			continue;
		}
		lprintf("opened client connection from %s\n", naddr2str(&child));
		if (false == allow_host(child.sin_addr)) {
			lprintf("unauthorized client connection from %s\n", naddr2str(&child));
			close(sockfd);
			continue;
		}
#ifndef DEBUG		
		pid_t pid;
		if ((pid = fork()) == 0) {			// This is the child
			close(listenfd);
			server(sockfd, is_trusted(child.sin_addr));
			close(sockfd);
			child_process = 1;				// Exit from main loop for the child process
		} else {									// We are parent, go on listening and error checking
			if (pid < 0)
				lprintf("fork error: %s\n", strerror(errno));
			while(waitpid(-1, &pidstatus, WNOHANG) > 0);
			close(sockfd);
			child_process = 0;
		}
#else
		server(sockfd, is_trusted(child.sin_addr));
#endif
	}
	
	return 0;
}

void signal_proc(int iSignal)
{
	int iUpdateTime;
	int iPidStatus;
	
	switch (iSignal) {
		case SIGALRM:
			poll();
			expired();
			iUpdateTime = pConfig->GetInteger("Global", "update_time", 60);
			if (iUpdateTime < 5 || iUpdateTime > 3600) {
				lprintf("Invalid update time %d. Setting default value of 60 seconds...\n", iUpdateTime);
				iUpdateTime = 60;		
			}
			alarm(iUpdateTime);
		break;
		case SIGUSR1:
			pConfig->EnumSection("Static", RotateStatic);
		break;
		case SIGHUP:
		break;
		case SIGTERM:
			lprintf("Exiting on signal %d...\n", iSignal);
			logclose();
			delete pConfig;
			delete pPidFile;
			exit(1);
		break;
		case SIGCHLD:
			while(waitpid(-1, &iPidStatus, WNOHANG) > 0);
		break;
	}
	signal(iSignal, signal_proc);
}

bool allow_host(struct in_addr lSource)
{
// We will compare the source address, the target address and target netmask
// all in host byte order. There is no matter what ordering we will use
// but is logically more correct to perform it in native ordering.	
	unsigned long int ulSourceHost = ntohl(lSource.s_addr);
	unsigned long int ulTargetHost;
	unsigned long int ulTargetMask;
	char *pszHost;
	bool fgInvert;

	CTokenizedString cAllowedHosts(pConfig->GetString("Permissions", "hosts_allow", "localhost"), ",;");

	if ((pszHost = cAllowedHosts.GetFirstString()))
		do {
			if (false == naddr2h(TrimBoth(pszHost, " \t\x0a\x0d"), &ulTargetHost, &ulTargetMask, &fgInvert)) {
				lprintf("configuration file error, skipping host\n");
				continue;
			}
			if (fgInvert && ((ulSourceHost & ulTargetMask) == ulTargetHost))
				return false;
			if ((ulSourceHost & ulTargetMask) == ulTargetHost)
				return true;
		} while ((pszHost = cAllowedHosts.GetNextString()));
	
	return false;
}

bool is_trusted(struct in_addr lSource)
{
// We will compare the source address, the target address and target netmask
// all in host byte order. There is no matter what ordering we will use
// but is logically more correct to perform it in native ordering.	
	unsigned long int ulSourceHost = ntohl(lSource.s_addr);
	unsigned long int ulTargetHost;
	unsigned long int ulTargetMask;
	char *pszHost;
	bool fgInvert;
	
	CTokenizedString cTrustedHosts(pConfig->GetString("Trusted", "trusted_hosts", ""), ";,");
	
	if ((pszHost = cTrustedHosts.GetFirstString()))
		 do {
			if (false == naddr2h(TrimBoth(pszHost, " \t\x0a\x0d"), &ulTargetHost, &ulTargetMask, &fgInvert)) {
				lprintf("configuration file error, skipping host\n");
				continue;
			}
			if (fgInvert && ((ulSourceHost & ulTargetMask) == ulTargetHost))
				return false;
			if ((ulSourceHost & ulTargetMask) == ulTargetHost)
				return true;
		 } while ((pszHost = cTrustedHosts.GetNextString()));
	
	return false;
}

void DisableEachBlocked(char *pszHost, char *pszPort)
{
	DisableHost(pszHost, atoi(pszPort));
}

void CheckStatic(char *pszLogin, char *pszHost)
{
	if (0 != UserUp(pszLogin, "*", pszHost, "*"))
		DisableHost(pszHost, pConfig->GetInteger("Local", "port_allow", -1));
	else
		EnableHost(pszHost, pConfig->GetInteger("Local", "port_allow", -1));
}

void RotateStatic(char *pszLogin, char *pszHost)
{
	UserDown(pszLogin, "*", pszHost, "*");
	CheckStatic(pszLogin, pszHost);
}