#include "petsclog.h"
#include "parpre_pc.h"
#include "ksp.h"
#include "src/vec/vecimpl.h"
#include "src/mat/impls/aij/mpi/mpiaij.h"
#include "parpre_subdomains.h"
#include "pcschwarz.h"

extern int MatMaxRowLen_MPIAIJ(Mat A, int *rowlen);
extern int MatIncreaseOverlap(Mat mat,int n, IS *is,int ov);

extern int ParPreTraceBackErrorHandler
    (int,char*,char*,char*,int,int,char*,void*);

/****************************************************************
 * Auxiliary routines for Schwarz preconditioners               *
****************************************************************/
#undef __FUNC__
#define __FUNC__ "PCSchwarzSetHaloSize"
int PCSchwarzSetHaloSize(PC pc, int ns)
{
  int ierr;

  ierr = PetscPushErrorHandler(&ParPreTraceBackErrorHandler,0);
  CHKERRQ(ierr);
  if (!PCIsSchwarzMethod(pc))
    SETERRQ(1,0,"Can only set halo for Schwarz methods");

  if (ns<0) SETERRQ(1,0,"Schwarz method halo should be positive");

  {
    PC_Schwarz_struct *pc_data = (PC_Schwarz_struct *) pc->data;
    pc_data->halo_width = ns;
  }
  ierr = PetscPopErrorHandler(); CHKERRQ(ierr);
  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCSchwarzGetIndices"
int PCSchwarzGetIndices(PC pc,IS *is)
{
  PC_Schwarz_struct *pc_data = (PC_Schwarz_struct *) pc->data;

  if (!PCIsSchwarzMethod(pc))
    SETERRQ(1,0,"Trying to get Schwarz indices for non-Schwarz method");

  *is = pc_data->Schwarz_rows;

  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCSetFromOptions_Schwarz"
int PCSetFromOptions_Schwarz(PC pc)
{
  int ierr,flg,ns;

  ierr = OptionsGetInt(pc->prefix,"-pc_halo_size",&ns,&flg); CHKERRQ(ierr);
  if (flg) {
    ierr = PCSchwarzSetHaloSize(pc,ns); CHKERRQ(ierr);
  }
  ierr = PCParallelLocalSolveSetFromOptions(pc); CHKERRQ(ierr);

  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCSchwarzGetLocalSLES"
int PCSchwarzGetLocalSLES(PC schwarz_pc,SLES *local_sles)
{
  PC_Schwarz_struct *pc_data = (PC_Schwarz_struct *) schwarz_pc->data;
  PC ssor_pc;
  int ierr;

  ierr = SLESGetPC(pc_data->ssor_solver,&ssor_pc); CHKERRQ(ierr);
  ierr = PCParallelGetLocalSLES(ssor_pc,local_sles);  CHKERRQ(ierr);

  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCSchwarzSetPipeType"
static int PCSchwarzSetPipeType(PC schwarz_pc,PipelineType type,PetscObject x)
{
  PC_Schwarz_struct *pc_data = (PC_Schwarz_struct *) schwarz_pc->data;
  PC ssor_pc;
  int ierr;

  ierr = SLESGetPC(pc_data->ssor_solver,&ssor_pc); CHKERRQ(ierr);
  ierr = PCParallelSubdomainPipelineSetType(ssor_pc,type,x);
  CHKERRQ(ierr);

  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCSchwarzGetPipeType"
static int PCSchwarzGetPipeType(PC schwarz_pc,PipelineType *type)
{
  PC_Schwarz_struct *pc_data = (PC_Schwarz_struct *) schwarz_pc->data;
  PC ssor_pc;
  int ierr;

  ierr = SLESGetPC(pc_data->ssor_solver,&ssor_pc); CHKERRQ(ierr);
  ierr = PCParallelSubdomainPipelineGetType(ssor_pc,type); CHKERRQ(ierr);

  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCSchwartzInstallSSORSubSolve"
int PCSchwartzInstallSSORSubSolve(PC pc)
{
  PC_Schwarz_struct *pc_data = (PC_Schwarz_struct *) pc->data;
  PC schwarz_pc;
  int ierr;

  ierr = ParPreGenerateSLES(pc->comm,&(pc_data->ssor_solver));
  CHKERRQ(ierr);
  ierr = PCParallelSetGetLocalSLES(pc,&PCSchwarzGetLocalSLES);
  CHKERRQ(ierr);
  ierr = PCParallelSetSubdomainPipelineSetType(pc,&PCSchwarzSetPipeType);
  CHKERRQ(ierr);
  ierr = PCParallelSetSubdomainPipelineGetType(pc,&PCSchwarzGetPipeType);
  CHKERRQ(ierr);

  ierr = SLESGetPC(pc_data->ssor_solver,&schwarz_pc); CHKERRQ(ierr);
  ierr = PCSetType(schwarz_pc,PCGenBlockSSOR); CHKERRQ(ierr);

  return 0;
}

/****************************************************************
 * Internal auxiliaries
 ****************************************************************/
#undef __FUNC__
#define __FUNC__ "GetOverlappingRange"
static int GetOverlappingRange
(Mat base_mat,PC pc,int *local_size,IS *domain_indices)
{
  PC_Schwarz_struct *pc_data = (PC_Schwarz_struct *) pc->data;
  IS domain;
  int lo,hi,domain_size,ierr;

  ierr = MatGetOwnershipRange(base_mat,&lo,&hi); CHKERRQ(ierr);
  domain_size = hi-lo;
  ierr = ISCreateStride(MPI_COMM_SELF,domain_size,lo,1,&domain); CHKERRQ(ierr);

  ierr =  MatIncreaseOverlap(base_mat,1,&domain,pc_data->halo_width);
  CHKERRQ(ierr);
  ierr = ISSort(domain); CHKERRQ(ierr); 

  *local_size = domain_size; *domain_indices = domain;

  return 0;
}

#undef __FUNC__
#define __FUNC__ "MakeSchwarzVectors"
/* MakeSchwarzVectors
   called from PCSetup_Schwarz */
static int MakeSchwarzVectors(PC pc,int local_size,int *unwrap_low)
{
  PC_Schwarz_struct *pc_data = (PC_Schwarz_struct *) pc->data;
  MPI_Comm    comm = pc->comm;
  IS the_domain;
  PipelineType pipe_type;
  int domain_size,s_hi,ierr;

  ierr = VecCreateMPI
    (comm,local_size,PETSC_DECIDE,&(pc_data->normal_vec));
  CHKERRQ(ierr);
  ierr = ISGetSize(pc_data->Schwarz_rows,&domain_size); CHKERRQ(ierr);
  ierr = VecCreateMPI
    (comm,domain_size,PETSC_DECIDE,&(pc_data->extended_vec));
  CHKERRQ(ierr);
  ierr = VecDuplicate(pc_data->extended_vec,&pc_data->extended_vec2);
  CHKERRQ(ierr);
  ierr = VecGetOwnershipRange
    (pc_data->extended_vec,unwrap_low,&s_hi);
  CHKERRQ(ierr);
  ierr = ISCreateStride
    (MPI_COMM_SELF,domain_size,*unwrap_low,1,&the_domain);
  CHKERRQ(ierr);
  ierr = VecScatterCreate
    (pc_data->normal_vec,pc_data->Schwarz_rows,
     pc_data->extended_vec,the_domain,
     &(pc_data->scatter_to_extended));
  CHKERRQ(ierr);
  ierr = VecPipelineCreate
    (comm,pc_data->extended_vec,the_domain,
     pc_data->normal_vec,pc_data->Schwarz_rows,
     &(pc_data->pipe_from_extended));
  CHKERRQ(ierr);
  ierr = PCParallelSubdomainPipelineGetType(pc,&pipe_type); CHKERRQ(ierr);
  ierr = VecPipelineSetType
    (pc_data->pipe_from_extended,pipe_type,(PetscObject)(pc->pmat));
  CHKERRQ(ierr);
  ierr = VecPipelineSetup(pc_data->pipe_from_extended); CHKERRQ(ierr);

  return 0;
}

static int TranslateNumberings(Mat base_mat,Mat schwarz_mat,IS Schwarz_rows,
			       int **big_r)
{
  MPI_Comm comm = base_mat->comm;
  int *big_row,*Big_row;
  int lo,hi,size, Lo,Hi,Size, i,idum,ierr;

  ierr = MatGetSize(base_mat,&size,&idum); CHKERRQ(ierr);
  ierr = MatGetOwnershipRange(base_mat,&lo,&hi); CHKERRQ(ierr);
  
  ierr = MatGetSize(schwarz_mat,&Size,&idum); CHKERRQ(ierr);
  ierr = MatGetOwnershipRange(schwarz_mat,&Lo,&Hi); CHKERRQ(ierr);
  
  big_row = (int *) PetscMalloc(size*sizeof(int)); CHKPTRQ(big_row);
  PetscMemzero(big_row,size*sizeof(int));
  
  Big_row = (int *) PetscMalloc(Size*sizeof(int)); CHKPTRQ(Big_row);
  PetscMemzero(Big_row,Size*sizeof(int));
  
  /* mark all variables we own in the extended sense */
  {
    int *rows;
    ierr = ISGetIndices(Schwarz_rows,&rows); CHKERRQ(ierr);
    for (i=Lo; i<Hi; i++)
      Big_row[i] = rows[i-Lo];
    ierr = ISRestoreIndices(Schwarz_rows,&rows);
    CHKERRQ(ierr);
  }
  
  /* make translation from original to extended numbering */
  for (i=0; i<Size; i++)
    if (Big_row[i]!=0) big_row[Big_row[i]] = i;
  PetscFree(Big_row);

#define INFO_TAG 1
#define DATA_TAG 2
  /* get the other processors' translations of local variables */
  {
    int mytid,ntids,*lows,*sizs,*big_row2;
    
    MPI_Comm_rank(comm,&mytid); MPI_Comm_size(comm,&ntids);
    lows = (int *) PetscMalloc((ntids+1)*sizeof(int)); CHKPTRQ(lows);
    sizs = (int *) PetscMalloc(ntids*sizeof(int)); CHKPTRQ(sizs);
    MPI_Allgather( &lo,1,MPI_INT, lows,1,MPI_INT, comm);
    lows[ntids] = size;
    for (i=0; i<ntids; i++) sizs[i] = lows[i+1]-lows[i];
    big_row2 = (int *) PetscMalloc(size*sizeof(int)); CHKPTRQ(big_row2);
    MPI_Allgatherv
      ( big_row+lo,hi-lo,MPI_INT, big_row2,sizs,lows,MPI_INT,comm);
    
    /* mark extended border */
    for (i=0; i<size; i++)
      if (big_row[i]==0) big_row[i] = big_row2[i];
    PetscFree(big_row2); PetscFree(lows); PetscFree(sizs);
  }

  *big_r = big_row;
  return 0;
}

#undef __FUNC__
#define __FUNC__ "UnOverlapBlocks"
static int UnOverlapBlocks(Mat base_mat,PC pc)
{
  PC_Schwarz_struct *pc_data = (PC_Schwarz_struct *) pc->data;
  MPI_Comm    comm = base_mat->comm;
  int domain_size,irow,ierr,mytid;
  Mat *ret;
  
  MPI_Comm_rank(comm,&mytid);

  {
    IS all_indices;
    int Isize,Jsize;

    ierr = MatGetSize(base_mat,&Isize,&Jsize); CHKERRQ(ierr);
    ierr = ISCreateStride
      (MPI_COMM_SELF,Jsize,0,1,&all_indices);
    CHKERRQ(ierr);
    ierr = MatGetSubMatrices
      (base_mat,1,&(pc_data->Schwarz_rows),&all_indices,
       MAT_INITIAL_MATRIX,&ret);
    CHKERRQ(ierr);
    /* printf("Submatrix %d\n",mytid); MatView(ret[0],0);*/
    ierr = ISDestroy(all_indices); CHKERRQ(ierr);
  }

  ierr = ISGetSize(pc_data->Schwarz_rows,&domain_size); CHKERRQ(ierr);

  {
    int *big_row,*band;

    /* get the number of nonzeros per row */
    band = (int *) PetscMalloc((domain_size+1)*sizeof(int)); CHKPTRQ(band);
    for (irow=0; irow<domain_size; irow++) {
      ierr = MatGetRow(ret[0],irow,band+irow,PETSC_NULL,PETSC_NULL);
      CHKERRQ(ierr);
      ierr = MatRestoreRow(ret[0],irow,band+irow,PETSC_NULL,PETSC_NULL);
      CHKERRQ(ierr);
    }

    ierr = MatCreateMPIAIJ
      (comm,domain_size,domain_size,PETSC_DECIDE,PETSC_DECIDE,
       PETSC_NULL,band,PETSC_NULL,band,&(pc_data->Schwarz_mat));
    CHKERRQ(ierr);
    ierr = TranslateNumberings
      (base_mat,pc_data->Schwarz_mat,pc_data->Schwarz_rows,&big_row);
    CHKERRQ(ierr);

    /* now pour the Schwarz matrix into the big matrix */
    for (irow=0; irow<domain_size; irow++) {
      int j,ncols,*cols,row; Scalar *vals;
      ierr = MatGetRow(ret[0],irow,&ncols,&cols,&vals); CHKERRQ(ierr);
      for (j=0; j<ncols; j++)
	cols[j] = big_row[cols[j]];
      row = pc_data->Schwarz_low+irow;
      /*      printf("[%d] Retrieved row %d of length %d; storing as %d\n",
	mytid,irow,ncols,row);*/
      ierr = MatSetValues
	(pc_data->Schwarz_mat,1,&row,ncols,cols,vals,INSERT_VALUES);
      CHKERRQ(ierr);
      ierr = MatRestoreRow(ret[0],irow,&ncols,&cols,&vals); CHKERRQ(ierr);
    }
  }
  ierr = MatAssemblyBegin
    (pc_data->Schwarz_mat,MAT_FINAL_ASSEMBLY);
  CHKERRQ(ierr);
  ierr = MatAssemblyEnd
    (pc_data->Schwarz_mat,MAT_FINAL_ASSEMBLY);
  CHKERRQ(ierr);
  /*  printf("Schwarz matrix\n"); MatView(pc_data->Schwarz_mat,0);*/
  
  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCSetup_Schwarz"
int PCSetup_Schwarz(PC pc)
{
  Mat base_mat = pc->pmat;
  PC_Schwarz_struct *pc_data = (PC_Schwarz_struct *) pc->data;
  int local_size,ierr;

  ierr = PetscPushErrorHandler(&ParPreTraceBackErrorHandler,0);
  CHKERRQ(ierr);

  if (!(base_mat->type==MATMPIAIJ)) {
    SETERRQ(1,0,"Overlapping preconditioner only implemented for AIJMPI\n");
  }

  /* Get the Schwarz matrix */
  ierr = GetOverlappingRange(base_mat,pc,&local_size,&(pc_data->Schwarz_rows));
  CHKERRQ(ierr);

  /* create vectors, scatters/pipelines */
  ierr = MakeSchwarzVectors(pc,local_size,&(pc_data->Schwarz_low));
  CHKERRQ(ierr);

  ierr = UnOverlapBlocks(base_mat,pc); CHKERRQ(ierr);
    
  ierr = SLESSetOperators
    (pc_data->ssor_solver,pc_data->Schwarz_mat,pc_data->Schwarz_mat,
     (MatStructure)0);
  CHKERRQ(ierr);
  ierr = SLESSetUp(pc_data->ssor_solver,pc_data->extended_vec,
		   pc_data->extended_vec2);
  CHKERRQ(ierr);

  ierr = PetscPopErrorHandler(); CHKERRQ(ierr);

  return 0;
}

#undef __FUNC__
#define __FUNC__ "PCView_Schwarz"
int PCView_Schwarz(PC pc,Viewer viewer)
{
  PC_Schwarz_struct *pc_data = (PC_Schwarz_struct *) pc->data;
  PC ssor_pc;
  int ierr;

  PetscFunctionBegin;
  PetscPrintf(pc->comm," Halo width: %d\n",pc_data->halo_width);
  PetscPrintf(pc->comm,">> Schwarz subdomains\n");
  ierr = SLESGetPC(pc_data->ssor_solver,&ssor_pc); CHKERRQ(ierr);
  ierr = PCSubdomainsView(ssor_pc,viewer); CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "PCDestroy_Schwarz"
int PCDestroy_Schwarz(PC pc)
{
  PC_Schwarz_struct *pc_data = (PC_Schwarz_struct *) pc->data;
  int ierr;

  PetscFunctionBegin;
  ierr = SLESDestroy(pc_data->ssor_solver); CHKERRQ(ierr);
  ierr = PCParallelDestroySubdomains(pc); CHKERRQ(ierr);

  ierr = MatDestroy(pc_data->Schwarz_mat); CHKERRQ(ierr);

  ierr = VecDestroy(pc_data->normal_vec); CHKERRQ(ierr);
  ierr = VecDestroy(pc_data->extended_vec); CHKERRQ(ierr);
  ierr = VecDestroy(pc_data->extended_vec2); CHKERRQ(ierr);

  ierr = VecScatterDestroy(pc_data->scatter_to_extended); CHKERRQ(ierr);
  ierr = VecPipelineDestroy(pc_data->pipe_from_extended); CHKERRQ(ierr);

  ierr = ISDestroy(pc_data->Schwarz_rows); CHKERRQ(ierr);

  PetscFunctionReturn(0);
}
