
/*
 * Fast shell-like pattern matching
 *
 * Copyright 1996 by E. Toernig
 *
 */

#include <stdio.h>
#include <stdlib.h>
#include <linux/string.h>	/* the asm versions */
#include "generic.h"
#include "cattr.h"
#include "fmatch.h"

#ifdef DEBUG
#define debug(x...) fprintf(stderr, x ##)
#else
#define debug(x...)
#endif

#define set_size (256/8)

enum token
{
    EXIT_FALSE, EXIT_TRUE, NO_CASE, MATCH_EMPTY, CHECK_LEN, FIND_ARG,
    LFIND_CHR, LFIND_SET, LMATCH_ANY, LMATCH_STR, LMATCH_SET,
    RFIND_CHR, RFIND_SET, RMATCH_ANY, RMATCH_STR, RMATCH_SET
};

enum match_dir
{
    M_LEFT, M_RIGHT, M_UNKNOWN
};


static inline void *
memrchr(const void *_m, char c, size_t n)
{
    char *m = (char *)_m + n++;

    while (--n)
	if (*--m == c)
	    return m;
    return 0;
}


static void
add_set(u8 *set, u8 c)
{
    set[c / 8] |= (1 << (c % 8));
}

static int
tst_set(u8 *set, u8 c)
{
    return set[c / 8] & (1 << (c % 8));
}


static u8 *
parse_set(u8 *str, u8 *set)
{
    int c, rev = 0, last = 999;

    if (*str++ != '[')
	return 0;

    memset(set, 0, set_size);

    if (*str == '^')
	rev = set_size, str++;

    while ((c = *str++) && c != ']')
    {
	if (c == '-')
	{
	    if ((c = *str++) == 0)
		break;
	    while (last < c)
		add_set(set, last++);
	}
	add_set(set, last = c);
    }

    while (rev)
	set[--rev] ^= -1;

    return c ? str : 0;
}


/*
	The pattern consists of an alternating sequence of fixed length
	subpatterns (literal chars, sets, and '?' {marked 'a'}) and
	variable length subpatterns ('*').

	Examples: '*a*a*a*'   'a*a*a*'   '*a*a*a'   'a*a*a'

	First:
	    The sequence '**' is equivalent to '*'.

	    The sequence '*?*' may be reduced to '*?'.

	    The sequence '*?' is equivalent to '?*'.

	    Thats means: the '?' in a sequence of '*' and '?'
	    may be collected at one end of the sequence.

	Second:
	    The sum of all fixed length patterns must be <= length of string.

	Third:
	    Examples 2-4 are a special case of example 1. We can match
	    the first/last 'a' and remove it giving somthing like example 1.

	Fourth:
	    Assuming a sequence like '*a*b*c*': Search for 'a' from the
	    left and 'c' from the right. If found, remove '*a' and 'b*'
	    from the pattern and try to match the remaining pattern on
	    the string which begins at the end of 'a' and ends at the
	    beginning of 'c'.
*/

struct pat_dat
{
    char type;
    char pad;
    short len;
    void *data;
};

struct pat_stat
{
    int nodes;	/* required nodes */
    int minlen;	/* number of characters required */
    int star;	/* number of '*' */
    int anychr;	/* number of '?' */
    int anystr;	/* number of '?' strings */
    int litchr;	/* number of literal chars */
    int litstr;	/* number of literal strings */
    int setchr;	/* number of char sets */
    int setstr;	/* number of char set strings */
};



static int
stat_pattern(char *str, struct pat_stat *st)
{
    u8 tmp_set[set_size];
    int i, j;

    if (str == 0)
	return 0;

    memset(st, 0, sizeof(*st));

    while (*str)
    {
	for (i = j = 0; *str == '*' || *str == '?'; str++)
	    *str == '*' ? i++ : j++;
	if (i)
	    st->star++;
	if (j)
	    st->anychr += j, st->anystr++;

	for (i = 0; *str == '['; i++)
	    if ((str = parse_set(str, tmp_set)) == 0)
		return 0;
	if (i)
	    st->setchr += i, st->setstr++;

	for (i = 0; *str && *str != '*' && *str != '?' && *str != '['; i++)
	    str++;
	if (i)
	    st->litchr += i, st->litstr++;
    }

    st->nodes = st->star + st->anystr + st->setstr + st->litstr;
    st->minlen = st->anychr + st->setchr + st->litchr;

    return 1;
}

static void
break_pattern(char *str, struct pat_dat *p, u8 *s)
{
    u8 *ss;
    int i, j;

    while (*str)
    {
	for (i = j = 0; *str == '*' || *str == '?'; str++)
	    *str == '*' ? i++ : j++;
	if (j)
	    p->type = '?', p->len = j, p->data = 0, p++;
	if (i)
	    p->type = '*', p->len = 0, p->data = 0, p++;

	for (i = 0, ss = s; *str == '['; s += set_size, i++)
	    str = parse_set(str, s);
	if (i)
	    p->type = '[', p->len = i, p->data = ss, p++;

	for (i = 0, ss = str; *str && *str != '*' && *str != '?' && *str != '['; i++)
	    str++;
	if (i)
	    p->type = 'a', p->len = i, p->data = ss, p++;
    }
}

static u8 *
comp_match(enum match_dir dir, u8 *prg, struct pat_dat *p)
{
    int i;

    debug("comp_match %c %d\n", p->type, p->len);

    switch (p->type)
    {
	case '-':
	    break;
	case 'a':
	    *prg++ = dir == M_LEFT ? LMATCH_STR : RMATCH_STR;
	    *prg++ = p->len;
	    memcpy(prg, p->data, p->len);
	    prg += p->len;
	    break;
	case '?':
	    *prg++ = dir == M_LEFT ? LMATCH_ANY : RMATCH_ANY;
	    *prg++ = p->len;
	    break;
	case '[':
	    *prg++ = dir == M_LEFT ? LMATCH_SET : RMATCH_SET;
	    *prg++ = p->len;
	    if (dir == M_LEFT)
		for (i = 0; i < p->len; prg += set_size)
		    memcpy(prg, p->data + i++ * set_size, set_size);
	    else
		for (i = p->len; i-- > 0; prg += set_size)
		    memcpy(prg, p->data + i * set_size, set_size);
	    break;
	default:
	    debug("***comp_match: ILL_TYPE***\n");
	    abort();
    }
    return prg;
}


void *
fmatch_compile(char *str, int flags)
{
    struct pat_stat st;
    struct pat_dat *pdp;
    u8 *csp, *prg, *prg_save, *x;
    int prgmem, l, i;

    if (not stat_pattern(str, &st))
	return 0;

    if (st.minlen > 255)	/* sorry, we can't handle more */
	return 0;

    if (flags & FMATCH_AUTOSS)
	if (st.star + st.anychr + st.setchr == 0)
	    flags |= FMATCH_SUBSTR;

    if (flags & (FMATCH_NOCASE | FMATCH_SUBSTR))
    {
	u8 *s = x = alloca(strlen(str) + 3);

	if (flags & FMATCH_SUBSTR)
	    *x++ = '*';
	while (*str)
	    *x++ = (flags & FMATCH_NOCASE) ? to_upper(*str++) : *str++;
	if (flags & FMATCH_SUBSTR)
	    *x++ = '*';
	*x++ = 0;
	str = s;

	if (flags & FMATCH_SUBSTR)
	    stat_pattern(str, &st);
    }

    /* worst case prg memory calculation */
    prgmem = 32;
    prgmem += (st.star + 1) * 4;
    prgmem += st.litstr * 4 + st.litchr;
    prgmem += st.setstr * (3 + set_size) + st.setchr * set_size;
    prgmem += st.anystr * 2;

    pdp = alloca(st.nodes * sizeof(*pdp));
    csp = alloca(st.setchr * set_size);
    prg = alloca(prgmem);
    prg_save = prg;

    if (flags & FMATCH_NOCASE)
	if (st.litchr + st.setchr)
	{
	    x = alloca(strlen(str)+1);
	    for (i = 0; x[i] = to_upper(str[i]); ++i)
		;
	    str = x;
	    *prg++ = NO_CASE;
	}

    break_pattern(str, pdp, csp);

    l = st.nodes;

    if (l && (pdp[0].len || pdp[l-1].len))
	*prg++ = CHECK_LEN, *prg++ = st.minlen;

    while (l && pdp[0].len)
	prg = comp_match(M_LEFT, prg, pdp++), l--;

    while (l && pdp[l-1].len)
	prg = comp_match(M_RIGHT, prg, pdp + --l);

    while (l > 2)	/* '*xxx*' NOTE: pdp[0] and pdp[l-1] is '*' */
    {
	int f, a, b;

	for (a = i = 0; i < l; ++i)
	    a += pdp[i].len;

	f = 0;
	for (b = 0, i = 1; pdp[i].len; b += pdp[i++].len)
	    if (pdp[i].type == 'a')
	    {
		debug("compile LFIND_CHR '%c' %d\n", *(u8 *)pdp[i].data, pdp[i].len);
		*prg++ = FIND_ARG;
		*prg++ = b + pdp[i].len - 1;
		*prg++ = a - b - pdp[i].len;
		*prg++ = b + pdp[i].len - 1;
		*prg++ = LFIND_CHR;
		*prg++ = ((u8 *)pdp[i].data)[pdp[i].len-1];
		if (pdp[1].len == 1 && pdp[2].len == 0)
		    pdp[1].type = (l == 3) ? '-' : '?';
		for (i = 1; pdp[i].len; ++i)
		    prg = comp_match(M_LEFT, prg, pdp + i);
		pdp += i;
		l -= i;
		f = 1;
		break;
	    }
	if (f == 0)
	    for (b = 0, i = l - 1; pdp[--i].len; b += pdp[i].len)
		if (pdp[i].type == 'a')
		{
		    debug("compile RFIND_CHR '%c' %d\n", *(u8 *)pdp[i].data, pdp[i].len);
		    *prg++ = FIND_ARG;
		    *prg++ = a - b - pdp[i].len;
		    *prg++ = b + pdp[i].len - 1;
		    *prg++ = b + pdp[i].len;
		    *prg++ = RFIND_CHR;
		    *prg++ = *(u8 *)pdp[i].data;
		    if (pdp[l-1].len == 1 && pdp[l-2].len == 0)
			pdp[l-1].type = (l == 3) ? '-' : '?';
		    for (i = l - 1; pdp[--i].len; )
			prg = comp_match(M_RIGHT, prg, pdp + i);
		    l = i + 1;
		    f = 1;
		    break;
		}
	if (f == 0)
	    for (b = 0, i = 1; pdp[i].len; b += pdp[i++].len)
		if (pdp[i].type == '[')
		{
		    debug("compile LFIND_SET %d\n", pdp[i].len);
		    *prg++ = FIND_ARG;
		    *prg++ = b;
		    *prg++ = a - b - 1;
		    *prg++ = b;
		    *prg++ = LFIND_SET;
		    memcpy(prg, pdp[i].data, set_size);
		    prg += set_size;
		    if (pdp[1].len == 1 && pdp[2].len == 0)
			pdp[1].type = (l == 3) ? '-' : '?';
		    for (i = 1; pdp[i].len; ++i)
			prg = comp_match(M_LEFT, prg, pdp + i);
		    pdp += i;
		    l -= i;
		    f = 1;
		    break;
		}
	if (f == 0)
	{
	    debug("***compile_pattern: NO_CHR/NO_SET***\n");
	    abort();
	}
    }

    if (l == 0)
	*prg++ = MATCH_EMPTY;
    else if (l == 1)
	*prg++ = EXIT_TRUE;
    else
    {
	debug("***compile_pattern: l==2***\n");
	abort();
    }

    if (prg - prg_save > prgmem)
    {
	debug("***compile_pattern: out of mem***\n");
	abort();
    }

    debug("memsize = %d, memdiff = %d\n", prg - prg_save, prgmem - (prg - prg_save));

    if (x = malloc(prg - prg_save))
	memcpy(x, prg_save, prg - prg_save);

    return x;
}




int
fmatch_exec(void *_prg, char *s, char *e)
{
    u8 *prg = _prg;
    u8 *fs = s;			/* find start */
    u8 *fe = e;			/* find end */
    int fo = 0;			/* find offset */
    u8 *fp = (char[]){ EXIT_FALSE }; /* find restart-prg */
    int i;

    for (;;)
	switch (*prg++)
	{
	    case EXIT_FALSE:
		debug("EXIT_FALSE\n");
	    	return 0;
	    case EXIT_TRUE:
		debug("EXIT_TRUE\n");
	    	return 1;
	    case MATCH_EMPTY:
		debug("MATCH_EMPTY\n");
	    	return e == s;
	    case NO_CASE:
		debug("NO_CASE\n");
		fs = fe = alloca(e - s);
		while (s < e)
		    *fe++ = to_upper(*s++);
		s = fs;
		e = fe;
		break;
	    case CHECK_LEN:
		debug("CHECK_LEN %d\n", prg[0]);
	    	if (e - s < *prg++)
		    return 0;
		break;
	    case FIND_ARG:
		debug("FIND_ARG %d %d %d\n", prg[0], prg[1], prg[2]);
	    	fs = s + *prg++;
		fe = e - *prg++;
		fo = *prg++;
		fp = prg;
		if (fs >= fe)
		    return 0;
		break;
	    case LFIND_CHR:
		debug("LFIND_CHR '%c'\n", prg[0]);
		fs = memchr(fs, *prg++, fe - fs);
		if (fs == 0)
		    return 0;
		s = fs++ - fo;
		break;
	    case LFIND_SET:
		debug("LFIND_SET ...\n");
		while (fs < fe && not tst_set(prg, *fs))
		    fs++;
		if (fs == fe)
		    return 0;
		prg += set_size;
		s = fs++ - fo;
		break;
	    case LMATCH_ANY:
		debug("LMATCH_ANY %d\n", prg[0]);
	    	s += *prg++;
		break;
	    case LMATCH_STR:
		debug("LMATCH_STR '%.*s'%d at '%.*s'\n", prg[0], prg+1, prg[0], e - s, s);
		i = *prg++;
		if (memcmp(s, prg, i))
		    prg = fp;
		else
		    prg += i, s += i;
		break;
	    case LMATCH_SET:
		debug("LMATCH_SET %d\n", prg[0]);
		i = *prg++;
		while (i && tst_set(prg, *s))
		{
		    i--;
		    s++;
		    prg += set_size;
		}
		if (i)
		    prg = fp;
		break;
	    case RFIND_CHR:
		debug("RFIND_CHR '%c'\n", prg[0]);
		fe = memrchr(fs, *prg++, fe - fs);
		if (fe == 0)
		    return 0;
		e = fe + fo;
		break;
	    case RFIND_SET:
		debug("RFIND_SET ...\n");
		while (fs < fe && not tst_set(prg, *--fe))
		    ;
		if (fs == fe)
		    return 0;
		prg += set_size;
		e = fe + fo;
		break;
	    case RMATCH_ANY:
		debug("RMATCH_ANY %d\n", prg[0]);
	    	e -= *prg++;
		break;
	    case RMATCH_STR:
		debug("RMATCH_STR '%.*s'\n", prg[0], prg+1);
		i = *prg++;
		e -= i;
		if (memcmp(e, prg, i))
		    prg = fp;
		else
		    prg += i;
		break;
	    case RMATCH_SET:
		debug("RMATCH_SET %d\n", prg[0]);
		i = *prg++;
		while (i && tst_set(prg, *--e))
		{
		    i--;
		    prg += set_size;
		}
		if (i)
		    prg = fp;
		break;
	    default:
	    	debug("***ILL_TOKEN (%d)***\n", prg[-1]);
	}
}


int
fmatch(void *pat, char *str)
{
    if (*(u8*)pat == EXIT_TRUE)
	return 1;
    if (*(u8*)pat == MATCH_EMPTY)
	return *str == 0;
    return fmatch_exec(pat, str, str + strlen(str));
}


#ifdef DEBUG_FMATCH

int
main(int argc, char **argv)
{
    char *x;

    while (--argc)
	if (x = fmatch_compile(*++argv, FMATCH_NOCASE|FMATCH_AUTOSS))
	{
	    if (fmatch(x, "abcdefgHIjklMnopqrStuvwxyz"))
		printf("'%s' matches 'abc..xyz'\n", *argv);
	    else
		printf("'%s' doesn't match 'abc..xyz'\n", *argv);
	    free(x);
	}
	else
	    fprintf(stderr, "compile error\n");
}

#endif
