#include "parpre_mat.h"

static int events[2];
#define MULT_EVENT 0
#define INVERT_EVENT 1

#if !defined(USE_PETSC_COMPLEX)
#define GEMM dgemm_
#define GETRF dgetrf_
#define GETRI dgetri_
#define EXTERN extern
#define CAST
#else
#define GEMM zgemm_
#define GETRF zgetrf_
#define GETRI zgetri_
#define EXTERN extern "C"
#define CAST (double *)
#endif

EXTERN void GEMM(char *,char *,int *,int *,int *,double *,
		 double *,int *,double *,int *,double *,
		 double *,int *);
EXTERN void GETRF(int*,int*,double*,int*,int*,int*);
EXTERN void GETRI(int*,double*,int*,int*,double*,int*,int*);

#undef __FUNC__
#define __FUNC__ "MatMatMultDense"
int MatMatMultDense(Mat a,Mat b,Mat c)
{
  int ierr,m,n,k,mm,nn,kk;
  Scalar *a_ar,*b_ar,*c_ar,one=1.,zero=0.;

  PetscFunctionBegin;
  PLogEventBegin(events[MULT_EVENT],0,0,0,0);
  ierr = MatGetSize(a,&m,&k); CHKERRQ(ierr);
  ierr = MatGetSize(b,&kk,&n); CHKERRQ(ierr);
  if (k!=kk) SETERRQ(1,0,"Inner dimension mismatch");
  ierr = MatGetSize(c,&mm,&nn); CHKERRQ(ierr);
  if (mm!=m || nn!=n) SETERRQ(1,0,"Result size mismatch");
  ierr = MatGetArray(a,&a_ar); CHKERRQ(ierr);
  ierr = MatGetArray(b,&b_ar); CHKERRQ(ierr);
  ierr = MatGetArray(c,&c_ar); CHKERRQ(ierr);
  GEMM("N","N",&m,&n,&k,
       CAST &one,CAST a_ar,&m,CAST b_ar,&k,CAST &zero,CAST c_ar,&m);
  ierr = MatRestoreArray(a,&a_ar); CHKERRQ(ierr);
  ierr = MatRestoreArray(b,&b_ar); CHKERRQ(ierr);
  ierr = MatRestoreArray(c,&c_ar); CHKERRQ(ierr);
  PLogFlops(2*n*m*k);
  PLogEventEnd(events[MULT_EVENT],0,0,0,0);
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "MatTMatMultDense"
int MatTMatMultDense(Mat a,Mat b,Mat c)
{
  int ierr,m,n,k,mm,nn,kk;
  Scalar *a_ar,*b_ar,*c_ar,one=1.,zero=0.;

  PetscFunctionBegin;
  PLogEventBegin(events[MULT_EVENT],0,0,0,0);
  ierr = MatGetSize(a,&k,&m); CHKERRQ(ierr);
  ierr = MatGetSize(b,&kk,&n); CHKERRQ(ierr);
  if (k!=kk) SETERRQ(1,0,"Inner dimension mismatch");
  ierr = MatGetSize(c,&mm,&nn); CHKERRQ(ierr);
  if (mm!=m || nn!=n) SETERRQ(1,0,"Result size mismatch");
  ierr = MatGetArray(a,&a_ar); CHKERRQ(ierr);
  ierr = MatGetArray(b,&b_ar); CHKERRQ(ierr);
  ierr = MatGetArray(c,&c_ar); CHKERRQ(ierr);
  GEMM("T","N",&m,&n,&k,
       CAST &one,CAST a_ar,&m,CAST b_ar,&k,CAST &zero,CAST c_ar,&m);
  ierr = MatRestoreArray(a,&a_ar); CHKERRQ(ierr);
  ierr = MatRestoreArray(b,&b_ar); CHKERRQ(ierr);
  ierr = MatRestoreArray(c,&c_ar); CHKERRQ(ierr);
  PLogFlops(2*n*m*k);
  PLogEventEnd(events[MULT_EVENT],0,0,0,0);
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "MatHMatMultDense"
int MatHMatMultDense(Mat a,Mat b,Mat c)
{
  int ierr,m,n,k,mm,nn,kk;
  Scalar *a_ar,*b_ar,*c_ar,one=1.,zero=0.;

  PetscFunctionBegin;
  PLogEventBegin(events[MULT_EVENT],0,0,0,0);
  ierr = MatGetSize(a,&k,&m); CHKERRQ(ierr);
  ierr = MatGetSize(b,&kk,&n); CHKERRQ(ierr);
  if (k!=kk) SETERRQ(1,0,"Inner dimension mismatch");
  ierr = MatGetSize(c,&mm,&nn); CHKERRQ(ierr);
  if (mm!=m || nn!=n) SETERRQ(1,0,"Result size mismatch");
  ierr = MatGetArray(a,&a_ar); CHKERRQ(ierr);
  ierr = MatGetArray(b,&b_ar); CHKERRQ(ierr);
  ierr = MatGetArray(c,&c_ar); CHKERRQ(ierr);
  GEMM("C","N",&m,&n,&k,
       CAST &one,CAST a_ar,&m,CAST b_ar,&k,CAST &zero,CAST c_ar,&m);
  ierr = MatRestoreArray(a,&a_ar); CHKERRQ(ierr);
  ierr = MatRestoreArray(b,&b_ar); CHKERRQ(ierr);
  ierr = MatRestoreArray(c,&c_ar); CHKERRQ(ierr);
  PLogFlops(2*n*m*k);
  PLogEventEnd(events[MULT_EVENT],0,0,0,0);
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "MatInverseDense"
int MatInverseDense(Mat a,Mat b)
{
  int* i_tmp;
  int ierr,m,n,mm,nn,info,lwork;
  Scalar *a_ar,*b_ar,*d_tmp;

  PetscFunctionBegin;
  PLogEventBegin(events[INVERT_EVENT],0,0,0,0);

  /* consistency checks */
  ierr = MatGetSize(a,&m,&n); CHKERRQ(ierr);
  ierr = MatGetSize(b,&mm,&nn); CHKERRQ(ierr);
  if (mm!=m || nn!=n) SETERRQ(1,0,"Result size mismatch");
  if (!m || !n) SETERRQ(1,0,"Cannot invert empty matrices");

  /* get data, temp allocation */
  ierr = MatGetArray(a,&a_ar); CHKERRQ(ierr);
  ierr = MatGetArray(b,&b_ar); CHKERRQ(ierr);
  i_tmp = (int *) PetscMalloc(PetscMax(m,n)*sizeof(int));
  d_tmp = (Scalar *) PetscMalloc(m*n*sizeof(Scalar));
  lwork = m*n;

  /* first factor */
  PetscMemcpy((void*)b_ar,(void*)a_ar,m*n*sizeof(Scalar));
  GETRF(&m,&n,CAST b_ar,&m,i_tmp,&info);
  PLogFlops(2*n*n*n/3);

  /* then invert */
  GETRI(&m,CAST b_ar,&m,i_tmp,CAST d_tmp,&lwork,&info);
  PLogFlops(2*n*n*n);

  /* clean up */
  ierr = MatRestoreArray(a,&a_ar); CHKERRQ(ierr);
  ierr = MatRestoreArray(b,&b_ar); CHKERRQ(ierr);
  PetscFree(i_tmp); PetscFree(d_tmp);
  PLogEventEnd(events[INVERT_EVENT],0,0,0,0);
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "ParPreDenseInit"
int ParPreDenseInit()
{
  int ierr;
  PetscFunctionBegin;
  ierr = PLogEventRegister
    (events+MULT_EVENT,  "MatMatMult      ",PETSC_NULL); CHKERRQ(ierr);
  ierr = PLogEventRegister
    (events+INVERT_EVENT,"MatInvert       ",PETSC_NULL); CHKERRQ(ierr);
  PetscFunctionReturn(0);
}
