#include <stdlib.h>
#include "error_returns.h"
#include "parpre.h"
#ifdef petsc29
#include "petscsles.h"
#include "petscksp.h"
#include "petscmat.h"
#else
#include "sles.h"
#include "ksp.h"
#include "mat.h"
#endif
#include "mec.h"

#define PARPRE 1
#define SCALE 1

int ScaleSystem(Mat A,Vec *D_ret,double *dnorm)
{
  Vec D; MPI_Comm comm; double norm;
  int local_size,ierr,idum;

  ierr = PetscObjectGetComm((PetscObject)A,&comm); CHKERRQ(ierr);
  ierr = MatGetLocalSize(A,&local_size,&idum); CHKERRQ(ierr);
  ierr = VecCreateMPI(comm,local_size,PETSC_DECIDE,&D); CHKERRQ(ierr);

  /* get diagonal */
  ierr = MatGetDiagonal(A,D); CHKERRQ(ierr);
  ierr = VecNorm(D,NORM_2,&norm); CHKERRQ(ierr);
  *dnorm = norm;

  /* invert diagonal */
  {
    Scalar *d; int i;
    ierr = VecGetArray(D,&d); CHKERRQ(ierr);
    for (i=0; i<local_size; i++) {
      if (d[i]<=0.) SETERRQ(1,0,"Nonpositive diagonal element");
      d[i] = 1./sqrt(d[i]);
    }
    ierr = VecRestoreArray(D,&d); CHKERRQ(ierr);
  }

  /* scale */
  ierr = MatDiagonalScale(A,D,D); CHKERRQ(ierr);

  *D_ret = D;

  return 0;
}

int TrueConvergenceTest(KSP ksp,int it,PetscReal cg_err,
			KSPConvergedReason *outcome,PetscReal *data);
int RandGprint(KSP ksp,int it,PetscReal rn,void *context);

int MecSolveSystem
(Mat A,Vec rhs,Vec sol,
 int maxit,double reltol,double abstol,double divtol,int *its)
{
  SLES solver;
  Vec D; MPI_Comm comm; double scale_norm;
  int ierr;

  ierr = PetscObjectGetComm((PetscObject)A,&comm); CHKERRQ(ierr);

  {
    PetscReal norm;
    ierr = VecNorm(rhs,NORM_2,&norm); CHKERRQ(ierr);
    PetscPrintf(comm,"rhs norm %e\n",norm);
  }
  /* Scale the system and right hand side */
#if SCALE
  ierr = ScaleSystem(A,&D,&scale_norm); CHKERRQ(ierr);
#else
  PetscPrintf(comm,">> Scaling disabled <<\n");
#endif

  /* Create iterative method */
  ierr = SLESCreate(comm,&solver); CHKERRQ(ierr);
  ierr = SLESSetOperators(solver,A,A,0); CHKERRQ(ierr);
  {
    KSP itmeth; PC prec,local_prec;
    PetscReal res;

    /* Set the iterative method */
    ierr = SLESGetKSP(solver,&itmeth); CHKERRQ(ierr);
    ierr = KSPSetType(itmeth,KSPGMRES/*TFQMR*/); CHKERRQ(ierr);
    ierr = KSPSetMonitor
      (itmeth,KSPDefaultMonitor,PETSC_NULL,PETSC_NULL); CHKERRQ(ierr);
    /*ierr = KSPSetMonitor
      (itmeth,RandGprint,PETSC_NULL,PETSC_NULL); CHKERRQ(ierr);*/
    /*ierr = KSPSetConvergenceTest
      (ksp,&TrueConvergenceTest,(void*)&res); CHKERRQ(ierr);*/
    abstol = abstol/scale_norm;
    printf("setting tolerances %e,%e\n",reltol,abstol);
    ierr = KSPSetTolerances(itmeth,reltol,abstol,divtol,maxit); CHKERRQ(ierr);
    ierr = KSPSetInitialGuessNonzero(itmeth); CHKERRQ(ierr);

    /* Set the preconditioner */
    ierr = SLESGetPC(solver,&prec); CHKERRQ(ierr);
#if PARPRE
    ierr = PCSetType(prec,PCMultiplicativeSchwarz); CHKERRQ(ierr);
    ierr = PCParallelSubdomainPipelineSetType
      (prec,PIPELINE_MULTICOLOUR,(void*)A); CHKERRQ(ierr);
    ierr = PCParallelGetLocalPC(prec,&local_prec); CHKERRQ(ierr);
    ierr = PCSetType(local_prec,PCLU); CHKERRQ(ierr);
    /*ierr = PCILUSetLevels(local_prec,3); CHKERRQ(ierr);*/
    /*ierr = PCILUSetUseDropTolerance(local_prec,.01,.01,50); CHKERRQ(ierr);*/
#else
    ierr = PCSetType(prec,PCASM); CHKERRQ(ierr);
    ierr = PCASMSetOverlap(prec,1); CHKERRQ(ierr);
    ierr = SLESSetUp(solver,rhs,sol); CHKERRQ(ierr);
    {
      int n_local,first_local,isles; SLES *subsles; KSP subksp; PC subpc;
      ierr = PCASMGetSubSLES
	(prec,&n_local,&first_local,&subsles); CHKERRQ(ierr);
      for (isles=0; isles<n_local; isles++) {
	ierr = SLESGetKSP(subsles[isles],&subksp); CHKERRQ(ierr);
	ierr = KSPSetType(subksp,KSPPREONLY); CHKERRQ(ierr);
	ierr = SLESGetPC(subsles[isles],&subpc); CHKERRQ(ierr);
	ierr = PCSetType(subpc,PCLU); CHKERRQ(ierr);
      }
    }
    /*    ierr = PCASMSetType(prec,PC_ASM_BASIC); CHKERRQ(ierr);*/
#endif
  }

#if PARPRE
#else
  {
    IS domain,*is_list; int first,last; PC prec;
    ierr = SLESGetPC(solver,&prec); CHKERRQ(ierr);
    ierr = MatGetOwnershipRange(A,&first,&last); CHKERRQ(ierr);
    ierr = ISCreateStride
      (MPI_COMM_SELF,last-first,first,1,&domain); CHKERRQ(ierr);
    ierr = PCASMSetLocalSubdomains(prec,1,&domain); CHKERRQ(ierr);
  }
#endif

  ierr = VecCopy(rhs,sol); CHKERRQ(ierr);
  ierr = SLESSetUp(solver,rhs,sol); CHKERRQ(ierr);

  /* System solution */
  ierr = SLESSolve(solver,rhs,sol,its); CHKERRQ(ierr);

  /* quickie check on size of residual */
  /*{
    Vec ax; Scalar mone = -1.; PetscReal nrm;
    ierr = VecDuplicate(sol,&ax); CHKERRQ(ierr);
    ierr = MatMult(A,sol,ax); CHKERRQ(ierr);
    ierr = VecAXPY(&mone,rhs,ax); CHKERRQ(ierr);
    ierr = VecNorm(ax,NORM_2,&nrm); CHKERRQ(ierr);
    PetscPrintf(comm,"Residual norm: %e\n",nrm);
    }*/
  /* scale solution back */
#ifdef SCALE
  ierr = VecPointwiseMult(sol,D,sol); CHKERRQ(ierr);
#endif
  /* Clean up */
  ierr = SLESDestroy(solver); CHKERRQ(ierr);
#if SCALE
  ierr = VecDestroy(D); CHKERRQ(ierr);
#endif

  return 0;
}

int RandGprint(KSP ksp,int it,PetscReal rn,void *context)
{
  MPI_Comm comm;
  PC pc; Vec res,pres;
  int ierr;

  ierr = PetscObjectGetComm((PetscObject)ksp,&comm); CHKERRQ(ierr);

  ierr = KSPBuildResidual(ksp,PETSC_NULL,PETSC_NULL,&res); CHKERRQ(ierr);
  ierr = VecDuplicate(res,&pres); CHKERRQ(ierr);
  ierr = KSPGetPC(ksp,&pc); CHKERRQ(ierr);
  ierr = PCApply(pc,res,pres); CHKERRQ(ierr);
  ierr = VecView(res,0); CHKERRQ(ierr);
  ierr = VecView(pres,0); CHKERRQ(ierr);
  ierr = VecDestroy(res); CHKERRQ(ierr);
  ierr = VecDestroy(pres); CHKERRQ(ierr);

  PetscPrintf(comm,"More monitoring\n");
  return 0;
}
