/*
   Defines matrix-matrix operations for the AIJ (compressed row)
   matrix storage format.
*/
#include "src/mat/impls/aij/seq/aij.h"
#include "src/vec/vecimpl.h"
#include "petsc.h"

int MatMultAXBY_AIJ(Scalar a, Mat aijin,Vec xx, Scalar b, Vec yy,Vec zz)
{
  Mat_SeqAIJ *aij = (Mat_SeqAIJ *) aijin->data;
  Scalar  *x, *y, *z, *v, sum;
  int     m = aij->m, n, i, *idx, *ii,shift = aij->indexshift;

  if (!aijin->assembled)
    SETERRQ(1,0,"MatMultAXBY_SeqAIJ:Not for unassembled matrix");
  VecGetArray(xx,&x); VecGetArray(yy,&y); VecGetArray(zz,&z); 

  x = x+shift; /* shift for Fortran start by 1 indexing */
  idx  = aij->j;
  v    = aij->a;
  ii   = aij->i;
  for ( i=0; i<m; i++ ) {
    n    = ii[i+1] - ii[i];
    sum  = 0.0;
    /* SPARSEDENSEDOT(sum,x,v,idx,n);*/
    while (n--) sum += *v++ * x[*idx++];
    z[i] = b * y[i] +a * sum;
  }
  PLogFlops(aij->nz+3*m)
  return 0;
}

int MatMatMult_AIJ(Mat a, Mat b, Mat *c)
{
  Mat_SeqAIJ *aij = (Mat_SeqAIJ *) a->data;
  Mat_SeqAIJ *bij = (Mat_SeqAIJ *) b->data;
  int nd, ia,ib,ja,jb,ierr;
  int ashift = aij->indexshift, bshift = bij->indexshift;

  if ( ! (aij->n==bij->m) )
    SETERRQ(1,0,"MatMatMult_AIJ: Non conforming matrix dimensions");

  nd = aij->nz*bij->nz/(aij->n*bij->m);
  ierr = MatCreateSeqAIJ(a->comm,aij->m,bij->n,nd,PETSC_NULL,c);
  CHKERRQ(ierr);

  for (ia=0; ia<aij->m; ia++)
    for (ja=aij->i[ia]/*-ashift*/; ja<aij->i[ia+1]/*-ashift*/; ja++)
      for (ib=0; ib<bij->m; ib++)
	if (aij->j[ja]-ashift==ib)
	  for (jb=bij->i[ib]/*-bshift*/; jb<bij->i[ib+1]/*-bshift*/; jb++) {
	    int    i = ia, j = bij->j[jb]-bshift;
	    Scalar v = aij->a[ja]*bij->a[jb];
	    ierr = MatSetValues(*c,1,&i,1,&j,&v,ADD_VALUES); CHKERRQ(ierr);
	  }
  if (bij->m) PLogFlops(aij->nz*bij->nz/bij->m); /* approx !!! */
  ierr = MatAssemblyBegin(*c,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  ierr = MatAssemblyEnd(*c,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  return 0;
}

int MatMatAdd_AIJ(Mat a, Mat b, Mat c)
{
  Mat_SeqAIJ *aij = (Mat_SeqAIJ *) a->data;
  Mat_SeqAIJ *bij = (Mat_SeqAIJ *) b->data;
  Mat_SeqAIJ *cij = (Mat_SeqAIJ *) c->data;
  int row,ierr,ct=0;

#define INFTY aij->n+1
  if ( ! ( (aij->m==bij->m) & (aij->m==cij->m) & 
	   (aij->n==bij->n) & (aij->n==cij->n) ) )
    SETERRQ(1,0,"MatMatAdd_AIJ: Non conforming matrix dimensions");

  for (row=0; row<aij->m; row++) {
      int i,j,ja,jb,aj,bj; Scalar av,bv,v;
      i = row;
      ja = aij->i[row]-1;
      jb = bij->i[row]-1;

      while ( (ja<aij->i[row+1]-1) | (jb<bij->i[row+1]-1) ) {
	if (ja==aij->i[row+1]-1) aj = INFTY;
	else
	  {aj = aij->j[ja]-1; av = aij->a[ja];}
	if (jb==bij->i[row+1]-1) bj = INFTY;
	else
	  {bj = bij->j[jb]-1; bv = bij->a[jb];}

	if (bj<aj) {v = bv; j = bj; jb++;}
	if (bj==aj) {v = av+bv; j = bj; ja++; jb++;}
	if (bj>aj) {v = av; j = aj; ja++;}

	ct++;
	ierr = MatSetValues(c,1,&i,1,&j,&v,INSERT_VALUES); CHKERRQ(ierr);
      }
    }
  PLogFlops(ct);
  ierr = MatAssemblyBegin(c,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  ierr = MatAssemblyEnd(c,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  return 0;
}

int MatMatSubtract_AIJ(Mat a, Mat b, Mat *c)
{
  Mat_SeqAIJ *aij = (Mat_SeqAIJ *) a->data;
  Mat_SeqAIJ *bij = (Mat_SeqAIJ *) b->data;
  int row,ierr, ashift = aij->indexshift, bshift = bij->indexshift,ct=0;

#define INFTY aij->n+1
  if ( ! ( (aij->m==bij->m) & (aij->n==bij->n) ) )
    SETERRQ(1,0,"MatMatSubtract_AIJ: Non conforming matrix dimensions");

  ierr = MatCreateSeqAIJ(a->comm,aij->m,aij->n,0,0,c); CHKERRQ(ierr);

  for (row=0; row<aij->m; row++) {
      int i,j,ja,jb,aj,bj; Scalar av,bv,v;
      int astart = aij->i[row]-ashift, aend = aij->i[row+1]-ashift;
      int bstart = bij->i[row]-bshift, bend = bij->i[row+1]-bshift;
      i = row;
      ja = astart;
      jb = bstart;

      while ( (ja<aend) | (jb<bend) ) {
	if (ja==aend) aj = INFTY;
	else
	  {aj = aij->j[ja]-ashift; av = aij->a[ja];}
	if (jb==bend) bj = INFTY;
	else
	  {bj = bij->j[jb]-bshift; bv = bij->a[jb];}

	if (bj<aj) {v = -bv; j = bj; jb++;}
	if (bj==aj) {v = av-bv; j = bj; ja++; jb++;}
	if (bj>aj) {v = av; j = aj; ja++;}

	ct++;
	ierr = MatSetValues(*c,1,&i,1,&j,&v,INSERT_VALUES); CHKERRQ(ierr);
      }
    }
  PLogFlops(ct);
  ierr = MatAssemblyBegin(*c,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  ierr = MatAssemblyEnd(*c,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);

  return 0;
}

int MatTranspose_AIJ(Mat a, Mat *b)
{
  Mat_SeqAIJ *aij = (Mat_SeqAIJ *) a->data;
  int ia,ja,i,j,ierr,shift = aij->indexshift;
  Scalar v;

  ierr = MatCreateSeqAIJ(a->comm,aij->n,aij->m,0,0,b);
  CHKERRQ(ierr);

  for (ia=0; ia<aij->m; ia++)
    for (ja=aij->i[ia]-shift; ja<aij->i[ia+1]-shift; ja++) {
      i = ia; j = aij->j[ja]-shift;
      v = aij->a[ja];
      ierr = MatSetValues(*b,1,&j,1,&i,&v,ADD_VALUES); CHKERRQ(ierr);
    }
  ierr = MatAssemblyBegin(*b,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  ierr = MatAssemblyEnd(*b,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  return 0;
}

#include "sles.h"
int MatSolveMat_AIJ(SLES solver,Mat off_mat,Mat *off_mult)
{
  Mat off_trans;
  int nrows = ((Mat_SeqAIJ *) off_mat->data)->m;
  int offcols = ((Mat_SeqAIJ *) off_mat->data)->n;
  int col,ierr; Vec rhs,sol;

  /* allocate work space, transpose A_{12} so that we can
   * extract columns as rows */
  ierr = VecCreateSeq(MPI_COMM_SELF,nrows,&sol); CHKERRQ(ierr);
  ierr = MatTranspose_AIJ(off_mat,&off_trans); CHKERRQ(ierr);

  ierr = VecCreateSeq(MPI_COMM_SELF,nrows,&rhs); CHKERRQ(ierr);
  ierr = MatCreateSeqAIJ(MPI_COMM_SELF,nrows,offcols,0,0,off_mult);
  CHKERRQ(ierr);

  /* extract successive columns (as rows of the transpose), solve,
   * and store them back */
  for (col=0; col<offcols; col++) {
    Scalar *vals,zero = 0; int nelt,*elts,its,row;
    ierr = MatGetRow(off_trans,col,&nelt,&elts,&vals); CHKERRQ(ierr);
    ierr = VecSet(&zero,rhs); CHKERRQ(ierr);
    for (row=0; row<nelt; row++) {
      ierr = VecSetValues(rhs,1,elts+row,vals+row,INSERT_VALUES);
      CHKERRQ(ierr);
    }
    ierr = VecAssemblyBegin(rhs); CHKERRQ(ierr);
    ierr = VecAssemblyEnd(rhs); CHKERRQ(ierr);

    ierr = SLESSolve(solver,rhs,sol,&its); CHKERRQ(ierr);
    {
      Scalar *tvals;
      ierr = VecGetArray(sol,&tvals); CHKERRQ(ierr);
      for (row=0; row<nrows; row++) {
	Scalar v = tvals[row];
	if (v!=0.) {
	  ierr = MatSetValues(*off_mult,1,&row,1,&col,&v,INSERT_VALUES);
	  CHKERRQ(ierr);
	}
      }
      ierr = VecRestoreArray(sol,&tvals); CHKERRQ(ierr);
    }
    ierr = MatRestoreRow(off_trans,col,&nelt,&elts,&vals); CHKERRQ(ierr);
  }
  ierr = VecDestroy(rhs); CHKERRQ(ierr);

  ierr = MatAssemblyBegin(*off_mult,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  ierr = MatAssemblyEnd(*off_mult,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
/*printf("Solved off mat\n"); MatView(off_mat,0);*/

  ierr = MatDestroy(off_trans); CHKERRQ(ierr);
  ierr = VecDestroy(sol); CHKERRQ(ierr);

  return 0;
}
