/*
 * Copyright (c) 2004-2005 Endace Technology Ltd, Hamilton, New Zealand.
 * All rights reserved.
 *
 * This source code is proprietary to Endace Technology Limited and no part
 * of it may be redistributed, published or disclosed except as outlined in
 * the written contract supplied with this product.
 *
 * $Id: parse_tree_node.c 12693 2010-04-01 02:33:09Z wilson.zhu $
 */

/* File header. */
#include "parse_tree_node.h"


/* CVS header. */
static const char* const kParseTreeNodeCvsHeader __attribute__ ((unused)) = "$Id: parse_tree_node.c 12693 2010-04-01 02:33:09Z wilson.zhu $";


static const uint32_t kParseTreeNodeMagicNumber = 0x00ddba11;

static const char* uTypeString[] =
{
	"PTN_INVALID",
	"PTN_AND",
	"PTN_ANDNOT",
	"PTN_OR",
	"PTN_NOT",
	"PTN_EQUALS",
	"PTN_NOT_EQUALS",
	"PTN_TCP_FLAGS",
	"PTN_HOST",
	"PTN_PROTOCOL",
	"PTN_IP",
	"PTN_PORT"
};

static const char* uQualifierString[] =
{
	"QUAL_INVALID",
	"QUAL_SRC",
	"QUAL_DST",
	"QUAL_BOTH"
};

typedef struct PtHeader PtHeader;
typedef PtHeader* PtHeaderPtr;

struct PtHeader
{
	uint32_t mMagicNumber;
	node_t mType;
	qualifiers_t mQualifiers;
	uint8_t mProtocol;
	uint16_t mPort;           /* Network byte order. */
	in_addr_t mHostname;      /* Network byte order. */
	uint32_t mNetmask;        /* Network byte order. */
	uint8_t mFlagsValue;
	uint8_t mFlagsMask;
	PtHeaderPtr mParent;
	PtHeaderPtr mChild;
	PtHeaderPtr mNext;
};


/* Internal routines. */
#if DEBUG_VERSION
static void verify_header(PtHeaderPtr header);
#endif /* DEBUG_VERSION */
static unsigned int valid_header(PtHeaderPtr header);
static void add_sibling(PtHeaderPtr header, PtHeaderPtr sibling);
static const char* protocol_to_string(uint8_t protocol);


/* Implementation of internal routines. */
#if DEBUG_VERSION
static void
verify_header(PtHeaderPtr header)
{
	assert(header);
	assert(header->mMagicNumber == kParseTreeNodeMagicNumber);
	assert((kFirstNodeType <= header->mType) && (header->mType <= kLastNodeType));

	if (header->mType == PTN_HOST)
	{
		assert(header->mProtocol == 0);
		assert(header->mPort == 0);
		assert(header->mHostname);
		assert(header->mNetmask);
		assert(header->mFlagsValue == 0);
		assert(header->mFlagsMask == 0);
	}
	else if (header->mType == PTN_PROTOCOL)
	{
		assert(header->mProtocol);
		assert(header->mPort == 0);
		assert(header->mHostname == 0);
		assert(header->mNetmask == 0);
		assert(header->mFlagsValue == 0);
		assert(header->mFlagsMask == 0);
	}
	else if (header->mType == PTN_PORT)
	{
		assert(header->mProtocol == 0);
		assert(header->mPort);
		assert(header->mHostname == 0);
		assert(header->mNetmask == 0);
		assert(header->mFlagsValue == 0);
		assert(header->mFlagsMask == 0);
	}
	else if (header->mType == PTN_TCP_FLAGS)
	{
		assert(header->mProtocol == 0);
		assert(header->mPort == 0);
		assert(header->mHostname == 0);
		assert(header->mNetmask == 0);
		assert(header->mQualifiers == 0);
		assert(header->mFlagsMask != 0);
	}
	else
	{
		assert(header->mProtocol == 0);
		assert(header->mPort == 0);
		assert(header->mHostname == 0);
		assert(header->mNetmask == 0);
		assert(header->mFlagsValue == 0);
		assert(header->mFlagsMask == 0);
	}

	if (header->mParent)
	{
		assert(header->mParent);
		assert(header->mParent->mMagicNumber == kParseTreeNodeMagicNumber);
	}

	if (header->mChild)
	{
		PtHeaderPtr child = header->mChild;
		PtHeaderPtr sibling = child->mNext;
		
		assert(child);
		assert(child->mMagicNumber == kParseTreeNodeMagicNumber);
		assert(child->mParent == header);
		
		while (sibling)
		{
			assert(sibling->mParent == header);
			sibling = sibling->mNext;
		}
	}

	if (header->mNext)
	{
		assert(header->mNext);
		assert(header->mNext->mMagicNumber == kParseTreeNodeMagicNumber);
	}
}
#endif /* DEBUG_VERSION */


static unsigned int
valid_header(PtHeaderPtr header)
{
	if (header && (header->mMagicNumber == kParseTreeNodeMagicNumber))
	{
		return 1;
	}

	return 0;
}


static void
add_sibling(PtHeaderPtr header, PtHeaderPtr sibling)
{
	/* Add the sibling to the end of header's sibling list. */
	PtHeaderPtr current = header;

	assert(sibling->mNext == NULL); /* May want to relax this later to allow entire lists to be added. */

	while (current->mNext)
	{
		current = current->mNext;
	}
		
	current->mNext = sibling;
}


static const char*
protocol_to_string(uint8_t protocol)
{
	if (protocol == IPPROTO_TCP)
	{
		return "tcp";
	}
	else if (protocol == IPPROTO_UDP)
	{
		return "udp";
	}
	else if (protocol == IPPROTO_ICMP)
	{
		return "icmp";
	}
	else if (protocol == IPPROTO_IGRP)
	{
		return "igrp";
	}

	return "none";
}



/* Construction/destruction. */
PtNodePtr
ptn_init(node_t type)
{
	PtHeaderPtr header = (PtHeaderPtr) malloc(sizeof(PtHeader));

	if (!header)
	{
		return NULL;
	}

	memset(header, 0, sizeof(PtHeader));

	header->mType = type;
	header->mMagicNumber = kParseTreeNodeMagicNumber;

	return (PtNodePtr) header;
}


void
ptn_dispose(PtNodePtr node)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header))
	{
		free(header);
	}
}


/* Accessors and mutators. */
node_t
ptn_get_type(PtNodePtr node)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header))
	{
		return header->mType;
	}

	return PTN_INVALID;
}


void
ptn_set_type(PtNodePtr node, node_t node_type)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
	if (valid_header(header))
	{
		header->mType = node_type;
	}
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */
}


PtNodePtr
ptn_get_parent(PtNodePtr node)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header))
	{
		return (PtNodePtr) header->mParent;
	}

	return NULL;
}


void
ptn_set_parent(PtNodePtr node, PtNodePtr parent)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
	if (valid_header(header) && valid_header((PtHeaderPtr) parent))
	{
		header->mParent = (PtHeaderPtr) parent;
	}
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */
}


PtNodePtr
ptn_get_next_child(PtNodePtr node, PtNodePtr prev_child)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header))
	{
		if (NULL == prev_child)
		{
			return (PtNodePtr) header->mChild;
		}
		else
		{
			PtHeaderPtr prev_header = (PtHeaderPtr) prev_child;

			assert(prev_header->mParent == header);

			return (PtNodePtr) prev_header->mNext;
		}
	}

	return NULL;
}


PtNodePtr
ptn_get_sibling(PtNodePtr node)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header))
	{
		return (PtNodePtr) header->mNext;
	}

	return NULL;
}


void
ptn_add_child(PtNodePtr node, PtNodePtr child)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	PtHeaderPtr child_header = (PtHeaderPtr) child;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header) && valid_header(child_header))
	{
		PtHeaderPtr sibling = NULL;
		
		assert((child_header->mParent == header) || (child_header->mParent == NULL));

		if (header->mChild == NULL)
		{
			/* First child. */
			header->mChild = child_header;
		}
		else
		{
			add_sibling(header->mChild, child_header);
		}

		child_header->mParent = header;
		
		/* Ensure the child's siblings have the correct parent. */
		sibling = child_header->mNext;
		while (sibling)
		{
			sibling->mParent = header;
			sibling = sibling->mNext;
		}
	}
}


void
ptn_add_sibling(PtNodePtr node, PtNodePtr sibling)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	PtHeaderPtr sibling_header = (PtHeaderPtr) sibling;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header) && valid_header(sibling_header))
	{
		add_sibling(header, sibling_header);
		
		/* Ensure siblings have the same parent. */
		sibling_header->mParent = header->mParent;
	}
}


uint8_t
ptn_get_protocol(PtNodePtr node)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header))
	{
		return header->mProtocol;
	}

	return 0;
}


void
ptn_set_protocol(PtNodePtr node, uint8_t protocol)
{
	PtHeaderPtr header = (PtHeaderPtr) node;

	if (valid_header(header))
	{
		header->mProtocol = protocol;
	}
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */
}


uint16_t
ptn_get_port(PtNodePtr node)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header))
	{
		return header->mPort;
	}

	return 0;
}


void
ptn_set_port(PtNodePtr node, uint16_t port)
{
	PtHeaderPtr header = (PtHeaderPtr) node;

	if (valid_header(header))
	{
		header->mPort = port;
	}
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */
}


in_addr_t
ptn_get_host(PtNodePtr node)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header))
	{
		return header->mHostname;
	}

	return 0;
}


in_addr_t
ptn_get_netmask(PtNodePtr node)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header))
	{
		return header->mNetmask;
	}

	return 0;
}


void
ptn_set_host(PtNodePtr node, in_addr_t host, uint32_t netmask)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
	if (valid_header(header))
	{
		header->mHostname = host;
		header->mNetmask = netmask;
	}
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */
}


qualifiers_t
ptn_get_qualifiers(PtNodePtr node)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header))
	{
		return header->mQualifiers;
	}

	return QUAL_INVALID;
}


void
ptn_set_qualifiers(PtNodePtr node, qualifiers_t qualifiers)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	assert(kFirstQualifier <= qualifiers);
	assert(qualifiers <= kLastQualifier);

	if (valid_header(header))
	{
		header->mQualifiers = qualifiers;
	}
}


uint8_t
ptn_get_tcp_flags_value(PtNodePtr node)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header))
	{
		return header->mFlagsValue;
	}

	return 0;
}


uint8_t
ptn_get_tcp_flags_mask(PtNodePtr node)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header))
	{
		return header->mFlagsMask;
	}

	return 0;
}


void
ptn_set_tcp_flags(PtNodePtr node, uint8_t flags, tcp_flag_op_t flag_op, uint8_t comparand)
{
	PtHeaderPtr header = (PtHeaderPtr) node;

	assert(flags != 0);

	if (valid_header(header))
	{
		if (flag_op == TCP_FLAG_OP_EQUALS)
		{
			if (comparand == 0)
			{
				header->mFlagsMask = flags;
				header->mFlagsValue = 0;
			}
			else
			{
				header->mFlagsMask = flags;
				header->mFlagsValue = flags;
			}
		}
		else if (flag_op == TCP_FLAG_OP_NOT_EQUALS)
		{
			if (comparand == 0)
			{
				header->mFlagsMask = flags;
				header->mFlagsValue = flags;
			}
			else
			{
				header->mFlagsMask = flags;
				header->mFlagsValue = 0;
			}
		}
		else
		{
			assert(0);
		}
	}
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */
}


uint32_t
ptn_count_children(PtNodePtr node)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	uint32_t count = 0;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header))
	{
		PtHeaderPtr child = header->mChild;

		while (child)
		{
			count++;
			child = child->mNext;
		}
	}

	return count;
}


void
ptn_display(PtNodePtr node, FILE* outfile)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header))
	{
		struct in_addr address;

		fprintf(outfile, "======================================\n");
		fprintf(outfile, "Parse tree node %p\n", header);
		fprintf(outfile, "    mMagicNumber: %u (should be %u)\n", header->mMagicNumber, kParseTreeNodeMagicNumber);
		fprintf(outfile, "    mType:        %s (%u)\n", uTypeString[header->mType], header->mType);
		fprintf(outfile, "    mQualifiers:  %s (%u)\n", uQualifierString[header->mQualifiers], header->mQualifiers);
		fprintf(outfile, "    mPort:        %u (%u in host byte order)\n", header->mPort, ntohs(header->mPort));
		fprintf(outfile, "    mProtocol:    %s (%u)\n", protocol_to_string(header->mProtocol), header->mProtocol);
		address.s_addr = header->mHostname;
		fprintf(outfile, "    mHostname:    %u (%s)\n", header->mHostname, inet_ntoa(address));
        /* memory checker tools complain about inet_ntoa as it
         * dynamically allocates memory for the result but never frees
         * it. As it is reused in consecutive runs it does not cause
         * memory leak.
         */
		address.s_addr = header->mNetmask;
		fprintf(outfile, "    mNetmask:     %u (%s)\n", header->mNetmask, inet_ntoa(address));

		fprintf(outfile, "    mFlagsValue:  %u\n", header->mFlagsValue);
		fprintf(outfile, "    mFlagsMask:   %u\n", header->mFlagsMask);
		fprintf(outfile, "    mParent:      %p\n", header->mParent);
		fprintf(outfile, "    mChild:       %p\n", header->mChild);
		fprintf(outfile, "    mNext:        %p\n", header->mNext);
		fprintf(outfile, "======================================\n");
	}
}


void
ptn_display_recursive(PtNodePtr node, FILE* outfile)
{
	PtHeaderPtr header = (PtHeaderPtr) node;
	
#if DEBUG_VERSION
	verify_header(header);
#endif /* DEBUG_VERSION */

	if (valid_header(header))
	{
		PtHeaderPtr child = header->mChild;
		
		ptn_display((PtNodePtr) header, outfile);

		while (child)
		{
			ptn_display_recursive((PtNodePtr) child, outfile);
			child = child->mNext;
		}
	}
}
