/* Copyright (C) 1979-1996 TcX AB & Monty Program KB & Detron HB
   
   This software is distributed with NO WARRANTY OF ANY KIND.  No author or
   distributor accepts any responsibility for the consequences of using it, or
   for whether it serves any particular purpose or works at all, unless he or
   she says so in writing.  Refer to the Free Public License (the "License")
   for full details.
   
   Every copy of this file must include a copy of the License, normally in a
   plain ASCII text file named PUBLIC.  The License grants you the right to 
   copy, modify and redistribute this file, but only under certain conditions
   described in the License.  Among other things, the License requires that
   the copyright notice and this notice be preserved on all copies. */

/* A lexical scanner on a temporary buffer with a yacc interface */

#include "mysql_priv.h"
#include "sql_lex.h"
#include <m_ctype.h>
#include <hash.h>

/* Macros to look like lex */

#define yyGet()		*(lex->ptr++)
#define yyGetLast()	lex->ptr[-1]
#define yyPeek()	lex->ptr[0]
#define yyUnget()	lex->ptr--
#define yySkip()	lex->ptr++

static pthread_key_t THR_LEX;

#define is_op(c)	(c == '<' || c == '>' || c == '=')
#define is_bool(c)	(c == '&' || c == '|')

#define TOCK_NAME_LENGTH 20

typedef struct st_symbol {
  char	*name;
  int	tok;
  uchar length;
} SYMBOL;

// Symbols are breaked in to separated arrays to allow fieldnames with
// same name as functions
// Theese are kept sorted for human lookup (the symbols are hashed)

static SYMBOL symbols[] = {
  { "&&",	AND,0},
  { "<",	LT,0},
  { "<=",	LE,0},
  { "<>",	NE,0},
  { "=",	EQ,0},
  { ">",	GT_SYMBOL,0},
  { ">=",	GE,0},
  { "add",	ADD,0},
  { "all",	ALL,0},
  { "alter",	ALTER,0},
  { "and",	AND,0},
  { "as",	AS,0},
  { "asc",	ASC,0},
  { "auto_increment",AUTO_INC,0},
  { "bigint",   BIGINT,0},
  { "blob",	BLOB,0},
  { "by",	BY,0},
  { "cascade",  CASCADE,0},
  { "char",	CHAR,0},
  { "character",CHAR,0},
  { "change",	CHANGE,0},
  { "column",  	COLUMN_SYMBOL,0},
  { "columns",  COLUMNS,0},
  { "create",	CREATE,0},
  { "data",	DATA_SYMBOL,0},
  { "databases",DATABASES,0},
  { "date",	DATE_SYMBOL,0},
  { "dec",	DECIMAL,0},
  { "decimal",	DECIMAL,0},
  { "default",	DEFAULT,0},
  { "delete",	DELETE,0},
  { "desc",	DESC,0},
  { "describe",	DESCRIBE,0},
  { "distinct", DISTINCT,0},
  { "double",	DOUBLE,0},
  { "drop",	DROP,0},
  { "escaped",	ESCAPED,0},
  { "enclosed", ENCLOSED,0},
  { "fields",	COLUMNS,0},
  { "float",	FLOAT,0},
  { "float4",	FLOAT,0},
  { "float8",	DOUBLE,0},
  { "foreign",  FOREIGN,0},
  { "from",	FROM,0},
  { "group",	GROUP,0},
  { "having",	HAVING,0},
  { "ignore",	IGNORE,0},
  { "in",	IN_SYMBOL,0},
  { "index",	INDEX,0},
  { "infile",	INFILE,0},
  { "insert",	INSERT,0},
  { "int",	INT,0},
  { "integer",	INT,0},
  { "int1",	TINYINT,0},
  { "int2",	SMALLINT,0},
  { "int3",	MEDIUMINT,0},
  { "int4",	INT,0},
  { "int8",	BIGINT,0},
  { "into",	INTO,0},
  { "is",	IS,0},
  { "key",	KEY_SYMBOL,0},
  { "keys",	KEYS,0},
  { "like",	LIKE,0},
  { "lines",	LINES,0},
  { "limit",	LIMIT,0},
  { "load",	LOAD,0},
  { "longblob",	LONGBLOB,0},
  { "mediumblob",MEDIUMBLOB,0},
  { "mediumint",MEDIUMINT,0},
  { "numeric",	NUMERIC,0},
  { "not",	NOT,0},
  { "null",	NULL_SYMBOL,0},
  { "on",	ON,0},
  { "optionally",OPTIONALLY,0},
  { "or",	OR,0},
  { "order",	ORDER_SYMBOL,0},
  { "outfile",	OUTFILE,0},
  { "primary",	PRIMARY_SYMBOL,0},
  { "procedure", PROCEDURE,0},
  { "precision", PRECISION,0},
  { "real",	REAL,0},
  { "references",REFERENCES,0},
  { "regexp",	REGEXP,0 },
  { "replace",	REPLACE,0},
  { "restrict", RESTRICT,0},
  { "rlike",	REGEXP,0 },			// Like in mSQL2
  { "select",	SELECT_SYMBOL,0},
  { "set",	SET,0},
  { "show",	SHOW,0},
  { "smallint", SMALLINT,0},
  { "straight_join", STRAIGHT_JOIN,0},
  { "table",	TABLE_SYMBOL,0},
  { "tables",	TABLES,0},
  { "terminated",TERMINATED,0},
  { "time",	TIME_SYMBOL,0},
  { "timestamp",TIMESTAMP,0},
  { "tinyblob", TINYBLOB,0},
  { "tinyint",	TINYINT,0},
  { "unique",	UNIQUE_SYMBOL,0},
  { "unsigned", UNSIGNED,0},
  { "update",	UPDATE_SYMBOL,0},
  { "values",	VALUES,0},
  { "varchar",	VARCHAR,0},
  { "varying",	VARYING,0},
  { "where",	WHERE,0},
  { "zerofill", ZEROFILL,0},
  { "||",	OR,0},
};


static SYMBOL sql_functions[] = {
  { "abs",	ABS,0},
  { "ascii",	ASCII,0},
  { "avg",	AVG_SUM,0},
  { "between",	BETWEEN,0},
  { "bit_count",BIT_COUNT,0},
  { "ceiling",  CEILING,0},
  { "char",	CHAR,0},
  { "concat",	CONCAT,0},
  { "count",	COUNT_SUM,0},
  { "curdate",	CURDATE,0},
  { "database", DATABASE,0},
  { "elt",	ELT_FUNC,0},
  { "exp",	EXP,0},
  { "field",	ELT_FUNC,0},			// For compability
  { "floor",	FLOOR,0},
  { "format",   FORMAT,0},
  { "from_days",FROM_DAYS,0},			// Convert string or number
  { "group_unique_users", GROUP_UNIQUE_USERS,0},
  { "if",	IF,0},
  { "ifnull",	IFNULL,0},
  { "insert",	INSERT,0},
  { "instr",	LOCATE,0},			// unireg function
  { "interval",INTERVAL,0},
  { "isnull",	ISNULL,0},
  { "lcase",	LCASE,0},
  { "left",	LEFT,0},
  { "length",	LENGTH,0},
  { "locate",	LOCATE,0},
  { "log",	LOG,0},
  { "log10",	LOG10,0},
  { "ltrim",	LTRIM,0},
  { "now",	NOW_SYMBOL,0},
  { "max",	MAX_SUM,0},
  { "mid",	SUBSTRING,0},			// unireg function
  { "min",	MIN_SUM,0},
  { "mod",	MOD_SYMBOL,0},
  { "password", PASSWORD,0},
  { "period_add",  PERIOD_ADD,0},
  { "period_diff", PERIOD_DIFF,0},
  { "pow",	POW,0},
  { "rand",	RAND,0},
  { "repeat",	REPEAT,0},
  { "replace",	REPLACE,0},
  { "right",	RIGHT,0},
  { "round",	ROUND,0},
  { "rtrim",	RTRIM,0},
  { "sign",	SIGN,0},
  { "sqrt",	SQRT,0},
  { "sum",	SUM_SUM,0},
  { "strcmp",	STRCMP,0},
  { "substring",SUBSTRING,0},
  { "to_days",	TO_DAYS,0},			// Convert string or number
  { "ucase",	UCASE,0},
  { "unique_users", UNIQUE_USERS,0},
  { "user",	USER,0},
  { "weekday",	WEEKDAY,0},
};


static HASH sym_hash,fun_hash;

static uchar *get_token(LEX *lex);
static uchar *get_text(LEX *lex);

static byte* get_hash_key(const byte *buff,uint *length)
{
  SYMBOL *symbol=(SYMBOL*) buff;
  *length=(uint) symbol->length;
  return symbol->name;
}

void lex_init(void)
{
  uint i;
  if (init_hash(&sym_hash,array_elements(symbols),0,0,get_hash_key,
		NULL) ||
      init_hash(&fun_hash,array_elements(sql_functions),0,0,
		get_hash_key,NULL))
    exit(1);

  for (i=0 ; i < array_elements(symbols) ; i++)
  {
    symbols[i].length=(uchar) strlen(symbols[i].name);
    if (hash_insert(&sym_hash,(byte*) (symbols+i)))
      exit(1);
  }
  for (i=0 ; i < array_elements(sql_functions) ; i++)
  {
    sql_functions[i].length=(uchar) strlen(sql_functions[i].name);
    if (hash_insert(&fun_hash,(byte*) (sql_functions+i)))
      exit(1);
  }
  VOID(pthread_key_create(&THR_LEX,NULL));
}


void lex_free(void)
{					// Call this when deamon ends
  free_hash(&sym_hash);
  free_hash(&fun_hash);
}


LEX *_current_lex(void)
{
  return (LEX*) pthread_getspecific(THR_LEX);
}


LEX *lex_start(uchar *buf,uint length)
{
  LEX *lex;
  lex=(LEX*) sql_alloc(sizeof(*lex));	// Doesn't have to be freed
  VOID(pthread_setspecific(THR_LEX,(void*) lex));
  lex->next_state=0;
  lex->end_of_query=(lex->ptr=buf)+length;
  lex->yylineno = 1;
  lex->create_refs=0;
  lex->length=0;
  lex->dummyBuf[1]=0;			// For 1 char tokens
  lex->expr_list.empty();
  return lex;
}

void lex_end(LEX *lex)
{
  lex->expr_list.delete_elements();	// If error when parsing sql-varargs
}


static int find_keyword(LEX *lex,bool function)
{
  char buff[TOCK_NAME_LENGTH];
  uchar *tok=lex->tok_start;
  uint len=(uint) (lex->ptr - tok);
  SYMBOL *symbol;
  HASH *hash_table= function ? &fun_hash : &sym_hash;

  if (len <= TOCK_NAME_LENGTH)
  {
    memcpy(buff,tok,(size_t) len);
    casedn(buff,(size_t) len);
    if ((symbol=(SYMBOL*) hash_search(hash_table,buff,len)))
    {
      lex->yytext = (uchar*) tok;
      lex->yylval->num = symbol->tok;
      return symbol->tok;
    }
    // If it wasn't function, try also ordinary keyword to allow
    // 'insert into table(field,field,field)values(value,value)
    if (hash_table == &fun_hash &&
	(symbol=(SYMBOL*) hash_search(&sym_hash,buff,len)))
    {
      lex->yytext = (uchar*) tok;
      lex->yylval->num = symbol->tok;
      return symbol->tok;
    }
  }
  return 0;
}


/* make a copy of token before ptr and set yytoklen */

static uchar *get_token(LEX *lex)
{
  uchar *ptr;
  yyUnget();			// ptr points now after last token char
  lex->yytoklen=(uint) (lex->ptr - lex->tok_start);
  ptr=(uchar*) sql_alloc(lex->yytoklen+1);
  memcpy((byte*) ptr,(byte*) lex->tok_start,lex->yytoklen);
  ptr[lex->yytoklen]=0;
  return(ptr);
}


/* Return an unescaped text literal without quotes */
/* Should maybe expand '\n' ? */

static uchar *get_text(LEX *lex)
{
  reg1 uchar c,sep;

  sep= yyGetLast();			// String should end with this
  lex->tok_start=lex->ptr;
  while (lex->ptr != lex->end_of_query)
  {
    c = yyGet();
    if (c == '\\')
    {					// Escaped character
      if (lex->ptr == lex->end_of_query)
	return 0;
      yyGet();
    }
    else if (c == sep)
    {
      if (c == yyGet())		// Check if two separators in a row
	continue;
      else
	yyUnget();

      /* Found end. Unescape and return string */
      uchar *str,*end,*to,*start;

      str=lex->tok_start;
      end=lex->ptr-1;
      to=start=(uchar*) sql_alloc((uint) (end-str)+1);
      for ( ; str != end ; str++)
      {
	if (*str == '\\' && str+1 != end)
	{
	  switch(*++str) {
	  case 'n':
	    *to++='\n';
	    break;
	  case 't':
	    *to++= '\t';
	    break;
	  case 'r':
	    *to++ = '\r';
	    break;
	  case 'b':
	    *to++ = '\b';
	    break;
	  case '0':
	    *to++= 0;			// Ascii null
	    break;
	  case '_':
	  case '%':
	    *to++= '\\';		// remember prefix for wildcard
	    /* Fall through */
	  default:
	    *to++ = *str;
	    break;
	  }
	}
	else if (*str == sep)
	  *to++= *str++;		// Two ' or " 
	else
	  *to++ = *str;

      }
      *to=0;
      lex->yytoklen=(uint) (to-start);
      return start;
    }
  }
  return 0;					// unexpected end of query
}


// yylex remember the following states from the following yylex()
// 13 ; found end of query
// 14 ; last state was an ident, text or number (with can't be followed by
//	a signed number)

int yylex(void *arg)
{
  reg1	uchar	c=0;
  int	tokval;
  uint	state;
  uchar  *start;
  LEX	*lex=current_lex;
  YYSTYPE *yylval=(YYSTYPE*) arg;

  lex->yylval=yylval;			// The global state
  lex->tok_end=lex->ptr;		// Prev token ends here
  state=lex->next_state; lex->next_state=14;
  start=lex->tok_start=lex->ptr;	// Start of real token
  for (;;)
  {
    switch(state) {
    case 14:				// Next is operator or keyword
    case 0:				// Start of token
      // Skipp startspace
      for (c=yyGet() ; (c && !isgraph(c)) ; c= yyGet())
      {
	if (c == '\n')
	  lex->yylineno++;
      }
      start=lex->tok_start=lex->ptr-1;	// Start of real token
      if (c == '\'' || c == '"')
      {
	state = 11;			// String
	break;
      }
      if (isalpha(c) || c == '_')
      {
	state = 2;			// Keyword or ident
	break;
      }
      if (isdigit(c))
      {
	state = 16;			// int,real or ident
	break;
      }
      if ((c == '-' || c == '+') && state == 0)
      {
	state = 7;			// Signed int or real
	break;
      }
      if (c == '.')			// Actually real shouldn't start
      {					// with . but allow them anyhow
	if (isdigit(yyGet()))
	  state = 9;			// Real
	else
	  state = 1;			// return '.'
	break;
      }
      if (is_op(c))
      {
	state = 10;			// Compare op
	break;
      }
      if (is_bool(c) && yyPeek() == c)
      {
	(void) yyGet();
	state = 15;			// && or ||
	break;
      }
      if (c == '#')
      {
	state = 12;			// Comment
	break;
      }

      if (c == 0)
      {					// End state
	lex->next_state=13;		// Mark for next loop
	return(END_OF_INPUT);
      }
      // fall through
    default:				// Shut up compiler
    case 1:				// Unknown or single char token
      lex->ptr=start;			// Set to first char
      lex->dummyBuf[0]=c=yyGet();
      yylval->str=(char*) (lex->yytext=lex->dummyBuf);
      if (c != ')')
	lex->next_state= 0;		// Allow signed numbers
      if (c == ',')
	lex->tok_start=lex->ptr;	// Let tok_start point at next item
      return((int) c);

    case 2:				// Incomplete keyword or ident
      while (isalnum(c=yyGet()) || c == '_') ;
      if (c == '.' && isalpha(yyPeek()))
	lex->next_state=3;
      else
      {					// '(' must follow directly if function
	yyUnget();
	if ((tokval = find_keyword(lex,c == '(')))
	{
	  lex->next_state= 0;		// Allow signed numbers
	  return(tokval);		// Was keyword
	}
	yySkip();			// next state does a unget
      }
      state = 5;			// Found compleat ident
      break;

    case 3:				// Found ident and now '.'
      lex->next_state=4;		// Next is an ident (not a keyword)
      lex->dummyBuf[0]=c=yyGet();	// should be '.'
      yylval->str=(char*) (lex->yytext=lex->dummyBuf);
      return((int) c);

    case 16:				// number or ident which starts with number
      while (isdigit((c = yyGet()))) ;
      if (!isalpha(c) && c != '_')
      {					// Can't be identifier
	state=17;
	break;
      }
      // fall through

    case 4:				// Incomplete ident
      while (isalnum((c = yyGet())) || c == '_') ;
      if (c == '.' && isalpha(yyPeek()))
	lex->next_state=3;
      // fall through

    case 5:				// Complete ident
      yylval->str= (char*) (lex->yytext = get_token(lex));
      return(IDENT);

    case 7:				// Incomplete signed number
      if (!isdigit(c=yyGet()) && c != '.')
      {
	state = 1;			// Return sign as single char
	break;
      }
      // fall through
    case 8:				// Incomplete real or int number
      while (isdigit(c=yyGet())) ;
      // fall through
    case 17:				// Compleat int or incompleat real
      if (c != '.')
      {
	// Found complete integer number. Change long numbers to real
	yylval->str = (char*) (lex->yytext = get_token(lex));
	return(lex->yytoklen <= 9 ? NUM : REAL_NUM);
      }
      // fall through
    case 9:				// Incomplete real number
      while (isdigit(c = yyGet())) ;
      if (c == 'e' || c == 'E')
      {
	c = yyGet();
	if (c != '-' && c != '+')
	{				// No exp sig found
	  state= 1;
	  break;
	}
	if (!isdigit(yyGet()))
	{				// No digit after sign
	  state= 1;
	  break;
	}
	while (isdigit(yyGet())) ;
      }
      yylval->str=(char*) (lex->yytext = get_token(lex));
      return(REAL_NUM);

    case 10:				// Incomplete comparison operator
      c=yyGet();			// May be 2 long
      if (!is_op(c))
	yyUnget();
      if ((tokval = find_keyword(lex,0))) // Should be comparison operator
      {
	lex->next_state= 0;		// Allow signed numbers
	return(tokval);
      }
      state = 1;			// Something fishy found
      break;

    case 15:
      tokval = find_keyword(lex,0);	// Is a bool operator
      lex->next_state= 0;		// Allow signed numbers
      return(tokval);

    case 11:				// Incomplete text string
      if (!(lex->yytext = get_text(lex)))
      {
	state = 1;			// Read char by char
	break;
      }
      yylval->str = (char*) lex->yytext;
      return(TEXT);

    case 12:				//  Comment
      while ((c = yyGet()) != '\n' && c) ;
      yyUnget();			// Safety against eof
      state = 0;			// Try again
      break;
    case 13:
      lex->next_state=13;
      return(0);			// We found end of input last time
    }
  }
}
