#include "petsc.h"
#include "mpi.h"
#include "is.h"
#include "mat.h"
#include "src/mat/parpre_matimpl.h"
#include "src/mat/impls/aij/mpi/mpiaij.h"
#include "src/pc/pcextra.h"
#include "./ml_impl.h"

/****************************************************************
 * Colouring and finding independent sets                       *
 ****************************************************************/

#if !RUGE
#undef __FUNC__
#define __FUNC__ "FreeClr"
static int FreeClr(int nn,int *n)
{
  int ret=1,i;
 try:
  for (i=0; i<nn; i++)
    if (n[i]==ret) {ret++; goto try;}
  return ret;
}

#undef __FUNC__
#define __FUNC__ "LookAtThisRow"
static int LookAtThisRow
(int this_var, Scalar this_rand,int *idx,int nidx, int *colour_now,
 Scalar *ran,Scalar *clr,int *n,int *neigh)
{
  int j,other;

  for (j=0; j<nidx; j++) {
    other=idx[j];
    if (other==this_var) continue;
    if (ran[other]>this_rand) {
      if (clr[other]==0.) {
	*colour_now = 0;
	return 0;
      } else {
	neigh[(*n)++] = (int)clr[other];
      }
    }
  }
  return 0;
}

#undef __FUNC__
#define __FUNC__ "LookAtUncolouredVar"
static int LookAtUncolouredVar
(int this_var,Mat A,Mat B, int *colour_now,
 Scalar *clr_array,Scalar *clrb_array,Scalar *ran_array,Scalar *ranb_array,
 int *neighb,int *clr)
{
  int nzA,nzB, *pcA,*pcB, nneigh=0, ierr;

  ierr = MatGetRow(A,this_var,&nzA,&pcA,PETSC_NULL); CHKERRQ(ierr);
  ierr = LookAtThisRow
    (this_var,ran_array[this_var], pcA,nzA, colour_now,ran_array,clr_array,
     &nneigh,neighb);
  CHKERRQ(ierr);
  /*  printf("Var %d, row %d, ",this_var,nzA);*/
  ierr = MatRestoreRow(A,this_var,&nzA,&pcA,0); CHKERRQ(ierr);
  
  if (*colour_now) {
    ierr = MatGetRow(B,this_var,&nzB,&pcB,PETSC_NULL); CHKERRQ(ierr);
    if (nzB) {
      ierr = LookAtThisRow
	(-1,ran_array[this_var], pcB,nzB, colour_now,ranb_array,clrb_array,
	 &nneigh,neighb);
      CHKERRQ(ierr);
    }
    /*    printf("%d, ",nzB);*/
    ierr = MatRestoreRow(B,this_var,&nzB,&pcB,0); CHKERRQ(ierr);
  }
  *clr = FreeClr(nneigh,neighb);
  /*  printf("nn=%d, now=%d, clr=%d\n",nneigh,*colour_now,*clr);*/
  /*  if (nzB>10) MatView(B,0);*/
  return 0;
}

#undef __FUNC__
#define __FUNC__ "FindIndSet"
static int FindIndSet
(Mat mat,Vec randoms,Vec clrs,IS *indp_vars,IS *rest_vars) 
{
  Mat_MPIAIJ *aij = (Mat_MPIAIJ *) mat->data;
  Scalar *ran_array,*ranb_array, *clr_array,*clrb_array, zero=0.0;
  Vec ran_bord,clr_bord = aij->lvec;
  int *neighb, rstart,rend,local_size,var,coloured, ierr;

  ierr = VecGetOwnershipRange(randoms,&rstart,&rend); CHKERRQ(ierr);
  local_size = rend-rstart;
  ierr = VecDuplicate(aij->lvec,&ran_bord); CHKERRQ(ierr);
  ierr = VecGetArray(randoms,&ran_array); CHKERRQ(ierr);
  ierr = VecGetArray(ran_bord,&ranb_array); CHKERRQ(ierr);
  ierr = VecGetArray(clrs,&clr_array); CHKERRQ(ierr);
  ierr = VecGetArray(clr_bord,&clrb_array); CHKERRQ(ierr);
  ierr = VecSet(&zero,clrs); CHKERRQ(ierr);
  ierr = VecSet(&zero,clr_bord); CHKERRQ(ierr);
  {
    int rl;
    ierr = MatMaxRowLen_MPIAIJ(mat,&rl); CHKERRQ(ierr);
    neighb = (int *) PetscMalloc(rl*sizeof(int)); CHKPTRQ(neighb);
  }
  /*>>>> Loop until completely coloured <<<<*/
  ierr = VecScatterBegin
    (randoms,ran_bord,INSERT_VALUES,SCATTER_FORWARD,aij->Mvctx); CHKERRQ(ierr);
  ierr = VecScatterEnd
    (randoms,ran_bord,INSERT_VALUES,SCATTER_FORWARD,aij->Mvctx); CHKERRQ(ierr);
  coloured = 0;
 pass:
  {
    int l_rem,g_rem;
    for (var=0; var<local_size; var++) {
      if (!clr_array[var]) {
	int colour_now = 1,clr;
	ierr = LookAtUncolouredVar
	  (var,aij->A,aij->B,&colour_now,clr_array,clrb_array,
	   ran_array,ranb_array,neighb,&clr); CHKERRQ(ierr);
	if (colour_now) {
	  coloured++; clr_array[var] = (Scalar) clr;
	}
      }
    }
    l_rem = local_size-coloured;
    MPI_Allreduce(&l_rem,&g_rem,1,MPI_INT,MPI_MAX,mat->comm);
    if (!g_rem) goto finished;
    if (g_rem<0) SETERRQ(1,0,"Cannot happen: negative points to colour");
    ierr = VecScatterBegin
      (clrs,clr_bord,INSERT_VALUES,SCATTER_FORWARD,aij->Mvctx); CHKERRQ(ierr);
    ierr = VecScatterEnd
      (clrs,clr_bord,INSERT_VALUES,SCATTER_FORWARD,aij->Mvctx); CHKERRQ(ierr);
    goto pass;
  }

 finished:
  ierr = VecRestoreArray(randoms,&ran_array); CHKERRQ(ierr);
  ierr = VecRestoreArray(ran_bord,&ranb_array); CHKERRQ(ierr);
  ierr = VecDestroy(ran_bord); CHKERRQ(ierr);
  {
    int n_coloured=0,n_not_coloured=0;
    int *points_coloured,*points_not_coloured;
    
    for (var=0; var<local_size; var++)
      if ((int)clr_array[var]==1) n_coloured++; else n_not_coloured++;

    points_coloured = (int *) PetscMalloc((n_coloured+1)*sizeof(int));
    CHKPTRQ(points_coloured);
    points_not_coloured = (int *) PetscMalloc((n_not_coloured+1)*sizeof(int));
    CHKPTRQ(points_not_coloured);

    n_coloured = 0; n_not_coloured = 0;
    for (var=0; var<local_size; var++)
      if ((int)clr_array[var]==1)
	points_coloured[n_coloured++] = var;
      else
	points_not_coloured[n_not_coloured++] = var;

    for (var=0; var<n_coloured; var++)
	points_coloured[var] += rstart;
    ierr = ISCreateGeneral
      (mat->comm,n_coloured,points_coloured,indp_vars);
    CHKERRQ(ierr); PetscFree(points_coloured);

    for (var=0; var<n_not_coloured; var++)
	points_not_coloured[var] += rstart;
    ierr = ISCreateGeneral
      (mat->comm,n_not_coloured,points_not_coloured,rest_vars);
    CHKERRQ(ierr); PetscFree(points_not_coloured);
  }
  
  return 0;
}
#endif

#if RUGE
#undef __FUNC__
#define __FUNC__ "FakeIndset"
static int FakeIndset(MC_OneLevel_struct *this_level,
		      IS *indp_vars,IS *rest_vars)
{
  MPI_Comm comm = this_level->comm;
  int istep,jstep,ierr;
  IndStash points_coloured,points_not_coloured;

  ierr = NewIndexStash(&points_coloured); CHKERRQ(ierr);
  ierr = NewIndexStash(&points_not_coloured); CHKERRQ(ierr);
  for (istep=0; istep<this_level->isize; istep++)
    for (jstep=0; jstep<this_level->jsize; jstep++) {
      int num = istep*this_level->jsize+jstep;
      int lvl = this_level->level;
      
      if (num<this_level->rend & !(num<this_level->rstart)) {
	if (lvl-2*(lvl/2)==0) {
	  if (istep-2*(istep/2)+jstep-2*(jstep/2)!=1) {
	    ierr = StashIndex(points_coloured,1,&num); CHKERRQ(ierr);
	  } else {
	    ierr = StashIndex(points_not_coloured,1,&num); CHKERRQ(ierr);
	  }
	} else {
	  if (istep-2*(istep/2)==0) {
	    ierr = StashIndex(points_coloured,1,&num); CHKERRQ(ierr);
	  } else {
	    ierr = StashIndex(points_not_coloured,1,&num); CHKERRQ(ierr);
	  }
	}
      }
    }
  ierr = ISCreateGeneral
    (comm,points_coloured->n,points_coloured->array,indp_vars);
  CHKERRQ(ierr); ierr = DestroyIndexStash(points_coloured); CHKERRQ(ierr);
  ierr = ISCreateGeneral
    (comm,points_not_coloured->n,points_not_coloured->array,rest_vars);
  CHKERRQ(ierr); ierr = DestroyIndexStash(points_not_coloured); CHKERRQ(ierr);
  
  return 0;
}
#endif

#undef __FUNC__
#define __FUNC__ "SplitIndAndRest"
static int SplitIndAndRest(MC_OneLevel_struct *this_level,Mat mat,
			   IS *ind,IS *rest)
{/* create random vector, derive one independent set */
  IS ind_set,rest_set;
#if !RUGE
  PetscRandom rctx;
#endif
  int ierr;
  
#if !RUGE
  ierr = PetscRandomCreate(MPI_COMM_SELF,RANDOM_DEFAULT,&rctx); CHKERRQ(ierr);
  ierr = VecSetRandom(rctx,this_level->u); CHKERRQ(ierr);

  ierr = FindIndSet
    (mat,this_level->u,this_level->v,&ind_set,&rest_set); CHKERRQ(ierr);

  ierr = PetscRandomDestroy(rctx); CHKERRQ(ierr);
#endif
#if RUGE
  ierr = FakeIndset(this_level,&ind_set,&rest_set); CHKERRQ(ierr);
#endif
  *ind = ind_set; *rest = rest_set;

  return 0;
}

#undef __FUNC__
#define __FUNC__ "SplitOffOneLevel"
int SplitOffOneLevel
(MC_OneLevel_struct *this_level,Mat mat,
 IS *set1,IS *set2,IS *set2_g,int *global_rest,int *local_rest)
{
  IS ind,rest;
  int local_i,local_r,ierr;

  ierr = SplitIndAndRest(this_level,mat,&ind,&rest); CHKERRQ(ierr);
/*printf("Independent set on level %d\n",this_level->level);ISView(ind,0);*/
  ierr = ISGetSize(ind,&local_i); CHKERRQ(ierr);
  ierr = ISGetSize(rest,&local_r); CHKERRQ(ierr);

  /* Is one processor out of points? Are all? */
  MPI_Allreduce(&local_r,global_rest,1,MPI_INT,MPI_SUM,this_level->comm);
  {
    int s_local_r,s_global_r;
    if (local_r) s_local_r = 1; else s_local_r = 0;
    MPI_Allreduce(&s_local_r,&s_global_r,1,MPI_INT,MPI_PROD,this_level->comm);
    *local_rest = s_global_r * local_r;
  }
  if (*global_rest==0)
    PetscPrintf(this_level->comm,"All processors out of rest\n");
  else if (*local_rest==0)
    PetscPrintf(this_level->comm,"One processors out of rest\n");

  if ( (*global_rest==0) | (this_level->grid_choice==AMLCoarseGridDependent) )
    {
      *set1 = ind; *set2 = rest;
      this_level->size1 = local_i; this_level->size2 = *local_rest; 
    } else {
      *set1 = rest; *set2 = ind;
      this_level->size1 = *local_rest; this_level->size2 = local_i;
    }
  printf("Colour has %d points, %d points remaining\n",
	 this_level->size1,this_level->size2);
  
  {/* get the global indices of the next set */
    int s,*i,*ig1,*ig2,*ig,var;

    s = this_level->size2;
    ig2 = (int *) PetscMalloc((s+1)*sizeof(int)); CHKPTRQ(ig2);
    ierr = ISGetIndices(*set2,&i); CHKERRQ(ierr);
    ierr = ISGetIndices(this_level->indices,&ig); CHKERRQ(ierr);
    for (var=0; var<s; var++) ig2[var] = ig[i[var]-this_level->rstart];
    ierr = ISRestoreIndices(*set2,&i); CHKERRQ(ierr);
    ierr = ISCreateGeneral(this_level->comm,s,ig2,set2_g); CHKERRQ(ierr);

    s = this_level->size1;
    ig1 = (int *) PetscMalloc((s+1)*sizeof(int)); CHKPTRQ(ig1);
    ierr = ISGetIndices(*set1,&i); CHKERRQ(ierr);
    for (var=0; var<s; var++) ig1[var] = ig[i[var]-this_level->rstart];
    ierr = ISRestoreIndices(*set1,&i); CHKERRQ(ierr);
    ierr = ISRestoreIndices(this_level->indices,&ig); CHKERRQ(ierr);
    ierr = ISCreateGeneral(this_level->comm,s,ig1,&(this_level->indices1));
    CHKERRQ(ierr);
  }
  return 0;
}

