/*****************************************************************/
/*      mxfillinoutputobects.c                                   */
/*      Henri Casanova          			   	 */
/*****************************************************************/

#include "mex.h"
#include "matrix.h"
#include "core.h"
#include "client.h"
#include "matlabclient.h"

/*
 * Fill in the output objects
 *
 */
int mxFillInOutputObjects(NS_ProblemDesc *pd,int noutput,mxArray **output)
{
  int i,j,k;
  NS_Object *obj;
  double *pr, *pi;
  mxArray **tmp_output;
  int nb_effective_merges;
  int current_effective_merges;
  int index;

  tmp_output = (mxArray**)mxCalloc(pd->nb_output_objects,sizeof(mxArray*));
  
  for (i=0;i<pd->nb_output_objects;i++)
  {
    obj = pd->output_objects[i];
    switch(obj->object_type)
    {
      case NETSOLVE_MATRIX:
      {
        int  n    = obj->attributes.matrix_attributes.n;
        int  m    = obj->attributes.matrix_attributes.m;
        void *ptr = obj->attributes.matrix_attributes.ptr;

        switch(obj->data_type)
        {
          case NETSOLVE_I:
            tmp_output[i] = mxCreateDoubleMatrix(m,n,mxREAL);
            pr = mxGetPr(tmp_output[i]);
            for (j=0;j<m*n;j++)
              pr[j] = (double)(((int*)ptr)[j]);
            mxFree(ptr);
            break;
          case NETSOLVE_S:
            tmp_output[i] = mxCreateDoubleMatrix(m,n,mxREAL);
            pr = mxGetPr(tmp_output[i]);
            for (j=0;j<m*n;j++)
              pr[j] = (double)(((float*)ptr)[j]);
            mxFree(ptr);
            break;
          case NETSOLVE_D:
            tmp_output[i] = mxCreateDoubleMatrix(1,1,mxREAL);
            mxSetM(tmp_output[i],m);
            mxSetN(tmp_output[i],n);
            mxSetPr(tmp_output[i],(double*)ptr);
            mexMakeMemoryPersistent(ptr);
            break;
          case NETSOLVE_CHAR:
          {
            int dims[2];
            mxChar *pr;

            dims[0] = m;
            dims[1] = n;
            tmp_output[i] = mxCreateCharArray(2,dims);
            pr = (mxChar *)mxGetPr(tmp_output[i]); 
            for(j=0;j<m*n;j++)
              pr[j] = ((char*)ptr)[j];
            mxFree(ptr);
            break;
          }
          case NETSOLVE_B:
          {
            int dims[2];
            mxChar *pr;

            dims[0] = m;
            dims[1] = n;
            tmp_output[i] = mxCreateCharArray(2,dims);
            pr = (mxChar *)mxGetPr(tmp_output[i]); 
            for(j=0;j<m*n;j++)
              pr[j] = ((char*)ptr)[j];
            mxFree(ptr);
            break;
          }
          case NETSOLVE_C:
            tmp_output[i] = mxCreateDoubleMatrix(m,n,mxCOMPLEX);
            pr = mxGetPr(tmp_output[i]);
            pi = mxGetPi(tmp_output[i]);
            for (j=0;j<m*n;j++)
            {
              pr[j] = (double)(((scomplex*)ptr)[j].r);
              pi[j] = (double)(((scomplex*)ptr)[j].i);
            }
            mxFree(ptr);
            break;
          case NETSOLVE_Z:
            tmp_output[i] = mxCreateDoubleMatrix(m,n,mxCOMPLEX);
            pr = mxGetPr(tmp_output[i]);
            pi = mxGetPi(tmp_output[i]);
            for (j=0;j<m*n;j++)
            {
              pr[j] = ((dcomplex*)ptr)[j].r;
              pi[j] = ((dcomplex*)ptr)[j].i;
            }
            mxFree(ptr);
            break;
        }
        break;
      }
      case NETSOLVE_SPARSEMATRIX:
      {
        int  n    = obj->attributes.sparsematrix_attributes.n;
        int  m    = obj->attributes.sparsematrix_attributes.m;
        int  f   = obj->attributes.sparsematrix_attributes.f;
        int  *rc_ptr = obj->attributes.sparsematrix_attributes.rc_ptr;
        int  *rc_index = obj->attributes.sparsematrix_attributes.rc_index;
        void *ptr = obj->attributes.sparsematrix_attributes.ptr;

        switch(obj->data_type)
        {
          case NETSOLVE_I:
            tmp_output[i] = mxCreateSparse(m,n,f,mxREAL);
            pr = mxGetPr(tmp_output[i]);
            for (j=0;j<f;j++)
              pr[j] = (double)(((int*)ptr)[j]);
            mxFree(ptr);
            mxFree(mxGetIr(tmp_output[i]));
            mxSetIr(tmp_output[i], rc_index);
            mxFree(mxGetJc(tmp_output[i]));
            mxSetJc(tmp_output[i], rc_ptr);
            break;
          case NETSOLVE_S:
            tmp_output[i] = mxCreateSparse(m,n,f,mxREAL);
            pr = mxGetPr(tmp_output[i]);
            for (j=0;j<f;j++)
              pr[j] = (double)(((float*)ptr)[j]);
            mxFree(ptr);
            mxFree(mxGetIr(tmp_output[i]));
            mxSetIr(tmp_output[i], rc_index);
            mxFree(mxGetJc(tmp_output[i]));
            mxSetJc(tmp_output[i], rc_ptr);
            break;
          case NETSOLVE_D:
            tmp_output[i] = mxCreateSparse(1,1,1,mxREAL);
            mxSetM(tmp_output[i],m);
            mxSetN(tmp_output[i],n);
            mxSetNzmax(tmp_output[i],f);
            mxSetPr(tmp_output[i],(double*)ptr);
            mexMakeMemoryPersistent(ptr);
            mxFree(mxGetIr(tmp_output[i]));
            mxSetIr(tmp_output[i], rc_index);
            mxFree(mxGetJc(tmp_output[i]));
            mxSetJc(tmp_output[i], rc_ptr);
            break;
          case NETSOLVE_C:
            tmp_output[i] = mxCreateSparse(m,n,f,mxCOMPLEX);
            pr = mxGetPr(tmp_output[i]);
            pi = mxGetPi(tmp_output[i]);
            for (j=0;j<f;j++)
            {
              pr[j] = (double)(((scomplex*)ptr)[j].r);
              pi[j] = (double)(((scomplex*)ptr)[j].i);
            }
            mxFree(ptr);
            mxFree(mxGetIr(tmp_output[i]));
            mxSetIr(tmp_output[i], rc_index);
            mxFree(mxGetJc(tmp_output[i]));
            mxSetJc(tmp_output[i], rc_ptr);
            break;
          case NETSOLVE_Z:
            tmp_output[i] = mxCreateSparse(m,n,f,mxCOMPLEX);
            pr = mxGetPr(tmp_output[i]);
            pi = mxGetPi(tmp_output[i]);
            for (j=0;j<f;j++)
            {
              pr[j] = ((dcomplex*)ptr)[j].r;
              pi[j] = ((dcomplex*)ptr)[j].i;
            }
            mxFree(ptr);
            mxFree(mxGetIr(tmp_output[i]));
            mxSetIr(tmp_output[i], rc_index);
            mxFree(mxGetJc(tmp_output[i]));
            mxSetJc(tmp_output[i], rc_ptr);
            break;
        }
        break;
      }
      case NETSOLVE_VECTOR:
      {
        int  m    = obj->attributes.vector_attributes.m;
        void *ptr = obj->attributes.vector_attributes.ptr;

        switch(obj->data_type)
        {
          case NETSOLVE_I:
            tmp_output[i] = mxCreateDoubleMatrix(m,1,mxREAL);
            pr = mxGetPr(tmp_output[i]);
            for (j=0;j<m;j++)
              pr[j] = (double)(((int*)ptr)[j]);
            mxFree(ptr);
            break;
          case NETSOLVE_S:
            tmp_output[i] = mxCreateDoubleMatrix(m,1,mxREAL);
            pr = mxGetPr(tmp_output[i]);
            for (j=0;j<m;j++)
              pr[j] = (double)(((float*)ptr)[j]);
            mxFree(ptr);
            break;
          case NETSOLVE_D:
            tmp_output[i] = mxCreateDoubleMatrix(1,1,mxREAL);
            mxSetM(tmp_output[i],m);
            mxSetN(tmp_output[i],1);
            mxSetPr(tmp_output[i],(double*)ptr);
            mexMakeMemoryPersistent(ptr);
            break;
          case NETSOLVE_CHAR:
          {
            int dims[1];
            mxChar *pr;

            dims[0] = m;
            tmp_output[i] = mxCreateCharArray(1,dims);
            pr = (mxChar *)mxGetPr(tmp_output[i]); 
            for(j=0;j<m;j++)
            {
              pr[j] = ((char*)ptr)[j];
            }
            mxFree(ptr);
            break;
          }
          case NETSOLVE_B:
          {
            int dims[1];
            mxChar *pr;

            dims[0] = m;
            tmp_output[i] = mxCreateCharArray(1,dims);
            pr = (mxChar *)mxGetPr(tmp_output[i]);
            for(j=0;j<m;j++)
            {
              pr[j] = ((char*)ptr)[j];
            }
            break;
          }
          case NETSOLVE_C:
            tmp_output[i] = mxCreateDoubleMatrix(m,1,mxCOMPLEX);
            pr = mxGetPr(tmp_output[i]);
            pi = mxGetPi(tmp_output[i]);
            for (j=0;j<m;j++)
            {
              pr[j] = (double)(((scomplex*)ptr)[j].r);
              pi[j] = (double)(((scomplex*)ptr)[j].i);
            }
            mxFree(ptr);
            break;
          case NETSOLVE_Z:
            tmp_output[i] = mxCreateDoubleMatrix(m,1,mxCOMPLEX);
            pr = mxGetPr(tmp_output[i]);
            pi = mxGetPi(tmp_output[i]);
            for (j=0;j<m;j++)
            {
              pr[j] = ((dcomplex*)ptr)[j].r;
              pi[j] = ((dcomplex*)ptr)[j].i;
            }
            mxFree(ptr);
            break;
        }
        break;
      }
      case NETSOLVE_SCALAR:
      {
        void *ptr = obj->attributes.scalar_attributes.ptr;
        switch(obj->data_type)
        {
          case NETSOLVE_I:
            tmp_output[i] = mxCreateDoubleMatrix(1,1,mxREAL);
            pr = mxGetPr(tmp_output[i]);
            pr[0] = (double)(((int*)ptr)[0]);
            break;
          case NETSOLVE_S:
            tmp_output[i] = mxCreateDoubleMatrix(1,1,mxREAL);
            pr = mxGetPr(tmp_output[i]);
            pr[0] = (double)(((float*)ptr)[0]);
            break;
          case NETSOLVE_D:
            tmp_output[i] = mxCreateDoubleMatrix(1,1,mxREAL);
            pr = mxGetPr(tmp_output[i]);
            pr[0] = (double)(((double*)ptr)[0]);
            break;
          case NETSOLVE_CHAR:
          {
            int dims[1];
            mxChar *pr;

            dims[0] = 1;
            tmp_output[i] = mxCreateCharArray(1,dims);
            pr = (mxChar *)mxGetPr(tmp_output[i]); 
            pr[0] = ((char*)ptr)[0];
            break;
          }
          case NETSOLVE_B:
          {
            int dims[1];
            mxChar *pr;

            dims[0] = 1;
            tmp_output[i] = mxCreateCharArray(1,dims);
            pr = (mxChar *)mxGetPr(tmp_output[i]);
            pr[0] = ((char*)ptr)[0];
            break;
          }
          case NETSOLVE_C:
            tmp_output[i] = mxCreateDoubleMatrix(1,1,mxCOMPLEX);
            pr = mxGetPr(tmp_output[i]);
            pi = mxGetPi(tmp_output[i]);
            pr[0] = (double)(((scomplex*)ptr)[0].r);
            pi[0] = (double)(((scomplex*)ptr)[0].i);
            break;
          case NETSOLVE_Z:
            tmp_output[i] = mxCreateDoubleMatrix(1,1,mxCOMPLEX);
            pr = mxGetPr(tmp_output[i]);
            pi = mxGetPi(tmp_output[i]);
            pr[0] = ((dcomplex*)ptr)[0].r;
            pi[0] = ((dcomplex*)ptr)[0].i;
            break;
        }
        mxFree(ptr);
        break;
      }
      case NETSOLVE_FILE:
      {
        tmp_output[i] = mxCreateString(obj->attributes.file_attributes.filename);
        break;
      }
      case NETSOLVE_PACKEDFILES:
      {
        int j;
        tmp_output[i] = mxCreateCharMatrixFromStrings(obj->attributes.packedfiles_attributes.m,
                    (const char **)(obj->attributes.packedfiles_attributes.filenames));
        break;
      }
      case NETSOLVE_STRING:
      {
        tmp_output[i] = mxCreateString(obj->attributes.string_attributes.ptr);
        mxFree(obj->attributes.string_attributes.ptr);
        break;
       }
      default:
      {
        mexPrintf("Unknown object type\n");
      }
    }
  }

  /* Take care of the merges */
  nb_effective_merges = pd->nb_output_objects - noutput;  

  for (i=0;i<nb_effective_merges;i++)
  {
    mxMatlabMerge(&(tmp_output[pd->matlab_merge[i].index1]),
                  tmp_output[pd->matlab_merge[i].index2]);
  }

  index = 0;
  for (i=0;i<noutput;i++)
  {
    while (tmp_output[index] == NULL)
      index++;
    output[i] = tmp_output[i]; 
  }
  mxFree(tmp_output);

  ns_errno = NetSolveOK;
  return 1;
}

/*
 * mxMatlabMerge()
 */
void mxMatlabMerge(mxArray **a1, mxArray *a2)
{
  mxArray *new;
  int m,n;
  int i;
  double *pr;
  double *pi;

  m = mxGetM(*a1);
  n = mxGetN(*a1);
  
  new = mxCreateDoubleMatrix(m,n,mxCOMPLEX);
  pr = mxGetPr(new);
  pi = mxGetPi(new);
  for (i=0;i<m*n;i++)
  {
    pr[i] = mxGetPr(*a1)[i];
    pi[i] = mxGetPr(a2)[i];
  }

  mxDestroyArray(*a1);
  mxDestroyArray(a2);
  *a1 = new;
  ns_errno = NetSolveOK;
  return;
}
