#include <math.h>
#include "petsc.h"
#include "mpi.h"
#include "parpre_mat.h"
#include "sles.h"
#include "src/mat/matimpl.h"
#include "src/mat/parpre_matimpl.h"
#include "pc.h"
#include "./ml_impl.h"
#include "./ml_head.h"

extern int MakeStrongMatrix(MC_OneLevel_struct *this_level,Mat *strong_matrix);

#include "src/pc/utils/auxs.c"

#undef __FUNC__
#define __FUNC__ "SetupSubVectors"
static int SetupSubVectors(IS clr,IS rest,MC_OneLevel_struct *this_level)
{
  MPI_Comm comm = this_level->comm;
  int s1,s2; Scalar *arr;
  int ierr;

  /* sub vectors */
  ierr = ISGetSize(clr,&s1); CHKERRQ(ierr);
  ierr = ISGetSize(rest,&s2); CHKERRQ(ierr);
  /* local vectors */
  ierr = VecCreateSeq(MPI_COMM_SELF,s1,&(this_level->u1)); CHKERRQ(ierr);
  ierr = VecDuplicate(this_level->u1,&(this_level->v1)); CHKERRQ(ierr);
  ierr = VecDuplicate(this_level->u1,&(this_level->w1)); CHKERRQ(ierr);
  ierr = VecDuplicate(this_level->u1,&(this_level->cdiag1)); CHKERRQ(ierr);
  /* global vectors */
  ierr = VecCreateMPI(comm,s1,PETSC_DECIDE,&(this_level->g1));
  CHKERRQ(ierr);
  ierr = VecDuplicate(this_level->g1,&(this_level->h1)); CHKERRQ(ierr);
  ierr = VecDuplicate(this_level->g1,&(this_level->k1)); CHKERRQ(ierr);
  if (s2) {
    ierr = VecCreateSeq(MPI_COMM_SELF,s2,&(this_level->u2)); CHKERRQ(ierr);
    ierr = VecCreateMPI(comm,s2,PETSC_DECIDE,&(this_level->g2));
    CHKERRQ(ierr);
    ierr = VecDuplicate(this_level->u2,&(this_level->v2)); CHKERRQ(ierr);
    ierr = VecDuplicate(this_level->g2,&(this_level->h2)); CHKERRQ(ierr);
  }
  /* alias some local and global vectors. Hmmmmmmmmm ....... */
  ierr = VecGetArray(this_level->u1,&arr); CHKERRQ(ierr);
  ierr = VecPlaceArray(this_level->g1,arr); CHKERRQ(ierr);
  ierr = VecGetArray(this_level->v1,&arr); CHKERRQ(ierr);
  ierr = VecPlaceArray(this_level->h1,arr); CHKERRQ(ierr);
  ierr = VecGetArray(this_level->w1,&arr); CHKERRQ(ierr);
  ierr = VecPlaceArray(this_level->k1,arr); CHKERRQ(ierr);
  if (s2) {
    ierr = VecGetArray(this_level->u2,&arr); CHKERRQ(ierr);
    ierr = VecPlaceArray(this_level->g2,arr); CHKERRQ(ierr);
    ierr = VecGetArray(this_level->v2,&arr); CHKERRQ(ierr);
    ierr = VecPlaceArray(this_level->h2,arr); CHKERRQ(ierr);
  }
  /* gather/scatter to sub vectors */
  {
    IS contig_1;
    ierr = ISCreateStride(MPI_COMM_SELF,s1,0,1,&contig_1); CHKERRQ(ierr);
    ierr = VecScatterCreate(this_level->u,clr,this_level->u1,contig_1,
			    &(this_level->get_clr));
    CHKERRQ(ierr);
    ierr = VecScatterCreate(this_level->u1,contig_1,this_level->u,clr,
			    &(this_level->put_clr));
    CHKERRQ(ierr);
    ierr = ISDestroy(contig_1); CHKERRQ(ierr);
    if (s2) {
      IS contig_2;
      ierr = ISCreateStride(MPI_COMM_SELF,s2,0,1,&contig_2); CHKERRQ(ierr);
      ierr = VecScatterCreate(this_level->u,rest,this_level->u2,contig_2,
			      &(this_level->get_rest));
      CHKERRQ(ierr);
      ierr = VecScatterCreate(this_level->u2,contig_2,this_level->u,rest,
			      &(this_level->put_rest));
      CHKERRQ(ierr);
      ierr = ISDestroy(contig_2); CHKERRQ(ierr);
    }
  }

  /* use the scatter contexts to get the diagonal on the 1-block */
  ierr = VecScatterBegin
    (this_level->cdiag,this_level->cdiag1,
     INSERT_VALUES,SCATTER_FORWARD,this_level->get_clr);
  CHKERRQ(ierr);
  ierr = VecScatterEnd
    (this_level->cdiag,this_level->cdiag1,
     INSERT_VALUES,SCATTER_FORWARD,this_level->get_clr);
  CHKERRQ(ierr);

  return 0;
}

#undef __FUNC__
#define __FUNC__ "ExtractBlocks"
static int ExtractBlocks
(IS clr,IS rest,MC_OneLevel_struct *this_level,Mat smat)
{
  Mat mat = this_level->mat;
  Mat *res_mat;
  MPI_Comm comm = mat->comm;
#define STRNG 4
#define ORIG 2
  IS grest,gclr,i_is[STRNG+ORIG],j_is[STRNG+ORIG];
  int loc=this_level->size1,local_next=this_level->size2,ierr;

  ierr = ISGetGlobalContent(comm,clr,&gclr); CHKERRQ(ierr);
  ierr = ISGetGlobalContent(comm,rest,&grest); CHKERRQ(ierr); 

  i_is[0] = clr;   i_is[1] = rest; i_is[2] = grest; i_is[3] = rest;
  j_is[0] = grest; j_is[1] = gclr; j_is[2] = clr;   j_is[3] = clr;

  i_is[STRNG] = clr;  i_is[STRNG+1] = rest;
  j_is[STRNG] = gclr; j_is[STRNG+1] = grest;

  /* get the off-diagonal subblocks from the strong matrix */
  ierr = MatGetSubMatrices(smat,STRNG,i_is,j_is,MAT_INITIAL_MATRIX,
			   &res_mat); CHKERRQ(ierr);
  ierr = MatrixAij2MpiAbut(local_next,res_mat[0],comm,&(this_level->g12));
  CHKERRQ(ierr);
  this_level->a12 = res_mat[0];
  ierr = MatrixAij2MpiAbut(loc,res_mat[1],comm,&(this_level->g21));
  CHKERRQ(ierr);
  ierr = MatDestroy(res_mat[1]); CHKERRQ(ierr);
  this_level->a21 = res_mat[2];
  this_level->l21 = res_mat[3];
  PetscFree(res_mat);

  /* get diagonal blocks from the original matrix */
  ierr = MatGetSubMatrices(mat,ORIG,i_is+STRNG,j_is+STRNG,MAT_INITIAL_MATRIX,
			   &res_mat); CHKERRQ(ierr);
  ierr = MatrixAij2MpiAbut(loc,res_mat[0],comm,&(this_level->g11));
  CHKERRQ(ierr);
  ierr = MatDestroy(res_mat[0]); CHKERRQ(ierr);
  ierr = MatrixAij2MpiAbut(local_next,res_mat[1],comm,&(this_level->g22));
  CHKERRQ(ierr);
  ierr = MatDestroy(res_mat[1]); CHKERRQ(ierr);
  PetscFree(res_mat);

  ierr = ISDestroy(gclr); CHKERRQ(ierr);
  ierr = ISDestroy(grest); CHKERRQ(ierr);

  return 0;
}

int PartitionMatrix(MC_OneLevel_struct *this_level,int make_strong,
		    IS *set2g,int *early_return,int *local_rest)
{
  Mat strong_mat; IS set1,set2;
  int lvl = this_level->level;
  int global_r, ierr;
  
  if (make_strong) {
    ierr = MakeStrongMatrix(this_level,&strong_mat); CHKERRQ(ierr);
  } else {
    strong_mat = this_level->mat;
  }

  /* split this level into F and C */
  ierr = SplitOffOneLevel
    (this_level,strong_mat,&set1,&set2,set2g,&global_r,local_rest);
  CHKERRQ(ierr);
  /*  if (*local_rest==0) return 0;*/
  if ( (global_r==0) & (lvl>0) ) {
    printf("Whole matrix was independent\n");
    *early_return = 1;
  } else
    *early_return = 0;
  
  /* analyse this level, get subblocks and all that */
  ierr = VecReciprocal(this_level->diag); CHKERRQ(ierr);
  if (make_strong) {
    ierr = VecDuplicate(this_level->diag,&(this_level->cdiag)); CHKERRQ(ierr);
    ierr = MatGetDiagonal(strong_mat,this_level->cdiag); CHKERRQ(ierr);
    ierr = VecReciprocal(this_level->cdiag); CHKERRQ(ierr);
  } else this_level->cdiag = this_level->diag;
  ierr = SetupSubVectors(set1,set2,this_level); CHKERRQ(ierr);
  
  /* sub blocks */
  if (*early_return) {
    this_level->g11 = this_level->mat;
    this_level->g22 = 0;
  } else {
    ierr = ExtractBlocks(set1,set2,this_level,strong_mat); CHKERRQ(ierr);
  }
  ierr = ISDestroy(set1); CHKERRQ(ierr);
  ierr = ISDestroy(set2); CHKERRQ(ierr);
  if (make_strong) {
    ierr = MatDestroy(strong_mat); CHKERRQ(ierr);
  }

  return 0;
}

#undef __FUNC__
#define __FUNC__ "NextMatFill"
static int NextMatFill(MC_OneLevel_struct *this_level,Mat fill,Mat orig,
		       Mat *res_mat)
{
  Mat res;
  int rstart,rend,loc,ierr;
  MPI_Comm comm = this_level->comm;
  AMLFillMethod meth = this_level->fill_method;
    
  if (meth == AMLFillNone) SETERRQ(1,0,"NextMatFill: should have been caught");

  ierr = MatGetOwnershipRange(fill,&rstart,&rend); CHKERRQ(ierr);
  loc = rend-rstart;

  /*
  ierr = MatConvert(orig,MATSAME,&res); CHKERRQ(ierr);
  */
  {
    int ncols,Row,*band;
    band = (int *) PetscMalloc((loc+1)*sizeof(int)); CHKPTRQ(band);
    for (Row=rstart; Row<rend; Row++) {
      ierr = MatGetRow(orig,Row,&ncols,PETSC_NULL,PETSC_NULL); CHKERRQ(ierr);
      band[Row-rstart] = ncols;
      ierr = MatRestoreRow(orig,Row,&ncols,PETSC_NULL,PETSC_NULL);
      CHKERRQ(ierr);
    }
    ierr = MatCreateMPIAIJ
      (comm,loc,loc,PETSC_DECIDE,PETSC_DECIDE,0,band,0,band,&res);
    CHKERRQ(ierr);
    PetscFree(band);
  }
  {
    int Row,ncols,*cols; Scalar *vals;
    for (Row=rstart; Row<rend; Row++) {
      ierr = MatGetRow(orig,Row,&ncols,&cols,&vals); CHKERRQ(ierr);
      ierr = MatSetValues(res,1,&Row,ncols,cols,vals,ADD_VALUES);
      CHKERRQ(ierr);
      ierr = MatRestoreRow(orig,Row,&ncols,&cols,&vals); CHKERRQ(ierr);
    }
  }
  if (meth == AMLFillDiag) {
    Vec d; Scalar *v; int iRow;
    /* extract only the diagonal */
    ierr = VecCreateMPI(comm,loc,PETSC_DECIDE,&d); CHKERRQ(ierr);
    ierr = MatGetDiagonal(fill,d); CHKERRQ(ierr);
    ierr = VecGetArray(d,&v); CHKERRQ(ierr);
    
    for (iRow=0; iRow<loc; iRow++) {
      int Row=rstart+iRow; Scalar val=-v[iRow];
      ierr = MatSetValues(res,1,&Row,1,&Row,&val,ADD_VALUES);
      CHKERRQ(ierr);
    }
    ierr = VecRestoreArray(d,&v); CHKERRQ(ierr);
    ierr = VecDestroy(d); CHKERRQ(ierr);
  } else if (meth == AMLFillStrong) {
    int iRow;
    for (iRow=0; iRow<loc; iRow++) {
      int Row=iRow+rstart,ncols,*cols,iCol,mc=-1;
      Scalar *vals, maxval = 0.0, sumval = 0.0;
      ierr = MatGetRow(fill,Row,&ncols,&cols,&vals); CHKERRQ(ierr);
      for (iCol=0; iCol<ncols; iCol++) {
	Scalar val,aval;
	val = vals[iCol]; aval = fabs(val);
	if (cols[iCol]==Row) {
	  Scalar v=-aval;
	  ierr = MatSetValues(res,1,&Row,1,&Row,&v,ADD_VALUES);
	  CHKERRQ(ierr);
	} else {
	  if (aval > maxval) {maxval = aval; mc = iCol;}
	  sumval += aval;
	}
      }
      if (mc>-1){
	Scalar avg = sumval/ncols;
	for (iCol=0; iCol<ncols; iCol++) {
	  Scalar tst = fabs(vals[iCol]);
	  Scalar v=- /*0.5* */tst;
	  if (cols[iCol]!=Row)
	    if (tst>avg) {
	      /*printf("fill %e ",v);*/
	      ierr = MatSetValues(res,1,&Row,1,cols+iCol,&v,ADD_VALUES);
	      CHKERRQ(ierr);
	      /*
		ierr = MatSetValues(res,1,cols+iCol,1,&Row,&v,ADD_VALUES);
		CHKERRQ(ierr);
		*/
	    } else {
	      /*printf("move %e ",v);*/
	      ierr = MatSetValues(res,1,&Row,1,&Row,&v,ADD_VALUES);
	      CHKERRQ(ierr);
	      /*
		ierr = MatSetValues(res,1,cols+iCol,1,cols+iCol,&v,ADD_VALUES);
		CHKERRQ(ierr);
		*/
	    }
	}/*printf("\n");*/
      }
      ierr = MatRestoreRow(fill,iRow,&ncols,&cols,&vals); CHKERRQ(ierr);
    }
  } else if (meth == AMLFillFull) {
    int iRow;
    for (iRow=0; iRow<loc; iRow++) {
      int Row=iRow+rstart,ncols,*cols,iCol; Scalar *vals;
      ierr = MatGetRow(fill,Row,&ncols,&cols,&vals); CHKERRQ(ierr);
      for (iCol=0; iCol<ncols; iCol++) {
	Scalar val = -vals[iCol]/*-fabs(.95*vals[iCol])*/; int Col=cols[iCol];
	ierr = MatSetValues(res,1,&Row,1,&Col,&val,ADD_VALUES);
	CHKERRQ(ierr);
      }
      ierr = MatRestoreRow(fill,iRow,&ncols,&cols,&vals); CHKERRQ(ierr);
    }
  } else SETERRQ(1,0,"NextMatFill: unknown method");
  ierr = MatAssemblyBegin(res,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  ierr = MatAssemblyEnd(res,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);

  *res_mat = res;

  return 0;
}

#undef __FUNC__
#define __FUNC__ "Set11Solver"
int Set11Solver(MC_OneLevel_struct *this_level,SLES *a11_solve)
{
  int ierr;
  
  /* define the (1,1) inverter */
  ierr = ParPreGenerateSLES(this_level->comm,a11_solve);
  CHKERRQ(ierr);
  ierr = ParPreSetupSLES
    (*a11_solve,this_level->local_pctype,
     this_level->g11,this_level->g1);
  CHKERRQ(ierr);

  return 0;
}

#undef __FUNC__
#define __FUNC__ "MakeModVector"
static int MakeModVector
(MPI_Comm comm,SLES a11_solve,Mat g11,Mat g12,
 Vec g1,Vec g2,Vec h1,Vec k1)
{
  int its,ierr;
  Scalar one = 1.0, mone = -1.0;
  SLES full_solve; KSP ksp; PC pc;

  ierr = VecSet(&one,g2); CHKERRQ(ierr);
  ierr = MatMult(g12,g2,g1); CHKERRQ(ierr);
  ierr = SLESSolve(a11_solve,g1,h1,&its); CHKERRQ(ierr);

  ierr = SLESCreate(comm,&full_solve); CHKERRQ(ierr);
  ierr = SLESSetOperators(full_solve,g11,g11,SAME_PRECONDITIONER);
  CHKERRQ(ierr);
  
  ierr = SLESGetKSP(full_solve,&ksp); CHKERRQ(ierr);
  ierr = KSPSetType(ksp,KSPCG/*KSPTFQMR*/); CHKERRQ(ierr);
  ierr = KSPSetTolerances(ksp,1.e-6,0.,2.,20); CHKERRQ(ierr);
  
  ierr = SLESGetPC(full_solve,&pc); CHKERRQ(ierr);
  ierr = PCSetType(pc,PCJACOBI); CHKERRQ(ierr);
  
  ierr = SLESSetUp(full_solve,g1,k1); CHKERRQ(ierr);
  ierr = SLESSolve(full_solve,g1,k1,&its); CHKERRQ(ierr);
  /*printf("exact solve A12e in %d\n",its);VecView(k1,0);*/
  ierr = SLESDestroy(full_solve); CHKERRQ(ierr);
  
  ierr = VecAXPY(&mone,k1,h1); CHKERRQ(ierr);
  ierr = VecPointwiseDivide(h1,g1,k1); CHKERRQ(ierr);

  return 0;
}

#undef __FUNC__
#define __FUNC__ "MakeTrans12"
static int MakeTrans12(MC_OneLevel_struct *this_level,Mat *trans_mat)
{
  Mat ltrans,modtrans;
  int ierr;

  ierr = MatSolveMat_AIJ
    (this_level->a11_solve,this_level->a12,&ltrans);
  CHKERRQ(ierr);

  ierr = MakeModVector
    (this_level->comm,this_level->a11_solve,this_level->g11,
     this_level->g12,
     this_level->g1,this_level->g2,this_level->h1,this_level->k1);
  CHKERRQ(ierr);
  ierr = MatDiagonalScale(this_level->a12,this_level->w1,PETSC_NULL);
  CHKERRQ(ierr);
  {
    int isize,jsize,Row,ncols,*cols; Scalar *vals;
    ierr = MatGetSize(this_level->a12,&isize,&jsize); CHKERRQ(ierr);
    ierr = MatCreateSeqAIJ(MPI_COMM_SELF,isize,jsize,0,0,&modtrans);
    CHKERRQ(ierr);
    for (Row=0; Row<isize; Row++) {
      ierr = MatGetRow(ltrans,Row,&ncols,&cols,&vals); CHKERRQ(ierr);
      ierr = MatSetValues(modtrans,1,&Row,ncols,cols,vals,ADD_VALUES);
      CHKERRQ(ierr);
      ierr = MatRestoreRow(ltrans,Row,&ncols,&cols,&vals); CHKERRQ(ierr);
      ierr = MatGetRow(this_level->a12,Row,&ncols,&cols,&vals); CHKERRQ(ierr);
      ierr = MatSetValues(modtrans,1,&Row,ncols,cols,vals,ADD_VALUES);
      CHKERRQ(ierr);
      ierr = MatRestoreRow(this_level->a12,Row,&ncols,&cols,&vals);
      CHKERRQ(ierr);
    }
    ierr = MatAssemblyBegin(modtrans,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
    ierr = MatAssemblyEnd(modtrans,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  }
  ierr = MatDestroy(ltrans); CHKERRQ(ierr);

  /* just to make sure we don't use this and think it is what it isn't */
  ierr = MatDestroy(this_level->a12); CHKERRQ(ierr);

  *trans_mat = modtrans;
  return 0;
}

#undef __FUNC__
#define __FUNC__ "CreateNextSystem"
int CreateNextSystem(MC_OneLevel_struct *this_level,Mat *schur)
{
  Mat fill;
  int flg;
  int ierr;

  /* eliminate and form the next system */
  ierr = MultilevelHasDiagonall11(this_level,&flg); CHKERRQ(ierr);
  if (flg) {
    Mat trans;
    ierr = MatConvert(this_level->g12,MATSAME,&trans); CHKERRQ(ierr);
    ierr = MatDiagonalScale(trans,this_level->cdiag1,0); CHKERRQ(ierr);
    ierr = MatMatMult_MPIAIJ(this_level->g21,trans,&fill); CHKERRQ(ierr);
    ierr = MatDestroy(trans); CHKERRQ(ierr);
  } else {
    Mat ltrans,lfill;
    int onesize,d;
    ierr = MakeTrans12(this_level,&ltrans); CHKERRQ(ierr);
    ierr = MatMatMult_AIJ(this_level->l21,ltrans,&lfill); CHKERRQ(ierr);
    ierr = MatDestroy(ltrans); CHKERRQ(ierr);
    ierr = MatGetLocalSize(this_level->l21,&onesize,&d); CHKERRQ(ierr);
    ierr = MatrixAij2MpiAbut(onesize,lfill,this_level->comm,&fill);
    CHKERRQ(ierr);
    ierr = MatDestroy(lfill); CHKERRQ(ierr);
  }
  ierr = NextMatFill
    (this_level,fill,this_level->g22,schur);
  CHKERRQ(ierr);
  /*MatView(*schur,0);*/
  ierr = MatDestroy(fill); CHKERRQ(ierr);

  return 0;
}
