/*
   Defines an Processor Quot preconditioner for any Mat implementation
*/
#include "sles.h"
#include "is.h"
#include "src/mat/impls/aij/mpi/mpiaij.h"
#include "src/mat/impls/aij/seq/aij.h"
#include "src/vec/vecimpl.h"
#include "src/pc/pcimpl.h"
#include "src/vec/utils/vpipe.h"
#include "parpre_pc.h"
#include "src/pc/pcparallel.h"
#include "src/pc/pcextra.h"
#include "options.h"
#include "src/mat/impls/aij/mpi/mpixtra.h"

#define CHUNCKSIZE   100

extern int LocalSolveSetFromOptions(PC pc);
extern int PCPstructSetSystem(PCPstruct *pc_data, Mat mat, Vec vec);
extern int PCParallelCreateSubSLES(SLES *sles);
extern int PCParallelInitCommStruct(PC pc);
extern int PCParallelSetCommNone(PC pc);
extern int VecPipelineCreate(Vec xin,IS ix,Vec yin,IS iy,VecPipeline *newctx);
extern int VecScatterCopyToPipeline_PtoP(VecScatter in,VecPipeline *out);

extern int MatMultAXBY_AIJ
(Scalar a, Mat aijin,Vec xx, Scalar b, Vec yy,Vec zz);
extern int MatMatMult_AIJ(Mat a, Mat b, Mat *c);
extern int MatMatSubtract_AIJ(Mat a, Mat b, Mat *c);
extern int MatTranspose_AIJ(Mat a, Mat *b);
extern int ParPreTraceBackErrorHandler
    (int,char*,char*,char*,int,int,char*,void*);

typedef struct {
  PCPstruct par_info;
  int global_factorisation;
} PC_GenBlockSSOR_struct;

/****************************************************************
 * User Interface                                               *
 ****************************************************************/

#undef __FUNC__
#define __FUNC__ "PCGenBlockSSORSetNoGlobalFactorisation"
int PCGenBlockSSORSetNoGlobalFactorisation(PC pc)
{
  PC_GenBlockSSOR_struct *pc_data = (PC_GenBlockSSOR_struct *) pc->data;
  pc_data->global_factorisation = 0;
  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCGenBlockSSORSetGlobalFactorisation"
int PCGenBlockSSORSetGlobalFactorisation(PC pc)
{
  PC_GenBlockSSOR_struct *pc_data = (PC_GenBlockSSOR_struct *) pc->data;
  pc_data->global_factorisation = 1;
  return 0;
}

/****************************************************************
 * Local Routines                                               *
 ****************************************************************/

#undef __FUNC__
#define __FUNC__ "BlockFillDiag"
static int BlockFillDiag(Mat base_mat,Mat get_mat, Mat *fill_dia)
{ 
  Mat_MPIAIJ *Aij = (Mat_MPIAIJ *) base_mat->data;
  Mat dia_mat = Aij->A, off_mat = Aij->B;
  Mat fill_band,fill_block; IS rows_abs,rows_rel;
  int ierr, mat_nrows = Aij->m;
  
  ierr = ISCreateStride
    (MPI_COMM_SELF,mat_nrows,0,1,&rows_rel); CHKERRQ(ierr);
  ierr = ISCreateStride
    (MPI_COMM_SELF,mat_nrows,Aij->rstart,1,&rows_abs); CHKERRQ(ierr);

  ierr = MatMatMult_AIJ(off_mat,get_mat,&fill_band); CHKERRQ(ierr);
  {
    Mat *res;
    ierr = MatGetSubMatrices(fill_band,1,&rows_rel,&rows_abs,
			   MAT_INITIAL_MATRIX,&res);
    CHKERRQ(ierr);
    fill_block = res[0]; PetscFree(res);
  }

  ierr = ISDestroy(rows_abs); CHKERRQ(ierr);
  ierr = ISDestroy(rows_rel); CHKERRQ(ierr);
  ierr = MatDestroy(fill_band); CHKERRQ(ierr);

  ierr = MatMatSubtract_AIJ(dia_mat,fill_block,fill_dia); CHKERRQ(ierr);
  ierr = MatDestroy(fill_block); CHKERRQ(ierr);

  return 0;
}

#undef __FUNC__
#define __FUNC__ "BlockSetPrec"
static int BlockSetPrec(PC_GenBlockSSOR_struct *pc_data,Mat fill_dia)
{
  Mat_SeqAIJ *Aij = (Mat_SeqAIJ *)fill_dia->data;
  Vec pc_vec; int ierr;
  
  ierr = VecCreateSeq(MPI_COMM_SELF,Aij->n,&pc_vec);
  CHKERRQ(ierr);
/*printf("fill dia declared to be:\n"); MatView(fill_dia,0);*/
  ierr = PCPstructSetSystem
    (&(pc_data->par_info),fill_dia,pc_vec); CHKERRQ(ierr);
  ierr = VecDestroy(pc_vec); CHKERRQ(ierr);
  return 0;
}

/* Compute A_{11} \inv \times A_{12} by solving
 * multiple systems */
#undef __FUNC__
#define __FUNC__ "BlockTransformOffBlock"
static int BlockTransformOffBlock(PC_GenBlockSSOR_struct *pc_data,Mat fill_mat)
{
  Mat_MPIAIJ *Cij = (Mat_MPIAIJ *) fill_mat->data;
  Mat dia_mat = Cij->A,off_mat = Cij->B,off_trans;
  int nrows = ((Mat_SeqAIJ *) dia_mat->data)->m;
  int offcols = ((Mat_SeqAIJ *) off_mat->data)->n;
  int col,ierr; Vec rhs,sol; Scalar *lvals;

  /* allocate work space, transpose A_{12} so that we can
   * extract columns as rows */
  ierr = VecCreateSeq(MPI_COMM_SELF,nrows,&sol); CHKERRQ(ierr);
  lvals = (Scalar *) PetscMalloc( nrows*sizeof(Scalar) ); CHKPTRQ(lvals);
  ierr = MatTranspose_AIJ(off_mat,&off_trans); CHKERRQ(ierr);

  /* extract successive columns (as rows of the transpose), solve,
   * and store them back */
  ierr = VecCreateSeq(MPI_COMM_SELF,nrows,&rhs); CHKERRQ(ierr);
  for (col=0; col<offcols; col++) {
    Scalar *vals,zero = 0; int nelt,*elts,its,row;
    ierr = MatGetRow(off_trans,col,&nelt,&elts,&vals); CHKERRQ(ierr);
    PetscMemzero(lvals,nrows*sizeof(Scalar));
    ierr = VecSet(&zero,rhs); CHKERRQ(ierr);
    for (row=0; row<nelt; row++) {
      lvals[elts[row]] = vals[row];
      ierr = VecSetValues(rhs,1,elts+row,vals+row,INSERT_VALUES);
      CHKERRQ(ierr);
    }
    ierr = VecAssemblyBegin(rhs); CHKERRQ(ierr);
    ierr = VecAssemblyEnd(rhs); CHKERRQ(ierr);
/*ierr = VecCreateSeqFromData(MPI_COMM_SELF,nrows,lvals,&rhs);CHKERRQ(ierr);*/
/*printf("off column vec\n");VecView(rhs,0);*/

    ierr = SLESSolve
      (pc_data->par_info.local_method,rhs,sol,&its); CHKERRQ(ierr);

    {
      Scalar *tvals;
      ierr = VecGetArray(sol,&tvals); CHKERRQ(ierr);
      for (row=0; row<nrows; row++) {
	Scalar v = tvals[row];
	ierr = MatSetValues(off_mat,1,&row,1,&col,&v,INSERT_VALUES);
	CHKERRQ(ierr);
      }
      ierr = VecRestoreArray(sol,&tvals); CHKERRQ(ierr);
    }
    ierr = MatRestoreRow(off_trans,col,&nelt,&elts,&vals); CHKERRQ(ierr);
  }
  ierr = VecDestroy(rhs); CHKERRQ(ierr);

  ierr = MatAssemblyBegin(off_mat,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  ierr = MatAssemblyEnd(off_mat,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
/*printf("Solved off mat\n"); MatView(off_mat,0);*/

  ierr = MatDestroy(off_trans); CHKERRQ(ierr);
  ierr = VecDestroy(sol); CHKERRQ(ierr);
  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCSetup_GBSSOR"
int PCSetup_GBSSOR(PC pc)
{
  Mat base_mat = pc->mat, fill_mat,get_mat;
  Mat_MPIAIJ *Aij = (Mat_MPIAIJ *) base_mat->data;
  PC_GenBlockSSOR_struct *pc_data = (PC_GenBlockSSOR_struct *) pc->data;
  int ierr;

/*  ierr = PetscPushErrorHandler(&ParPreTraceBackErrorHandler,MPI_COMM_WORLD);
  CHKERRQ(ierr);*/

  /* Create the preconditioner matrix as copy of the original */
  ierr = MatConvert(base_mat,MATSAME,&fill_mat); CHKERRQ(ierr);
  ierr = MatSetOption(fill_mat,MAT_COLUMNS_SORTED); CHKERRQ(ierr);
/*  ierr = MatSetOption(fill_mat,NO_NEW_NONZERO_LOCATIONS); CHKERRQ(ierr);*/

  /* Determine local coefficient matrix */
  if (pc_data->global_factorisation) {
    /* Case: Global algorithm: receive rows from processors before */
    MatGatherCtx ectx;
    Mat fill_dia;

    /* Set up the row elimination structure */
    {
      Mat off_mat = Aij->B;
      Mat_SeqAIJ *off = (Mat_SeqAIJ *) off_mat->data;
      IS wanted;
      
      ierr = ISCreateGeneral
	(MPI_COMM_SELF,off->n,Aij->garray,&wanted); CHKERRQ(ierr);
      ierr = MatGatherCtxCreate(fill_mat,wanted,&ectx); CHKERRQ(ierr);
      ierr = VecPipelineSetCustomPipelineFromPCPstruct
	((VecPipeline)ectx->vs,&(pc_data->par_info)); CHKERRQ(ierr);
      ierr = ISDestroy(wanted); CHKERRQ(ierr);
    }

    ierr = MatGatherRowsPipelineBegin
      (fill_mat,PIPELINE_CUSTOM_UP,ectx,&get_mat);
    CHKERRQ(ierr);
    
    /* Compute a new diagonal block by filling old rows 
     * onto an old diagonal block */
    ierr = BlockFillDiag(fill_mat,get_mat,&fill_dia); CHKERRQ(ierr);
    
    /* Store the new diagonal block as preconditioner operator */
    ierr = BlockSetPrec(pc_data,fill_dia); CHKERRQ(ierr);
    
    /* multiply the off-diagonal block */
    ierr = BlockTransformOffBlock(pc_data,fill_mat); CHKERRQ(ierr);
    
    ierr = MatGatherRowsPipelineEnd
      (fill_mat,PIPELINE_CUSTOM_UP,ectx,&get_mat);
    CHKERRQ(ierr);
    ierr = MatDestroy(get_mat); CHKERRQ(ierr);
    ierr = MatDestroy(fill_mat); CHKERRQ(ierr);
    ierr = MatGatherCtxDestroy(ectx); CHKERRQ(ierr);
  } else {
    /* Case: no global algorithm; simply use diagonal block */
    Mat_MPIAIJ *Aij = (Mat_MPIAIJ *) base_mat->data;
    Mat dia_mat = Aij->A;
    ierr = BlockSetPrec(pc_data,dia_mat); CHKERRQ(ierr);
  }

/*  ierr = VecScatterCopyToPipeline_PtoP(Aij->Mvctx,&pc_data->main_pipe);
  CHKERRQ(ierr);*/
  {
    Vec tmp_g; IS is_g,is_l; int len,dum, b=((Mat_SeqAIJ *)Aij->B->data)->n;
    ierr = MatGetLocalSize(base_mat,&len,&dum); CHKERRQ(ierr);
    ierr = VecCreateMPI(base_mat->comm,len,PETSC_DECIDE,&tmp_g); CHKERRQ(ierr);
    ierr = ISCreateGeneral(base_mat->comm,b,Aij->garray,&is_g); CHKERRQ(ierr);
    ierr = ISCreateStride(MPI_COMM_SELF,b,0,1,&is_l); CHKERRQ(ierr);
    ierr = VecPipelineCreate
      (tmp_g,is_g,Aij->lvec,is_l,&pc_data->par_info.main_pipe);
    CHKERRQ(ierr);    
  }
  ierr = VecPipelineSetCustomPipelineFromPCPstruct
    (pc_data->par_info.main_pipe,&(pc_data->par_info)); CHKERRQ(ierr);
  
/*  ierr = PetscPopErrorHandler(); CHKERRQ(ierr);*/
  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCApply_GBSSOR"
static int PCApply_GBSSOR(PC pc,Vec x,Vec y)
{
  PCPstruct *pc_data = &((PC_GenBlockSSOR_struct *) pc->data)->par_info;
  Mat base_mat = pc->mat;
  Mat_MPIAIJ *Aij = (Mat_MPIAIJ *) (base_mat->data);
  Mat off_diag = Aij->B;
  int ierr,its;
  Vec tmp,tmq;
  Scalar zero = 0.0, mone = -1.0, one = 1.0;

  ierr = PetscPushErrorHandler(&ParPreTraceBackErrorHandler,MPI_COMM_WORLD);
  CHKERRQ(ierr);

  { /* should these creations be done more globally, once and for all ? */
    int len;
    ierr = VecGetLocalSize(x,&len); CHKERRQ(ierr);
    ierr = VecCreateSeq(MPI_COMM_SELF,len,&tmp); CHKERRQ(ierr);
    ierr = VecDuplicate(tmp,&tmq); CHKERRQ(ierr);
  }

  ierr = VecSet(&zero,Aij->lvec); CHKERRQ(ierr);
  ierr = VecPipelineBegin
    (tmq,Aij->lvec,INSERT_VALUES, PIPELINE_CUSTOM_UP,pc_data->main_pipe);
  CHKERRQ(ierr);

  ierr = MatMultAXBY_AIJ(mone,off_diag,Aij->lvec,one,x,tmp); CHKERRQ(ierr);
  ierr = SLESSolve(pc_data->local_method,tmp,tmq,&its); CHKERRQ(ierr);

  ierr = VecPipelineEnd
    (tmq,Aij->lvec,INSERT_VALUES, PIPELINE_CUSTOM_UP,pc_data->main_pipe);
  CHKERRQ(ierr);

  ierr = VecSet(&zero,Aij->lvec); CHKERRQ(ierr);
  ierr = VecPipelineBegin
    (y,Aij->lvec,INSERT_VALUES, PIPELINE_CUSTOM_DOWN,pc_data->main_pipe);
  CHKERRQ(ierr);

  ierr = MatMultAXBY_AIJ(mone,off_diag,Aij->lvec,one,tmp,tmp); CHKERRQ(ierr);
/*printf("input\n");
VecView(tmp,0);*/

  {/* some pointer rerouting to be able to solve into y */
    Scalar *y_ar,*t_ar;
    ierr = VecGetArray(tmq,&t_ar); CHKERRQ(ierr);
    ierr = VecGetArray(y,&y_ar); CHKERRQ(ierr);
    ierr = VecPlaceArray(tmq,y_ar); CHKERRQ(ierr);
    ierr = SLESSolve(pc_data->local_method,tmp,tmq,&its); CHKERRQ(ierr);
    ierr = VecPlaceArray(tmq,t_ar); CHKERRQ(ierr);
  }

  ierr = VecPipelineEnd
    (y,Aij->lvec,INSERT_VALUES, PIPELINE_CUSTOM_DOWN,pc_data->main_pipe);
  CHKERRQ(ierr);
/*printf("gbssor output\n"); VecView(y,0);*/
  ierr = VecDestroy(tmp); CHKERRQ(ierr);
  ierr = VecDestroy(tmq); CHKERRQ(ierr);

  ierr = PetscPopErrorHandler(); CHKERRQ(ierr);
  return 0;

}

#undef __FUNC__
#define __FUNC__ "PCApplyrich_GBSSOR"
static int PCApplyrich_GBSSOR(PC pc,Vec b,Vec y,Vec w,int its)
{
  return 0;
}

/* parses arguments of the form -lap [symmetric,forward,back][omega=...] */
#undef __FUNC__
#define __FUNC__ "PCSetFromOptions_GBSSOR"
static int PCSetFromOptions_GBSSOR(PC pc)
{
  int ierr;
  ierr = LocalSolveSetFromOptions(pc); CHKERRQ(ierr);
  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCDestroy_GBSSOR"
int PCDestroy_GBSSOR(PetscObject obj)
{
  PC pc = (PC) obj;
  PC_GenBlockSSOR_struct *pc_data = (PC_GenBlockSSOR_struct *) pc->data;
  int ierr;
  
  ierr = SLESDestroy(pc_data->par_info.local_method); CHKERRQ(ierr);
/*  ierr = (pc_data->comm_method->comm_destroy)(pc); CHKERRQ(ierr);*/
  return 0;

}

#undef __FUNC__
#define __FUNC__ "PCCreate_GBSSOR"
int PCCreate_GBSSOR(PC pc)
{
  int ierr;
  PC_GenBlockSSOR_struct *bij;

  pc->apply     = PCApply_GBSSOR;
  pc->applyrich = PCApplyrich_GBSSOR;
  pc->destroy   = PCDestroy_GBSSOR;
  pc->setfrom   = PCSetFromOptions_GBSSOR;
  pc->printhelp = 0;
  pc->setup     = PCSetup_GBSSOR;
  pc->type      = PCGenBlockSSOR;

  bij = PetscNew(PC_GenBlockSSOR_struct); CHKPTRQ(bij);
  ierr = PCParallelInstallSubSolve
    (MPI_COMM_SELF,&(bij->par_info.local_method)); CHKERRQ(ierr);
  {
    PC local_pc;
    ierr = SLESGetPC(bij->par_info.local_method,&local_pc); CHKERRQ(ierr);
    ierr = PCSetType(local_pc,PCSOR); CHKERRQ(ierr);
    ierr = PCSORSetIterations(local_pc,1); CHKERRQ(ierr);
    ierr = PCSORSetSymmetric(local_pc,SOR_SYMMETRIC_SWEEP); CHKERRQ(ierr);
  }
  bij->global_factorisation = 0;
  pc->data      = (void *) bij;

  ierr = PCParallelInitCommStruct(pc); CHKERRQ(ierr);

  return 0;
}
