/*
   Defines an Processor Quot preconditioner for any Mat implementation
*/
#include "is.h"
#include "src/mat/impls/aij/mpi/mpiaij.h"
#include "src/mat/impls/aij/seq/aij.h"
#include "src/mat/parpre_matimpl.h"
#include "src/vec/vecimpl.h"
#include "src/pc/pcimpl.h"
#include "parpre_mat.h"
#include "parpre_pipeline.h"
#include "parpre_subdomains.h"
#include "parpre_pc.h"
#include "sles.h"
#include "options.h"

#define CHUNCKSIZE   100

extern int PCCustomPipelineSetFromOptions(PC pc);
extern int MatTranspose_AIJ(Mat a, Mat *b);
extern int ParPreTraceBackErrorHandler
    (int,char*,char*,char*,int,int,char*,void*);

typedef struct {
  PCParallelSubdomainStruct subdomains;
  int global_factorisation;
  Vec local1,local2,border_vec,global_vec,global_vec2;
  FillMethod fillmethod;
} PC_GenBlockSSOR_struct;

/****************************************************************
 * User Interface                                               *
 ****************************************************************/

#undef __FUNC__
#define __FUNC__ "PCGenBlockSSORSetNoGlobalFactorisation"
int PCGenBlockSSORSetNoGlobalFactorisation(PC pc)
{
  PC_GenBlockSSOR_struct *pc_data = (PC_GenBlockSSOR_struct *) pc->data;
  PetscFunctionBegin;
  pc_data->global_factorisation = 0;
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "PCGenBlockSSORSetGlobalFactorisation"
int PCGenBlockSSORSetGlobalFactorisation(PC pc)
{
  PC_GenBlockSSOR_struct *pc_data = (PC_GenBlockSSOR_struct *) pc->data;
  PetscFunctionBegin;
  pc_data->global_factorisation = 1;
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "PCGenBlockSSORSetFillMethod"
int PCGenBlockSSORSetFillMethod(PC pc,FillMethod fillmethod)
{
  PC_GenBlockSSOR_struct *pc_data = (PC_GenBlockSSOR_struct *) pc->data;
  PetscFunctionBegin;
  pc_data->fillmethod = fillmethod;
  PetscFunctionReturn(0);
}

/****************************************************************
 * Local Routines                                               *
 ****************************************************************/

#undef __FUNC__
#define __FUNC__ "BlockFillDiag"
static int BlockFillDiag(Mat dia_mat,Mat fill_band,FillMethod fill,
			 int r0,int rs,Mat *fill_dia)
{ 
  Mat fill_block;
  Mat *res;
  IS rows_abs,rows_rel;
  int ierr;

  /* extract relevant part */
  ierr = ISCreateStride
    (MPI_COMM_SELF,rs,0,1,&rows_rel); CHKERRQ(ierr);
  ierr = ISCreateStride
    (MPI_COMM_SELF,rs,r0,1,&rows_abs); CHKERRQ(ierr);

  ierr = MatGetSubMatrices
    (fill_band,1,&rows_rel,&rows_abs,MAT_INITIAL_MATRIX,&res);
  CHKERRQ(ierr);

  ierr = ISDestroy(rows_abs); CHKERRQ(ierr);
  ierr = ISDestroy(rows_rel); CHKERRQ(ierr);

  /* fill in */
  fill_block = res[0]; PetscFree(res);
  ierr = MatMatSubtract_AIJ(dia_mat,fill_block,fill_dia); CHKERRQ(ierr);
  ierr = MatDestroy(fill_block); CHKERRQ(ierr);

  return 0;
}

#undef __FUNC__
#define __FUNC__ "BlockSetPrec"
static int BlockSetPrec(PC pc,Mat fill_dia)
{
  Mat_SeqAIJ *Aij = (Mat_SeqAIJ *)fill_dia->data;
  Vec pc_vec; int ierr;
  
  ierr = VecCreateSeq(MPI_COMM_SELF,Aij->n,&pc_vec);
  CHKERRQ(ierr);
  /*
printf("subdomains system:\n"); MatView(fill_dia,0);
*/
  ierr = PCParallelSetSubdomainSystem(pc,fill_dia,pc_vec); CHKERRQ(ierr);
  ierr = VecDestroy(pc_vec); CHKERRQ(ierr);
  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCSetup_GBSSOR"
int PCSetup_GBSSOR(PC pc)
{
  Mat base_mat = pc->pmat;
  PC_GenBlockSSOR_struct *pc_data = (PC_GenBlockSSOR_struct *) pc->data;
  int ierr,np;

  ierr = PetscPushErrorHandler
    (&ParPreTraceBackErrorHandler,0);
  CHKERRQ(ierr);

  MPI_Comm_size(pc->comm,&np);
  if (np==1) {
    ierr = BlockSetPrec(pc,base_mat); CHKERRQ(ierr);
    return 0;
  }
  {
    VecPipeline main_pipe;
    PipelineType pipe_type;
    PetscObject pipe_obj;

    ierr = PCParallelSubdomainPipelineGetType(pc,&pipe_type); CHKERRQ(ierr);
    ierr = PCParallelSubdomainPipelineGetObject(pc,&pipe_obj); CHKERRQ(ierr);
    if (pipe_type<PIPELINE_NONE) SETERRQ(1,(int)pipe_type,"Invalid pipe type");
    ierr = MatCreateScatterPipeline
      (base_mat,pipe_type,pipe_obj,&(pc_data->border_vec),&main_pipe);
    CHKERRQ(ierr);
    ierr = PCParallelSubdomainsSetPipeline(pc,main_pipe); CHKERRQ(ierr);
  }

  /* create local vectors */
  {
    int len,dum; Scalar *a;
    ierr = MatGetLocalSize(base_mat,&len,&dum); CHKERRQ(ierr);
    ierr = VecCreateSeq(MPI_COMM_SELF,len,&(pc_data->local1)); CHKERRQ(ierr);
    ierr = VecGetArray(pc_data->local1,&a); CHKERRQ(ierr);
    ierr = VecCreateMPIWithArray
      (base_mat->comm,len,PETSC_DECIDE,a,&(pc_data->global_vec));
    CHKERRQ(ierr);
    ierr = VecRestoreArray(pc_data->local1,&a); CHKERRQ(ierr);

    ierr = VecDuplicate(pc_data->local1,&(pc_data->local2)); CHKERRQ(ierr);
    ierr = VecGetArray(pc_data->local2,&a); CHKERRQ(ierr);
    ierr = VecCreateMPIWithArray
      (base_mat->comm,len,PETSC_DECIDE,a,&(pc_data->global_vec2));
    CHKERRQ(ierr);
    ierr = VecRestoreArray(pc_data->local2,&a); CHKERRQ(ierr);
  }

  /* Determine local coefficient matrix */
  {
    Mat_MPIAIJ *Aij = (Mat_MPIAIJ *) base_mat->data;
    Mat dia_mat = Aij->A,off_mat = Aij->B, fill_mat,get_mat,fill_dia;
    MatGatherCtx ectx;
    Mat_SeqAIJ *off = (Mat_SeqAIJ *) off_mat->data;

    if (pc_data->global_factorisation) {
      /* Set up the row elimination structure */
      Mat fill_band;

      /* Create the preconditioner matrix as copy of the original */
      ierr = MatConvert(base_mat,MATSAME,&fill_mat); CHKERRQ(ierr);
      ierr = MatSetOption(fill_mat,MAT_COLUMNS_SORTED); CHKERRQ(ierr);
      /*ierr = MatSetOption(fill_mat,NO_NEW_NONZERO_LOCATIONS);*/
      {
	IS wanted;
	VecPipeline main_pipe;
      
	ierr = ISCreateGeneral(MPI_COMM_SELF,off->n,Aij->garray,&wanted);
	CHKERRQ(ierr);
	ierr = MatGatherCtxCreate(fill_mat,wanted,&ectx); CHKERRQ(ierr);
	ierr = ISDestroy(wanted); CHKERRQ(ierr);
	ierr = PCParallelSubdomainsGetPipeline(pc,&main_pipe);
	ierr = VecPipelineSetup(main_pipe); CHKERRQ(ierr);
	ectx->vs = main_pipe;
      }
      ierr = MatGatherRowsPipelineBegin(fill_mat,PIPELINE_UP,ectx,&get_mat);
      CHKERRQ(ierr);
      /* create the fill */
      ierr = MatMatMult_AIJ(off_mat,get_mat,&fill_band); CHKERRQ(ierr);
      /*
printf("fill in\n");MatView(fill_band,0);
*/
      /* Compute a new diagonal block by filling old rows 
       * onto an old diagonal block */
      ierr = BlockFillDiag(dia_mat,fill_band,pc_data->fillmethod,
			   Aij->rstart,Aij->m,&fill_dia);
      CHKERRQ(ierr);

      ierr = MatDestroy(fill_band); CHKERRQ(ierr);
    }
    else
      fill_dia = dia_mat;

    /* Store the new diagonal block as preconditioner operator */
    ierr = BlockSetPrec(pc,fill_dia); CHKERRQ(ierr);

    if (pc_data->global_factorisation) {
      Mat_MPIAIJ *Cij = (Mat_MPIAIJ *) fill_mat->data;
      Mat fill_off = Cij->B, fill_solve;
      SLES local_method;

      /* multiply the off-diagonal block */
      ierr = PCParallelGetLocalSLES(pc,&local_method); CHKERRQ(ierr);
      ierr = MatSolveMat_AIJ(local_method,fill_off,&fill_solve); CHKERRQ(ierr);

      ierr = MatDestroy(fill_off); CHKERRQ(ierr);
      Cij->B = fill_solve;
      ierr = MatGatherRowsPipelineEnd(fill_mat,PIPELINE_UP,ectx,&get_mat);
      CHKERRQ(ierr);
      ierr = MatDestroy(get_mat); CHKERRQ(ierr);
      ierr = MatDestroy(fill_mat); CHKERRQ(ierr);
      ierr = MatGatherCtxDestroy(ectx); CHKERRQ(ierr);
    }
  }
  ierr = PetscPopErrorHandler(); CHKERRQ(ierr);

  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCApply_GBSSOR"
static int PCApply_GBSSOR(PC pc,Vec x,Vec y)
{
  PC_GenBlockSSOR_struct *pc_data = (PC_GenBlockSSOR_struct *) pc->data;
  Mat base_mat = pc->mat;
  Mat_MPIAIJ *Aij = (Mat_MPIAIJ *) (base_mat->data);
  Mat off_diag = Aij->B;
  int ierr,its;
  Vec local1 = pc_data->local1, local2 = pc_data->local2;
  SLES local_method;
  PipelineType pipe_type;

  ierr = PetscPushErrorHandler(&ParPreTraceBackErrorHandler,0);
  CHKERRQ(ierr);

  ierr = PCParallelGetLocalSLES(pc,&local_method); CHKERRQ(ierr);
  ierr = PCParallelSubdomainPipelineGetType(pc,&pipe_type); CHKERRQ(ierr);
  if (pipe_type==PIPELINE_NONE) {

    /* >>>> solve in parallel <<<< */

    Scalar *ps,*qs,*xa,*ya;
    /* redirect input */
    ierr = VecGetArray(local1,&ps); CHKERRQ(ierr);
    ierr = VecGetArray(x,&xa); CHKERRQ(ierr);
    ierr = VecPlaceArray(local1,xa); CHKERRQ(ierr);
    /* redirect output */
    ierr = VecGetArray(local2,&qs); CHKERRQ(ierr);
    ierr = VecGetArray(y,&ya); CHKERRQ(ierr);
    ierr = VecPlaceArray(local2,ya); CHKERRQ(ierr);
    /* solve */
    ierr = SLESSolve(local_method,local1,local2,&its); CHKERRQ(ierr);
    /* undirect */
    ierr = VecPlaceArray(local1,ps); CHKERRQ(ierr);
    ierr = VecPlaceArray(local2,qs); CHKERRQ(ierr);
  } else {

    /* >>>> solve in a pipeline <<<< */

    Vec global1 = pc_data->global_vec, global2 = pc_data->global_vec2;
    Vec border = pc_data->border_vec;
    VecPipeline main_pipe;
    Scalar zero = 0.0, mone = -1.0, one = 1.0;
    int flg;

    ierr = PCParallelSubdomainsGetPipeline(pc,&main_pipe); CHKERRQ(ierr);

    /* forward solve from global1 into global2,
       where first global1=x-off*border */
    ierr = VecSet(&zero,border); CHKERRQ(ierr);
    ierr = VecPipelineBegin
      (global2,border, INSERT_VALUES, SCATTER_FORWARD,PIPELINE_UP,main_pipe);
    CHKERRQ(ierr);
    
    ierr = MatMultAXBY_AIJ(mone,off_diag,border,one,x,global1); CHKERRQ(ierr);

    /* solve */
    ierr = SLESSolve(local_method,local1,local2,&its); CHKERRQ(ierr);
    
    ierr = VecPipelineEnd
      (global2,border, INSERT_VALUES, SCATTER_FORWARD,PIPELINE_UP,main_pipe);
    CHKERRQ(ierr);

    /* backward solve from global1 into y */
    ierr = VecSet(&zero,border); CHKERRQ(ierr);
    ierr = VecPipelineBegin
      (global2,border,INSERT_VALUES,SCATTER_FORWARD,PIPELINE_DOWN,main_pipe);
    CHKERRQ(ierr);
    
    ierr = VecPipelineIsStartOfPipe(main_pipe,PIPELINE_DOWN,&flg);
    CHKERRQ(ierr);
    if (!flg) {
      ierr = MatMultAXBY_AIJ(mone,off_diag,border,one,x,global1);
      CHKERRQ(ierr);
    
      /* solve */
      ierr = SLESSolve(local_method,local1,local2,&its); CHKERRQ(ierr);
    }
    ierr = VecPipelineEnd
      (global2,border,INSERT_VALUES,SCATTER_FORWARD,PIPELINE_DOWN,main_pipe);
    CHKERRQ(ierr);
    ierr = VecCopy(global2,y); CHKERRQ(ierr);
  }
  
  ierr = PetscPopErrorHandler(); CHKERRQ(ierr);
  return 0;

}

#undef __FUNC__
#define __FUNC__ "PCApplyrich_GBSSOR"
static int PCApplyrich_GBSSOR(PC pc,Vec b,Vec y,Vec w,int its)
{
  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCDestroy_GBSSOR"
int PCDestroy_GBSSOR(PC pc)
{
  PC_GenBlockSSOR_struct *pc_data = (PC_GenBlockSSOR_struct *) pc->data;
  SLES local_method;
  int ierr;
  
  ierr = PCParallelGetLocalSLES(pc,&local_method); CHKERRQ(ierr);
  ierr = SLESDestroy(local_method); CHKERRQ(ierr);
  ierr = VecDestroy(pc_data->local1); CHKERRQ(ierr);
  ierr = VecDestroy(pc_data->local2); CHKERRQ(ierr);
  return 0;

}

#undef __FUNC__
#define __FUNC__ "PCSetFromOptions_GBSSOR"
static int PCSetFromOptions_GBSSOR(PC pc)
{
  int flg,ierr;

  ierr = OptionsHasName(pc->prefix,"-pc_global_fac",&flg);
  CHKERRQ(ierr);
  if (flg) {
    ierr = PCGenBlockSSORSetGlobalFactorisation(pc); CHKERRQ(ierr);
  }
  ierr = OptionsHasName(pc->prefix,"-pc_noglobal_fac",&flg);
  CHKERRQ(ierr);
  if (flg) {
    ierr = PCGenBlockSSORSetNoGlobalFactorisation(pc); CHKERRQ(ierr);
  }
  ierr = PCParallelLocalSolveSetFromOptions(pc); CHKERRQ(ierr);
  ierr = PCCustomPipelineSetFromOptions(pc); CHKERRQ(ierr);

  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCView_GBSSOR"
static int PCView_GBSSOR(PC pc,Viewer viewer)
{
  PC_GenBlockSSOR_struct *pc_data = (PC_GenBlockSSOR_struct *) pc->data;
  int ierr;

  PetscFunctionBegin;
  if (pc_data->global_factorisation)
    PetscPrintf(pc->comm,"Global factorisation\n");
  else
    PetscPrintf(pc->comm,"No global factorisation\n");
  ierr = PCSubdomainsView(pc,viewer); CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "PCCreate_GBSSOR"
int PCCreate_GBSSOR(PC pc)
{
  int ierr;

  pc->apply     = PCApply_GBSSOR;
  pc->applyrich = PCApplyrich_GBSSOR;
  pc->destroy   = PCDestroy_GBSSOR;
  pc->setfromoptions   = PCSetFromOptions_GBSSOR;
  pc->printhelp = 0;
  pc->setup     = PCSetup_GBSSOR;
  /*  pc->type      = PCGenBlockSSOR;*/
  pc->view      = PCView_GBSSOR;

  ierr = PCParallelSubdomainsCreate(pc,sizeof(PC_GenBlockSSOR_struct));
  ierr = PCParallelSubdomainPipelineSetType
    (pc,PIPELINE_SEQUENTIAL,(PetscObject)PETSC_NULL); CHKERRQ(ierr);

  {
    PC_GenBlockSSOR_struct *pc_data = (PC_GenBlockSSOR_struct *) pc->data;
    pc_data->global_factorisation = 0;
    pc_data->fillmethod = FillDiag;
  }

  return 0;
}
