/**********************************************************************
Copyright 1999 by ITG Australia.

                        All Rights Reserved

Permission to use, copy, modify, and distribute this software and its
documentation for any purpose and without fee is hereby granted,
provided that the above copyright notice appear in all copies and that
both that copyright notice and this permission notice appear in
supporting documentation, and that the names of ITG Australia or ITGA
not be used in advertising or publicity pertaining to distribution of
the software without specific, written prior permission.

ITG AUSTRALIA DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE,
INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO
EVENT SHALL ITG AUSTRALIA BE LIABLE FOR ANY SPECIAL, INDIRECT OR
CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF
USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
PERFORMANCE OF THIS SOFTWARE.
**********************************************************************/
#include <ctype.h>
#include "Sybase.h"

staticforward PyTypeObject BulkcopyType;	/* shared type descriptor */

#define is_Bulkcopy_object(v) ((v)->ob_type == &BulkcopyObj)

/* Try to split @row_str into columns by just scanning for comma
 * separators.  @str_len specifies the length of the string.  As
 * columns are identified in @row_str, the buffer pointers and lengths
 * in @self are set to point at that data.  Returns non-zero on
 * success, zero on failure.
 */
static int comma_split(BulkcopyObj *self, char *row_str, int str_len)
{
    int idx;			/* iterate over columns */
    ColInfo *col;		/* current column description */

    for (idx = 0, col = self->col_info; idx < self->num_cols; idx++, col++) {
	int len;		/* determine column data width */

	/* Save start of column, then search for end of column
	 */
	col->buff = row_str;
	for (len = 0; len < str_len; len++)
	    if (row_str[len] == ',')
		break;
	col->buff_len = len;

	/* Advance position in @row_str
	 */
	row_str += col->buff_len;
	str_len -= col->buff_len;

	debug_msg(self->debug, "  comma_split: %s<%d:%.*s>[%d]\n",
		  col->fmt.name,
		  col->buff_len, col->buff_len, col->buff,
		  str_len);
	/* Check for end of the input string
	 */
	if (str_len == 0)
	    return idx == self->num_cols - 1;
	/* Skip over comma
	 */
	row_str++;
	str_len--;
    }

    return 0;
}

/* Try to split @row_str into columns according to the column types.
 * For CHAR and BINARY Sybase types, try to match an eact column
 * width.  If the exact column width match fails, revert to comma
 * scanning.  @str_len specifies the length of the string.  As columns
 * are identified in @row_str, the buffer pointers and lengths in
 * @self are set to point at that data.  Returns non-zero on success,
 * zero on failure.
 */
static int size_split(BulkcopyObj *self, char *row_str, int str_len)
{
    int idx;			/* iterate over columns */
    ColInfo *col;		/* current column description */

    for (idx = 0, col = self->col_info; idx < self->num_cols; idx++, col++) {
	int variable;		/* is the column data variable length? */

	/* Determine whether or not the field has variable length
	 * input data.
	 */
	switch (col->fmt.datatype) {
	case CS_LONGCHAR_TYPE:
	case CS_TEXT_TYPE:
	case CS_CHAR_TYPE:
	case CS_IMAGE_TYPE:
	case CS_LONGBINARY_TYPE:
	case CS_BINARY_TYPE:
	    debug_msg(self->debug, "  size_split: %s/%d not variable\n",
		      col->fmt.name, col->fmt.datatype);
	    variable = 0;
	    break;

	default:
	    debug_msg(self->debug, "  size_split: %s/%d is variable\n",
		      col->fmt.name, col->fmt.datatype);
	    variable = 1;
	    break;
	}

	/* Save the start of the column data
	 */
	col->buff = row_str;

	/* If this field is fixed length, check that the length is
	 * correct.  If the length is not correct, revert to variable
	 * length.
	 */
	if (!variable) {
	    if (col->fmt.maxlength > str_len
		|| (col->fmt.maxlength < str_len
		    && row_str[col->fmt.maxlength] != ','))
		variable = 1;
	    else
		col->buff_len = col->fmt.maxlength;
	}
	/* If field is variable length, we have to scan for the field
	 * delimiter.
	 */
	if (variable) {
	    int len;

	    for (len = 0; len < str_len; len++)
		if (row_str[len] == ',')
		    break;
	    col->buff_len = len;
	}
	/* Advance position in @row_str
	 */
	row_str += col->buff_len;
	str_len -= col->buff_len;

	debug_msg(self->debug, "  size_split: %s<%d:%.*s>[%d]\n",
		  col->fmt.name,
		  col->buff_len, col->buff_len, col->buff,
		  str_len);
	/* Check for end of the input string
	 */
	if (str_len == 0)
	    return idx == self->num_cols - 1;
	/* Skip over comma
	 */
	row_str++;
	str_len--;
    }

    return 0;
}

static int bulk_split(BulkcopyObj *self, char *row_str)
{
    int str_len;		/* length of row data string */

    str_len = strlen(row_str);
    /* Strip off the line terminator if present
     */
    if (str_len > 0 && row_str[str_len - 1] == '\n')
	str_len--;
    if (str_len > 0 && row_str[str_len - 1] == '\r')
	str_len--;

    /* Try splitting fields by different methods.  First try dumb
     * comma scanning.  If that fails, try again by not looking for
     * separators in fixed width column data.
     */
    if (!comma_split(self, row_str, str_len)
	&& !size_split(self, row_str, str_len)) {
	raise_exception_string(DataError, "data does not match table");
	return 0;
    }
    return 1;
}

/* Parse a CSV string in @row_str and bulkcopy it to the Sybase
 * server.  Return non-zero on success, zero on failure.
 */
static int bulk_send_str(BulkcopyObj *self, char *row_str)
{
    int idx;			/* iterate over table columns */
    ColInfo *col;		/* current column description */

    if (!bulk_split(self, row_str))
	return 0;

    /* Bind the data to the table columns.  The split function has set
     * col->buff and col->buff_len for us.
     */
    for (idx = 0, col = self->col_info; idx < self->num_cols; idx++, col++) {
	CS_DATAFMT fmt;		/* bind each column  */

	/* Bind the column buffer to the input data
	 */
	fmt.datatype = CS_CHAR_TYPE;
	fmt.format = CS_FMT_UNUSED;
	fmt.maxlength = col->fmt.maxlength;
	fmt.count = 1;
	if (col->buff_len == 0)
	    col->indicator = -1;
	else
	    col->indicator = 0;
	debug_msg(self->debug, "  blk_bind(%s(%d/%d)=%.*s)\n",
		  col->fmt.name,
		  col->buff_len, col->fmt.maxlength,
		  col->buff_len, col->buff);
	if (blk_bind(self->blk, idx + 1, &fmt, col->buff, &col->buff_len,
		     &col->indicator) != CS_SUCCEED) {
	    raise_exception(self->conn_info, "blk_bind failed");
	    return 0;
	}
    }

    /* Now transfer the row data
     */
    debug_msg(self->debug, "  blk_rowxfer()\n");
    if (blk_rowxfer(self->blk) == CS_FAIL) {
	raise_exception(self->conn_info, "blk_rowxfer failed");
	return 0;
    }

    return 1;
}

/* bulkcopy a row of data in @row_data to the Sybase server.  Return
 * non-zero on success, zero on failure.
 */
static int bulk_send_sequence(BulkcopyObj *self, PyObject *row_data)
{
    int idx;			/* iterate over table columns */
    ColInfo *col;		/* current column description */

    /* Bind the data to the table columns.  The split function has set
     * col->buff and col->buff_len for us.
     */
    for (idx = 0, col = self->col_info; idx < self->num_cols; idx++, col++) {
	CS_DATAFMT fmt;		/* bind each column  */
	PyObject *col_data;

	col_data = PySequence_GetItem(row_data, idx);
	if (col_data == NULL) {
	    raise_exception_string(ProgrammingError, "not enough fields");
	    return 0;
	}
	/* Only borrowing the item
	 */
	Py_DECREF(col_data);
	/* Bind the column buffer to the input data
	 */
	if (col_data == Py_None) {
	    fmt.datatype = CS_CHAR_TYPE;
	    col->buff = "";
	    col->buff_len = 0;
	    debug_msg(self->debug, "  blk_bind(%s(%d/%d) = None)\n",
		      col->fmt.name,
		      col->buff_len, col->fmt.maxlength,
		      col->buff_len, col->buff);
	} else if (PyString_Check(col_data)) {
	    fmt.datatype = CS_CHAR_TYPE;
	    col->buff_len = PyString_Size(col_data);
	    col->buff = PyString_AsString(col_data);
	    debug_msg(self->debug, "  blk_bind(%s(%d/%d) = '%.*s')\n",
		      col->fmt.name,
		      col->buff_len, col->fmt.maxlength,
		      col->buff_len, col->buff);
	} else if (PyInt_Check(col_data)) {
	    fmt.datatype = CS_INT_TYPE;
	    col->buff_len = sizeof(CS_INT);
	    col->v.cs_int = PyInt_AsLong(col_data);
	    col->buff = &col->v.cs_int;
	    debug_msg(self->debug, "  blk_bind(%s = %d)\n",
		      col->fmt.name,
		      col->buff_len, col->v.cs_int);
	} else if (PyLong_Check(col_data)) {
	    fmt.datatype = CS_INT_TYPE;
	    col->buff_len = sizeof(CS_INT);
	    col->v.cs_int = PyLong_AsLong(col_data);
	    col->buff = &col->v.cs_int;
	    debug_msg(self->debug, "  blk_bind(%s = %dL)\n",
		      col->fmt.name,
		      col->buff_len, col->v.cs_int);
	} else if (PyFloat_Check(col_data)) {
	    fmt.datatype = CS_FLOAT_TYPE;
	    col->buff_len = sizeof(CS_INT);
	    col->v.cs_float = PyFloat_AsDouble(col_data);
	    col->buff = &col->v.cs_float;
	    debug_msg(self->debug, "  blk_bind(%s = %f)\n",
		      col->fmt.name,
		      col->buff_len, col->fmt.maxlength,
		      col->buff_len, col->v.cs_float);
	} else {
	    raise_exception_string(ProgrammingError, "bad field type");
	    return 0;
	}
	fmt.format = CS_FMT_UNUSED;
	fmt.maxlength = col->fmt.maxlength;
	fmt.count = 1;
	if (col_data == Py_None)
	    col->indicator = -1;
	else
	    col->indicator = 0;

	if (blk_bind(self->blk, idx + 1, &fmt, col->buff, &col->buff_len,
		     &col->indicator) != CS_SUCCEED) {
	    raise_exception(self->conn_info, "blk_bind failed");
	    return 0;
	}
    }

    /* Now transfer the row data
     */
    debug_msg(self->debug, "  blk_rowxfer()\n");
    if (blk_rowxfer(self->blk) == CS_FAIL) {
	raise_exception(self->conn_info, "blk_rowxfer failed");
	return 0;
    }

    return 1;
}

static PyObject *build_row_list(BulkcopyObj *self)
{
    int idx;			/* iterate over table columns */
    PyObject *list;		/* list containing all columns */
    ColInfo *col;		/* current column description */

    list = PyList_New(self->num_cols);
    if (list == NULL)
	return NULL;

    /* Turn the column values into a tuple of column strings.
     */
    for (idx = 0, col = self->col_info; idx < self->num_cols; idx++, col++) {
	PyObject *str;

	str = PyString_FromStringAndSize(col->buff, col->buff_len);
	if (str == NULL) {
	    Py_DECREF(list);
	    return NULL;
	}
	if (PyList_SetItem(list, idx, str) != 0) {
	    Py_DECREF(list);
	    return NULL;
	}
    }

    return list;
}

/* Translate Sybase type name to Sybase type
 */
static struct {
    char *name;			/* name of type */
    CS_INT type;		/* type identifier */
} sybase_types[] = {
    { "char", CS_CHAR_TYPE },
    { "binary", CS_BINARY_TYPE },
    { "longchar", CS_LONGCHAR_TYPE },
    { "longbinary", CS_LONGBINARY_TYPE },
    { "text", CS_TEXT_TYPE },
    { "image", CS_IMAGE_TYPE },
    { "tinyint", CS_TINYINT_TYPE },
    { "smallint", CS_SMALLINT_TYPE },
    { "int", CS_INT_TYPE },
    { "real", CS_REAL_TYPE },
    { "float", CS_FLOAT_TYPE },
    { "bit", CS_BIT_TYPE },
    { "datetime", CS_DATETIME_TYPE },
    { "smalldatetime", CS_DATETIME4_TYPE },
    { "money", CS_MONEY_TYPE },
    { "smallmoney", CS_MONEY4_TYPE },
    { "numeric", CS_NUMERIC_TYPE },
    { "decimal", CS_DECIMAL_TYPE },
    { "varchar", CS_VARCHAR_TYPE },
    { "varbinary", CS_VARBINARY_TYPE },
    { "long", CS_LONG_TYPE },
    { "ushort", CS_USHORT_TYPE }
};
#define NUM_SYBASE_TYPES (sizeof(sybase_types) / sizeof(sybase_types[0]))

/* Override the type of a column.  @col_num is a Sybase column number,
 * so the first column is numbered 1.  The basic column type is parsed
 * from @col_type.  Currently information such as 'null', 'not null',
 * 'primary key', etc. is ignored.
 *
 * Columns types are initialised by
 * blk_describe() in bulk_setup().  For some reason, Sybase returns
 * CS_CHAR_TYPE for CS_VARCHAR_TYPE columns.  This makes is a bit hard
 * to do the column splitting that handles variable and fixed length
 * columns differently.
 */
static int bulk_column_format(BulkcopyObj *self, int col_num, char *col_type)
{
    int len;			/* length of type name in @col_type */
    int type_idx;		/* index into sybase_types */
    int precision, scale;	/* for numeric columns */
    ColInfo *col;		/* current column description */

    /* Turn Sybase column number into a C array index (0-based) and
     * check for out of range values.
     */
    col_num--;
    if (col_num < 0 || col_num >= self->num_cols) {
	raise_exception_string(ProgrammingError, "bad column number");
	return 0;
    }
    col = &self->col_info[col_num];

    /* Find the basic column type
     */
    for (len = 0; isalnum((int)col_type[len]); len++)
	;
    for (type_idx = 0; type_idx < NUM_SYBASE_TYPES; type_idx++)
	if (strncasecmp(sybase_types[type_idx].name, col_type, len) == 0)
	    break;
    if (type_idx == NUM_SYBASE_TYPES) {
	raise_exception_string(ProgrammingError, "unknown column type");
	return 0;
    }

    /* Skip over type name and trailing spaces
     */
    col_type += len;
    while (*col_type == ' ')
	col_type++;

    /* Extract additional type parameters if necessary
     */
    switch (sybase_types[type_idx].type) {
    case CS_NUMERIC_TYPE:
    case CS_DECIMAL_TYPE:
	/* (p,s) or (p)
	 */
	if (sscanf(col_type, "(%d,%d)", &precision, &scale) != 2) {
	    scale = 0;
	    if (sscanf(col_type, "(%d)", &precision) != 1) {
		raise_exception_string(ProgrammingError, "expected precision and scale");
		return 0;
	    }
	}
	col->fmt.precision = precision;
	col->fmt.scale = scale;
	break;
    case CS_CHAR_TYPE:
    case CS_BINARY_TYPE:
    case CS_VARCHAR_TYPE:
    case CS_VARBINARY_TYPE:
    case CS_LONGCHAR_TYPE:
    case CS_LONGBINARY_TYPE:
	/* (n)
	 */
	if (sscanf(col_type, "(%d)", &len) != 1) {
	    raise_exception_string(ProgrammingError, "expected length");
	    return 0;
	}
	col->fmt.maxlength = len;
	break;
    case CS_TEXT_TYPE:
    case CS_IMAGE_TYPE:
    case CS_TINYINT_TYPE:
    case CS_SMALLINT_TYPE:
    case CS_INT_TYPE:
    case CS_REAL_TYPE:
    case CS_FLOAT_TYPE:
    case CS_BIT_TYPE:
    case CS_DATETIME_TYPE:
    case CS_DATETIME4_TYPE:
    case CS_MONEY_TYPE:
    case CS_MONEY4_TYPE:
    case CS_LONG_TYPE:
    case CS_USHORT_TYPE:
	break;
    }
    /* Set the new column type
     */
    col->fmt.datatype = type_idx;
    return 1;
}

/* Wrap Sybase blk_done() function and handle return result.  @type
 * must be either CS_BLK_ALL, or CS_BLK_BATCH.  @num_rows returns the
 * number of rows uploaded since the last call.  Returns non-zero on
 * success, zero on failure.
 */
static int bulk_batch(BulkcopyObj *self, CS_INT type, CS_INT *num_rows)
{
    debug_msg(self->debug, "  blk_done()\n");
    if (blk_done(self->blk, type, num_rows) == CS_FAIL) {
	raise_exception(self->conn_info, "blk_done() failed");
	return 0;
    }
    return 1;
}

/* Prepare the bulk copy context for loading rows into the table.
 */
static int bulk_setup(BulkcopyObj *self)
{
    int idx;			/* iterate over table columns */
    CS_DATAFMT fmt;		/* format of data column */
    ColInfo *col;		/* current column description */

    /* Initialise a bulk copy context
     */
    debug_msg(self->debug, "  blk_alloc()\n");
    if (blk_alloc(self->conn_info->conn,
		  BLK_VERSION_100, &self->blk) != CS_SUCCEED) {
	raise_exception(self->conn_info, "blk_alloc() failed");
	return 0;
    }

    debug_msg(self->debug, "  blk_init()\n");
    if (blk_init(self->blk,
		 CS_BLK_IN, self->table, strlen(self->table)) == CS_FAIL) {
	raise_exception(self->conn_info, "blk_init() failed");
	return 0;
    }

    /* Count the number of columns in the table
     */
    for (self->num_cols = 0;; self->num_cols++) {
	debug_msg(self->debug, "  blk_describe() [test %d]\n", self->num_cols);
	if (blk_describe(self->blk, self->num_cols + 1, &fmt) != CS_SUCCEED)
	    break;
    }
    /* The last blk_describe() caused an error.  We expected that
     * error, so clear all messages on the connection.
     */
    conn_clear_messages(self->conn_info);

    /* Get the description of all table columns
     */
    self->col_info = Py_Malloc(self->num_cols * sizeof(*self->col_info));
    if (self->col_info == NULL)
	return 0;
    memset(self->col_info, 0, self->num_cols * sizeof(*self->col_info));
    for (idx = 0, col = self->col_info; idx < self->num_cols; idx++, col++) {
	debug_msg(self->debug, "  blk_describe() [get %d]\n", idx);
	if (blk_describe(self->blk, idx + 1, &col->fmt) != CS_SUCCEED) {
	    raise_exception(self->conn_info, "blk_describe() failed");
	    return 0;
	}
    }

    return 1;
}

/* Finish and free the Sybase bulkcopy context.
 */
static void bulk_free(BulkcopyObj *self)
{
    if (self->blk != NULL) {
	if (self->copy_in_progress) {
	    CS_INT num_rows;

	    bulk_batch(self, CS_BLK_ALL, &num_rows);
	}
	debug_msg(self->debug, "  blk_drop()\n");
	blk_drop(self->blk);
    }
}

/* ---------------------------------------------------------------- */

/* Implement the Bulkcopy.rowxfer() method
 */
static PyObject *Bulkcopy_rowxfer(BulkcopyObj *self, PyObject *args)
{
    PyObject *row_data;		/* row string or column tuple */

    if (!PyArg_ParseTuple(args, "O", &row_data))
	return NULL;

    if (PyString_Check(row_data)) {
	if (!bulk_send_str(self, PyString_AsString(row_data)))
	    return NULL;
    } else if (PyTuple_Check(row_data) || PyList_Check(row_data)) {
	if (!bulk_send_sequence(self, row_data))
	    return NULL;
    } else {
	raise_exception_string(DataError, "expected string or sequence");
	return NULL;
    }
    self->copy_in_progress = 1;

    Py_INCREF(Py_None);
    return Py_None;
}

/* Implement the Bulkcopy.split() method
 */
static PyObject *Bulkcopy_split(BulkcopyObj *self, PyObject *args)
{
    char *row_str;

    if (!PyArg_ParseTuple(args, "s", &row_str)
	|| !bulk_split(self, row_str))
	return NULL;

    return build_row_list(self);
}

/* Implement the Bulkcopy.format() method
 */
static PyObject *Bulkcopy_format(BulkcopyObj *self, PyObject *args)
{
    int col_num;		/* column number */
    char *col_type;		/* column type  */

    if (!PyArg_ParseTuple(args, "is", &col_num, &col_type))
	return NULL;

    if (!bulk_column_format(self, col_num, col_type))
	return NULL;

    Py_INCREF(Py_None);
    return Py_None;
}

/* Implement the Bulkcopy.batch() method
 */
static PyObject *Bulkcopy_batch(BulkcopyObj *self, PyObject *args)
{
    CS_INT num_rows;		/* number of rows since last batch() */

    if (!PyArg_ParseTuple(args, ""))
	return NULL;

    if (!bulk_batch(self, CS_BLK_BATCH, &num_rows))
	return NULL;

    return PyInt_FromLong(num_rows);
}

/* Implement the Bulkcopy.done() method
 */
static PyObject *Bulkcopy_done(BulkcopyObj *self, PyObject *args)
{
    CS_INT num_rows;		/* number of rows since last batch() */

    if (!PyArg_ParseTuple(args, ""))
	return NULL;

    if (!bulk_batch(self, CS_BLK_ALL, &num_rows))
	return NULL;
    self->copy_in_progress = 0;

    return PyInt_FromLong(num_rows);
}

static struct PyMethodDef Bulkcopy_methods[] = {
    { "format", (PyCFunction)Bulkcopy_format, METH_VARARGS },
    { "split", (PyCFunction)Bulkcopy_split, METH_VARARGS },
    { "rowxfer", (PyCFunction)Bulkcopy_rowxfer, METH_VARARGS },
    { "batch", (PyCFunction)Bulkcopy_batch, METH_VARARGS },
    { "done", (PyCFunction)Bulkcopy_done, METH_VARARGS },
    { NULL, NULL }
};

static void Bulkcopy_dealloc(BulkcopyObj *self)
{
    bulk_free(self);

    if (self->col_info != NULL)
	Py_Free(self->col_info);
    if (self->table != NULL)
	Py_Free(self->table);

    PyMem_DEL(self);
}

static PyObject *Bulkcopy_getattr(BulkcopyObj *self, char *name)
{
    return Py_FindMethod(Bulkcopy_methods, (PyObject*)self, name);
}

/* ---------------------------------------------------------------- */

static PyTypeObject BulkcopyType = { /* main python type-descriptor */
    /* type header */
    PyObject_HEAD_INIT(&PyType_Type)
    0,				/* ob_size */
    "Bulkcopy",			/* tp_name */
    sizeof(BulkcopyObj),	/* tp_basicsize */
    0,				/* tp_itemsize */

    /* standard methods */
    (destructor)Bulkcopy_dealloc,/* tp_dealloc */
    (printfunc)0,
    (getattrfunc)Bulkcopy_getattr, /* tp_getattr */
    (setattrfunc)0,
    (cmpfunc)0,
    (reprfunc)0
};

/* ---------------------------------------------------------------- */

/* Implement the Connect.bulkcopy() method (not in DB API 2.0 spec).
 * Creates a new bulkcopy object for loading data into @table.
 */
PyObject *bulkcopy_new(ConnInfo *conn_info, char *table, int debug)
{
    BulkcopyObj *self;

    self = PyObject_NEW(BulkcopyObj, &BulkcopyType);
    if (self == NULL)
	return NULL;

    self->blk = NULL;
    self->table = NULL;
    self->num_cols = 0;
    self->col_info = NULL;
    self->copy_in_progress = 0;

    self->conn_info = conn_info;
    self->debug = debug;

    self->table = Py_Malloc(strlen(table) + 1);
    if (self->table == NULL)
	return NULL;
    strcpy(self->table, table);

    if (!bulk_setup(self)) {
	Py_DECREF(self);
	return NULL;
    }

    return (PyObject*)self;
}
