#include "mpi.h"
#include "mat.h"
#include "src/mat/parpre_matimpl.h"
#include "src/mat/impls/aij/mpi/mpiaij.h"
#include "./ml_impl.h"

/****************************************************************
 * Strong and weak connections                                  *
 ****************************************************************/
#undef __FUNC__
#define __FUNC__ "MakeStrongMatrix"
int MakeStrongMatrix(MC_OneLevel_struct *this_level,Mat *strong_matrix)
{
  Mat mat = this_level->mat,strong_mat;
  MPI_Comm comm = this_level->comm;
  int rstart = this_level->rstart,local_size = this_level->local_size;
  int iRow,idum,ierr;
#if TRACING
  int K=0,D=0;
#endif
  Scalar zero=0.0;
  
#define WEIGH 1
  Mat_MPIAIJ  *Aij = (Mat_MPIAIJ *) mat->data;
  Mat A = Aij->A, B = Aij->B;
  Scalar *rA,*cA,*cB,weight=.45;
  
  /* compare strong connections to max off diag element per row */
  ierr = MatMaxRowOffDiagElement_MPIAIJ(mat,this_level->u); CHKERRQ(ierr);
  ierr = VecSet(&zero,this_level->v); CHKERRQ(ierr);
  ierr = MatMaxColOffDiagElement_MPIAIJ(mat,this_level->v); CHKERRQ(ierr);
  ierr = VecScatterBegin(this_level->v,Aij->lvec,INSERT_VALUES,
			 SCATTER_FORWARD,Aij->Mvctx); CHKERRQ(ierr);
  ierr = VecScatterEnd(this_level->v,Aij->lvec,INSERT_VALUES,
		       SCATTER_FORWARD,Aij->Mvctx); CHKERRQ(ierr);
  ierr = VecGetArray(this_level->u,&rA); CHKERRQ(ierr);
  ierr = VecGetArray(this_level->v,&cA); CHKERRQ(ierr);
  ierr = VecGetArray(Aij->lvec,&cB); CHKERRQ(ierr);
  
  ierr = MatGetLocalSize(mat,&local_size,&idum); CHKERRQ(ierr);
  {
    int *band;
    band = (int *) PetscMalloc((local_size+1)*sizeof(int));
    for (iRow=0; iRow<local_size; iRow++) {
      ierr = MatGetRow(A,iRow,band+iRow,PETSC_NULL,PETSC_NULL);
      CHKERRQ(ierr);
      ierr = MatRestoreRow(A,iRow,band+iRow,PETSC_NULL,PETSC_NULL);
      CHKERRQ(ierr);
    }
    ierr = MatCreateMPIAIJ
      (comm,local_size,local_size,PETSC_DECIDE,PETSC_DECIDE,
       0,band,0,band, &strong_mat);
    CHKERRQ(ierr);
    PetscFree(band);
  }
  for (iRow=0; iRow<local_size; iRow++) {
    int Row=rstart+iRow,ncols,*cols,iCol,auto_accept=1;
    Scalar *vals,d;
    /* get the row for as far as it's in A */
    ierr = MatGetRow(A,iRow,&ncols,&cols,&vals); CHKERRQ(ierr);
    if (ncols>5) auto_accept = 0;
    for (iCol=0; iCol<ncols; iCol++) {
      int Col=rstart+cols[iCol]; Scalar v=vals[iCol];
      if (WEIGH) d=sqrt(rA[iRow]*cA[cols[iCol]]);
      if ( (auto_accept) | (Col==Row) | (fabs(v)>weight*d) ) {
#if TRACING
	K++;
#endif
	ierr = MatSetValues
	  (strong_mat,1,&Row,1,&Col,&v,ADD_VALUES); CHKERRQ(ierr);
      } else {
#if TRACING
	D++;
#endif
	ierr = MatSetValues
	  (strong_mat,1,&Row,1,&Row,&v,ADD_VALUES); CHKERRQ(ierr);
      }
    }
    ierr = MatRestoreRow(A,iRow,&ncols,&cols,&vals); CHKERRQ(ierr);
    /* get the row for as far as it's in B */
    ierr = MatGetRow(B,iRow,&ncols,&cols,&vals); CHKERRQ(ierr);
    for (iCol=0; iCol<ncols; iCol++) {
      int Col=Aij->garray[cols[iCol]]; Scalar v=vals[iCol];
      if (WEIGH) d=sqrt(rA[iRow]*cB[cols[iCol]]);
      if (fabs(v)>weight*d) {
#if TRACING
	K++;
#endif
	ierr = MatSetValues
	  (strong_mat,1,&Row,1,&Col,vals+iCol,ADD_VALUES); CHKERRQ(ierr);
      } else {
#if TRACING
	D++;
#endif
	ierr = MatSetValues
	  (strong_mat,1,&Row,1,&Row,vals+iCol,ADD_VALUES); CHKERRQ(ierr);
      }
    }
    ierr = MatRestoreRow(B,iRow,&ncols,&cols,&vals); CHKERRQ(ierr);
  }
  ierr = VecRestoreArray(this_level->u,&rA); CHKERRQ(ierr);
  ierr = VecRestoreArray(this_level->v,&cA); CHKERRQ(ierr);
  ierr = VecRestoreArray(Aij->lvec,&cB); CHKERRQ(ierr);
  ierr = MatAssemblyBegin(strong_mat,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  ierr = MatAssemblyEnd(strong_mat,MAT_FINAL_ASSEMBLY); CHKERRQ(ierr);
  *strong_matrix = strong_mat;
/*printf("strong matrix\n"); MatView(strong_mat,0);*/


#if TRACING
  {int b,a; MatMaxRowLen_MPIAIJ(this_level->mat,&b);
   MatMaxRowLen_MPIAIJ(strong_mat,&a); if(!K)K=1;
   printf("Row lengths original=%d, strong=%d; discard %d=%d pct\n",
	  b,a,D,100*D/(D+K));
 }
#endif
  
  return 0;
}

