/*  split.c  */

#include "../spoolesMPI.h"

#define MYDEBUG 0

/*--------------------------------------------------------------------*/
typedef struct _TriplesBuffer TriplesBuffer ;
struct _TriplesBuffer {
   int      size     ;
   int      maxsize  ;
   int      *ivec1   ;
   int      *ivec2   ;
   double   *dvec    ;
   void     *base    ;
} ;
static TriplesBuffer * TriplesBuffer_new ( void ) ;
static void TriplesBuffer_setDefaultFields ( TriplesBuffer *buffer ) ;
static void TriplesBuffer_clearData ( TriplesBuffer *buffer ) ;
static void TriplesBuffer_free ( TriplesBuffer *buffer ) ;
static void TriplesBuffer_init ( TriplesBuffer *buffer, int size ) ;
static void extractTriples ( DInpMtx *inpmtx, TriplesBuffer *buffer,
   int target, int map[] ) ;
static void TriplesBuffer_sendrecv ( TriplesBuffer *outbuffer,
   TriplesBuffer *inbuffer, int destination, int source, int tag,
   MPI_Comm comm, MPI_Status *status, int stats[], int msglvl,
   FILE *msgFile ) ;
static void TriplesBuffer_writeForHumanEye ( TriplesBuffer *buffer,
   FILE *msgFile ) ;
/*--------------------------------------------------------------------*/
/*
   ------------------------------------------------------------
   purpose -- to split a distributed DInpMtx object

   inpmtx     -- pointer to the local DInpMtx object
   mapIV      -- pointer to the map from vertices to processes
   firsttag   -- first tag value, one will be used
   stats[4]    -- statistics vector
      stats[0] -- # of messages sent
      stats[1] -- # of messages received
      stats[2] -- # of bytes sent
      stats[3] -- # of bytes received
   msglvl     -- local message level
   msgFile    -- local message file
   comm       -- MPI communication structure

   return value -- pointer to the new DInpMtx object 
      that contains the owned entries.

   created  -- 97jun20, cca
   modified -- 97oct17, cca
      stats added
   ------------------------------------------------------------
*/
DInpMtx *
DInpMtx_MPI_split (
   DInpMtx    *inpmtx,
   IV         *mapIV,
   int        firsttag,
   int        stats[],
   int        msglvl,
   FILE       *msgFile,
   MPI_Comm   comm
) {
double          *dvec ;
DInpMtx         *keepmtx ;
int             destination, flag, ient, incount, inputMode, iproc, 
                ival1, left, myid, nent, nkeep, nowned, nproc, nvtx, 
                offset, outcount, right, source, tag, tag_bound, val ;
int             *incounts, *ivec1, *ivec2, *map, *outcounts ;
MPI_Status      status ;
TriplesBuffer   *inbuffer, *outbuffer ;
/*
   --------------------------------------
   get id of self and number of processes
   --------------------------------------
*/
MPI_Comm_rank(comm, &myid) ;
MPI_Comm_size(comm, &nproc) ;
#if MYDEBUG > 0
fprintf(stdout, "\n proc %d : inside split", myid) ;
fflush(stdout) ;
#endif
/*
   --------------------------
   check the data and the map
   --------------------------
*/
MPI_Attr_get(MPI_COMM_WORLD, MPI_TAG_UB, &tag_bound, &flag) ;
if ( firsttag < 0 || firsttag > tag_bound ) {
   fprintf(stderr, "\n fatal error in DInpMtx_MPI_split()"
           "\n firsttag = %d, tag_bound = %d", firsttag, tag_bound) ;
   exit(-1) ;
}
nent  = DInpMtx_nent(inpmtx) ;
ivec1 = DInpMtx_ivec1(inpmtx) ;
ivec2 = DInpMtx_ivec2(inpmtx) ;
if ( (inputMode = DInpMtx_inputMode(inpmtx)) == 2 ) {
   dvec = DInpMtx_dvec(inpmtx) ;
} else {
   dvec = NULL ;
}
IV_sizeAndEntries(mapIV, &nvtx, &map) ;
if ( nvtx <= 0 || map == NULL ) {
   fprintf(stderr, "\n process %d : nvtx = %d, map = %p",
           myid, nvtx, map) ;
   exit(-1) ;
}
if ( (val = IVmin(nent, ivec1, &ient)) < 0 ) {
   fprintf(stderr, "\n process %d : IV_min(ivec1) = %d", 
           myid, val) ;
   exit(-1) ;
}
if ( (val = IVmax(nent, ivec1, &ient)) >= nvtx ) {
   fprintf(stderr, "\n process %d : IV_max(ivec1) = %d", 
           myid, val) ;
   exit(-1) ;
}
if ( (val = IV_min(mapIV)) < 0 ) {
   fprintf(stderr, "\n process %d : IVmin(mapIV) = %d", 
           myid, val) ;
   exit(-1) ;
}
if ( (val = IV_max(mapIV)) >= nproc ) {
   fprintf(stderr, "\n process %d : IVmax(mapIV) = %d", 
           myid, val) ;
   exit(-1) ;
}
/*
   ----------------------------------
   compute the number of entries that 
   must be sent to each other process
   ----------------------------------
*/
outcounts = IVinit(2*nproc, 0) ;
incounts  = outcounts + nproc ;
for ( ient = nkeep = 0 ; ient < nent ; ient++ ) {
   if ( (iproc = map[ivec1[ient]]) != myid ) {
      outcounts[iproc]++ ;
   } else {
      nkeep++ ;
   }
}
/*
   -------------------------------
   do an all-to-all gather/scatter
   -------------------------------
*/
MPI_Alltoall((void *) outcounts, 1, MPI_INT,
             (void *) incounts,  1, MPI_INT, comm) ;
nowned = nkeep + IVsum(nproc, incounts) ;
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n\n incounts") ;
   IVfprintf(msgFile, nproc, incounts) ;
   fprintf(msgFile, "\n\n outcounts") ;
   IVfprintf(msgFile, nproc, outcounts) ;
   fflush(msgFile) ;
}
/*
   ---------------------------------
   allocate the send/receive buffers
   ---------------------------------
*/
nent = nkeep ;
for ( iproc = 0 ; iproc < nproc ; iproc++ ) {
   if ( nent < incounts[iproc] ) {
      nent = incounts[iproc] ;
   }
}
inbuffer = TriplesBuffer_new() ;
TriplesBuffer_init(inbuffer, nent) ;
#if MYDEBUG > 0
fprintf(stdout, 
        "\n proc %d : in buffer set up, size = %d", myid, nent) ;
fflush(stdout) ;
#endif
nent = 0 ;
for ( iproc = 0 ; iproc < nproc ; iproc++ ) {
   if ( nent < outcounts[iproc] ) {
      nent = outcounts[iproc] ;
   }
}
outbuffer = TriplesBuffer_new() ;
TriplesBuffer_init(outbuffer, nent) ;
#if MYDEBUG > 0
fprintf(stdout, 
        "\n proc %d : out buffer set up, size = %d", myid, nent) ;
fflush(stdout) ;
#endif
/*
   -------------------------------------------------------
   set up the new DInpMtx object to hold the owned entries
   -------------------------------------------------------
*/
keepmtx = DInpMtx_new() ;
DInpMtx_init(keepmtx, inpmtx->coordType, inpmtx->inputMode,
             nowned, 0) ;
/*
   ----------------------------------------------------------
   extract the triples from the original matrix,
   put them into the triples buffer,
   then load the triples into the keep matrix
   change the storage mode of the keep matrix to raw triples.
   ----------------------------------------------------------
*/
extractTriples(inpmtx, inbuffer, myid, map) ;
if ( msglvl > 2 ) {
   fprintf(msgFile, "\n kept entries in buffer") ;
   TriplesBuffer_writeForHumanEye(inbuffer, msgFile ) ;
   fflush(msgFile) ;
}
DInpMtx_inputTriples(keepmtx, inbuffer->size, inbuffer->ivec1,
                     inbuffer->ivec2, inbuffer->dvec) ;
DInpMtx_changeStorageMode(keepmtx, 1) ;
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n storage mode is now %d",
           DInpMtx_storageMode(inpmtx)) ;
   fflush(msgFile) ;
}
if ( msglvl > 2 ) {
   fprintf(msgFile, 
           "\n\n keepmtx after storing owned original entries") ;
   DInpMtx_writeForHumanEye(keepmtx, msgFile) ;
   fflush(msgFile) ;
}
/*
   ----------------------------------
   loop over the other processes, 
      gather values and send them off
      receive values
   ----------------------------------
*/
if ( msglvl > 1 ) {
   fprintf(msgFile, 
           "\n\n ready to split entries and send to other processes") ;
   fflush(msgFile) ;
}
tag = firsttag ;
for ( offset = 1 ; offset < nproc ; offset++ ) {
   right = (myid + offset) % nproc ;
   if ( offset <= myid ) {
      left = myid - offset ;
   } else {
      left = nproc + myid - offset ;
   }
   outcount = outcounts[right] ;
   incount  = incounts[left] ;
   if ( msglvl > 1 ) {
      fprintf(msgFile, 
         "\n ### process %d, send %d to right %d, recv %d from left %d",
              myid, outcount, right, incount, left) ;
      fflush(msgFile) ;
   }
   if ( outcount > 0 ) {
/*
      -----------------------------------------
      load the entries to send to process right
      -----------------------------------------
*/
      if ( msglvl > 2 ) {
         fprintf(msgFile, "\n\n matrix to extract from") ;
         DInpMtx_writeForHumanEye(inpmtx, msgFile ) ;
         fflush(msgFile) ;
      }
      TriplesBuffer_init(outbuffer, outcount) ;
      extractTriples(inpmtx, outbuffer, right, map) ;
      if ( msglvl > 2 ) {
         fprintf(msgFile, "\n\n entries in outbuffer") ;
         TriplesBuffer_writeForHumanEye(outbuffer, msgFile ) ;
         fflush(msgFile) ;
      }
      destination = right ;
   } else {
      destination = MPI_PROC_NULL ;
   }
   if ( incount > 0 ) {
/*
      -----------------------
      set up the input buffer
      -----------------------
*/
      TriplesBuffer_init(inbuffer, incount) ;
      source = left ;
   } else {
      source = MPI_PROC_NULL ;
   }
/*
   -----------------
   do a send receive
   -----------------
*/
   TriplesBuffer_sendrecv(outbuffer, inbuffer, destination, source, tag,
                          comm, &status, stats, msglvl, msgFile) ;
   if ( incount > 0 ) {
/*
      -------------------------------------
      load the triples into the keep matrix
      -------------------------------------
*/
      if ( msglvl > 2 ) {
         fprintf(msgFile, "\n\n entries in inbuffer") ;
         TriplesBuffer_writeForHumanEye(inbuffer, msgFile ) ;
         fflush(msgFile) ;
      }
      DInpMtx_inputTriples(keepmtx, inbuffer->size, inbuffer->ivec1,
                           inbuffer->ivec2, inbuffer->dvec) ;
      if ( msglvl > 2 ) {
         fprintf(msgFile, 
                 "\n\n keepmtx after storing received entries") ;
         DInpMtx_writeForHumanEye(keepmtx, msgFile) ;
         fflush(msgFile) ;
      }
   }
}
/*
   -----------------------------
   sort and compress the entries
   -----------------------------
*/
if ( msglvl > 3 ) {
   fprintf(msgFile, "\n before changing storage mode to %d", 2) ;
   DInpMtx_writeForHumanEye(keepmtx, msgFile) ;
   fflush(msgFile) ;
}
DInpMtx_changeStorageMode(keepmtx, 2) ;
if ( msglvl > 3 ) {
   fprintf(msgFile, "\n after changing storage mode to %d", 2) ;
   DInpMtx_writeForHumanEye(keepmtx, msgFile) ;
   fflush(msgFile) ;
}
#if MYDEBUG > 0
         fprintf(stdout, "\n proc %d : entries sorted and compressed",
                 myid) ;
         fflush(stdout) ;
#endif
/*
   ------------------------------------
   check that the data we have is valid
   ------------------------------------
*/
nent  = DInpMtx_nent(keepmtx) ;
ivec1 = DInpMtx_ivec1(keepmtx) ;
for ( ient = 0 ; ient < nent ; ient++ ) {
   ival1 = ivec1[ient] ;
   if ( ival1 < 0 || ival1 >= nvtx ) {
      fprintf(stderr, 
              "\n process %d: fatal error in DInpMtx_MPI_split()"
              "\n nvtx = %d, ival1 = %d", myid, nvtx, ival1) ;
      exit(-1) ;
   }
   if ( map[ival1] != myid ) {
      fprintf(stderr, 
              "\n process %d: fatal error in DInpMtx_MPI_split()"
              "\n ival1 = %d, map = %d",
              myid, ival1, map[ival1]) ;
      exit(-1) ;
   }
}
/*
   ------------------------
   free the working storage
   ------------------------
*/
TriplesBuffer_free(inbuffer) ;
TriplesBuffer_free(outbuffer) ;
IVfree(outcounts) ;

return(keepmtx) ; }

/*--------------------------------------------------------------------*/
/*
   -----------------------
   create a new instance

   created -- 97nov12, cca
   -----------------------
*/
static TriplesBuffer *
TriplesBuffer_new (
   void
) {
TriplesBuffer   *buffer ;

ALLOCATE(buffer, struct _TriplesBuffer, 1) ;
TriplesBuffer_setDefaultFields(buffer) ;

return(buffer) ; }

/*--------------------------------------------------------------------*/
/*
   -----------------------
   set the default fields

   created -- 97nov12, cca
   -----------------------
*/
static void
TriplesBuffer_setDefaultFields (
   TriplesBuffer *buffer
) {
if ( buffer == NULL ) {
   fprintf(stderr, 
           "\n fatal error in TriplesBuffer_setDefaultFields(%p)"
           "\n NULL pointer\n", buffer) ;
   exit(-1) ;
}
buffer->size    =   0  ;
buffer->maxsize =   0  ;
buffer->ivec1   = NULL ;
buffer->ivec2   = NULL ;
buffer->dvec    = NULL ;
buffer->base    = NULL ;

return ; }

/*--------------------------------------------------------------------*/
/*
   -----------------------
   clear the data

   created -- 97nov12, cca
   -----------------------
*/
static void
TriplesBuffer_clearData (
   TriplesBuffer *buffer
) {
if ( buffer == NULL ) {
   fprintf(stderr, 
           "\n fatal error in TriplesBuffer_clearData(%p)"
           "\n NULL pointer\n", buffer) ;
   exit(-1) ;
}
if ( buffer->base != NULL ) {
   FREE(buffer->base) ;
}
TriplesBuffer_setDefaultFields(buffer) ;

return ; }

/*--------------------------------------------------------------------*/
/*
   -----------------------
   free the object

   created -- 97nov12, cca
   -----------------------
*/
static void
TriplesBuffer_free (
   TriplesBuffer *buffer
) {
if ( buffer == NULL ) {
   fprintf(stderr, 
           "\n fatal error in TriplesBuffer_free(%p)"
           "\n NULL pointer\n", buffer) ;
   exit(-1) ;
}
TriplesBuffer_clearData(buffer) ;
FREE(buffer) ;

return ; }

/*--------------------------------------------------------------------*/
/*
   -----------------------
   initialize the buffer

   created -- 97nov12, cca
   -----------------------
*/
static void
TriplesBuffer_init (
   TriplesBuffer   *buffer,
   int             size
) {
double   *dbuf ;
int      ndouble ;
int      *ibuf ;

if ( buffer->maxsize < size ) {
   TriplesBuffer_clearData(buffer) ;
   if ( sizeof(int) == sizeof(double) ) {
      ndouble = 3*size + 1 ;
   } else if ( 2*sizeof(int) == sizeof(double) ) {
      ndouble = 2*size + 1 ;
   } else {
      fprintf(stderr, "\n fatal error in TriplesBuffer_init()"
              "\n sizeof(int) = %d, sizeof(double) = %d",
              sizeof(int), sizeof(double)) ;
      exit(-1) ;
   }
   ALLOCATE(buffer->base, double, ndouble) ;
   buffer->maxsize = size ;
}
ibuf = (int *)    buffer->base ;
dbuf = (double *) buffer->base ;
buffer->ivec1 = ibuf + 1 ;
buffer->ivec2 = buffer->ivec1 + size ;
if ( sizeof(int) == sizeof(double) ) {
   buffer->dvec = dbuf + 1 + 2*size ;
} else if ( 2*sizeof(int) == sizeof(double) ) {
   buffer->dvec = dbuf + 1 + size ;
}
buffer->size = size ;

return ; }

/*--------------------------------------------------------------------*/
/*
   --------------------------------------------------------
   extract triples from the inpmtx and load into the buffer
   whose mapped first coordinate matches the target.

   created -- 97nov12, cca
   --------------------------------------------------------
*/
static void
extractTriples (
   DInpMtx         *inpmtx,
   TriplesBuffer   *buffer,
   int             target,
   int             map[]
) {
double   *dbuf, *dvec ;
int      ii, jj, jproc, kk, nent ;
int      *ibuf1, *ibuf2, *ivec1, *ivec2 ;

ibuf1 = buffer->ivec1 ;
ibuf2 = buffer->ivec2 ;
dbuf  = buffer->dvec  ;
ivec1 = DInpMtx_ivec1(inpmtx) ;
ivec2 = DInpMtx_ivec2(inpmtx) ;
dvec  = DInpMtx_dvec(inpmtx) ;
nent  = DInpMtx_nent(inpmtx) ;
for ( ii = jj = kk = 0 ; ii < nent ; ii++ ) {
   jproc = map[ivec1[ii]] ;
#if MYDEBUG > 0
   fprintf(stdout, 
           "\n ivec1[%d] = %d, jproc = %d", ii, ivec1[ii], jproc) ;
   fflush(stdout) ;
#endif
   if ( jproc == target ) {
#if MYDEBUG > 0
      fprintf(stdout, 
              ", moving to buffer, location %d", jj) ;
      fflush(stdout) ;
#endif
      ibuf1[jj] = ivec1[ii] ;
      ibuf2[jj] = ivec2[ii] ;
      if ( dbuf != NULL && dvec != NULL ) {
         dbuf[jj] = dvec[ii] ;
      }
      jj++ ;
   } else {
#if MYDEBUG > 0
      fprintf(stdout, 
              ", keeping, location %d", kk) ;
      fflush(stdout) ;
#endif
      if ( ii != kk ) {
         ivec1[kk] = ivec1[ii] ;
         ivec2[kk] = ivec2[ii] ;
         if ( dvec != NULL ) {
            dvec[kk] = dvec[ii] ;
         }
      }
      kk++ ;
   }
}
#if MYDEBUG > 0
fprintf(stdout, "\n end: ii = %d, jj = %d, kk = %d", ii, jj, kk) ;
fflush(stdout) ;
#endif
DInpMtx_setNent(inpmtx, kk) ;
buffer->size = *((int *) buffer->base) = jj ;

return ; }

/*--------------------------------------------------------------------*/
/*
   -----------------------
   do a send and receive

   created -- 97nov12, cca
   -----------------------
*/
static void
TriplesBuffer_sendrecv (
   TriplesBuffer   *outbuffer,
   TriplesBuffer   *inbuffer,
   int             destination,
   int             source,
   int             tag,
   MPI_Comm        comm,
   MPI_Status      *status,
   int             stats[],
   int             msglvl,
   FILE            *msgFile
) {
void   *inbase, *outbase ;
int    count, inndouble, outndouble ;

outbase = outbuffer->base ;
inbase  = inbuffer->base  ;
if ( destination != MPI_PROC_NULL ) {
   if ( sizeof(int) == sizeof(double) ) {
      outndouble = 3*outbuffer->size + 1 ;
   } else if ( 2*sizeof(int) == sizeof(double) ) {
      outndouble = 2*outbuffer->size + 1 ;
   } 
} else {
   outndouble = 0 ;
}
if ( source != MPI_PROC_NULL ) {
   if ( sizeof(int) == sizeof(double) ) {
      inndouble  = 3*inbuffer->size  + 1 ;
   } else if ( 2*sizeof(int) == sizeof(double) ) {
      inndouble  = 2*inbuffer->size  + 1 ;
   } 
} else {
   inndouble = 0 ;
}
if ( msglvl > 1 ) {
   fprintf(msgFile, 
           "\n destination = %d, source = %d"
           "\n outsize = %d, insize = %d"
           "\n inbase  = %p, inndouble  = %d"
           "\n outbase = %p, outndouble = %d",
           destination, source, outbuffer->size, inbuffer->size,
           inbase, inndouble, outbase, outndouble) ;
   fflush(msgFile) ;
}
if ( inndouble > 0 ) {
   stats[1]++ ;
   stats[3] += inndouble*sizeof(double) ;
}
if ( outndouble > 0 ) {
   stats[0]++ ;
   stats[2] += outndouble*sizeof(double) ;
}
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n calling MPI_Sendrecv") ;
   fflush(msgFile) ;
}
MPI_Sendrecv(outbase, outndouble, MPI_DOUBLE, destination, tag,
             inbase, inndouble, MPI_DOUBLE, source, tag,
             comm, status) ;
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n return from MPI_Sendrecv") ;
   MPI_Get_count(status, MPI_DOUBLE, &count) ;
   fprintf(msgFile, "\n count = %d", count) ;
   fflush(msgFile) ;
}

return ; }

/*--------------------------------------------------------------------*/
/*
   ---------------------
   print out the triples
   ---------------------
*/
static void
TriplesBuffer_writeForHumanEye (
   TriplesBuffer   *buffer,
   FILE            *msgFile
) {
double   *dvec ;
int      ii, size ;
int      *ivec1, *ivec2 ;

size  = buffer->size  ;
ivec1 = buffer->ivec1 ;
ivec2 = buffer->ivec2 ;
dvec  = buffer->dvec  ;

fprintf(msgFile, "\n\n TriplesBuffer: size = %d, maxsize = %d",
        buffer->size, buffer->maxsize) ;
if ( size > 0 ) {
   if ( dvec != NULL ) {
      for ( ii = 0 ; ii < size - 1 ; ii += 2 ) {
         fprintf(msgFile, 
                 "\n <%6d,%6d,%20.12e> <%6d,%6d,%20.12e>",
                 ivec1[ii],   ivec2[ii],   dvec[ii],
                 ivec1[ii+1], ivec2[ii+1], dvec[ii+1]) ;
      }
      if ( ii == size - 1 ) { 
         fprintf(msgFile, "\n <%6d,%6d,%20.12e>",
                 ivec1[ii], ivec2[ii], dvec[ii]) ;
      }
   } else {
      for ( ii = 0 ; ii < size - 3 ; ii += 4 ) {
         fprintf(msgFile, 
                 "\n <%6d,%6d> <%6d,%6d> <%6d,%6d> <%6d,%6d> ",
                 ivec1[ii],   ivec2[ii], 
                 ivec1[ii+1], ivec2[ii+1], 
                 ivec1[ii+2], ivec2[ii+2], 
                 ivec1[ii+3], ivec2[ii+3]) ;
      }
      if ( ii == size - 3 ) {
         fprintf(msgFile, "\n <%6d,%6d> <%6d,%6d> <%6d,%6d> ",
                 ivec1[ii],   ivec2[ii], 
                 ivec1[ii+1], ivec2[ii+1], 
                 ivec1[ii+2], ivec2[ii+2]) ;
      } else if ( ii == size - 2 ) {
         fprintf(msgFile, "\n <%6d,%6d> <%6d,%6d> ",
                 ivec1[ii],   ivec2[ii], 
                 ivec1[ii+1], ivec2[ii+1]) ;
      } else if ( ii == size - 1 ) {
         fprintf(msgFile, "\n <%6d,%6d> ", 
                 ivec1[ii], ivec2[ii]) ;
      }
   }
}
return ; }

/*--------------------------------------------------------------------*/
