#include "petsclog.h"
#include "src/vec/vecimpl.h"
#include "src/mat/impls/aij/mpi/mpiaij.h"
#include "src/vec/utils/vpipe.h"
#include "parpre_pc.h"
#include "src/mat/impls/aij/mpi/mpixtra.h"
#include "pcschwarz.h"
#include "pcextra.h"

extern int PCIsSchwarzMethod(PC pc); /* Shouldn't be: it's in parpre_pc.h */
extern int LocalSolveSetFromOptions(PC pc);
extern int VecPipelineCreate(Vec xin,IS ix,Vec yin,IS iy,VecPipeline *newctx);
extern int MatMaxRowLen_MPIAIJ(Mat A, int *rowlen);
extern int PCParallelInitLocalMethod(PCPstruct *pc_data, Mat mat, Vec vec);
extern int ParPreTraceBackErrorHandler
    (int,char*,char*,char*,int,int,char*,void*);

/****************************************************************
 * Auxiliary routines for Schwarz preconditioners               *
****************************************************************/
#undef __FUNC__
#define __FUNC__ "PCSchwarzSetHaloSize"
int PCSchwarzSetHaloSize(PC pc, int ns)
{
  int ierr;

  ierr = PetscPushErrorHandler(&ParPreTraceBackErrorHandler,MPI_COMM_WORLD);
  CHKERRQ(ierr);
  if (!PCIsSchwarzMethod(pc))
    SETERRQ(1,0,"Can only set halo for Schwarz methods");

  if (ns<0) SETERRQ(1,0,"Schwarz method halo should be positive");

  {
    PC_Schwarz_struct *pc_data = (PC_Schwarz_struct *) pc->data;
    pc_data->halo_width = ns;
  }
  ierr = PetscPopErrorHandler(); CHKERRQ(ierr);
  return 0;
}

#undef __FUNC__
#define __FUNC__ "ExtendMatrix"
int ExtendMatrix(Mat *Schwarz_mat,Mat base_mat,
		 int *nSchwarz_rows, int **Schwarz_rows,
		 IndStash *result_rows_to_get,/* int *nt_rows,*/
		 MPI_Comm comm,
		 int glob_max_rowlen,int off_max_rowlen)
{
  Mat_MPIAIJ *Aij = (Mat_MPIAIJ *) base_mat->data;
  int ierr,*t_rows,nt_rows;
  int N = Aij->N;
  int numtids, mytid;
  Mat tmat;
  IndStash rows_to_get = *result_rows_to_get;
  int nrows, row, *xrows=0;/* *ncols, **cols; Scalar **vals;*/
/*  int extend_matrix;*/ /* log this routine */
  MatGatherCtx gs; Mat get_mat=0;

/*
  PLogEventRegister(&extend_matrix,"Extend matrix   ","black");
  PLogEventBegin(extend_matrix,0,0,0,0);
*/
  MPI_Comm_size(base_mat->comm,&numtids);
  MPI_Comm_rank(base_mat->comm,&mytid);

  /* First we establish a bigger matrix that contains the
   * old matrix elements, then we get the other elements */

  /* determine the new rows */
/*{int i;
printf("Adding rows to matrix:");
for (i=0; i<rows_to_get->n; i++) printf(" %d",rows_to_get->array[i]);
printf("\n");}
*/
  nt_rows = *nSchwarz_rows+rows_to_get->n;
  t_rows = (int *) PetscMalloc( nt_rows*sizeof(int) ); CHKPTRQ(t_rows);
  PetscMemcpy(t_rows,
	      *Schwarz_rows,*nSchwarz_rows*sizeof(int));
  PetscMemcpy(t_rows+*nSchwarz_rows,
	      rows_to_get->array,rows_to_get->n*sizeof(int));

  /* allocate a bigger matrix than the earlier Schwarz matrix */
  ierr = MatCreateSeqAIJ(MPI_COMM_SELF,nt_rows,N,5,0, &tmat);
  CHKERRQ(ierr);
  {
    int i,j,idx;
    for ( i=0; i<*nSchwarz_rows; i++ ) {
      int ncol,*cols;
      Scalar *vals,v;
      
      ierr = MatGetRow(*Schwarz_mat,i,&ncol,&cols,&vals); CHKERRQ(ierr);
      for (idx=0; idx<ncol; idx++) {
	j = cols[idx];
	v = vals[idx];
	ierr =  MatSetValues(tmat,1,&i,1,&j,&v,INSERT_VALUES); CHKERRQ(ierr);
	
      }
      ierr = MatRestoreRow(*Schwarz_mat,i,&ncol,&cols,&vals); CHKERRQ(ierr);
    }
  }

  /* copy the new stuff into the Schwarz information */
  PetscFree(*Schwarz_rows);
  *nSchwarz_rows = nt_rows; *Schwarz_rows = t_rows;
  ierr = MatDestroy(*Schwarz_mat); CHKERRQ(ierr);
  *Schwarz_mat = tmat;
/*MatView(*Schwarz_mat,0);*/

  nrows = rows_to_get->n;
  if (nrows>0) {
    IS wanted;
    ierr = ISCreateGeneral(MPI_COMM_SELF,nrows,rows_to_get->array,&wanted);
    CHKERRQ(ierr);

    xrows = (int *) PetscMalloc( nrows*sizeof(int)+1 ); CHKPTRQ(xrows);
    PetscMemcpy(xrows,rows_to_get->array,nrows*sizeof(int));

    ierr = MatGatherCtxCreate
      (base_mat,wanted,&gs); CHKERRQ(ierr);
    ierr = MatGatherRows(base_mat,gs,&get_mat); CHKERRQ(ierr);
    ierr = MatGatherCtxDestroy(gs); CHKERRQ(ierr);
  }

  /* now handle the rows that are being sent;
   * we go through the rows, and any indexes that are not yet
   * local, are stashed as to get for the next halo */
  ierr = NewIndexStash(&rows_to_get); CHKERRQ(ierr);
  for (row=0; row<nrows; row++) {
    int global_row,local_row,rowlen, j,k;
    int *this_row; Scalar *this_val;

    ierr = MatGetRow(get_mat,row,&rowlen,&this_row,&this_val);
    CHKERRQ(ierr);
    global_row = xrows[row]; local_row = -1;
    for (k=0; k<*nSchwarz_rows; k++) {/* global -> local */
      if ((*Schwarz_rows)[k] == global_row) {local_row = k; break;}
    }
    if (local_row==-1) SETERRQ(1,0,"Could not find global row locally");

    for (j=0; j<rowlen; j++) {
      int idx,found; Scalar val;

      idx = this_row[j]; val = this_val[j]; found = 0;
      for (k=0; k<*nSchwarz_rows; k++) {/* global -> local */
	if ((*Schwarz_rows)[k] == idx) {idx = k; found = 1; break;}
      }
      if (found) {
      } else {
	ierr = StashIndex(rows_to_get,1,&idx); CHKERRQ(ierr);
	idx = *nSchwarz_rows+rows_to_get->n-1;
      }
      ierr =  MatSetValues
	(*Schwarz_mat,1,&local_row,1,&idx,&val,INSERT_VALUES);
    }
    ierr = MatRestoreRow(get_mat,row,&rowlen,&this_row,&this_val);
    CHKERRQ(ierr);
  }

  if (get_mat) ierr = MatDestroy(get_mat); CHKERRQ(ierr);
  ierr = MatAssemblyBegin(*Schwarz_mat,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  ierr = MatAssemblyEnd(*Schwarz_mat,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  if (xrows) PetscFree(xrows);
  *result_rows_to_get = rows_to_get;

/* printf("Matrix with halo\n"); MatView(*Schwarz_mat,0);*/

  PLogEventEnd(extend_matrix,0,0,0,0);

  return 0;
 }

/*
  {
    int irow,row,icol,col; Scalar val;
    for (irow=0; irow<nrows; irow++) {
      row = xrows[irow];
      printf("Received row %d=%d",irow,row);
      for (icol=0; icol<ncols[irow]; icol++) {
	col = cols[irow][icol];
	val = vals[irow][icol];
	printf(" elt %d=%e",col,val);
      }
      printf("\n");
    }
  }
*/

#undef __FUNC__
#define __FUNC__ "PCSetFromOptions_Schwarz"
int PCSetFromOptions_Schwarz(PC pc)
{
  int ierr,flg,ns;

  ierr = OptionsGetInt(pc->prefix,"-pc_halo_size",&ns,&flg); CHKERRQ(ierr);
  if (flg) {
    ierr = PCSchwarzSetHaloSize(pc,ns); CHKERRQ(ierr);
  }
  ierr = LocalSolveSetFromOptions(pc); CHKERRQ(ierr);

  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCSetup_Schwarz"
int PCSetup_Schwarz(PC pc)
{
  Mat base_mat = pc->mat;
  Mat_MPIAIJ *Aij = (Mat_MPIAIJ *) base_mat->data;
  MPI_Comm    comm = base_mat->comm;
  Mat dia_mat = Aij->A, off_mat = Aij->B;
  Mat_SeqAIJ *aij = (Mat_SeqAIJ *) dia_mat->data;
  Mat_SeqAIJ *bij = (Mat_SeqAIJ *) off_mat->data;
  PC_Schwarz_struct *pc_data = (PC_Schwarz_struct *) pc->data;
  int         ierr;
  int k,glob_max_rowlen;
  /* temporaries for swapping the Schwarz matrix and such */
  int nhalos = pc_data->halo_width;
  IndStash rows_to_get;

  ierr = PetscPushErrorHandler(&ParPreTraceBackErrorHandler,MPI_COMM_WORLD);
  CHKERRQ(ierr);

  if (!(base_mat->type==MATMPIAIJ)) {
    SETERRQ(1,0,"Overlapping preconditioner only implemented for AIJMPI\n");
    return -1;
  }

  /* >>>> Schwarz Matrix Construction <<<< */

  /* the Schwarz rows are initially your own */
  pc_data->nSchwarz_rows = Aij->m;
  pc_data->Schwarz_rows = (int *) PetscMalloc(pc_data->nSchwarz_rows * sizeof(int));
  CHKPTRQ(pc_data->Schwarz_rows);
  for (k=0; k<Aij->m; k++)
    pc_data->Schwarz_rows[k] = Aij->rstart+k;

  /* the Schwarz matrix first of all contains the local matrix */
  ierr = MatCreateSeqAIJ
    (MPI_COMM_SELF,Aij->m,Aij->N,5,0, &pc_data->Schwarz_mat);
  CHKERRQ(ierr);

  {
    int i,j,ncols,*cols; Scalar *vals;
    for ( i=0; i<aij->m; i++ ) {
      ierr = MatGetRow(dia_mat,i,&ncols,&cols,&vals);
      CHKERRQ(ierr);
      ierr =  MatSetValues
	(pc_data->Schwarz_mat,1,&i,ncols,cols,vals,INSERT_VALUES);
      CHKERRQ(ierr);
      ierr = MatRestoreRow(dia_mat,i,&ncols,&cols,&vals);
      CHKERRQ(ierr);
      
      ierr = MatGetRow(off_mat,i,&ncols,&cols,&vals);
      CHKERRQ(ierr);
      for (j=0; j<ncols; j++) cols[j] = cols[j]+Aij->m;
      ierr =  MatSetValues
	(pc_data->Schwarz_mat,1,&i,ncols,cols,vals,INSERT_VALUES);
      for (j=0; j<ncols; j++) cols[j] = cols[j]-Aij->m;
      CHKERRQ(ierr);
      ierr = MatRestoreRow(off_mat,i,&ncols,&cols,&vals);
      CHKERRQ(ierr);
    }
  }

  ierr = MatAssemblyBegin(pc_data->Schwarz_mat,MAT_FINAL_ASSEMBLY);
  CHKERRQ(ierr);
  ierr = MatAssemblyEnd(pc_data->Schwarz_mat,MAT_FINAL_ASSEMBLY);
  CHKERRQ(ierr);

/*
printf("Initial matrix\n");
MatView(pc_data->Schwarz_mat,0);
*/

  {
    int len;
    ierr = MatMaxRowLen_MPIAIJ(base_mat,&len); CHKERRQ(ierr);
    MPI_Allreduce(&len,&glob_max_rowlen,1,MPI_INT,MPI_MAX,comm);
  }

  /* /// Major loop over the halo width /// */
  ierr = NewIndexStash(&rows_to_get); CHKERRQ(ierr);
  ierr = StashIndex(rows_to_get,bij->n,Aij->garray); CHKERRQ(ierr);

      while(nhalos--) {
	ierr = ExtendMatrix
	  (&(pc_data->Schwarz_mat),base_mat,
	   &(pc_data->nSchwarz_rows),&(pc_data->Schwarz_rows),
	   &rows_to_get,/*&nt_rows,*/comm, glob_max_rowlen,bij->n);
	CHKERRQ(ierr);
      }

  {
    int i,j,idx,nt_rows = pc_data->nSchwarz_rows;
    Mat tmat;
    
    /* At this point the Schwarz matrix contains everything we need,
     * but it is Aij->N wide, ie, it is not square.
     * a boring transformation step follows */
    tmat = pc_data->Schwarz_mat;
    ierr = MatCreateSeqAIJ
      (MPI_COMM_SELF,nt_rows,nt_rows,5,0, &pc_data->Schwarz_mat);
    CHKERRQ(ierr);
    ierr = MatCreateSeqAIJ
      (MPI_COMM_SELF,nt_rows,Aij->N-nt_rows,3,0, &pc_data->Schwarz_off);
    CHKERRQ(ierr);
    
    for ( i=0; i<pc_data->nSchwarz_rows; i++ ) {
      int loc,ncol,*cols; Scalar *vals,v;
      
      ierr = MatGetRow(tmat,i,&ncol,&cols,&vals); CHKERRQ(ierr);
      for (idx=0; idx<ncol; idx++) {
	j = cols[idx]; v = vals[idx];
	
	if (j<pc_data->nSchwarz_rows) {
/*printf("about to set %d,%d\n",i,j);*/
	  ierr = MatSetValues
	    (pc_data->Schwarz_mat,1,&i,1,&j,&v,INSERT_VALUES);
	  CHKERRQ(ierr);
	} else {
	  loc = j-pc_data->nSchwarz_rows;
	  ierr = MatSetValues
	    (pc_data->Schwarz_off,1,&i,1,&loc,&v,INSERT_VALUES);
	  CHKERRQ(ierr);
	}
      }
      ierr = MatRestoreRow(tmat,i,&ncol,&cols,&vals); CHKERRQ(ierr);
    }
    ierr = MatDestroy(tmat); CHKERRQ(ierr);
  }
  
  ierr = MatAssemblyBegin
    (pc_data->Schwarz_mat,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  ierr = MatAssemblyEnd
    (pc_data->Schwarz_mat,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);

  ierr = MatAssemblyBegin
    (pc_data->Schwarz_off,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  ierr = MatAssemblyEnd
    (pc_data->Schwarz_off,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);

/*
printf("Final Matrix\n");
MatView(pc_data->Schwarz_mat,0);
*/

  /* >>>> scatter and gather into the halo part <<<< */

  {IS int_in_glob,rim_in_glob,edge_in_glob, int_in_loc,rim_in_loc;
   IS from_domain,to_domain/*,from_rim,to_rim*/; IS edge_itself;
   Vec gvec/*,int_vec*/;
   int rim_size = pc_data->nSchwarz_rows-Aij->n;
   int int_size = Aij->n, edge_size = rows_to_get->n;

   /* vectors: global, local extended, and rim/internal separately */
   ierr = VecCreateMPI(comm,Aij->n,Aij->N,&gvec);
   CHKERRQ(ierr);
   ierr = VecCreateSeq
     (MPI_COMM_SELF,pc_data->nSchwarz_rows,&pc_data->domain_vec);
   CHKERRQ(ierr);
   ierr = VecCreateSeq
     (MPI_COMM_SELF,edge_size,&pc_data->edge_vec);
   CHKERRQ(ierr);

   /* index sets: rim/int in global, extended, and separate vectors */
   ierr = ISCreateGeneral
     (comm,int_size,pc_data->Schwarz_rows,&int_in_glob);
   CHKERRQ(ierr);
   ierr = ISCreateGeneral
     (comm,rim_size,&(pc_data->Schwarz_rows[Aij->n]),&rim_in_glob);
   CHKERRQ(ierr);
   ierr = ISCreateGeneral
     (comm,edge_size,rows_to_get->array,&edge_in_glob);
   CHKERRQ(ierr);
   
   ierr = ISCreateStride
     (MPI_COMM_SELF,rim_size,int_size,1,&rim_in_loc);
   CHKERRQ(ierr);
   ierr = ISCreateStride
     (MPI_COMM_SELF,int_size,0,1,&int_in_loc);
   CHKERRQ(ierr);

   /* index sets describing extended domain */
   ierr = ISCreateGeneral
     (comm,pc_data->nSchwarz_rows,pc_data->Schwarz_rows,&from_domain);
   CHKERRQ(ierr);
   ierr = ISCreateStride
     (MPI_COMM_SELF,pc_data->nSchwarz_rows,0,1,&to_domain);
   CHKERRQ(ierr);

   /* make scatter contexts for the extended domain */
   ierr = VecScatterCreate
     (gvec,from_domain,pc_data->domain_vec,to_domain,&pc_data->get_xdomain);
   CHKERRQ(ierr);
   ierr = VecScatterCreate
     (pc_data->domain_vec,to_domain,gvec,from_domain,&pc_data->put_xdomain);
   CHKERRQ(ierr);
   ierr = ISDestroy(from_domain); CHKERRQ(ierr);
   ierr = ISDestroy(to_domain); CHKERRQ(ierr);

   /* make scatter contexts for the rim / internal */
   ierr = VecPipelineCreate
     (gvec,rim_in_glob,pc_data->domain_vec,rim_in_loc, &pc_data->lift_xrim);
   CHKERRQ(ierr);
   ierr = VecPipelineCreate
     (pc_data->domain_vec,rim_in_loc,gvec,rim_in_glob, &pc_data->drop_xrim);
   CHKERRQ(ierr);

   ierr = VecScatterCreate
     (gvec,int_in_glob,pc_data->domain_vec,int_in_loc, &pc_data->lift_xint);
   CHKERRQ(ierr);
   ierr = VecScatterCreate
     (pc_data->domain_vec,int_in_loc,gvec,int_in_glob, &pc_data->drop_xint);
   CHKERRQ(ierr);

   /* scatter context for the edge */
   ierr = ISCreateStride(MPI_COMM_SELF,edge_size,0,1,&edge_itself);
   CHKERRQ(ierr);
   ierr = VecPipelineCreate
     (gvec,edge_in_glob,pc_data->edge_vec,edge_itself,&pc_data->lift_edge);
   CHKERRQ(ierr);
   ierr = VecPipelineSetCustomPipelineFromPCPstruct
     (pc_data->lift_edge,&(pc_data->comm_method)); CHKERRQ(ierr);
   ierr = VecPipelineSetCustomPipelineFromPCPstruct
     (pc_data->drop_xrim,&(pc_data->comm_method)); CHKERRQ(ierr);
   ierr = VecPipelineCreate
     (pc_data->edge_vec,edge_itself,gvec,edge_in_glob,&pc_data->drop_edge);
   CHKERRQ(ierr);

   ierr = ISDestroy(int_in_glob); CHKERRQ(ierr);
   ierr = ISDestroy(rim_in_glob); CHKERRQ(ierr);
   ierr = ISDestroy(edge_in_glob); CHKERRQ(ierr);
   ierr = ISDestroy(int_in_loc); CHKERRQ(ierr);
   ierr = ISDestroy(rim_in_loc); CHKERRQ(ierr);
   ierr = ISDestroy(edge_itself); CHKERRQ(ierr);
   ierr = VecDestroy(gvec); CHKERRQ(ierr);
/*
   ierr = VecDestroy(int_vec); CHKERRQ(ierr);
*/
 }
  
  ierr = DestroyIndexStash(rows_to_get); CHKERRQ(ierr);

  /* Initialise the local solution method */
  ierr = SLESSetOperators
    (pc_data->comm_method.local_method,
     pc_data->Schwarz_mat,pc_data->Schwarz_mat,0);
  CHKERRQ(ierr);
  {
    PC local_pc;
    PCParallelGetLocalPC(pc,&local_pc);
    ierr = PCSetVector(local_pc,pc_data->domain_vec); CHKERRQ(ierr);
    ierr = PCSetUp(local_pc); CHKERRQ(ierr);
  }

  ierr = PetscPopErrorHandler(); CHKERRQ(ierr);

  return 0;
}

