#include <math.h>
#include "petsc.h"
#include "mpi.h"
#include "parpre_mat.h"
#include "petscsles.h"
#include "src/mat/matimpl.h"
#include "src/mat/parpre_matimpl.h"
#include "src/mat/impls/aij/mpi/mpiaij.h"
#include "petscpc.h"
#include "./ml_head.h"

#include "src/sles/pc/utils/auxs.c"

#undef __FUNC__
#define __FUNC__ "NextMatFill"
/* From the original (2,2) block and the fill matrix, compute
   the next level coefficient matrix. */
int NextMatFill(MPI_Comm comm,AMLFillMethod fill_method,
		int modification,int repair,int trace,double weight,
		Mat fill,Mat orig,Vec d1,Vec d2,Mat *res_mat)
{
  Mat res;
  int rstart,rend,loc,ierr;
    
  PetscFunctionBegin;
  if (fill_method == AMLFillNone)
    SETERRQ(1,0,"NextMatFill: FillNone should have been caught");

  ierr = MatGetOwnershipRange(fill,&rstart,&rend); CHKERRQ(ierr);
  loc = rend-rstart;

  if (fill_method == AMLFillDiag) {
    Scalar *v1,*v2; int iRow,bd = 0; double bdx = 0.;
    ierr = MatConvert(orig,MATSAME,&res); CHKERRQ(ierr);
    /* extract only the diagonal */
    ierr = MatGetDiagonal(orig,d1); CHKERRQ(ierr);
    ierr = MatGetDiagonal(fill,d2); CHKERRQ(ierr);
    ierr = VecGetArray(d1,&v1); CHKERRQ(ierr);
    ierr = VecGetArray(d2,&v2); CHKERRQ(ierr);
    
    for (iRow=0; iRow<loc; iRow++) {
      int Row=rstart+iRow; Scalar val;
      val = -v2[iRow];
#if !defined(USE_PETSC_COMPLEX)
      if (v2[iRow]>v1[iRow]) {
	bd++; bdx = PetscMax(bdx,PetscAbsScalar(v1[iRow]-v2[iRow]));
	if (repair==1) val = 0.;
      }
#endif
      ierr = MatSetValues(res,1,&Row,1,&Row,&val,ADD_VALUES); CHKERRQ(ierr);
    }
#if !defined(USE_PETSC_COMPLEX)
    if (bd>0) {
      if (!repair) printf("breakdown in %d places\n",bd);
      else if (repair==1)
	printf("breakdown in %d places, repaired with max %e\n",bd,bdx);
      else if (repair>1) {
	double bdx_g = bdx;
	if (repair==3)
	  MPI_Allreduce((void*)&bdx,(void*)&bdx_g,1,MPI_DOUBLE,MPI_MAX,MPI_COMM_WORLD);
	printf("breakdown in %d places, corrected by %e\n",bd,bdx_g);
	for (iRow=0; iRow<loc; iRow++) {
	  int Row=rstart+iRow;
	  ierr = MatSetValues(res,1,&Row,1,&Row,&bdx_g,ADD_VALUES);
	  CHKERRQ(ierr);
	}}}
#endif
    ierr = VecRestoreArray(d1,&v1); CHKERRQ(ierr);
    ierr = VecRestoreArray(d2,&v2); CHKERRQ(ierr);
  } else if (fill_method == AMLFillStrong) {
    Mat tmp;
    ierr = MatMatSub_MPIAIJ(orig,fill,&tmp); CHKERRQ(ierr);
    ierr = MakeStrongMatrix
      (tmp,modification,weight,d1,d2,&res,trace); CHKERRQ(ierr);
    ierr = MatDestroy(tmp); CHKERRQ(ierr);
    if (trace & AMLTraceFill) {
      int a;
      ierr = MatMaxRowLen_MPIAIJ(res,&a); CHKERRQ(ierr);
      PetscPrintf(comm,"Max row length of sparsified matrix: %d\n",a);
    }
  } else if (fill_method == AMLFillFull) {
    ierr = MatMatSub_MPIAIJ(orig,fill,&res); CHKERRQ(ierr);
  } else SETERRQ(1,0,"NextMatFill: unknown method");
  ierr = MatAssemblyBegin(res,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  ierr = MatAssemblyEnd(res,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);

  *res_mat = res;

  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "ScaleModVector"
static int ScaleModVector
(MPI_Comm comm,Mat t1,Mat a11,Mat a12,Vec u1,Vec v1,Vec w1,Vec u2)
{
  int ierr;
  Scalar mone = -1.0;
  SLES solve1; KSP ksp; PC pc;

  PetscFunctionBegin;
  /* constant vector */
  ierr = VecSet(&mone,u2); CHKERRQ(ierr);

  /* setup a solve for A11; not too exact */
  ierr = SLESCreate(comm,&solve1); CHKERRQ(ierr);
  ierr = SLESSetOperators(solve1,a11,a11,SAME_PRECONDITIONER); CHKERRQ(ierr);
  ierr = SLESGetKSP(solve1,&ksp); CHKERRQ(ierr);
  ierr = KSPSetType(ksp,/*KSPCG*/KSPTFQMR); CHKERRQ(ierr);
  ierr = KSPSetTolerances(ksp,1.e-5,0.,20.,50); CHKERRQ(ierr);
  ierr = SLESGetPC(solve1,&pc); CHKERRQ(ierr);
  ierr = PCSetType(pc,PCJACOBI); CHKERRQ(ierr);

  /* u1 = A12 e */
  ierr = MatMult(a12,u2,u1); CHKERRQ(ierr);

  /* v1 = A11inv u1 */
  {
    int its;
    int n,n1,i; double *ar;
    ierr = SLESSolve(solve1,u1,v1,&its); CHKERRQ(ierr);
    PetscPrintf(comm,"Modification solved in %d iterations \n",its);
    ierr = VecGetLocalSize(u1,&n); CHKERRQ(ierr);
    ierr = VecGetArray(u1,&ar); CHKERRQ(ierr);
    n1 = 0; for (i=0; i<n; i++) if (ar[i]<0.) n1++;
    if (n1) printf("negative mods: %d\n",n1);
    ierr = VecRestoreArray(u1,&ar); CHKERRQ(ierr);
  }
  ierr = SLESDestroy(solve1); CHKERRQ(ierr);

  /* w1 = M e */
  ierr = MatMult(t1,u2,w1); CHKERRQ(ierr);

  /* D M e = A11inv A12 e, so u1 = D = v1/w1 */
  ierr = VecPointwiseDivide(v1,w1,u1); CHKERRQ(ierr);

  /* scale from the left */
  ierr = MatDiagonalScale(t1,u1,PETSC_NULL); CHKERRQ(ierr);

  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "DiagModVector"
static int DiagModVector
(MPI_Comm comm,Mat fill,Mat a21,Mat a11,Mat a12,
 Vec u1,Vec v1,Vec w1,Vec u2,Vec v2,Vec w2)
{
  int ierr;
  Scalar mone = -1.0;
  SLES solve1; KSP ksp; PC pc;

  PetscFunctionBegin;
  /* constant vector */
  ierr = VecSet(&mone,u2); CHKERRQ(ierr);

  /* v2 = fill * e */
  ierr = MatMult(fill,u2,v2); CHKERRQ(ierr);

  /* setup a solve for A11; not too exact */
  ierr = SLESCreate(comm,&solve1); CHKERRQ(ierr);
  ierr = SLESSetOperators(solve1,a11,a11,SAME_PRECONDITIONER); CHKERRQ(ierr);
  ierr = SLESGetKSP(solve1,&ksp); CHKERRQ(ierr);
  ierr = KSPSetType(ksp,/*KSPCG*/KSPTFQMR); CHKERRQ(ierr);
  ierr = KSPSetTolerances(ksp,1.e-5,0.,20.,50); CHKERRQ(ierr);
  ierr = SLESGetPC(solve1,&pc); CHKERRQ(ierr);
  ierr = PCSetType(pc,PCJACOBI); CHKERRQ(ierr);

  /* u1 = A12 e */
  ierr = MatMult(a12,u2,u1); CHKERRQ(ierr);

  /* v1 = A11inv A12 e = A11inv u1 */
  {
    int its;
    int n,n1,i; double *ar;
    ierr = SLESSolve(solve1,u1,v1,&its); CHKERRQ(ierr);
    PetscPrintf(comm,"Modification solved in %d iterations \n",its);
    ierr = VecGetLocalSize(u1,&n); CHKERRQ(ierr);
    ierr = VecGetArray(u1,&ar); CHKERRQ(ierr);
    n1 = 0; for (i=0; i<n; i++) if (ar[i]<0.) n1++;
    if (n1) printf("negative mods: %d\n",n1);
    ierr = VecRestoreArray(u1,&ar); CHKERRQ(ierr);
  }
  ierr = SLESDestroy(solve1); CHKERRQ(ierr);

  /* w2 =  A21 A11inv A12 e = A21 v1 */
  ierr = MatMult(a21,v1,w2); CHKERRQ(ierr);

  /* D e = A21 A11inv A12 e - fill e : w2 = w2-v2 */
  ierr = VecAYPX(&mone,w2,v2); CHKERRQ(ierr);

  {
    int i,s,f; Scalar *a;
    ierr = VecGetOwnershipRange(v2,&f,&s); CHKERRQ(ierr); s -= f;
    ierr = VecGetArray(v2,&a); CHKERRQ(ierr);
    for (i=0; i<s; i++) {
      int Row = f+i; Scalar v=-a[i];
      ierr = MatSetValues(fill,1,&Row,1,&Row,&v,ADD_VALUES); CHKERRQ(ierr);
    }
    ierr = MatAssemblyBegin(fill,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
    ierr = VecRestoreArray(v2,&a); CHKERRQ(ierr);
    ierr = MatAssemblyEnd(fill,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  }
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "AddModVector"
static int AddModVector
(MPI_Comm comm,SLES a11appx_solve,Mat g11,Mat g12,
 Vec g1,Vec g2,Vec h1,Vec k1)
{
  int its,ierr;
  Scalar one = 1.0, mone = -1.0;
  SLES a11_solve; KSP ksp; PC pc;

  PetscFunctionBegin;
  ierr = VecSet(&one,g2); CHKERRQ(ierr);
  ierr = SLESCreate(comm,&a11_solve); CHKERRQ(ierr);
  ierr = SLESSetOperators(a11_solve,g11,g11,SAME_PRECONDITIONER); CHKERRQ(ierr);
  ierr = SLESGetKSP(a11_solve,&ksp); CHKERRQ(ierr);
  ierr = KSPSetType(ksp,KSPCG/*KSPTFQMR*/); CHKERRQ(ierr);
  ierr = KSPSetTolerances(ksp,1.e-6,0.,2.,20); CHKERRQ(ierr);
  ierr = SLESGetPC(a11_solve,&pc); CHKERRQ(ierr);
  ierr = PCSetType(pc,PCJACOBI); CHKERRQ(ierr);

  /* g1 = A12 e */
  ierr = MatMult(g12,g2,g1); CHKERRQ(ierr);
  /* h1 = A11inv A12 e */
  ierr = SLESSolve(a11_solve,g1,h1,&its); CHKERRQ(ierr);
  ierr = SLESDestroy(a11_solve); CHKERRQ(ierr);
  /* k1 = A11approxinv A12 e */
  ierr = SLESSolve(a11appx_solve,g1,k1,&its); CHKERRQ(ierr);
  
  /* the diagonal to be added */
  ierr = VecAXPY(&mone,k1,h1); CHKERRQ(ierr);
  ierr = VecPointwiseDivide(h1,g1,k1); CHKERRQ(ierr);

  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "RSmake12"
static int RSmake12(MPI_Comm comm,double weight,Mat g11,Mat g12,Mat *b12)
{
  int irow,rows,cols,first,last,ierr;
  Vec v1; Scalar *d1;

  PetscFunctionBegin;
  ierr = MatDuplicate(g12,MAT_COPY_VALUES,b12); CHKERRQ(ierr);
  /*   printf("G12\n");  MatView(*b12,0);*/
  ierr = MatGetLocalSize(g12,&rows,&cols); CHKERRQ(ierr);
  ierr = MatGetOwnershipRange(g12,&first,&last); CHKERRQ(ierr);
  ierr = VecCreateMPI(comm,rows,PETSC_DECIDE,&v1); CHKERRQ(ierr);

  for (irow=0; irow<rows; irow++) {
    int row=first+irow,nFcols,nCcols,*Fcols,*Ccols,iFcol,dicol=-1;
    Scalar *Fvals,*Cvals,dival,weak=0;
    ierr = MatGetRow(g11,row,&nFcols,&Fcols,&Fvals); CHKERRQ(ierr);
    { /* make a permanent copy of the current row, because we have to
	 investigate the rows it's connected with in turn */
      int *tCcols; Scalar *tCvals;
      ierr = MatGetRow(g12,row,&nCcols,&tCcols,&tCvals); CHKERRQ(ierr);
      Ccols = (int*) malloc(nCcols*sizeof(int));
      memcpy(Ccols,tCcols,nCcols*sizeof(int));
      Cvals = (Scalar*) malloc(nCcols*sizeof(Scalar));
      memcpy(Cvals,tCvals,nCcols*sizeof(double));
      ierr = MatRestoreRow(g12,row,&nCcols,&tCcols,&tCvals); CHKERRQ(ierr);
    }
    for (iFcol=0; iFcol<nFcols; iFcol++) if (Fcols[iFcol]==row) dicol=iFcol;
    if (dicol<0) SETERRQ(1,1,"Could not find diagonal");
    dival = Fvals[dicol];
    for (iFcol=0; iFcol<nFcols; iFcol++)
      if (PetscAbsScalar(Fcols[iFcol])<weight*PetscAbsScalar(Fcols[dicol]))
	dival += Fvals[iFcol];
      else if (Fcols[iFcol]!=row) {
	int iCcol,icol,ncols,*cols,tcol; Scalar *vals,dia=0;
	ierr = MatGetRow(g12,Fcols[iFcol],&ncols,&cols,&vals); CHKERRQ(ierr);
	for (tcol=0; tcol<ncols; tcol++) /*
	  for (iCcol=0; iCcol<nCcols; iCcol++)
	  if (cols[tcol]==Ccols[iCcol])*/ dia -= vals[tcol];
	for (tcol=0; tcol<ncols; tcol++) /*
	  for (iCcol=0; iCcol<ncols; iCcol++)
	  if (cols[tcol]==Ccols[iCcol])*/ {
	      Scalar v=vals[tcol]/dia;
	      /*printf("add %e=%e/%e at %d,%d\n",v,vals[tcol],dia,row,cols[tcol]);*/
	      ierr = MatSetValues(*b12,1,&row,1,cols+tcol,&v,ADD_VALUES
				  ); CHKERRQ(ierr);
	    }
	ierr = MatRestoreRow(g12,Fcols[iFcol],&ncols,&cols,&vals); CHKERRQ(ierr);
      }
    ierr = MatRestoreRow(g11,row,&nFcols,&Fcols,&Fvals); CHKERRQ(ierr);
    if (dival<=0.) printf("A11 breakdown at %d\n",row);
    dival = 1./dival;
    ierr = VecSetValues(v1,1,&row,&dival,INSERT_VALUES); CHKERRQ(ierr);
  }
  ierr = MatAssemblyBegin(*b12,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  ierr = MatAssemblyEnd(*b12,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  /*    printf("RS\n");  MatView(*b12,0);*/
  ierr = MatDiagonalScale(*b12,v1,0); CHKERRQ(ierr);
  /*    printf("scaled\n");  MatView(*b12,0);*/
  ierr = VecDestroy(v1); CHKERRQ(ierr);

  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "NextLevelFill"
int NextLevelFill
(MPI_Comm comm,int level,
 AMLCoarseGridChoice grid_choice,AMLSchurChoice schur_choice,
 Mat g21,Mat l21,Mat a21,Mat *b12, IS indices1,IS indices2,IS g_indices2,
 SLES a11_solve,Mat g11,Vec cdiag1,Mat g12,Mat a12,Mat g22,
 Vec g1,Vec h1,Vec k1, Vec g2,Vec h2,Vec k2,
 Vec u1,Vec v1,Vec w1,Vec u2,int mod,int trace,double weight,Mat *fill)
{
  int ierr;

  PetscFunctionBegin;
  if (schur_choice == AMLSchurVariational) {
    Mat t1,t2,t3,t4; int i2size,j2size,res_lsize,idum;

    ierr = MatGetLocalSize(g22,&i2size,&j2size); CHKERRQ(ierr);

    /* c12 = M approx A11 inv A12 */
    /*
    ierr = RSmake12(comm,weight,g11,g12,b12); CHKERRQ(ierr);
    */
    if (trace & AMLTraceProgress) PetscPrintf(comm,"..solve a11inv a12\n");
    ierr = MatSolveMat_MPIAIJ(a11_solve,g12,g1,h1,b12); CHKERRQ(ierr);
    if (mod & AMLModScale) {
      if (trace & AMLTraceProgress) PetscPrintf(comm,"Scale mod vector\n");
      ierr = ScaleModVector(comm,*b12,g11,g12,g1,h1,k1,g2); CHKERRQ(ierr);
    }

    if (trace & AMLTraceProgress) PetscPrintf(comm,"..products\n");
    /* t1 = A11 M */
    ierr = MatMatMult_MPIAIJ(g11,*b12,&t1); CHKERRQ(ierr);
    /* t2 = A12 - A11 M  = A12 - t1 */
    ierr = MatMatSub_MPIAIJ(g12,t1,&t2); CHKERRQ(ierr);
    ierr = MatDestroy(t1); CHKERRQ(ierr);
    /* t4 = A21 M */
    ierr = MatMatMult_MPIAIJ(g21,*b12,&t4); CHKERRQ(ierr);
    ierr = MatGetLocalSize(t4,&res_lsize,&idum); CHKERRQ(ierr);
    /* t3 = Mt ( A12 - A11 M ) = Mt t2 */
    ierr = MatTMatMult_MPIAIJ(res_lsize,*b12,t2,&t3); CHKERRQ(ierr);
    ierr = MatDestroy(t2); CHKERRQ(ierr);

    ierr = MatCreateMPIAIJ(comm,i2size,j2size,PETSC_DECIDE,PETSC_DECIDE,
			   5,0,5,0,fill); CHKERRQ(ierr);

    /* fill = A21 M + Mt ( A12 - A11 M ) = t4 + t3 */
    ierr = MatMatAdd_MPIAIJ(t3,t4,MAT_REUSE_MATRIX,fill); CHKERRQ(ierr);
    ierr = MatDestroy(t3); CHKERRQ(ierr);
    ierr = MatDestroy(t4); CHKERRQ(ierr);
    if (mod & AMLModDiag) {
      if (trace & AMLTraceProgress) PetscPrintf(comm,"Diag mod vector\n");
      ierr = DiagModVector
	(comm,*fill,g21,g11,g12,g1,h1,k1,g2,h2,k2); CHKERRQ(ierr);
    }
  } else if (schur_choice == AMLSchurElimination) {
    if (grid_choice == AMLCoarseGridDependent) {
      Mat trans;
      ierr = MatConvert(g12,MATSAME,&trans); CHKERRQ(ierr);
      ierr = MatDiagonalScale(trans,cdiag1,0); CHKERRQ(ierr);
      ierr = MatMatMult_MPIAIJ(g21,trans,fill); CHKERRQ(ierr);
      ierr = MatDestroy(trans); CHKERRQ(ierr);
    } else {
      Mat ltrans,mfill;
      PetscPrintf
	(comm,"\n WARNING undebugged case elim&course grid independent\n\n");
      ierr = MatSolveMat_MPIAIJ(a11_solve,g12,g1,h1,&ltrans); CHKERRQ(ierr);
      ierr = MatMatMult_MPIAIJ(g21,ltrans,&mfill); CHKERRQ(ierr);
      ierr = MatDestroy(ltrans); CHKERRQ(ierr);
      if (mod) {
	Mat c12,xfill;
	/* mod vector in h1 */
	ierr = AddModVector(comm,a11_solve,g11,g12,g1,g2,h1,k1); CHKERRQ(ierr);
	ierr = MatConvert(g12,MATSAME,&c12); CHKERRQ(ierr);
	ierr = MatDiagonalScale(c12,k1,PETSC_NULL); CHKERRQ(ierr);
	ierr = MatMatMult_MPIAIJ(g21,c12,&xfill);
	ierr = MatDestroy(c12); CHKERRQ(ierr);
	ierr = MatMatAdd_MPIAIJ(mfill,xfill,MAT_INITIAL_MATRIX,fill);
	CHKERRQ(ierr);
	ierr = MatDestroy(mfill); CHKERRQ(ierr);
	ierr = MatDestroy(xfill); CHKERRQ(ierr);
      } else *fill = mfill;
    }
  } else SETERRQ(1,1,"Unknown Schur elimination type\n");
  PetscFunctionReturn(0);
}
