
/* $Id: iprules.c,v 1.2 2008/10/31 13:58:14 jayrfink Exp $ */

#include "nw.h"

/* readline() */
ssize_t readline(int fd, void *vptr, size_t maxlen);

/* Match functions prototypes */
int match_proto(const struct rule *r, u_int8_t p);
int match_ip(const struct in_addr ip, const uint8_t prefix,
             struct in_addr addr);
int match_port(const struct port p, uint16_t port);
int match_flags(const struct rule *r, tcp_hdr h);

/* Get functions prototypes */
int get_type(struct rule *r, char *t);
int call_type(struct rule *r, char *t);
int get_proto(struct rule *r, char *p);
int call_proto(struct rule *r, char *p);
int get_ip(struct rule *r, struct in_addr *n_addr, uint32_t * n_prefix,
           char *addr);
int call_ip(struct rule *r, struct in_addr *n_addr, uint32_t * n_prefix,
           char *addr);
int get_port(struct rule *r, struct port *port, char *p);
int call_port(struct rule *r, struct port *port, char *p);
int get_flags(struct rule *r, char *f);
int call_flags(struct rule *r, char *f);
int get_score(struct rule *r, uint8_t s);
int call_score(struct rule *r, uint8_t s);

/* Match packet to rule */
int match_packet(const u_char * p, const struct rule *r)
{
    eth_hdr *ethernet;          /* The ethernet header */
    ip_hdr *ip;                 /* The IP header */
    tcp_hdr *tcp;               /* The TCP header */
    udp_hdr *udp;               /* The UDP header */
    uint32_t id;                /* Host id */

    /* Extract ip,tcp,udp headers from packet */
    ethernet = (eth_hdr *) (p); /* Pointer to ethernet header */
    ip = (ip_hdr *) (p + sizeof(eth_hdr));  /* Pointer to IP header */
    tcp = (tcp_hdr *) (p + sizeof(eth_hdr) + sizeof(ip_hdr)); /* Pointer to IP header */
    udp = (udp_hdr *) (p + sizeof(eth_hdr) + sizeof(ip_hdr)); /* Pointer to IP header */

    /* Match protocol */
    if (!match_proto(r, ip->ip_p)) {
#ifdef DEBUG
        printf("Proto MISS\n");
#endif
        return -1;
    }
#ifdef DEBUG
    else
        printf("Proto HIT");
#endif

    /* Match IP */
    /* Source */
    if (!match_ip(r->src_ip, r->src_prefix, ip->ip_src)) {
#ifdef DEBUG
        printf(", Source IP MISS\n");
#endif
        return -1;
    }
#ifdef DEBUG
    else
        printf(", Source IP HIT");
#endif

    /* Destination */
    if (!match_ip(r->dst_ip, r->dst_prefix, ip->ip_dst)) {
#ifdef DEBUG
        printf(", Dest IP MISS\n");
#endif
        return -1;
    }
#ifdef DEBUG
    else
        printf(", Dest IP HIT");
#endif

    /* Match another fields */
    switch (r->proto) {
    case IPPROTO_ICMP:
#ifdef DEBUG
        printf(", Match ICMP packet");
#endif
        break;
    case IPPROTO_TCP:
    case IPPROTO_UDP:
#ifdef DEBUG
        printf(", Match TCP/UDP packet");
#endif
        if (!match_port(r->src_port, ntohs(tcp->source))) {
#ifdef DEBUG
            printf(", Source PORT MISS\n");
#endif
            return -1;
        }
#ifdef DEBUG
        else
            printf(", Source PORT HIT");
#endif

        if (!match_port(r->dst_port, ntohs(tcp->dest))) {
#ifdef DEBUG
            printf(", Dest PORT MISS\n");
#endif
            return -1;
        }
#ifdef DEBUG
        else
            printf(", Dest PORT HIT");
#endif
        break;
    }

    /* Match flags */
    if (r->proto == IPPROTO_TCP)
        if (!match_flags(r, *tcp)) {
#ifdef DEBUG
            printf(", Flags MISS\n");
#endif
            return -1;
        }
#ifdef DEBUG
        else
            printf(", Flags HIT");
#endif

#ifdef DEBUG
    printf(", Score %d\n", r->score);
#endif

    return 0;
}

/* Match protocol in packet */
int match_proto(const struct rule *r, uint8_t p)
{
    if (r->proto == p)
        return 1;
    else
        return 0;
}

/* Match IP */
int
match_ip(const struct in_addr ip, const uint8_t prefix,
         struct in_addr addr)
{
    if ((ip.s_addr & prefix) == (addr.s_addr & prefix))
        return 1;
    else
        return 0;
}

/* Match port */
int match_port(const struct port p, uint16_t port)
{
    if ((port >= p.begin) && (port <= p.end))
        return 1;
    else
        return 0;
}

/* Match flags */
int match_flags(const struct rule *r, tcp_hdr h)
{
    if ((r->flags == TH_FIN) && (h.fin))
        return 1;
    else if ((r->flags == TH_SYN) && (h.syn))
        return 1;
    else if ((r->flags == TH_RST) && (h.rst))
        return 1;
    else if ((r->flags == TH_PUSH) && (h.psh))
        return 1;
    else if ((r->flags == TH_ACK) && (h.ack))
        return 1;
    else if ((r->flags == TH_URG) && (h.urg))
        return 1;
    else
        return 0;
}

/* Parse and validate rule */
struct rule *parse_rule(char *str)
{
    char type[16];
    char proto[16];
    char s_port[16], d_port[16];
    char s_ip[16], d_ip[16];
    char flags[16];
    u_int score;

    struct rule *rule;

    rule = malloc(sizeof(struct rule));
    bzero(rule, sizeof(struct rule));

    if ((sscanf
         (str,
          "%s proto %s from %s port %s to %s port %s flags %s score %d",
          type, proto, s_ip, s_port, d_ip, d_port, flags, &score) == 8)
        ||
        (sscanf
         (str, "%s proto %s from %s port %s to %s port %s score %d", type,
          proto, s_ip, s_port, d_ip, d_port, &score) == 7)
        ||
        (sscanf
         (str, "%s proto %s from %s to %s score %d", type, proto, s_ip,
          d_ip, &score) == 5)) {

        /* Type */
        call_type(rule, type);

        /* Protocol */
        call_proto(rule, proto);

        /* Src IP */
        call_ip(rule, &rule->src_ip, &rule->src_prefix, s_ip);

        /* Dst IP */
        call_ip(rule, &rule->dst_ip, &rule->dst_prefix, d_ip);

        /* Ports only for TCP/UDP */
        if ((rule->proto == IPPROTO_TCP) || (rule->proto == IPPROTO_UDP)) {
            /* Src Port */
            call_port(rule, &rule->src_port, s_port);

            /* Dst Port */
            call_port(rule, &rule->dst_port, d_port);
        }

        /* Flags only for TCP */
        if (rule->proto == IPPROTO_TCP) {
            /* Flags */
            call_flags(rule, flags);
        }

        /* Score */
        call_score(rule, score);

    } else
        printf("Invalid rule : %s\n", str);

#ifdef DEBUG
    printf("Rule parsed ok\n");
#endif

    return rule;
}

int init_rules2(void)
{
    struct rule *r;
    int i;

    i = 0; 
    while (default_rules[i][0] != 0) {
        if ((r = parse_rule(default_rules[i])) != NULL) {
            rules_num++;
#ifdef DEBUG
                printf("Rule: %s\n", default_rules[i]);
                printf("Parsing line : %d\n", rules_num);
#endif

            rules = nrealloc(rules, rules_num * sizeof(struct rule));
            rules[rules_num - 1] = *r;
        } else  {
		return -1;
	}

        i++;
    }

	return 0;
}

/* Realine from descriptor */
ssize_t readline(int fd, void *vptr, size_t maxlen)
{
    ssize_t n, rc;
    char c, *ptr;

    ptr = vptr;
    for (n = 1; n < maxlen; n++) {
      again:
        if ((rc = read(fd, &c, 1)) == 1) {
            *ptr++ = c;
            if (c == '\n')
                break;
        } else if (rc == 0) {
            if (n == 1)
                return 0;
            else
                break;
        } else {
            if (errno == EINTR)
                goto again;
            return -1;
        }
    }

    *ptr = 0;
    return n;
}

/* get_type */
int get_type(struct rule *r, char *t)
{
    if (strncasecmp(t, "normal", 16) == 0)
        r->type = TYPE_NORMAL;
    else if (strncasecmp(t, "quick", 16) == 0)
        r->type = TYPE_QUICK;
    else
        return -1;

    return 0;
}

/* call_type */
int call_type(struct rule *r, char *t)
{
    if (get_type(r, t) != 0) {
        printf("Invalid type: %s\n", t);
        free(r);
        exit(-1);
    }
}

/* get_proto */
int get_proto(struct rule *r, char *p)
{
    if (strncasecmp(p, "icmp", 16) == 0)
        r->proto = IPPROTO_ICMP;
    else if (strncasecmp(p, "udp", 16) == 0)
        r->proto = IPPROTO_UDP;
    else if (strncasecmp(p, "tcp", 16) == 0)
        r->proto = IPPROTO_TCP;
    else
        return -1;

    return 0;
}


/* call_proto */
int call_proto(struct rule *r, char *p)
{
    if (get_proto(r, p) != 0) {
        printf("Invalid proto: %s\n", p);
        free(r);
        exit(-1);
    }
}

/* get_ip */
int
get_ip(struct rule *r, struct in_addr *n_addr, uint32_t * n_prefix, char *addr)
{
    char ip[16];
    uint8_t a[4];
    uint32_t prefix;
    uint8_t p;
    uint8_t i;

    /* any IP */
    if (strncasecmp(addr, "any", 16) == 0) {
        sprintf(ip, "%d.%d.%d.%d", 0, 0, 0, 0);
        prefix = 0;
        p = 0;
        /* xxx.xxx.xxx.xxx/xx */
    } else
        if (sscanf(addr, "%d.%d.%d.%d/%d", &a[0], &a[1], &a[2], &a[3], &p)
            != 5) {
        return -1;
    } else {
        /* Prepare IP */
        sprintf(ip, "%d.%d.%d.%d", a[0], a[1], a[2], a[3]);
        if (inet_addr(ip) == INADDR_NONE) {
            return -1;
        }

        /* Prepare prefix */
		prefix = netmask(p);
    }

#ifdef DEBUG
    printf("IP: %s , PREFIX: /%d %x\n", ip, p, prefix);
#endif

    /* Initialize IP and prefix in rule */
    n_addr->s_addr = inet_addr(ip);
    *n_prefix = prefix;

    return 0;
}

/* call_ip */
int
call_ip(struct rule *r, struct in_addr *n_addr, uint32_t * n_prefix, char *addr)
{
    if (get_ip(r, n_addr, n_prefix, addr) != 0) {
        printf("Invalid ip: %s\n", addr);
        free(r);
        exit(-1);
    }
}

/* get_port */
int get_port(struct rule *r, struct port *port, char *p)
{
    u_int16_t b, e;             /* begin & and of range */

    /* ports range xxx-xxx */
    if (sscanf(p, "%d-%d", &b, &e) == 2) {
        if (b > e) {
            printf("Invalid ports range: %s\n", p);
            return -1;
        }
        port->begin = b;
        port->end = e;
        /* single port xxx */
    } else if (sscanf(p, "%d", &b) == 1) {
        port->begin = b;
        port->end = b;
        /* any port */
    } else if (strncasecmp(p, "any", 16) == 0) {
        port->begin = 0;
        port->end = 65535;
        /* invalid ports */
    } else
        return -1;

    return 0;
}

/* call_port */
int call_port(struct rule *r, struct port *port, char *p)
{
    if (get_port(r, port, p) != 0) {
        printf("Invalid ports: %s\n", p);
        free(r);
        exit(-1);
    }
}

/* get_flags */
int get_flags(struct rule *r, char *f)
{
    if (strncasecmp(f, "fin", 16) == 0)
        r->flags = TH_FIN;
    else if (strncasecmp(f, "syn", 16) == 0)
        r->flags = TH_SYN;
    else if (strncasecmp(f, "rst", 16) == 0)
        r->flags = TH_RST;
    else if (strncasecmp(f, "push", 16) == 0)
        r->flags = TH_PUSH;
    else if (strncasecmp(f, "ack", 16) == 0)
        r->flags = TH_ACK;
    else if (strncasecmp(f, "urg", 16) == 0)
        r->flags = TH_URG;
    else if (strncasecmp(f, "any", 16) == 0)
        r->flags = 0;
    else
        return -1;

    return 0;
}

/* call_flags */
int call_flags(struct rule *r, char *f)
{
    if (get_flags(r, f) != 0) {
        printf("Invalid flags: %s\n", f);
        free(r);
        exit(-1);
    }
}

/* get_score */
int get_score(struct rule *r, uint8_t s)
{
    r->score = s;
    return 0;
}

/* call_score */
int call_score(struct rule *r, uint8_t s)
{
    if (get_score(r, s) != 0) {
        printf("Invalid score: %s\n", s);
        free(r);
        exit(-1);
    }
}
