#include	<stdio.h>
#include	<stdlib.h>
#include	<string.h>
#include	<ctype.h>
#include	<sys/time.h>
#include	<stdarg.h>

#include	"smbhdr.h"
#include	"smbpd.h"
#include	"lpt.h"
#include	"netio.h"
#include	"cmd.h"
#include	"trans.h"
#include	"util.h"

#define	strn(s)			s, sizeof(s)-1
#define	round_to_long(n)	((n + 3) & ~0x3)

struct fmt_desc
{
	int		ulevel;
	char		*format;
};

/* dup the string or bust */
static char *safe_strdup(const char *s)
{
	char	*d = strdup(s);

	if (d == 0)
	{
		fprintf(stderr, "Out of heap memory in safe_strdup\n");
		abort();
	}
	return (d);
}

/*
 *	Routines for handling packed structures adapted from Samba package
 */

static int count_struct(char *p)
{
	int             n = 0, repeat;

	if (p == 0)
		return (0);
	while (*p)
	{
		switch (*p++)
		{
		case 'W':	/* word (2 byte) */
		case 'N':	/* count of substructures (word) at end */
			n += 2;
			break;
		case 'D':	/* double word (4 byte) */
		case 'z':	/* offset to zero terminated string (4
				 * byte) */
		case 'l':	/* offset to user data (4 byte) */
			n += 4;
			break;
		case 'b':	/* offset to data (with counter) (4 byte) */
			n += 4;
			(void)strtol(p, &p, 10);
			break;
		case 'B':	/* byte (with optional counter) */
			repeat = strtol(p, &p, 10);
			n += repeat <= 0 ? 1 : repeat;
			break;
		}
	}
	return (n);
}

static int check_format(struct fmt_desc *d, int ulevel, char *format)
{
	while (d->ulevel >= 0)
	{
		if (ulevel == d->ulevel && strcmp(format, d->format) == 0)
			return (count_struct(format));
		++d;
	}
	return (-1);
}

int pack_struct(char *format, unsigned char *base,
	unsigned char *desc, unsigned char *data, ...)
{
	va_list         args;
	unsigned char	*descptr = desc, *dataptr = data;
	int		temp, needed;
	unsigned long	ltemp;
	char		*str;

	va_start(args, data);
	while (*format != '\0')
	{
		switch (*format++)
		{
		case 'W':		/* word (2 byte) */
			temp = va_arg(args, int);
			hstouc2(temp, descptr);
			descptr += 2;
			break;
		case 'N':		/* count of substructures (word) at end */
			temp = va_arg(args, int);
			hstouc2(temp, descptr);
			descptr += 2;
			break;
		case 'D':		/* double word (4 byte) */
			ltemp = va_arg(args, long);
			hltouc4(ltemp, descptr);
			descptr += 4;
			break;
		case 'B':		/* byte (with optional counter) */
			needed = strtol(format, &format, 10);
			if (needed <= 0)
				needed = 1;
			str = va_arg(args, char *);
			if (str == 0)
				str = "";
			if (needed > 1)
				Strncpy(descptr, str, needed);
			else	/* if single B, just copy byte over */
				*descptr = *str;
			descptr += needed;
			break;
		case 'z':		/* offset to zero terminated string (4
					 * byte) */
			str = va_arg(args, char *);
			if (str == 0)
				str = "";
			needed = strlen(str) + 1;
			Strncpy(dataptr, str, needed);
			hltouc4((unsigned long)(dataptr - base), descptr);
			descptr += 4;
			dataptr += needed;
			break;
		case 'l':		/* offset to user data (4 byte) */
			str = va_arg(args, char *);
			needed = va_arg(args, int);
			memcpy(dataptr, str, needed);
			hltouc4((unsigned long)(dataptr - base), descptr);
			descptr += 4;
			dataptr += needed;
			break;
		case 'b':		/* offset to data (with counter) (4 byte) */
			str = va_arg(args, char *);
			needed = strtol(format, &format, 10);
			memcpy(dataptr, str, needed);
			hltouc4((unsigned long)(dataptr - base), descptr);
			descptr += 4;
			dataptr += needed;
			break;
		}
	}
	va_end(args);
	return (dataptr - data);
}

/* assumes that buffer is big enough that total counts = thisbuffer counts */
/* returns the number of bytes in setup + parameters (+ padding) + data */
static int set_counts(struct smbhdr *sh, int nparam, int nsetup, int ndata)
{
	int		anparam = round_to_long(nparam);
	unsigned char	*p;

	sh->wordcount = 10 + nsetup;	/* set word count */
	hstouc2(nparam, sh->param1);	/* total param bytes */
	hstouc2(ndata, sh->param2);	/* total data bytes */
	hstouc2(0, sh->param3);		/* reserved */
	hstouc2(nparam, sh->param4);	/* this buffer param bytes */
	hstouc2(sh->param12 - sh->protocol, sh->param5);
					/* offset of parameter block */
	hstouc2(0, sh->param6);
	hstouc2(ndata, sh->param7);	/* this buffer param bytes */
	hstouc2((sh->param12 - sh->protocol) + anparam, sh->param8);
	hstouc2(0, sh->param9);
	sh->param10[0] = nsetup;
	p = sh->param11 + nsetup * 2;	/* skip past setup words */
	hstouc2(anparam + ndata, p);	/* set byte count */
	return (nsetup * 2 + anparam + ndata);
}

static int api_rnetshareenum(struct smbhdr *sh, unsigned char *str1,
	unsigned char *str2, long *smblen)
{
	unsigned char		*base, *p, *q;
	unsigned short		ulevel, nprinters;
	int			i, n, anparam, formatsize, ndata;
	struct lpt_info		*l;

	if (strncmp(str1, strn("WrLeh")) != 0)
		return (ERRSRV | ERRerror);
	p = str2 + strlen(str2) + 1;
	/* should double check str2 */
	ulevel = uc2tohs(p);
	for (nprinters = 0, i = 0; i < MAXLPT; ++i)
	{
		l = &lpt[i];
		if (l->avail == FREE || l->avail == BUSY)
			++nprinters;
	}
	str2 = safe_strdup(str2);
	printf("%s RNetShareEnum, %d\n", ptime(), ulevel);
	formatsize = count_struct(str2);
	anparam = round_to_long(8);		/* 8 bytes in param block */
	p = base = sh->param12 + anparam;	/* start of descriptor area */
	q = p + formatsize * nprinters;
	ndata = 0;
	for (i = 0; i < MAXLPT; ++i)
	{
		l = &lpt[i];
		if (l->avail != FREE && l->avail != BUSY)
			continue;
		n = pack_struct(str2, base, p, q,
			l->name, "",		/* B13B: printer name */
			1,			/* W: STYPE_PRINTQ */
			"");			/* z: comment */
		p += formatsize;
		q += n;
		ndata += n;
	}
	ndata += formatsize * nprinters;
	hstouc2(NERR_Success, sh->param12);
	hstouc2(0, sh->param13);
	hstouc2(nprinters, sh->param14);
	hstouc2(nprinters, sh->param15);
	*smblen += set_counts(sh, 8, 0, ndata);
	free(str2);
	return (0);
}

static int api_rnetservergetinfo(struct smbhdr *sh, unsigned char *str1,
	unsigned char *str2, long *smblen)
{
	unsigned char		*p, *q;
	unsigned short		ulevel;
	int			anparam, formatsize, ndata;
	static struct fmt_desc	desc[] =
	{
		{ 0, "B16" },
		{ 1, "B16BBDz" },
		{ 2, "B16BBDzDDDWWzWWWWWWWBB21zWWWWWWWWWWWWWWWWWWWWWWz" },
		{ 3, "B16BBDzDDDWWzWWWWWWWBB21zWWWWWWWWWWWWWWWWWWWWWWzDWz" },
		{ 20, "DN" },
		{ 50, "B16BBDzWWzzz"}
	};

	if (strncmp(str1, strn("WrLh")) != 0)
		return (ERRSRV | ERRerror);
	p = str2 + strlen(str2) + 1;
	ulevel = uc2tohs(p);
	anparam = round_to_long(6);	/* 6 bytes in param block */
	p = sh->param12 + anparam;	/* start of descriptor area */
	if ((formatsize = check_format(desc, ulevel, str2)) < 0)
		return (ERRSRV | ERRerror);
	/* now make copies of the strings because the returned data will
	   overwrite that space; we don't need str1 anymore */
	str2 = safe_strdup(str2);
	printf("%s RNetServerGetInfo, %d\n", ptime(), ulevel);
	q = p + formatsize;
	if (ulevel == 0)
		ndata = pack_struct(str2, p, p, q,
			my_host_name);		/* B16: hostname */
	else
		ndata = pack_struct(str2, p, p, q,
			my_host_name,		/* B16: hostname */
			"\004",			/* B: major version */
			"\001",			/* B: minor version */
			SV_TYPE_SERVER | SV_TYPE_PRINTQ_SERVER |
			SV_TYPE_LOCAL_LIST_ONLY,
			PROGRAM " " VERSION);
	ndata += formatsize;
	hstouc2(NERR_Success, sh->param12);
	hstouc2(0, sh->param13);
	hstouc2(ndata, sh->param14);
	*smblen += set_counts(sh, 6, 0, ndata);
	free(str2);
	return (0);
}

static int api_netwkstagetinfo(struct smbhdr *sh, unsigned char *str1,
	unsigned char *str2, long *smblen)
{
	unsigned char		*p, *q;
	unsigned short		ulevel;
	int			anparam, ndata;

	p = str2 + strlen(str2) + 1;
	ulevel = uc2tohs(p);
	if (ulevel != 10 || strncmp(str1, strn("WrLh")) != 0 ||
		strncmp(str2, strn("zzzBBzz") != 0))
		return (ERRSRV | ERRerror);
	anparam = round_to_long(6);	/* 6 bytes in param block */
	p = sh->param12 + anparam;	/* start of descriptor area */
	q = p + 22;
	printf("%s NetWkstaGetInfo, %d\n", ptime(), ulevel);
	ndata = pack_struct("zzzBBzz", p, p, q,
		my_host_name,		/* z: local machine */
		"*",			/* z: sesssetup user */
		my_work_group,		/* z: workgroup */
		"\004", "\001",		/* BB: version */
		my_work_group,		/* z: login domain? */
		"");			/* z: ? */
	ndata += 22;
	hstouc2(NERR_Success, sh->param12);
	hstouc2(0, sh->param13);
	hstouc2(ndata, sh->param14);
	*smblen += set_counts(sh, 6, 0, ndata);
	return (0);
}

static int api_dosprintqgetinfo(struct smbhdr *sh, unsigned char *str1,
	unsigned char *str2, long *smblen)
{
	unsigned char		*str3, *qname, *p, *q;
	unsigned short		ulevel, buflen;
	int			printer, anparam, njobs;
	int			formatsize, subformatsize, ndata;
	char			*pcomment;
	enum lpt_state		pstate;
	struct smbconn		*client;
	static struct fmt_desc	desc[] =
	{
		{ 0, "B13" },
		{ 1, "B13BWWWzzzzzWW" },
		{ 2, "B13BWWWzzzzzWN" },
		{ 3, "zWWWWzzzzWWzzl" },
		{ 4, "zWWWWzzzzWNzzl" },
		{ 5, "z" },
		{-1, 0 }
	};
	static struct fmt_desc	sub_desc[] =
	{
		{ 1, "" },
		{ 2, "WB21BB16B10zWWzDDz" },
		{ 3, "" },
		{ 4, "WWzWWDDzz" },
		{ -1, 0 }
	};

	qname = str2 + strlen(str2) + 1;
	p = qname + strlen(qname) + 1;
	ulevel = uc2tohs(p);
	buflen= uc2tohs(p+2);
	str3 = p + 4;
	if (strncmp(str1, strn("zWrLh")) != 0)
		return (ERRSRV | ERRerror);
	/* sanity checks */
	if (ulevel > 5)
		return (ERRSRV | ERRerror);
	/* if ulevel != 2 and ulevel != 4, then str3 is not needed */
	if (ulevel != 2 && ulevel != 4)
		str3 = "";
	if (check_format(desc, ulevel, str2) < 0 ||
		check_format(sub_desc, ulevel, str3) < 0)
		return (ERRSRV | ERRerror);
	/* now make copies of the strings because the returned data will
	   overwrite that space; we don't need str1 anymore */
	str2 = safe_strdup(str2);
	str3 = safe_strdup(str3);
	qname = safe_strdup(qname);
	printf("%s DosPrintQGetInfo, <%s> %d %d\n",
		ptime(), qname, ulevel, buflen);
	njobs = 0;
	if ((printer = find_printer(qname)) < 0
		|| (pstate = lpt[printer].avail) == NONESUCH
		|| lpt[printer].avail == DISABLED)
	{
		pcomment = "Not valid";
		pstate = LPSTAT_ERROR;
	}
	else if ((pstate = lpt[printer].avail) == BUSY)
	{
		pcomment = "In use";
		pstate = LPSTAT_STOPPED;
		if ((client = lpt[printer].client) != 0)
			njobs = 1;
	}
	else
	{
		pcomment = "Free";
		pstate = LPSTAT_OK;
	}
	anparam = round_to_long(6);	/* 6 bytes in param block */
	p = sh->param12 + anparam;	/* start of descriptor area */
	formatsize = count_struct(str2);
	subformatsize = njobs > 0 ? count_struct(str3) : 0;
	q = p + formatsize + subformatsize;
	switch (ulevel)
	{
	case 0:
		ndata = pack_struct(str2, p, p, q,
			qname);		/* B13: printer name */
		break;
	case 1:
	case 2:
		ndata = pack_struct(str2, p, p, q,
			qname,		/* B13: printer name */
			"",		/* B: alignment */
			5,		/* W: priority */
			0,		/* W: start time */
			0,		/* W: until time */
			"",		/* z: pSepFile */
			PROGRAM,	/* z: pPrProc */
			qname,		/* z: pDestinations */
			"",		/* z: pParms */
			pcomment,	/* z: pComment */
			pstate,		/* W: status */
			njobs);		/* W/N: count */
		break;
	case 3:
	case 4:
		{
		unsigned char	drivdata[4+4+32];

		hltouc4((unsigned long)(sizeof(drivdata)), drivdata);
		hltouc4(1000L, drivdata+4);
		memset(drivdata+8, 0, 32);
		strcpy(drivdata+8, "NULL");
		ndata = pack_struct(str2, p, p, q,
			qname,		/* z: printer name */
			5,		/* W: priority */
			0,		/* W: start time */
			0,		/* W: until time */
			5,		/* W: pad1 */
			"",		/* z: pszSepFile */
			PROGRAM,	/* z: pszPrProc */
			"",		/* z: pszParms */
			pcomment,	/* z: pszComment */
			pstate,	/* W: status */
			njobs,		/* W/N: cJobs */
			qname,		/* z: pszPrinters */
			"NULL",		/* z: pszDriverName */
			drivdata,	/* l: lDrivdata */
			sizeof(drivdata));
			break;
		}
	}
	if (str3[0] != '\0' && njobs > 0)
	{
		q += ndata;
		switch (ulevel)
		{
		case 2:
			ndata += pack_struct(str3, p, p + formatsize, q,
				0,		/* W: uJobId */
				client->cname,	/* B21: szUsername */
				"",		/* B: pad */
				"",		/* B16: szNotifyName */
				"PM_Q_RAQ",	/* B10: szDataType */
				"",		/* z: pszParams */
				1,		/* W: uPosition */
				LPQ_PRINTING,	/* W: fsStatus */
				"",		/* z: pszStatus */
				(unsigned long)(client->starttime - my_tzsec),
						/* D: ulSubmitted */
				client->joblen,	/* D: ulSize */
				lpt[printer].jobname);	/* z: pszComment */
			break;
		case 4:
			ndata += pack_struct(str3, p, p + formatsize, q,
				0,		/* W: uJobId */
				client->cname,	/* B21: szUsername */
				"",		/* B: pad */
				"",		/* B16: szNotifyName */
				"PM_Q_RAQ",	/* B10: szDataType */
				"",		/* z: pszParams */
				1,		/* W: uPosition */
				LPQ_PRINTING,	/* W: fsStatus */
				"",		/* z: pszStatus */
				(unsigned long)client->starttime,
						/* D: ulSubmitted */
				client->joblen,	/* D: ulSize */
				lpt[printer].jobname,	/* z: pszComment */
				0,		/* W: uPriority */
				client->cname,	/* z: pszUsername */
				1,		/* W: uPosition */
				LPQ_PRINTING,	/* W: fsStatus */
				(unsigned long)client->starttime,
						/* D: ulSubmitted */
				client->joblen,	/* D: ulSize */
				lpt[printer].jobname,	/* z: pszComment */
				client->pname);	/* z: pszDocument */
			break;
		}
	}
	ndata += formatsize + subformatsize;
	hstouc2(NERR_Success, sh->param12);
	hstouc2(0, sh->param13);
	hstouc2(ndata, sh->param14);
	*smblen += set_counts(sh, 6, 0, ndata);
	free(str2);
	free(str3);
	free(qname);
	return (0);
}

static int api_printjobinfo(struct smbhdr *sh, unsigned char *str1,
	unsigned char *str2, long *smblen)
{
	unsigned char		*str3;
	unsigned short		snum_jobid, ulevel, function;

	str3 = str2 + strlen(str2) + 1;
	snum_jobid = uc2tohs(str3);
	ulevel = uc2tohs(str3+2);
	function= uc2tohs(str3+4);
#ifdef	DEBUG
	printf(" PrintJobInfo <%s> <%s> %d %d %x\n", str1, str2, snum_jobid,
		ulevel, function);
#endif
	hstouc2(NERR_Success, sh->param12);
	hstouc2(0, sh->param13);
	*smblen += set_counts(sh, 4, 0, 0);
	return (0);
}

static int api_not_supported(int apicmd, struct smbhdr *sh, unsigned char *str1,
	unsigned char * str2, long *smblen)
{
	unsigned char		*str3;
	unsigned short		ulevel, buflen;

	str3 = str2 + strlen(str2) + 1;
	ulevel = uc2tohs(str3);
	buflen= uc2tohs(str3+2);
#ifdef	DEBUG
	printf(" Unsupported trans %d <%s> <%s> %d %d\n", apicmd, str1, str2,
		ulevel, buflen);
#endif
	hstouc2(NERR_notsupported, sh->param12);
	hstouc2(0, sh->param13);
	*smblen += set_counts(sh, 4, 0, 0);
	return (0);
}

int do_transaction(struct cmdentry *d,
	struct smbconn *client, struct smbhdr *sh,
	unsigned char *op, long *smblen)
{
	unsigned char		*p, *str1, *str2;
	int			tpscnt, tdscnt, mprcnt, mdrcnt, msrcnt, flags;
	int			pscnt, psoff, dscnt, dsoff, suwcnt, bytecount;
	int			apicmd, status;

	*smblen = 23;
	tpscnt = uc2tohs(sh->param1);
	tdscnt = uc2tohs(sh->param2);
	mprcnt = uc2tohs(sh->param3);
	mdrcnt = uc2tohs(sh->param4);
	msrcnt = sh->param5[0];
	flags = uc2tohs(sh->param6);
	pscnt = uc2tohs(sh->param10);
	psoff = uc2tohs(sh->param11);
	dscnt = uc2tohs(sh->param12);
	dsoff = uc2tohs(sh->param13);
	suwcnt = sh->param14[0];
	/* skip past setup words */
	p = (unsigned char *)&sh->param15 + suwcnt * 2;
	bytecount = uc2tohs(p);
	p += 2;			/* skip past bytecount */
	if (verbose >= 2)
		printf("%s Transaction: name:%s"
		" tpscnt:%d tdscnt:%d pscnt:%d dscnt:%d suwcnt:%d bytecount:%d\n",
		ptime(), p, tpscnt, tdscnt, pscnt, dscnt, suwcnt, bytecount);
	if (strcmp(p, "\\PIPE\\LANMAN") != 0)
		return (ERRSRV | ERRerror);
	if (pscnt < tpscnt || dscnt < tdscnt)
	{
		printf("%s Sorry, transaction too big to handle\n", ptime());
		return (ERRSRV | ERRerror);
	}
	p = (unsigned char *)&sh->protocol + psoff;
	apicmd = uc2tohs(p);
	str1 = p += 2;
	str2 = p += strlen(p) + 1;
	switch (apicmd)
	{
	case RNetShareEnum:
		status = api_rnetshareenum(sh, str1, str2, smblen);
		break;
	case RNetServerGetInfo:
		status = api_rnetservergetinfo(sh, str1, str2, smblen);
		break;
	case NetWkstaGetInfo:
		status = api_netwkstagetinfo(sh, str1, str2, smblen);
		break;
	case DosPrintQGetInfo:
		status = api_dosprintqgetinfo(sh, str1, str2, smblen);
		break;
	case PrintJobInfo:
		status = api_printjobinfo(sh, str1, str2, smblen);
		break;
	/* to be implemented */
	case NetServerEnum:
	default:
		status = 1;
	}
	if (status != 0)
		status = api_not_supported(apicmd, sh, str1, str2, smblen);
	return (status);
}
