#include <math.h>
#include "petsc.h"
#include "mpi.h"
#include "parpre_mat.h"
#include "sles.h"
#include "src/mat/matimpl.h"
#include "src/mat/parpre_matimpl.h"
#include "src/mat/impls/aij/mpi/mpiaij.h"
#include "pc.h"
#include "./ml_head.h"

#include "src/pc/utils/auxs.c"

#undef __FUNC__
#define __FUNC__ "NextMatFill"
/* From the original (2,2) block and the fill matrix, compute
   the next level coefficient matrix. */
int NextMatFill(MPI_Comm comm,AMLFillMethod fill_method,
		int modification,int repair,int trace,
		Mat fill,Mat orig,Vec d1,Vec d2,Mat *res_mat)
{
  Mat res;
  int rstart,rend,loc,ierr;
    
  if (fill_method == AMLFillNone)
    SETERRQ(1,0,"NextMatFill: should have been caught");

  ierr = MatGetOwnershipRange(fill,&rstart,&rend); CHKERRQ(ierr);
  loc = rend-rstart;

  /* allocate resultant matrix;
     make some guesstimate of the nonzeros per row */
  {
    int ncols,Row,*band;
    band = (int *) PetscMalloc((loc+1)*sizeof(int)); CHKPTRQ(band);
    for (Row=rstart; Row<rend; Row++) {
      ierr = MatGetRow(orig,Row,&ncols,PETSC_NULL,PETSC_NULL); CHKERRQ(ierr);
      band[Row-rstart] = ncols;
      ierr = MatRestoreRow(orig,Row,&ncols,PETSC_NULL,PETSC_NULL);
      CHKERRQ(ierr);
    }
    ierr = MatCreateMPIAIJ
      (comm,loc,loc,PETSC_DECIDE,PETSC_DECIDE,0,band,0,band,&res);
    CHKERRQ(ierr);
    PetscFree(band);
  }
  /* copy original matrix values into resultant matrix */
  {
    int Row,ncols,*cols; Scalar *vals;
    for (Row=rstart; Row<rend; Row++) {
      ierr = MatGetRow(orig,Row,&ncols,&cols,&vals); CHKERRQ(ierr);
      ierr = MatSetValues(res,1,&Row,ncols,cols,vals,ADD_VALUES);
      CHKERRQ(ierr);
      ierr = MatRestoreRow(orig,Row,&ncols,&cols,&vals); CHKERRQ(ierr);
    }
  }
  if (fill_method == AMLFillDiag) {
    Scalar *v1,*v2; int iRow,bd = 0; double bdx = 0.;
    /* extract only the diagonal */
    /*int is,js;
      MatGetLocalSize(orig,&is,&js);printf("orig local %dx%d\n",is,js);
      MatGetLocalSize(fill,&is,&js);printf("fill local %dx%d\n",is,js);*/
    ierr = MatGetDiagonal(orig,d1); CHKERRQ(ierr);
    ierr = MatGetDiagonal(fill,d2); CHKERRQ(ierr);
    ierr = VecGetArray(d1,&v1); CHKERRQ(ierr);
    ierr = VecGetArray(d2,&v2); CHKERRQ(ierr);
    
    for (iRow=0; iRow<loc; iRow++) {
      int Row=rstart+iRow; Scalar val;
      val = -v2[iRow];
#if !defined(USE_PETSC_COMPLEX)
      if (v2[iRow]>v1[iRow]) {
	bd++; bdx = PetscMax(bdx,PetscAbsScalar(v1[iRow]-v2[iRow]));
	if (repair==1) val = 0.;
      }
#endif
      ierr = MatSetValues(res,1,&Row,1,&Row,&val,ADD_VALUES); CHKERRQ(ierr);
    }
#if !defined(USE_PETSC_COMPLEX)
    if (bd>0)
      if (!repair) printf("breakdown in %d places\n",bd);
      else if (repair==1)
	printf("breakdown in %d places, repaired with max %e\n",bd,bdx);
      else if (repair>1) {
	double bdx_g = bdx;
	if (repair==3)
	  MPI_Allreduce((void*)&bdx,(void*)&bdx_g,1,MPI_DOUBLE,MPI_MAX,MPI_COMM_WORLD);
	printf("breakdown in %d places, corrected by %e\n",bd,bdx_g);
	for (iRow=0; iRow<loc; iRow++) {
	  int Row=rstart+iRow;
	  ierr = MatSetValues(res,1,&Row,1,&Row,&bdx_g,ADD_VALUES);
	  CHKERRQ(ierr);
	}
      }
#endif
    ierr = VecRestoreArray(d1,&v1); CHKERRQ(ierr);
    ierr = VecRestoreArray(d2,&v2); CHKERRQ(ierr);
  } else if (fill_method == AMLFillStrong) {
    Mat tmp; int b;
    ierr = NextMatFill(comm,AMLFillFull,0,repair,trace,fill,orig,d1,d2,&tmp);
    CHKERRQ(ierr);
    ierr = MakeStrongMatrix(tmp,modification,d1,d2,&res,trace); CHKERRQ(ierr);
    ierr = MatDestroy(tmp); CHKERRQ(ierr);
    if (trace & AMLTraceFill) {
      int a;
      ierr = MatMaxRowLen_MPIAIJ(res,&a); CHKERRQ(ierr);
      PetscPrintf(comm,"Max row length of sparsified matrix: %d\n",a);
    }
  } else if (fill_method == AMLFillFull) {
    int iRow;
    for (iRow=0; iRow<loc; iRow++) {
      int Row=iRow+rstart,ncols,*cols,iCol; Scalar *vals;
      ierr = MatGetRow(fill,Row,&ncols,&cols,&vals); CHKERRQ(ierr);
      for (iCol=0; iCol<ncols; iCol++) {
	Scalar val = -vals[iCol]/*-fabs(.95*vals[iCol])*/; int Col=cols[iCol];
	ierr = MatSetValues(res,1,&Row,1,&Col,&val,ADD_VALUES);
	CHKERRQ(ierr);
      }
      ierr = MatRestoreRow(fill,iRow,&ncols,&cols,&vals); CHKERRQ(ierr);
    }
  } else SETERRQ(1,0,"NextMatFill: unknown method");
  ierr = MatAssemblyBegin(res,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  ierr = MatAssemblyEnd(res,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);

  *res_mat = res;

  return 0;
}

#undef __FUNC__
#define __FUNC__ "MakeModVector"
static int MakeModVector
(MPI_Comm comm,SLES a11_solve,Mat g11,Mat g12,
 Vec g1,Vec g2,Vec h1,Vec k1)
{
  int its,ierr;
  Scalar one = 1.0, mone = -1.0;
  SLES full_solve; KSP ksp; PC pc;

  ierr = VecSet(&one,g2); CHKERRQ(ierr);
  ierr = MatMult(g12,g2,g1); CHKERRQ(ierr);
  ierr = SLESSolve(a11_solve,g1,h1,&its); CHKERRQ(ierr);

  ierr = SLESCreate(comm,&full_solve); CHKERRQ(ierr);
  ierr = SLESSetOperators(full_solve,g11,g11,SAME_PRECONDITIONER);
  CHKERRQ(ierr);
  
  ierr = SLESGetKSP(full_solve,&ksp); CHKERRQ(ierr);
  ierr = KSPSetType(ksp,KSPCG/*KSPTFQMR*/); CHKERRQ(ierr);
  ierr = KSPSetTolerances(ksp,1.e-6,0.,2.,20); CHKERRQ(ierr);
  
  ierr = SLESGetPC(full_solve,&pc); CHKERRQ(ierr);
  ierr = PCSetType(pc,PCJACOBI); CHKERRQ(ierr);
  
  ierr = SLESSetUp(full_solve,g1,k1); CHKERRQ(ierr);
  ierr = SLESSolve(full_solve,g1,k1,&its); CHKERRQ(ierr);
  ierr = SLESDestroy(full_solve); CHKERRQ(ierr);
  
  ierr = VecAXPY(&mone,k1,h1); CHKERRQ(ierr);
  ierr = VecPointwiseDivide(h1,g1,k1); CHKERRQ(ierr);

  return 0;
}

#undef __FUNC__
#define __FUNC__ "MakeTrans12"
static int MakeTrans12(MPI_Comm comm,SLES a11_solve,Mat g11,Mat a12,Mat g12,
		       Vec g1,Vec g2,Vec h1,Vec k1,Vec w1,
		       Mat *trans_mat)
{
  Mat ltrans,modtrans;
  int ierr;

  ierr = MatSolveMat_AIJ(a11_solve,a12,&ltrans); CHKERRQ(ierr);

  ierr = MakeModVector(comm,a11_solve,g11,g12,g1,g2,h1,k1); CHKERRQ(ierr);
  ierr = MatDiagonalScale(a12,w1,PETSC_NULL);
  CHKERRQ(ierr);
  {
    int isize,jsize,Row,ncols,*cols; Scalar *vals;
    ierr = MatGetSize(a12,&isize,&jsize); CHKERRQ(ierr);
    ierr = MatCreateSeqAIJ(MPI_COMM_SELF,isize,jsize,0,0,&modtrans);
    CHKERRQ(ierr);
    for (Row=0; Row<isize; Row++) {
      ierr = MatGetRow(ltrans,Row,&ncols,&cols,&vals); CHKERRQ(ierr);
      ierr = MatSetValues(modtrans,1,&Row,ncols,cols,vals,ADD_VALUES);
      CHKERRQ(ierr);
      ierr = MatRestoreRow(ltrans,Row,&ncols,&cols,&vals); CHKERRQ(ierr);
      ierr = MatGetRow(a12,Row,&ncols,&cols,&vals); CHKERRQ(ierr);
      ierr = MatSetValues(modtrans,1,&Row,ncols,cols,vals,ADD_VALUES);
      CHKERRQ(ierr);
      ierr = MatRestoreRow(a12,Row,&ncols,&cols,&vals);
      CHKERRQ(ierr);
    }
    ierr = MatAssemblyBegin(modtrans,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
    ierr = MatAssemblyEnd(modtrans,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  }
  ierr = MatDestroy(ltrans); CHKERRQ(ierr);

  /* just to make sure we don't use this and think it is what it isn't */
  ierr = MatDestroy(a12); CHKERRQ(ierr);

  *trans_mat = modtrans;
  return 0;
}

#undef __FUNC__
#define __FUNC__ "NextLevelFill"
int NextLevelFill
(MPI_Comm comm,AMLCoarseGridChoice grid_choice,AMLSchurChoice schur_choice,
 Mat g21,Mat l21,SLES a11_solve,Mat g11,Vec cdiag1,Mat g12,Mat a12,Mat g22,
 Vec g1,Vec g2,Vec h1,Vec k1,Vec w1,Mat *fill)
{
  int ierr;

  PetscFunctionBegin;
  if (schur_choice == AMLSchurVariational) {
    Mat c12,c21,t1,t2,t3,t4,t5;
    /* c12 = M approx A11 inv A12, c21 = Mt approx A21 A11 inv */
    ierr = MatConvert(g12,MATSAME,&c12); CHKERRQ(ierr);
    ierr = MatConvert(g21,MATSAME,&c21); CHKERRQ(ierr);
    ierr = MatDiagonalScale(c12,cdiag1,0); CHKERRQ(ierr);
    ierr = MatDiagonalScale(c21,0,cdiag1); CHKERRQ(ierr);

    /* t2 = A12 - A11 M,  t3 = Mt ( A12 - A11 M ) */
    ierr = MatMatMult_MPIAIJ(g11,c12,&t1); CHKERRQ(ierr);
    ierr = MatMatSub_MPIAIJ(g12,t1,&t2); CHKERRQ(ierr);
    ierr = MatDestroy(t1); CHKERRQ(ierr);
    ierr = MatMatMult_MPIAIJ(c21,t2,&t3); CHKERRQ(ierr);
    ierr = MatDestroy(t2); CHKERRQ(ierr);
    /*printf("M(A12-A11M)\n");MatView(t3,0);*/

    /* fill = A21 M + Mt ( A12 - A11 M ) */
    ierr = MatMatMult_MPIAIJ(g21,c12,&t4); CHKERRQ(ierr);
    {
      int x,y;
      ierr = MatGetLocalSize(g22,&x,&y); CHKERRQ(ierr);
      ierr = MatCreateMPIAIJ(comm,x,y,PETSC_DECIDE,PETSC_DECIDE,0,0,0,0,fill);
      CHKERRQ(ierr);
    }
    ierr = MatMatAdd_MPIAIJ(t3,t4,MatAlreadyCreated,fill); CHKERRQ(ierr);
    ierr = MatDestroy(c12); CHKERRQ(ierr);
    ierr = MatDestroy(c21); CHKERRQ(ierr);
    ierr = MatDestroy(t3); CHKERRQ(ierr);
    ierr = MatDestroy(t4); CHKERRQ(ierr);
  } else {
    if (grid_choice == AMLCoarseGridDependent) {
      Mat trans;
      ierr = MatConvert(g12,MATSAME,&trans); CHKERRQ(ierr);
      ierr = MatDiagonalScale(trans,cdiag1,0); CHKERRQ(ierr);
      ierr = MatMatMult_MPIAIJ(g21,trans,fill); CHKERRQ(ierr);
      ierr = MatDestroy(trans); CHKERRQ(ierr);
    } else {
      Mat ltrans,lfill;
      int onesize,d;
      ierr = MakeTrans12(comm,a11_solve,g11,a12,g12,g1,g2,h1,k1,w1,&ltrans);
      CHKERRQ(ierr);
      ierr = MatMatMult_AIJ(l21,ltrans,&lfill); CHKERRQ(ierr);
    ierr = MatDestroy(ltrans); CHKERRQ(ierr);
    ierr = MatGetLocalSize(l21,&onesize,&d); CHKERRQ(ierr);
    ierr = MatrixAij2MpiAbut(onesize,lfill,comm,fill); CHKERRQ(ierr);
    ierr = MatDestroy(lfill); CHKERRQ(ierr);
    }
  }
  PetscFunctionReturn(0);
}
