/*****************************************************************/
/*      mxinitializeinputobjects.c                               */
/*      Henri Casanova          			   	 */
/*****************************************************************/

#include "mex.h"
#include "core.h"
#include "client.h"
#include "matlabclient.h"
#include <sys/types.h>
#include <sys/stat.h>

/*
 * Initialize the input objects
 *
 *  If it fails, then it's the caller responsibility to
 *  call free.
 */
int mxInitializeInputObjects(NS_ProblemDesc *pd,int ninput,mxArray **input,
                             int **allocatedByNetSolve)
{
  int i,j;
  int data_type;

  if (pd->nb_input_objects != ninput)
  {
    mexPrintf("'%s' requires %d objects in input (%d provided)\n",
              pd->nickname,pd->nb_input_objects,ninput);
    *allocatedByNetSolve = NULL;
    ns_errno = NetSolveBadProblemSpecification;
    return -1;
  }

  *allocatedByNetSolve = mxCalloc(ninput,sizeof(int));

  for (i=0;i<ninput;i++)
  {
    switch(pd->input_objects[i]->object_type)
    {
      case NETSOLVE_MATRIX:
      {
        char * buf;
        int buflen;
        int m,n;
        double *pr,*pi;
        const char *name;
        void *ptr;

        m = mxGetM(input[i]); 
        n = mxGetN(input[i]); 
        pr = mxGetPr(input[i]); 
        pi = mxGetPi(input[i]); 
        name = mxGetName(input[i]);

        if (pd->input_objects[i]->data_type == NETSOLVE_D) /* No conversion */
        {
          if (*name != '\0') /* No allocation */
          {
            freeObject(pd->input_objects[i]);
            pd->input_objects[i] = createMatrixObject(NETSOLVE_D,
                                   pr,COL_MAJOR,m,n,m); 
          }
          else /* Allocation */
          {
            ptr = mxCalloc(m*n,netsolve_sizeof(NETSOLVE_D));
            for (j=0;j<m*n;j++)
                ((double*)(ptr))[j] = pr[j];
            (*allocatedByNetSolve)[i] = 1; 
            freeObject(pd->input_objects[i]);
            pd->input_objects[i] = createMatrixObject(NETSOLVE_D,
                                   ptr,COL_MAJOR,m,n,m);
          }
        }
        else /* allocation and conversion */
        {
          ptr = mxCalloc(m*n,netsolve_sizeof(pd->input_objects[i]->data_type));
          data_type = pd->input_objects[i]->data_type;

          if(data_type == NETSOLVE_CHAR || data_type == NETSOLVE_B){
            buflen = (m * n * sizeof(mxChar)) + 1;
            buf = (char *)mxCalloc(buflen, netsolve_sizeof(pd->input_objects[i]->data_type));
            mxGetString(input[i], buf, buflen);
          }

          for (j=0;j<m*n;j++)
          {
            switch(data_type)
            {
              case NETSOLVE_I:
                ((int*)ptr)[j] = (int)(pr[j]);
                break;
              case NETSOLVE_S:
                ((float*)ptr)[j] = (float)(pr[j]);
                break;
              case NETSOLVE_C:
                ((scomplex*)ptr)[j].r = (float)(pr[j]);
                ((scomplex*)ptr)[j].i = (float)(pi[j]);
                break;
              case NETSOLVE_Z:
                ((dcomplex*)ptr)[j].r = pr[j];
                ((dcomplex*)ptr)[j].i = pi[j];
                break;
              case NETSOLVE_CHAR:
              case NETSOLVE_B:
                ((char*)ptr)[j] = buf[j];
                break;
            }
          }
          if(data_type == NETSOLVE_CHAR || data_type == NETSOLVE_B)
            mxFree(buf);

          freeObject(pd->input_objects[i]);
          pd->input_objects[i] = createMatrixObject(
                           data_type,
                           ptr,COL_MAJOR,m,n,m);
          (*allocatedByNetSolve)[i] = 1;
        }
        break;
      }
      case NETSOLVE_SPARSEMATRIX:
      {
        char * buf;
        int buflen;
        int m,n, *rc_index, *rc_ptr, f;
        double *pr,*pi;
        const char *name;
        void *ptr;
 
        if(mxIsSparse(input[i]) == 0){ /* not a matlab sparse matrix */
          mexPrintf("Input object %d should be a sparse matrix\n",i);
          ns_errno = NetSolveBadProblemSpecification;
          return -1;
        }
 
        m = mxGetM(input[i]);
        n = mxGetN(input[i]);
        rc_index = mxGetIr(input[i]);
        rc_ptr = mxGetJc(input[i]);
        f = *(rc_ptr + n);
        pr = mxGetPr(input[i]);
        pi = mxGetPi(input[i]);
        name = mxGetName(input[i]);
 
        if (pd->input_objects[i]->data_type == NETSOLVE_D) /* No conversion */
        {
          if (*name != '\0') /* No allocation */
          {
            freeObject(pd->input_objects[i]);
            pd->input_objects[i] = createSparseMatrixObject(NETSOLVE_D,
                                   pr,COL_MAJOR,m,n,f, rc_ptr, rc_index);
          }
          else /* Allocation */
          {
            ptr = mxCalloc(f,netsolve_sizeof(NETSOLVE_D));
            for (j=0;j<f;j++)
                ((double*)(ptr))[j] = pr[j];
            (*allocatedByNetSolve)[i] = 1;
            freeObject(pd->input_objects[i]);
            pd->input_objects[i] = createSparseMatrixObject(NETSOLVE_D,
                                   ptr,COL_MAJOR,m,n,f, rc_ptr, rc_index);
          }
        }
        else /* allocation and conversion */
        {
          ptr = mxCalloc(f,netsolve_sizeof(pd->input_objects[i]->data_type));
          data_type = pd->input_objects[i]->data_type;

          if(data_type == NETSOLVE_CHAR || data_type == NETSOLVE_B){
            buflen = (f * sizeof(mxChar)) + 1;
            buf = (char *)mxCalloc(buflen, netsolve_sizeof(pd->input_objects[i]->data_type));
            mxGetString(input[i], buf, buflen);
          }

          for (j=0;j<f;j++)
          {
            switch(data_type)
            {
              case NETSOLVE_I:
                ((int*)ptr)[j] = (int)(pr[j]);
                break;
              case NETSOLVE_S:
                ((float*)ptr)[j] = (float)(pr[j]);
                break;
              case NETSOLVE_C:
                ((scomplex*)ptr)[j].r = (float)(pr[j]);
                ((scomplex*)ptr)[j].i = (float)(pi[j]);
                break;
              case NETSOLVE_Z:
                ((dcomplex*)ptr)[j].r = pr[j];
                ((dcomplex*)ptr)[j].i = pi[j];
                break;
              case NETSOLVE_CHAR:
              case NETSOLVE_B:
                ((char*)ptr)[j] = buf[j];
                break;
            }
          }

          if(data_type == NETSOLVE_CHAR || data_type == NETSOLVE_B)
            mxFree(buf);

          freeObject(pd->input_objects[i]);
          pd->input_objects[i] = createSparseMatrixObject(
                           data_type,
                           ptr,COL_MAJOR,m,n,f, rc_ptr, rc_index);
          (*allocatedByNetSolve)[i] = 1;
        }
        break;
      }
      case NETSOLVE_VECTOR:
      {
        char * buf;
        int buflen;
        int m,n;
        double *pr,*pi;
        const char *name;
        void *ptr;

        m = mxGetM(input[i]);
        n = mxGetN(input[i]);
        pr = mxGetPr(input[i]);
        pi = mxGetPi(input[i]);
        name = mxGetName(input[i]);

        if (n != 1)
        {
          mexPrintf("Warning: input %d is a matrix - using first column only\n",i);
        }

        if (pd->input_objects[i]->data_type == NETSOLVE_D) /* No conversion */
        {
          if (*name != '\0') /* No allocation */
          {
            freeObject(pd->input_objects[i]);
            pd->input_objects[i] = createVectorObject(NETSOLVE_D,
                                   pr,m);
          }
          else /* Allocation */
          {
            ptr = mxCalloc(m,netsolve_sizeof(NETSOLVE_D));
            for (j=0;j<m;j++)
                ((double*)ptr)[j] = pr[j];
            (*allocatedByNetSolve)[i] = 1;
            freeObject(pd->input_objects[i]);
            pd->input_objects[i] = createVectorObject(NETSOLVE_D,
                                   ptr,m);
          }
        }
        else /* allocation and conversion */
        {
          ptr = mxCalloc(m,netsolve_sizeof(pd->input_objects[i]->data_type));
          data_type = pd->input_objects[i]->data_type;

          if(data_type == NETSOLVE_CHAR || data_type == NETSOLVE_B){
            buflen = (m * n * sizeof(mxChar)) + 1;
            buf = (char *)mxCalloc(buflen, netsolve_sizeof(pd->input_objects[i]->data_type));
            mxGetString(input[i], buf, buflen);
          }

          for (j=0;j<m;j++)
          {
            switch(data_type)
            {
              case NETSOLVE_I:
                ((int*)ptr)[j] = (int)(pr[j]);
                break;
              case NETSOLVE_S:
                ((float*)ptr)[j] = (float)(pr[j]);
                break;
              case NETSOLVE_C:
                ((scomplex*)ptr)[j].r = (float)(pr[j]);
                ((scomplex*)ptr)[j].i = (float)(pi[j]);
                break;
              case NETSOLVE_Z:
                ((dcomplex*)ptr)[j].r = pr[j];
                ((dcomplex*)ptr)[j].i = pi[j];
                break;
              case NETSOLVE_CHAR:
              case NETSOLVE_B:
                ((char*)ptr)[i] = buf[i];
                break;
            }
          }
          if(data_type == NETSOLVE_CHAR || data_type == NETSOLVE_B)
            mxFree(buf);

          freeObject(pd->input_objects[i]);
          pd->input_objects[i] = createVectorObject(
                        data_type,ptr,m);
          (*allocatedByNetSolve)[i] = 1;
        }
        break;
      }
      case NETSOLVE_SCALAR:
      {
        int m,n;
        double *pr,*pi;
        void *ptr;
        char *buf;
        int buflen;

        m = mxGetM(input[i]);
        n = mxGetM(input[i]);
        pr = mxGetPr(input[i]);
        pi = mxGetPi(input[i]);

        if ((n != 1)||(m!= 1))
        {
          mexPrintf("Warning: input %d is a not a scalar - using first element only\n",i);
        }

        ptr = mxCalloc(1,netsolve_sizeof(pd->input_objects[i]->data_type));
        data_type = pd->input_objects[i]->data_type;
        switch(data_type)
        {
          case NETSOLVE_I:
            ((int*)ptr)[0] = (int)(pr[0]);
            break;
          case NETSOLVE_D:
            ((double*)ptr)[0] = pr[0];
            break;
          case NETSOLVE_S:
            ((float*)ptr)[0] = (float)(pr[0]);
            break;
          case NETSOLVE_C:
            ((scomplex*)ptr)[0].r = (float)(pr[0]);
            ((scomplex*)ptr)[0].i = (float)(pi[0]);
            break;
          case NETSOLVE_Z:
            ((dcomplex*)ptr)[0].r = pr[0];
            ((dcomplex*)ptr)[0].i = pi[0];
            break;
          case NETSOLVE_CHAR:
          case NETSOLVE_B:
            buflen = (m * n * sizeof(mxChar)) + 1;
            buf = (char *)mxCalloc(buflen, netsolve_sizeof(pd->input_objects[i]->data_type));
            mxGetString(input[i], buf, buflen);
            ((char*)ptr)[0] = buf[0];
            mxFree(buf);
            break;
        }
        freeObject(pd->input_objects[i]);
        pd->input_objects[i] = createScalarObject(
                       data_type,ptr);
        (*allocatedByNetSolve)[i] = 1;
      break;
      }
      case NETSOLVE_FILE:
      {
        char *buf;
        int buflen;
        struct stat st;

        if (!mxIsChar(input[i]))
        {        
          mexPrintf("Input object %d should be a string\n",i);
          ns_errno = NetSolveBadProblemSpecification;
          return -1;
        }

        buflen = mxGetM(input[i])*mxGetN(input[i])+1;
        buf = (char*)mxCalloc(buflen,sizeof(char));

        mxGetString(input[i],buf,buflen);

        if (stat(buf,&st))
        {
          mexPrintf("Impossible to find file '%s'\n",buf);
          ns_errno = NetSolveFileError;
          return -1;
        }

        freeObject(pd->input_objects[i]);
        pd->input_objects[i] = createFileObject(buf);
        mxFree(buf);

        break;
      }
      case NETSOLVE_PACKEDFILES:
      {
        char buf[256];
        struct stat st;
        int j,k;
        char **filenames;
        char *ptr;
        int nb_files;
        int n;
        mxChar *pr;

        if (!mxIsChar(input[i]))
        {
          mexPrintf("Input object %d should be a matrix of strings\n",i);
          ns_errno = NetSolveBadProblemSpecification;
          return -1;
        }

        nb_files = mxGetM(input[i]);
        n = mxGetN(input[i]);
        ptr = (char *)mxGetData(input[i]);

        filenames = (char**)mxCalloc(nb_files,sizeof(char*));
        pr = (mxChar*)mxGetPr(input[i]);

        for(j=0;j<nb_files;j++)
        {
          
          filenames[j] = (char*)mxCalloc(n+1,sizeof(char));
          for(k=0;k<n;k++)
            filenames[j][k] = (char)(pr[j+k*nb_files]);
          /* Eliminate the extra white spaces */
          for (k=0;k<n;k++)
            if (filenames[j][k] == ' ')
              filenames[j][k] = '\0';

          if (stat(filenames[j],&st))
          {
            mexPrintf("Impossible to find file '%s'\n",filenames[j]);
            ns_errno = NetSolveFileError;
            return -1;
          }
        }

        mxFree(pd->input_objects[i]);
        pd->input_objects[i] = createPackedFilesObject(NULL,filenames,nb_files);
        for (j=0;j<nb_files;j++)
          mxFree(filenames[j]);
        mxFree(filenames);

        break;
      }
      case NETSOLVE_UPF:
      {
        char *buf;
        char *funcname;
        int buflen;
        int language;
        struct stat st;

        if (!mxIsChar(input[i]))
        {        
          mexPrintf("Input object %d should be a string\n",i);
          ns_errno = NetSolveBadProblemSpecification;
          return -1;
        }

        buflen = mxGetM(input[i])*mxGetN(input[i])+1;
        buf = (char*)mxCalloc(buflen,sizeof(char));

        mxGetString(input[i],buf,buflen);

        if (stat(buf,&st))
        {
          mexPrintf("Impossible to find file '%s'\n",buf);
          ns_errno = NetSolveFileError;
          return -1;
        }

        if (buflen<=2)
        {
          mexPrintf("filename '%s' too short\n",buf);
          ns_errno = NetSolveInvalidUPFFilename;
          return -1;
        }

        if (!strcmp(&buf[buflen-3], ".c")) 
         language = UPF_LANG_C;
        else if (!strcmp(&buf[buflen-3], ".f")) 
         language = UPF_LANG_FORTRAN;
        else 
        {  
          mexPrintf("file '%s' should end in .f or .c\n",buf);
          ns_errno = NetSolveInvalidUPFFilename;
          return -1;
        }

        funcname = strdup(buf);
        funcname[buflen-3] = '\0';

        freeObject(pd->input_objects[i]);
        pd->input_objects[i] = createUPFObject(language,buf,funcname);
        mxFree(buf);
        mxFree(funcname);

        break;
      }
      case NETSOLVE_STRING:
      {
        char *buf;
        char buflen;
        char *name;

        if (!mxIsChar(input[i]))
        {
          mexPrintf("Input object %d should be a string\n",i);
          ns_errno = NetSolveBadProblemSpecification;
          return -1;
        }        

        buflen = mxGetM(input[i])*mxGetN(input[i])+1;
        buf = (char*)mxCalloc(buflen,sizeof(char));

        mxGetString(input[i],buf,buflen);

        (*allocatedByNetSolve)[i] = 1;

        freeObject(pd->input_objects[i]);
        pd->input_objects[i] = createStringObject(buf);

        break;
      }
      default:
        mexPrintf("Unknown Object type\n");
        ns_errno = NetSolveInternalError;
        return -1;
    }
  }
  ns_errno = NetSolveOK;
  return 1;
}
