#include <stdio.h>
#include <math.h>
/* general Petsc includes */
#include "petsc.h"
#ifdef petsc29
#include "petscsles.h"
#include "petscmat.h"
#include "petscis.h"
#else
#include "sles.h"
#include "mat.h"
#include "is.h"
#endif
#include "./petsc_matrix_auxs.h"
#include "./petsc_impl.h"
#include "matrix_auxs.h"
#include "netsolve_petsc.h"

/* defined below */
int UserConverge(KSP ksp, int it, PetscReal rnorm,
		 KSPConvergedReason *reason,void *mctx);
int TrueMonitor(KSP ksp,int it,PetscReal cg_err,void *data);
int TrueConvergenceTest(KSP ksp,int it,PetscReal cg_err,
			KSPConvergedReason *outcome,PetscReal *data);

int petsc_iterative_solve
(MPI_Comm comm,
 double *mat_el,double *x_el,
 int first,int isize,int nnzero,
 int *idx,int *pointers,
 double *return_vec_el,
 double rtol,double atol,double dtol,int maxit,
 petsc_info_block info,
 int *conv,int *its
 );

int petsc_iterative_driver
(int Argc,char **Args, MPI_Comm comm,
 int first,int local_size,
 double *values,int *indices,int *pointers,
 double *rhs_vector,double *solution_vector,
 iterative_info_block info)
{
  petsc_info_block petsc_info = (petsc_info_block) info;
  int conv,its,nnz,ierr;

  ierr = PetscInitialize(&Argc,&Args,PETSC_NULL,PETSC_NULL); ERR_RETURN(ierr);
  ierr = crs_nnzeros(pointers,local_size,0,&nnz); ERR_RETURN(ierr);
  ierr = petsc_iterative_solve
    (comm,values,rhs_vector, first,local_size,nnz,
     indices,pointers, solution_vector,
     petsc_info->rtol,petsc_info->atol,petsc_info->dtol,petsc_info->maxit,
     petsc_info,
     &conv,&its); ERR_RETURN(ierr);
  iterative_set_return_params(info,conv,its);

  ierr = PetscFinalize(); ERR_RETURN(ierr);
  return 0;
}

int petsc_iterative_solve
(MPI_Comm comm,
 double *mv,double *xv, int first,int local_size,int nz,
 int *idx,int *pointers,
 double *yv,
 double rtol,double atol,double dtol,int maxit,
 petsc_info_block info,
 int *conv,int *its
 )
{
  Mat       A;
  Vec       X,V;
  int       ierr,p_its;
  Scalar    zero = 0.;

  ierr = VecCreateMPIWithArray(comm,local_size,PETSC_DECIDE,xv,&X); ERR_RETURN(ierr);
  ierr = VecCreateMPIWithArray(comm,local_size,PETSC_DECIDE,yv,&V); ERR_RETURN(ierr);
  ierr = VecSet(&zero,V); ERR_RETURN(ierr);

  ierr = petsc_matrix_from_crs
    (comm,first,local_size,mv,idx,pointers, &A); ERR_RETURN(ierr);

  {
    SLES method;

    ierr = SLESCreate(comm,&method); ERR_RETURN(ierr);
    ierr = SLESSetOperators(method,A,A,SAME_PRECONDITIONER); ERR_RETURN(ierr);
    {
      KSP ksp;
      ierr = SLESGetKSP(method,&ksp); ERR_RETURN(ierr);
      ierr = KSPSetType(ksp,KSPBCGS); ERR_RETURN(ierr);
      ierr = KSPSetTolerances(ksp,rtol,atol,dtol,maxit); ERR_RETURN(ierr);
      ierr = KSPSetMonitor
	(ksp,&KSPDefaultMonitor,(void*)PETSC_NULL,(void*)PETSC_NULL);
      ERR_RETURN(ierr);
      if (info->convtest) {
	ierr = MatNorm(A,NORM_INFINITY,&info->anorm); CHKERRQ(ierr);
	ierr = KSPSetConvergenceTest
	  (ksp,UserConverge,info); ERR_RETURN(ierr);}
      /*
      ierr = KSPSetMonitor
	(ksp,&TrueMonitor,(void*)PETSC_NULL,(void*)PETSC_NULL);
        ERR_RETURN(ierr);
      */
      /*
      {
	PetscReal res;
	ierr = KSPSetConvergenceTest
	  (ksp,&TrueConvergenceTest,(void*)&res); CHKERRQ(ierr);
      }
      */
    }
    {
      PC pc;
      ierr = SLESGetPC(method,&pc); ERR_RETURN(ierr);
      ierr = PCSetType(pc,PCJACOBI); ERR_RETURN(ierr);
    }
    ierr = SLESSolve(method,X,V,&p_its); ERR_RETURN(ierr);
    ierr = SLESDestroy(method); ERR_RETURN(ierr);
  }

  if (p_its<0) {
    *conv = 0; *its = -p_its;
  } else {
    *conv = 1; *its = p_its;}

  ierr = MatDestroy(A); ERR_RETURN(ierr);
  ierr = VecDestroy(X); ERR_RETURN(ierr);
  ierr = VecDestroy(V); ERR_RETURN(ierr);

  return 0;
}

int UserConverge(KSP ksp, int it, PetscReal rnorm,
		 KSPConvergedReason *reason,void *mctx)
{
  petsc_info_block info = (petsc_info_block) mctx;
  Vec v;
  double epsr,epsa,epsd;
  int maxit,ierr;
  PetscFunctionBegin;
  if (it==0) {
    Vec rhs;
    ierr = KSPGetRhs(ksp,&rhs); CHKERRQ(ierr);
    ierr = VecNorm(rhs,NORM_2,&info->bnorm); CHKERRQ(ierr);
    info->r0norm = rnorm;
    *reason = 0;}
  ierr = KSPGetTolerances(ksp,&epsr,&epsa,&epsd,&maxit); CHKERRQ(ierr);
  ierr = KSPGetSolution(ksp,&v); CHKERRQ(ierr);
  ierr = VecNorm(v,NORM_2,&info->xnorm); CHKERRQ(ierr);
  ierr = (info->convtest)
    (info->anorm,rnorm,info->r0norm,info->bnorm,info->xnorm,
     epsr,epsa,epsd, it,maxit,(int*)reason); CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

int TrueConvergenceTest(KSP ksp,int it,PetscReal cg_err,
			KSPConvergedReason *outcome,PetscReal *data)
{
  PC pc; Mat amat,pmat; MatStructure flag; Vec b,x,ax;
  MPI_Comm comm;
  Scalar mone = -1.; PetscReal res_err,rtol,atol,dtol;
  int maxit,ierr;

  ierr = PetscObjectGetComm((PetscObject)ksp,&comm); CHKERRQ(ierr);
  ierr = KSPGetTolerances(ksp,&rtol,&atol,&dtol,&maxit); CHKERRQ(ierr);
  ierr = KSPGetRhs(ksp,&b); CHKERRQ(ierr);
  ierr = KSPBuildSolution(ksp,PETSC_NULL,&x); CHKERRQ(ierr);
  ierr = VecDuplicate(x,&ax); CHKERRQ(ierr);
  ierr = KSPGetPC(ksp,&pc); CHKERRQ(ierr);
  ierr = PCGetOperators(pc,&amat,&pmat,&flag); CHKERRQ(ierr);
  ierr = MatMult(amat,x,ax); CHKERRQ(ierr);
  ierr = VecAXPY(&mone,b,ax); CHKERRQ(ierr);
  ierr = VecNorm(ax,NORM_2,&res_err); CHKERRQ(ierr);
  PetscPrintf(comm,"[%d] true error is %e\n",it,res_err);
  if (it==0) {
    *data = res_err;
    if (atol>0 && res_err<atol) *outcome=1; else *outcome=0;
  } else {
    PetscReal err0 = *data;
    if (atol>0 && res_err<atol) *outcome=1;
    else if (rtol>0 && res_err/err0<rtol) *outcome=1;
    else if (dtol>0 && res_err>dtol) *outcome=1;
    else if (it>maxit) *outcome=-1;
    else *outcome=0;
  }
  return 0;
}

int TrueMonitor(KSP ksp,int it,PetscReal cg_err,void *data)
{
  PC pc; Mat amat,pmat; MatStructure flag; Vec b,x,ax;
  Scalar mone = -1.; PetscReal res_err; int ierr;

  ierr = KSPGetRhs(ksp,&b); CHKERRQ(ierr);
  ierr = KSPBuildSolution(ksp,PETSC_NULL,&x); CHKERRQ(ierr);
  ierr = VecDuplicate(x,&ax); CHKERRQ(ierr);
  ierr = KSPGetPC(ksp,&pc); CHKERRQ(ierr);
  ierr = PCGetOperators(pc,&amat,&pmat,&flag); CHKERRQ(ierr);
  ierr = MatMult(amat,x,ax); CHKERRQ(ierr);
  ierr = VecAXPY(&mone,b,ax); CHKERRQ(ierr);
  ierr = VecNorm(ax,NORM_2,&res_err); CHKERRQ(ierr);
  printf("Confusius says: recursive error is %e; true error is %e\n",
	 cg_err,res_err);
  return 0;
}

