/*
 * firewall.c - Packet filtering for diald.
 *
 * Copyright (c) 1994 Eric Schenk.
 * All rights reserved.
 *
 * Permission is hereby granted, without written agreement and without
 * license or royalty fees, to use, copy, modify, and distribute this
 * software and its documentation for any purpose, provided that the
 * above copyright notice and the following two paragraphs appear in
 * all copies of this software.
 * 
 * IN NO EVENT SHALL ERIC SCHENK BE LIABLE TO ANY PARTY FOR
 * DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT
 * OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF ERIC
 * SCHENK HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * ERIC SCHENK SPECIFICALLY DISCLAIMS ANY WARRANTIES,
 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS FOR A PARTICULAR PURPOSE.  THE SOFTWARE PROVIDED HEREUNDER IS
 * ON AN "AS IS" BASIS, AND ERIC SCHENK HAS NO OBLIGATION TO
 * PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
 */

#include "diald.h"

static FW_unit units[FW_NRUNIT];
static int initialized = 0;
int impulse_init_time = 0;
int impulse_time = 0;
int impulse_fuzz = 0;

static void del_connection(FW_Connection *);

/*
 * Initialize the units.
 */

static void init_units(void)
{
    int i;

    for (i = 0; i < FW_NRUNIT; i++) {
	units[i].used = 0;
	units[i].filters = NULL;
	units[i].last = NULL;
	units[i].connections = malloc(sizeof(FW_Connection));
	units[i].nrules = 0;
	units[i].nfilters = 0;
	if (!units[i].connections) {
	    syslog(LOG_ERR,"Out of memory! AIIEEE!");
	    die(1);
	}
	units[i].connections->next = units[i].connections->prev
	    = units[i].connections;
    }
    initialized = 1;
}

/* is the time given by "clock" in the given slot? */

static unsigned int in_slot(FW_Timeslot *slot, time_t *clock)
{
    struct tm *ltime = localtime(clock);
    int ctime = ltime->tm_hour*60*60+ltime->tm_min*60+ltime->tm_sec; 

#ifdef 0
    syslog(LOG_INFO,"slot check: %d %d %d %d",
	ltime->tm_sec+ltime->tm_min*60+ltime->tm_hour*60*60, ltime->tm_wday,
	ltime->tm_mday, ltime->tm_mon);
#endif

    while (slot) {
#ifdef 0
    syslog(LOG_INFO,"slot def: %d %d %x %x %x",
	slot->start, slot->end, slot->wday, slot->mday, slot->month);
#endif
	if ((slot->start <= ctime)
	&&  (ctime <= slot->end)
	&&  (slot->wday & (1<<ltime->tm_wday))
	&&  (slot->mday & (1<<(ltime->tm_mday-1)))
	&&  (slot->month & (1<<ltime->tm_mon))) {
	    return 1;
	}
	slot = slot->next;
    }

    return 0;
}


/*
 * return 0 if the given time is in the given slots active time set.
 * Otherwise return the number of seconds until the slot is next active, or
 * the number of seconds until the next occurance of 00:00 hours, whichever
 * comes first.
 */

static unsigned int slot_start_timeout(FW_Timeslot *slot, time_t *clock)
{
    struct tm *ltime = localtime(clock);
    int ctime, mintime;

    if (in_slot(slot,clock)) return 0;

    /* Ok, we are currently NOT in this slot's time range. */

    ctime = ltime->tm_hour*60*60 + ltime->tm_min*60 + ltime->tm_sec;
    mintime =  24*60*60 - ctime;

    while (slot) {
    	if ((slot->wday & (1<<ltime->tm_wday))
    	&& (slot->mday & (1<<(ltime->tm_mday-1)))
    	&& (slot->month & (1<<ltime->tm_mon))
	&& (slot->start >= ctime)) {
	    /* Ok, this slot disjunct occurs today */
	    if (mintime >= (slot->start - ctime))
		mintime = slot->start - ctime;
    	}
	slot = slot->next;
    }

    return mintime;
}

/*
 * return 0 if the given time is not in the given slots active time set.
 * Otherwise return the number of seconds until the slot times out, or
 * the number of seconds until the next occurance of 00:00 hours, whichever
 * comes first.
 */

static unsigned int slot_end_timeout(FW_Timeslot *slot, time_t *clock)
{
    struct tm *ltime = localtime(clock);
    int ctime, maxtime;

    if (!in_slot(slot,clock)) return 0;

    /* Ok, we are currently in this slot's time range. */

    ctime = ltime->tm_hour*60*60 + ltime->tm_min*60 + ltime->tm_sec;
    maxtime = -1;

    while (slot) {
    	if ((slot->wday & (1<<ltime->tm_wday))
    	&& (slot->mday & (1<<(ltime->tm_mday-1)))
    	&& (slot->month & (1<<ltime->tm_mon))
	&& (slot->start <= ctime)
	&& (ctime <= slot->end)) {
	    /* Ok, this slot disjunct is active now */
	    if (maxtime <= (slot->end - ctime))
		maxtime = slot->end - ctime;
    	}
	slot = slot->next;
    }

    if (maxtime == -1)
        return 24*60*60 - ctime;
    else
    	return maxtime;
}

/*
 * Add a connection to the queue.
 */


static void add_connection(FW_unit *unit, FW_ID *id, unsigned int timeout)
{
    FW_Connection *c = unit->connections->next;

    /* look for a connection that matches this one */
    while (c != unit->connections) {
	if (memcmp((unsigned char *)&c->id,
		(unsigned char *)id,sizeof(FW_ID))==0)
	   break;
	c = c->next;
    }
    if (c == unit->connections) {
	if (timeout > 0) {
	    /* no matching connection, add one */
	    c = malloc(sizeof(FW_Connection));
	    if (c == 0) {
	       syslog(LOG_ERR,"Out of memory! AIIEEE!");
	       die(1);
	    }
	    c->id = *id;
	    init_timer(&c->timer);
	    c->timer.data = (int)c;
	    c->timer.function = (void *)(int)del_connection;
	    c->next = unit->connections->next;
	    c->prev = unit->connections;
	    unit->connections->next->prev = c;
	    unit->connections->next = c;
	    c->timer.expires = timeout;
	    add_timer(&c->timer);
	    if (debug&DEBUG_CONNECTION_QUEUE)
    		syslog(LOG_INFO,"Adding connection %d @ %d - timeout %d",(int)c,
			time(0),timeout);
	}
    } else {
	/* found a matching connection, toss it's old timer */
	del_timer(&c->timer);
	if (timeout > 0) {
	    c->timer.expires = timeout;
	    add_timer(&c->timer);
	    if (debug&DEBUG_CONNECTION_QUEUE)
    		syslog(LOG_INFO,"Adding connection %d @ %d - timeout %d",(int)c,
			time(0),timeout);
	} else {
	    /* timeout = 0, so toss the connection */
	    del_connection(c);
	}
    }
}

/*
 * Get a connection out of a queue.
 */

static void del_connection(FW_Connection *c)
{
    if (debug&DEBUG_CONNECTION_QUEUE)
	syslog(LOG_INFO,"Deleting connection %d @ %d",(int)c,time(0));
    c->next->prev = c->prev;
    c->prev->next = c->next;
}

static void del_impulse(FW_unit *unit)
{

    if (unit->impulse_mode) {
	unit->impulse_mode = 0;
	if (impulse_time > 0) {
	    unit->impulse.data = (int)unit;
	    unit->impulse.function = (void *)(int)del_impulse;
	    unit->impulse.expires = impulse_time;
	    if (debug&DEBUG_CONNECTION_QUEUE)
		syslog(LOG_INFO,"Refreshing impulse generator: mode %d, time %d @ %d",unit->impulse_mode,unit->impulse.expires,time(0));
	    add_timer(&unit->impulse);
	}
    } else {
	unit->impulse_mode = 1;
	impulse_init_time = 0;	/* zero the initial impulse time */
	if (impulse_fuzz > 0) {
	    unit->impulse.data = (int)unit;
	    unit->impulse.function = (void *)(int)del_impulse;
	    unit->impulse.expires = impulse_fuzz;
	    if (debug&DEBUG_CONNECTION_QUEUE)
		syslog(LOG_INFO,"Refreshing impulse generator: mode %d, time %d @ %d",unit->impulse_mode,unit->impulse.expires,time(0));
	    add_timer(&unit->impulse);
	}
    }
}

/* Check if a forcing rule currently applies to the connection */

static void fw_force_update(FW_unit *unit)
{
    FW_Filters *fw;
    int timeout, mintime;
    time_t clock = time(0);

    /* check if the current forcing slot has expired */
    if (unit->force_etime > clock) return;

    fw = unit->filters;
    mintime = 24*60*60;
    unit->force = 0;

    while (fw) {
	if (fw->filt.type == FW_TYPE_UP || fw->filt.type == FW_TYPE_DOWN) {
	    /* check when the rule is next applicable */
	    timeout = slot_start_timeout(fw->filt.times,&clock);
	    if (timeout > 0) {
		/* first time at which a previous slot starts */
		if (timeout < mintime)
		    mintime = timeout;
		goto next_rule;
	    } else {
		/* time at which the current slot ends */
		timeout = slot_end_timeout(fw->filt.times,&clock);
		if (timeout < mintime)
		    mintime = timeout;
	    }
	} else
	    goto next_rule;

        if (fw->filt.type == FW_TYPE_UP)
	    unit->force = 1;
	else
	    unit->force = 2;

	break;

next_rule: /* try the next filter */
	fw = fw->next;
    }

    unit->force_etime = clock + mintime;
}

/* Check if an impulse rule currently applies to the connection */

static void fw_impulse_update(FW_unit *unit, int force)
{
    FW_Filters *fw;
    int timeout, mintime, itimeout, ifuzz, ftimeout;
    time_t clock = time(0);

    /* check if the current forcing slot has expired */
    if (clock < unit->impulse_etime && !force) return;

    fw = unit->filters;
    mintime = 24*60*60;
    itimeout = 0;
    ftimeout = 0;
    ifuzz = 0;

    while (fw) {
	if (fw->filt.type == FW_TYPE_IMPULSE) {
	    /* check when the rule is next applicable */
	    timeout = slot_start_timeout(fw->filt.times,&clock);
	    if (timeout > 0) {
		/* Will be applicable soon */
		/* first time at which a previous slot starts
	 	 * (i.e. schedule changes) */
		if (timeout < mintime)
		    mintime = timeout;
		goto next_rule;
	    } else {
		/* time at which the current slot ends */
		timeout = slot_end_timeout(fw->filt.times,&clock);
		ifuzz = fw->filt.fuzz;
		itimeout = fw->filt.timeout;
		ftimeout = fw->filt.timeout2;
		if (timeout < mintime)
		    mintime = timeout;
	    }
	} else
	    goto next_rule;

	break;

next_rule: /* try the next filter */
	fw = fw->next;
    }
    unit->impulse_etime = clock + mintime;

    del_timer(&unit->impulse);
    if (unit->up && (itimeout > 0 || (force && ftimeout > 0))) {
    	/* place the current impulse generator into the impulse queue */
    	impulse_time = itimeout;
   	impulse_init_time = ftimeout;
    	impulse_fuzz = ifuzz;
    	unit->impulse_mode = 0;
    	unit->impulse.data = (int)unit;
    	unit->impulse.function = (void *)(int)del_impulse;
    	unit->impulse.expires = (force)?ftimeout:itimeout;
    	add_timer(&unit->impulse);
	if (debug&DEBUG_CONNECTION_QUEUE)
	    syslog(LOG_INFO,"Refreshing impulse generator: mode %d, time %d @ %d",unit->impulse_mode,unit->impulse.expires,time(0));
    }
}

static void log_packet(int accept, struct iphdr *pkt, int len,  int rule)
{
    char saddr[20], daddr[20];
    struct in_addr addr;
    int sport = 0, dport = 0;
    struct tcphdr *tcp = (struct tcphdr *)((char *)pkt + 4*pkt->ihl);
    struct udphdr *udp = (struct udphdr *)tcp;

    addr.s_addr = pkt->saddr;
    strcpy(saddr,inet_ntoa(addr));
    addr.s_addr = pkt->daddr;
    strcpy(daddr,inet_ntoa(addr));

    if (pkt->protocol == IPPROTO_TCP || pkt->protocol == IPPROTO_UDP)
	sport = ntohs(udp->source), dport = ntohs(udp->dest);

    if (pkt->protocol == IPPROTO_TCP) {
	syslog(LOG_INFO,
	    "filter %s rule %d proto %d len %d seq %x ack %x flags %02x%s%s%s%s%s%s packet %s,%d => %s,%d",
	    (accept)?"accepted":"ignored",rule,
	    pkt->protocol,
	    htons(pkt->tot_len),
	    htonl(tcp->th_seq), htonl(tcp->th_ack),
	    tcp->th_flags,
	    (tcp->th_flags&TH_FIN) ? " FIN" : "",
	    (tcp->th_flags&TH_SYN) ? " SYN" : "",
	    (tcp->th_flags&TH_RST) ? " RST" : "",
	    (tcp->th_flags&TH_PUSH) ? " PUSH" : "",
	    (tcp->th_flags&TH_ACK) ? " ACK" : "",
	    (tcp->th_flags&TH_URG) ? " URG" : "",
	    saddr, sport, daddr, dport);
    } else {
	syslog(LOG_INFO,
	    "filter %s rule %d proto %d len %d packet %s,%d => %s,%d",
	    (accept)?"accepted":"ignored",rule,
	    pkt->protocol,
	    htons(pkt->tot_len),
	    saddr, sport, daddr, dport);
    }
}

void print_filter(FW_Filter *filter)
{
    int i;
    syslog(LOG_INFO,"filter: prl %d log %d type %d cnt %d tm %d",
	filter->prule,filter->log,filter->type,
	filter->count,filter->timeout);
    for (i = 0; i < filter->count; i++) {
	syslog(LOG_INFO,"    term: shift %d op %d off %d%c msk %x tst %x",
	    filter->terms[i].shift, filter->terms[i].op,
	    filter->terms[i].offset&0x7f,
	    (filter->terms[i].offset&0x80)?'d':'h',
	    filter->terms[i].mask, filter->terms[i].test);
    }
}

/* Check if a packet passes the filters */
int check_firewall(int unitnum, unsigned char *pkt, int len)
{
    FW_unit *unit;
    FW_Filters *fw;
    unsigned char *data;
    FW_ProtocolRule *prule;
    FW_Term *term;
    int i,v,rule;
    clock_t clock = time(0);

    if (!initialized) init_units();

    if (unitnum < 0 || unitnum >= FW_NRUNIT) {
	/* FIXME: set an errorno? */
	return -1;
    }

    unit = &units[unitnum];
    fw = unit->filters;

    data = pkt + 4*((struct iphdr *)pkt)->ihl;

    rule = 1;
    while (fw) {
#ifdef 0
	print_filter(&fw->filt);
#endif
	/* is this rule currently applicable? */
	if ((unit->up && fw->filt.type == FW_TYPE_BRINGUP)
	   || (!unit->up && fw->filt.type == FW_TYPE_KEEPUP)
	   || !in_slot(fw->filt.times,&clock)
	   || fw->filt.type == FW_TYPE_IMPULSE)
	    goto next_rule;

	/* Check the protocol rule */
	prule = &unit->prules[fw->filt.prule];
	if (!(FW_PROTO_ALL(prule->protocol)
	|| prule->protocol == ((struct iphdr *)pkt)->protocol))
	    goto next_rule;

	/* Check the terms */
	for (i = 0;
	(fw->filt.count > FW_MAX_TERMS) || (i < fw->filt.count); i++) {
	    if (i > FW_MAX_TERMS && fw->filt.count == 0) {
		fw = fw->next, i = 0;
		if (fw == NULL) break;
	    }
	    term = &fw->filt.terms[i];
	    v = (ntohl(*(int *)(&(FW_IN_DATA(term->offset)?data:pkt)
				  [FW_OFFSET(term->offset)]))
		    >> term->shift) & term->mask;
#ifdef 0
	    syslog(LOG_INFO,"testing ip %x:%x data %x:%x mask %x shift %x test %x v %x",
		ntohl(*(int *)(&pkt[FW_OFFSET(term->offset)])),
		*(int *)(&pkt[FW_OFFSET(term->offset)]),
		ntohl(*(int *)(&data[FW_OFFSET(term->offset)])),
		*(int *)(&data[FW_OFFSET(term->offset)]),
		term->mask,
		term->shift,
		term->test,
		v);
#endif
	    switch (term->op) {
	    case FW_EQ: if (v != term->test) goto next_rule; break;
	    case FW_NE: if (v == term->test) goto next_rule; break;
	    case FW_GE: if (v >= term->test) goto next_rule; break;
	    case FW_LE: if (v <= term->test) goto next_rule; break;
	    }
	}
	/* Ok, we matched a rule. What are we suppose to do? */
#ifdef 0
	if (fw->filt.log)
#endif
        if (debug&DEBUG_FILTER_MATCH)
	    log_packet(fw->filt.type!=FW_TYPE_IGNORE,(struct iphdr *)pkt,len,rule);

	/* Check if this entry goes into the queue or not */
	if (fw->filt.type != FW_TYPE_IGNORE) {
	    /* Store the connection. Timeout = 0 means delete. */
	    FW_ID id;
	    for (i = 0; i < FW_ID_LEN; i++)
		id.id[i] = (FW_IN_DATA(prule->codes[i])?data:pkt)
			    [FW_OFFSET(prule->codes[i])];
	    add_connection(unit,&id,fw->filt.timeout);
	}
	/* Return 1 if accepting rule with non zero timeout, 0 otherwise */
	return (fw->filt.type != FW_TYPE_IGNORE && fw->filt.timeout > 0);

next_rule: /* try the next filter */
	fw = fw->next;
	rule++;
    }
    /* Failed to match any rule. This means we ignore the packet */
    if (debug&DEBUG_FILTER_MATCH)
        log_packet(0,(struct iphdr *)pkt,len,0);
    return 1;
}

int ctl_firewall(int op, struct firewall_req *req)
{
    FW_unit *unit;
    if (!initialized) init_units();

    /* Need to check that req is OK */

    if (req && req->unit >= FW_NRUNIT) return -1; /* ERRNO */

    if (req) unit = &units[req->unit];
    else unit = units;
    
    switch (op) {
    case IP_FW_QFLUSH:
	if (!req) return -1; /* ERRNO */
	{
	    FW_Connection *c;
	    while ((c = unit->connections->next) != unit->connections) {
		del_timer(&c->timer);
		del_connection(c);
		free((void *)c);
	    }
	    return 0;
	}
    case IP_FW_QCHECK:
	if (!req) return -1; /* ERRNO */
	
	fw_force_update(unit);
	fw_impulse_update(unit,0);

	return (unit->force == 2
		|| (unit->force == 0
		    && !(unit->up && unit->impulse_mode == 0
		   	 && (impulse_init_time > 0 || impulse_time > 0))
		    && unit->connections->next == unit->connections));


    case IP_FW_PFLUSH:
	if (!req) return -1; /* ERRNO */
	unit->nrules = 0;
	return 0;
    /* PFLUSH implies FFLUSH */
    case IP_FW_FFLUSH:
	if (!req) return -1; /* ERRNO */
	{
	    FW_Filters *next, *filt = unit->filters;
	    while (filt)
	    	{ next = filt->next; free(filt); filt = next; }
	    unit->filters = NULL;
	    unit->last = NULL;
	}
	return 0;
    case IP_FW_AFILT:
	if (!req) return -1; /* ERRNO */
	{
	    FW_Filters *filters = malloc(sizeof(FW_Filters));
	    if (filters == 0) {
		syslog(LOG_ERR,"Out of memory! AIIEEE!");
		return -1; /* ERRNO */
	    }
	    filters->next = 0;
	    filters->filt = req->fw_arg.filter;
	    if (unit->last) unit->last->next = filters;
	    if (!unit->filters) unit->filters = filters;
	    unit->last = filters;
	    unit->nfilters++;
	}
	return 0;
    case IP_FW_APRULE:
	if (!req) return -1; /* ERRNO */
	if (unit->nrules >= FW_MAX_PRULES) return -1; /* ERRNO */
	unit->prules[(int)unit->nrules] = req->fw_arg.rule;
	return unit->nrules++;
    /* Printing does nothing right now */
    case IP_FW_PCONN:
	if (!req) return -1; /* ERRNO */
	{
	    unsigned long atime = time(0);
            unsigned long tstamp = timestamp();
	    FW_Connection *c;
	    char saddr[20], daddr[20];
    	    struct in_addr addr;
	    syslog(LOG_INFO,"up = %d, forcing = %d, impulse = %d, iitime = %d, itime = %d, ifuzz = %d, itimeout = %d, timeout = %d, next alarm = %d",
		unit->up,unit->force, unit->impulse_mode, impulse_init_time, impulse_time,
		impulse_fuzz,
		unit->impulse.expected-tstamp,unit->force_etime-atime,next_alarm());
	    for (c=unit->connections->next; c!=unit->connections; c=c->next) {
                addr.s_addr = c->id.id[1] + (c->id.id[2]<<8)
                        + (c->id.id[3]<<16) + (c->id.id[4]<<24);
                strcpy(saddr,inet_ntoa(addr));
                addr.s_addr = c->id.id[5] + (c->id.id[6]<<8)
                        + (c->id.id[7]<<16) + (c->id.id[8]<<24);
                strcpy(daddr,inet_ntoa(addr));
                syslog(LOG_INFO,
                        "ttl %d, %d - %s/%d => %s/%d",
                        c->timer.expected-tstamp, c->id.id[0],
                        saddr, c->id.id[10]+(c->id.id[9]<<8),
                        daddr, c->id.id[12]+(c->id.id[11]<<8));
	    }
	    return 0;
	}
	return 0;
    case IP_FW_PPRULE:
	if (!req) return -1; /* ERRNO */
	return 0;
    case IP_FW_PFILT:
	if (!req) return -1; /* ERRNO */
	return 0;
    /* Opening and closing firewalls is cooperative right now.
     * Also, it does nothing to change the behavior of a device
     * associated with the firewall.
     */
    case IP_FW_OPEN:
	{
	    int i;
	    for (i = 0; i < FW_NRUNIT; i++)
		if (units[i].used == 0) {
		    struct firewall_req mreq;
		    mreq.unit = i;
		    ctl_firewall(IP_FW_PFLUSH,&mreq);
		    units[i].used = 1;
		    units[i].force_etime = 0;
		    units[i].impulse_etime = 0;
		    return i;
		}
	    return -1;	/* ERRNO */
	}
    case IP_FW_CLOSE:
	{
	    struct firewall_req mreq;
	    if (!req) return -1; /* ERRNO */
	    mreq.unit = req->unit;
	    ctl_firewall(IP_FW_PFLUSH,&mreq);
	    unit->used = 0;
	    return 0;
	}
    case IP_FW_UP:
	unit->up = 1;
	fw_force_update(unit);
	fw_impulse_update(unit,1);
	return 0;

    case IP_FW_DOWN:
	unit->up = 0;
	/* turn off the impulse generator */
	del_timer(&unit->impulse);
	return 0;
    }
    return -1; /* ERRNO */
}
