#include <math.h>
#include "petsc.h"
#include "mpi.h"
#include "parpre_mat.h"
#include "petscsles.h"
#include "src/mat/matimpl.h"
#include "src/mat/parpre_matimpl.h"
#include "petscpc.h"
#include "./ml_head.h"

#include "src/sles/pc/utils/auxs.c"

/****************************************************************
 * Level splitting routines.
 * None of these use the multilevel data structure
 ****************************************************************/

#undef __FUNC__
#define __FUNC__ "SetupSubVectors"
/* Create sub vectors for the colour and rest sets on a level,
   also the scatters to and from these two sets */
static int SetupSubVectors
(MPI_Comm comm,IS clr,IS rest,Vec u,
 Vec *g1,Vec *u1,Vec *h1,Vec *v1,Vec *k1,Vec *w1,
 Vec *g2,Vec *u2,Vec *h2,Vec *v2,Vec *k2,Vec *w2,
 VecScatter *get_clr,VecScatter *put_clr,
 VecScatter *get_rest,VecScatter *put_rest)
{
  IS contig_1,contig_2;
  int s1,s2,ierr;
  Scalar *arr;

  PetscFunctionBegin;
  ierr = ISGetSize(clr,&s1); CHKERRQ(ierr);
  ierr = ISGetSize(rest,&s2); CHKERRQ(ierr);

  /* global vectors and local vectors with space aliased */
  ierr = VecCreateMPI(comm,s1,PETSC_DECIDE,g1); CHKERRQ(ierr);
  ierr = VecGetArray(*g1,&arr); CHKERRQ(ierr);
  ierr = VecCreateSeqWithArray(MPI_COMM_SELF,s1,arr,u1); CHKERRQ(ierr);
  ierr = VecRestoreArray(*g1,&arr); CHKERRQ(ierr);

  ierr = VecDuplicate(*g1,h1); CHKERRQ(ierr);
  ierr = VecGetArray(*h1,&arr); CHKERRQ(ierr);
  ierr = VecCreateSeqWithArray(MPI_COMM_SELF,s1,arr,v1); CHKERRQ(ierr);
  ierr = VecRestoreArray(*h1,&arr); CHKERRQ(ierr);

  ierr = VecDuplicate(*g1,k1); CHKERRQ(ierr);
  ierr = VecGetArray(*k1,&arr); CHKERRQ(ierr);
  ierr = VecCreateSeqWithArray(MPI_COMM_SELF,s1,arr,w1); CHKERRQ(ierr);
  ierr = VecRestoreArray(*k1,&arr); CHKERRQ(ierr);

  ierr = ISCreateStride(MPI_COMM_SELF,s1,0,1,&contig_1); CHKERRQ(ierr);
  ierr = VecScatterCreate(u,clr,*u1,contig_1,get_clr); CHKERRQ(ierr);
  ierr = VecScatterCreate(*u1,contig_1,u,clr,put_clr); CHKERRQ(ierr);
  ierr = ISDestroy(contig_1); CHKERRQ(ierr);

  ierr = VecCreateMPI(comm,s2,PETSC_DECIDE,g2); CHKERRQ(ierr);
  ierr = VecGetArray(*g2,&arr); CHKERRQ(ierr);
  ierr = VecCreateSeqWithArray(MPI_COMM_SELF,s2,arr,u2); CHKERRQ(ierr);
  ierr = VecRestoreArray(*g2,&arr); CHKERRQ(ierr);

  ierr = VecDuplicate(*g2,h2); CHKERRQ(ierr);
  ierr = VecGetArray(*h2,&arr); CHKERRQ(ierr);
  ierr = VecCreateSeqWithArray(MPI_COMM_SELF,s2,arr,v2); CHKERRQ(ierr);
  ierr = VecRestoreArray(*h2,&arr); CHKERRQ(ierr);

  ierr = VecDuplicate(*g2,k2); CHKERRQ(ierr);
  ierr = VecGetArray(*k2,&arr); CHKERRQ(ierr);
  ierr = VecCreateSeqWithArray(MPI_COMM_SELF,s2,arr,w2); CHKERRQ(ierr);
  ierr = VecRestoreArray(*k2,&arr); CHKERRQ(ierr);
  
  ierr = ISCreateStride(MPI_COMM_SELF,s2,0,1,&contig_2); CHKERRQ(ierr);
  ierr = VecScatterCreate(u,rest,*u2,contig_2,get_rest); CHKERRQ(ierr);
  ierr = VecScatterCreate(*u2,contig_2,u,rest,put_rest); CHKERRQ(ierr);
  ierr = ISDestroy(contig_2); CHKERRQ(ierr);

  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "SetDiagonalMatrices"
static int SetDiagonalMatrices
(Vec diag,Mat strong_mat,int make_strong,VecScatter get_clr,Vec u1,
 Vec *cdiag,Vec *cdiag1)
{
  int ierr;

  PetscFunctionBegin;

  ierr = VecReciprocal(diag); CHKERRQ(ierr);
  if (make_strong) {
    ierr = VecDuplicate(diag,cdiag); CHKERRQ(ierr);
    ierr = MatGetDiagonal(strong_mat,*cdiag); CHKERRQ(ierr);
    ierr = VecReciprocal(*cdiag); CHKERRQ(ierr);
  } else *cdiag = diag;

  /* more local vectors */
  ierr = VecDuplicate(u1,cdiag1); CHKERRQ(ierr);

  /* use the scatter contexts to get the diagonal on the 1-block */
  ierr = VecScatterBegin(*cdiag,*cdiag1,INSERT_VALUES,SCATTER_FORWARD,get_clr);
  CHKERRQ(ierr);
  ierr = VecScatterEnd(*cdiag,*cdiag1,INSERT_VALUES,SCATTER_FORWARD,get_clr);
  CHKERRQ(ierr);

  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "ExtractBlocks"
/* Extract the four matrix blocks corresponding to the colour and
   rest sets. Both local (on-processor) and distributed versions
   are needed. We get some from the original matrix, some
   from the strong rendering of it. This will need some more thought. */
static int ExtractBlocks
(IS clr,IS rest,Mat mat,Mat smat,
 Mat *g12,Mat *a12,  Mat *g21,Mat *a21,Mat *l21,  Mat *g11,Mat *g22)
{
  Mat *res_mat;
  MPI_Comm comm = mat->comm;
#define STRNG 4
#define ORIG 2
  IS grest,gclr,i_is[STRNG+ORIG],j_is[STRNG+ORIG];
  int loc,local_next,ierr;

  ierr = ISGetSize(clr,&loc); CHKERRQ(ierr);
  ierr = ISGetSize(rest,&local_next); CHKERRQ(ierr);
  ierr = ISGetGlobalContent(comm,clr,&gclr); CHKERRQ(ierr);
  ierr = ISGetGlobalContent(comm,rest,&grest); CHKERRQ(ierr); 

  i_is[0] = clr;   i_is[1] = rest; i_is[2] = grest; i_is[3] = rest;
  j_is[0] = grest; j_is[1] = gclr; j_is[2] = clr;   j_is[3] = clr;

  i_is[STRNG] = clr;  i_is[STRNG+1] = rest;
  j_is[STRNG] = gclr; j_is[STRNG+1] = grest;

  /* get the off-diagonal subblocks from the strong matrix */
  ierr = MatGetSubMatrices(smat,STRNG,i_is,j_is,MAT_INITIAL_MATRIX,
			   &res_mat); CHKERRQ(ierr);
  ierr = MatrixAij2MpiAbut(comm,local_next,res_mat[0],0,g12); CHKERRQ(ierr);
  *a12 = res_mat[0];
  ierr = MatrixAij2MpiAbut(comm,loc,res_mat[1],0,g21); CHKERRQ(ierr);
  ierr = MatDestroy(res_mat[1]); CHKERRQ(ierr);
  *a21 = res_mat[2];
  *l21 = res_mat[3];
  PetscFree(res_mat);

  /* get diagonal blocks from the original matrix */
  ierr = MatGetSubMatrices(mat,ORIG,i_is+STRNG,j_is+STRNG,MAT_INITIAL_MATRIX,
			   &res_mat); CHKERRQ(ierr);
  ierr = MatrixAij2MpiAbut(comm,loc,res_mat[0],1,g11); CHKERRQ(ierr);
  ierr = MatDestroy(res_mat[0]); CHKERRQ(ierr);
  ierr = MatrixAij2MpiAbut(comm,local_next,res_mat[1],1,g22); CHKERRQ(ierr);
  ierr = MatDestroy(res_mat[1]); CHKERRQ(ierr);
  PetscFree(res_mat);

  ierr = ISDestroy(gclr); CHKERRQ(ierr);
  ierr = ISDestroy(grest); CHKERRQ(ierr);

  return 0;
}

#undef __FUNC__
#define __FUNC__ "SetGlobalIndices"
/* Given the colour and rest sets, relative to the current level,
   compute their global indices in the top level. */
static int SetGlobalIndices(MPI_Comm comm,int rstart,
		     IS indices,IS set1,IS set2,
		     IS *indices1,IS *indices2,IS *g_indices2)
{
  int s,*i,*ig1,*ig2,*ig,var,ierr;
  PetscFunctionBegin;

  ierr = ISGetIndices(indices,&ig); CHKERRQ(ierr);

  ierr = ISGetSize(set1,&s); CHKERRQ(ierr);
  ig1 = (int *) PetscMalloc((s+1)*sizeof(int)); CHKPTRQ(ig1);
  ierr = ISGetIndices(set1,&i); CHKERRQ(ierr);
  for (var=0; var<s; var++) ig1[var] = ig[i[var]-rstart];
  ierr = ISRestoreIndices(set1,&i); CHKERRQ(ierr);
  ierr = ISCreateGeneral(MPI_COMM_SELF,s,ig1,indices1); CHKERRQ(ierr);
  PetscFree(ig1);
    
  ierr = ISGetSize(set2,&s); CHKERRQ(ierr);
  ig2 = (int *) PetscMalloc((s+1)*sizeof(int)); CHKPTRQ(ig2);
  ierr = ISGetIndices(set2,&i); CHKERRQ(ierr);
  for (var=0; var<s; var++) ig2[var] = ig[i[var]-rstart];
  ierr = ISRestoreIndices(set2,&i); CHKERRQ(ierr);
  ierr = ISCreateGeneral(MPI_COMM_SELF,s,ig2,indices2); CHKERRQ(ierr);
  /*for (var=0; var<s; var++) ig2[var] = ig2[var]+rstart;*/
  /*ISView(*indices2,0);*/
  ierr = ISCreateGeneral(MPI_COMM_SELF,s,ig2,g_indices2); CHKERRQ(ierr);
  /*ISView(*g_indices2,0);*/
  PetscFree(ig2);

  ierr = ISRestoreIndices(indices,&ig); CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "PartitionMatrix"
/* Find colour and rest sets on the current level, split the matrix
   accordingly, and decide whether or not to continue depending
   on the size of the rest set. */
int PartitionMatrix
(Mat mat,Vec u,Vec v,int level,AMLCoarseGridChoice grid_choice,
 Vec *g1,Vec *u1,Vec *h1,Vec *v1,Vec *k1,Vec *w1,
 Vec *g2,Vec *u2,Vec *h2,Vec *v2,Vec *k2,Vec *w2,
 VecScatter *get_clr,VecScatter *put_clr,
 VecScatter *get_rest,VecScatter *put_rest,
 Vec diag,Vec *cdiag,Vec *cdiag1,
 IS indices,IS *indices1,IS *indices2,IS *g_indices2,
 Mat *g12,Mat *a12,  Mat *g21,Mat *a21,Mat *l21,  Mat *g11,Mat *g22,
 int trace,int make_strong,double weight,
 int cutoff,int *stop_now,int *back_track,
 int orth,int isize,int jsize)
{
  MPI_Comm comm = mat->comm;
  Mat strong_mat;
  int ierr;
  
  PetscFunctionBegin;

  if (make_strong) {
    ierr = MakeStrongMatrix(mat,0,weight,u,v,&strong_mat,trace); CHKERRQ(ierr);
  } else {
    strong_mat = mat;
  }

  /* split this level into F and C */
  {
    IS ind,rest,set1,set2; int one_zero,all_zero;
    if (trace & AMLTraceProgress) PetscPrintf(comm,"..splitting index set\n");
    ierr = SplitOffOneLevel(strong_mat,u,v,level, &ind,&rest,
			    &all_zero,&one_zero,
			    orth,isize,jsize);
    CHKERRQ(ierr);
    if (trace) {
      if (all_zero)
	PetscPrintf(comm,"All processors out of rest\n");
      else if (one_zero)
	PetscPrintf(comm,"Some (not all) processors out of rest\n");
    }
    ierr = SetFandCsets(comm,level,trace,ind,rest,
			&set1,&set2,grid_choice,all_zero); CHKERRQ(ierr);
    if (trace & AMLTraceProgress) PetscPrintf(comm,"..sub vectors\n");
    ierr = SetupSubVectors
      (comm,set1,set2,u,
       g1,u1,h1,v1,k1,w1,
       g2,u2,h2,v2,k2,w2, get_clr,put_clr,get_rest,put_rest);
    CHKERRQ(ierr);
    ierr = SetDiagonalMatrices
      (diag,strong_mat,make_strong,*get_clr,*u1,cdiag,cdiag1);
    CHKERRQ(ierr);

    {
      int rstart,rend;
      ierr = MatGetOwnershipRange(mat,&rstart,&rend); CHKERRQ(ierr);
      ierr = SetGlobalIndices(comm,rstart,indices,set1,set2,
			      indices1,indices2,g_indices2);
      CHKERRQ(ierr);
      if (trace & AMLTraceIndexSets) {
	PetscPrintf(comm,"Index set 1:\n");
	ierr = ISView(*indices1,0); CHKERRQ(ierr);
	PetscPrintf(comm,"Index set 2:\n");
	ierr = ISView(*indices2,0); CHKERRQ(ierr);
      }
    }
    *back_track = all_zero;
    {
      int s2,ss2;
      ierr = ISGetSize(set2,&ss2); CHKERRQ(ierr);
      MPI_Allreduce(&ss2,&s2,1,MPI_INT,MPI_MAX,comm);
      /*if (trace)
	PetscPrintf(comm,"Level %d decisions: one proc zero = %d, all procs zero = %d, max remaining size = %d (cutoff=%d)\n",level,one_zero,all_zero,s2,cutoff);*/
      *stop_now = one_zero || all_zero || (s2<cutoff);
    }
    if (trace & AMLTraceProgress) PetscPrintf(comm,"..extracting blocks\n");
    ierr = ExtractBlocks(set1,set2,mat,strong_mat,g12,a12,g21,a21,l21,g11,g22);
    CHKERRQ(ierr);
    ierr = ISDestroy(set1); CHKERRQ(ierr);
    ierr = ISDestroy(set2); CHKERRQ(ierr);
  }
  if (make_strong) {
    ierr = MatDestroy(strong_mat); CHKERRQ(ierr);
  }
  
  PetscFunctionReturn(0);
}

