/*  split.c  */

#include "../spoolesMPI.h"

#define MYDEBUG 0

/*--------------------------------------------------------------------*/
static void
extractRows (
   DDenseMtx   *sourcemtx,
   DDenseMtx   *destmtx,
   int         rowmap[],
   int         destination,
   int         msglvl,
   FILE        *msgFile
) ;
/*--------------------------------------------------------------------*/
/*
   -----------------------------------------------------------------
   purpose -- to split a DDenseMtx object by rows

   mtx         -- DDenseMtx object
   rowmapIV    -- map from rows to owning processes
   firsttag    -- first tag to be used in these messages
   stats[4]    -- statistics vector
      stats[0] -- # of messages sent
      stats[1] -- # of messages received
      stats[2] -- # of bytes sent
      stats[3] -- # of bytes received
   msglvl      -- message level
   msgFile     -- message file
   comm        -- MPI communicator

   return value -- a new DDenseMtx object filled with the owned rows 

   created  -- 97jul13, cca
   modified -- 97oct17, cca
      stats added
   modified -- 97nov08, cca
      send/recv replaces send and receive
   -----------------------------------------------------------------
*/
DDenseMtx *
DDenseMtx_MPI_splitByRows (
   DDenseMtx   *mtx,
   IV          *rowmapIV,
   int         firsttag,
   int         stats[],
   int         msglvl,
   FILE        *msgFile,
   MPI_Comm    comm
) {
DDenseMtx    *inmtx, *keepmtx, *outmtx ;
double       *inbuffer, *outbuffer ;
int          destination, flag, ii, inbuffersize, incount, iproc, 
             irow, left, myid, ncol, ndouble, neqns, nkeep, nowned, 
             nproc, nrecv, nrow, nsend, offset, outbuffersize, 
             outcount, right, source, tag, tag_bound ;
int          *rowind, *rowmap, *rowsToRecv, *rowsToSend ;
MPI_Status   status ;
/*
   --------------
   check the data
   --------------
*/
if ( mtx == NULL || rowmapIV == NULL 
   || (msglvl > 0 && msgFile == NULL) ) {
   fprintf(msgFile, "\n fatal error in DDenseMtx_MPI_splitByRows()"
          "\n mtx %p, rowmapIV %p, firsttag %d, msglvl %d, msgFile %p"
          "\n bad input\n", mtx, rowmapIV, firsttag, msglvl, msgFile) ;
   exit(-1) ;
}
tag = firsttag ;
MPI_Attr_get(MPI_COMM_WORLD, MPI_TAG_UB, &tag_bound, &flag) ;
if ( tag > tag_bound ) {
   fprintf(stderr, "\n fatal error in DDenseMtx_MPI_splitByRows()"
           "\n tag = %d, tag_bound = %d", tag, tag_bound) ;
   exit(-1) ;
}
/*
   -------------------------------------------------
   get id of self, # of processes and # of equations
   -------------------------------------------------
*/
MPI_Comm_rank(comm, &myid) ;
MPI_Comm_size(comm, &nproc) ;
IV_sizeAndEntries(rowmapIV, &neqns, &rowmap) ;
DDenseMtx_dimensions(mtx, &nrow, &ncol) ;
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n\n inside DDenseMtx_MPI_splitByRows"
           "\n nproc = %d, myid = %d, neqns = %d, nrow = %d, ncol = %d",
           nproc, myid, neqns, nrow, ncol) ;
   fflush(msgFile) ;
}
/*
   -------------------------------------------------------------
   get the counts of the entries to send to the other processors
   -------------------------------------------------------------
*/
DDenseMtx_rowIndices(mtx, &nrow, &rowind) ;
rowsToSend = IVinit(2*nproc, 0) ;
rowsToRecv = rowsToSend + nproc ;
for ( ii = 0, nkeep = 0 ; ii < nrow ; ii++ ) {
   irow  = rowind[ii] ;
   if ( (iproc = rowmap[irow]) != myid ) {
      rowsToSend[iproc]++ ;
   } else {
      nkeep++ ;
   }
}
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n nkeep = %d, row send counts ", nkeep) ;
   IVfprintf(msgFile, nproc, rowsToSend) ;
   fflush(msgFile) ;
}
/*
   -------------------------------
   do an all-to-all gather/scatter
   -------------------------------
*/
MPI_Alltoall((void *) rowsToSend, 1, MPI_INT,
             (void *) rowsToRecv, 1, MPI_INT, comm) ;
nowned = nkeep + IVsum(nproc, rowsToRecv) ;
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n nkeep = %d, row receive counts ", nkeep) ;
   IVfprintf(msgFile, nproc, rowsToRecv) ;
   fflush(msgFile) ;
}
/*
   -------------------------
   determine the buffer size
   -------------------------
*/
nsend = IVmax(nproc, rowsToSend, &iproc) ;
nrecv = IVmax(nproc, rowsToRecv, &iproc) ;
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n nsend = %d, nrecv = %d", nsend, nrecv) ;
   fflush(msgFile) ;
}
/*
   -------------------------------------------
   allocate the send/receive DDenseMtx objects
   -------------------------------------------
*/
outmtx = DDenseMtx_new() ;
if ( mtx->inc1 == 1 ) {
   DDenseMtx_init(outmtx, myid, -1, nsend, ncol, 1, nsend) ;
} else {
   DDenseMtx_init(outmtx, myid, -1, nsend, ncol, ncol, 1) ;
}
inmtx = DDenseMtx_new() ;
if ( mtx->inc1 == 1 ) {
   DDenseMtx_init(inmtx, myid, -1, nrecv, ncol, 1, nrecv) ;
} else {
   DDenseMtx_init(inmtx, myid, -1, nrecv, ncol, ncol, 1) ;
}
/*
   -------------------------------------
   allocate the DDenseMtx object to keep
   -------------------------------------
*/
keepmtx = DDenseMtx_new() ;
if ( mtx->inc1 == 1 ) {
   DDenseMtx_init(keepmtx, myid, -1, nowned, ncol, 1, nowned) ;
} else {
   DDenseMtx_init(keepmtx, myid, -1, nowned, ncol, ncol, 1) ;
}
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n keepmtx object allocated") ;
   fflush(msgFile) ;
}
/*
   ----------------------------------------------------------------
   copy all rows to keep from the input matrix into the keep matrix
   ----------------------------------------------------------------
*/
extractRows(mtx, keepmtx, rowmap, myid, msglvl, msgFile) ;
if ( msglvl > 2 ) {
   fprintf(msgFile, "\n\n keepmtx") ;
   DDenseMtx_writeForHumanEye(keepmtx, msgFile) ;
   fflush(msgFile) ;
}
/*
   --------------------------------------------------------------
   loop over the processes, gather their values and send them off
   --------------------------------------------------------------
*/
for ( offset = 1 ; offset < nproc ; offset++ ) {
   right = (myid + offset) % nproc ;
   if ( offset <= myid ) {
      left = myid - offset ;
   } else {
      left = nproc + myid - offset ;
   }
   outcount = rowsToSend[right] ;
   incount  = rowsToRecv[left] ;
   if ( msglvl > 1 ) {
      fprintf(msgFile, 
       "\n\n ### process %d, send %d to right %d, recv %d from left %d",
       myid, outcount, right, incount, left) ;
      fflush(msgFile) ;
   }
   if ( outcount > 0 ) {
/*
      --------------------------
      load the out matrix object
      --------------------------
*/
      if ( mtx->inc1 == 1 ) {
         DDenseMtx_init(outmtx, myid, -1, outcount, ncol, 1, outcount) ;
      } else {
         DDenseMtx_init(outmtx, myid, -1, outcount, ncol, ncol, 1) ;
      }
      extractRows(mtx, outmtx, rowmap, right, msglvl, msgFile) ;
      destination = right ;
      if ( msglvl > 2 ) {
         fprintf(msgFile, "\n\n outmtx for process %d", destination) ;
         DDenseMtx_writeForHumanEye(outmtx, msgFile) ;
         fflush(msgFile) ;
      }
      stats[0]++ ;
      stats[2] += sizeof(double)*DV_size(&outmtx->wrkDV) ;
   } else {
/*
      ------------------------------------------
      set the destination to be the NULL process
      ------------------------------------------
*/
      destination = MPI_PROC_NULL ;
   }
   if ( incount > 0 ) {
/*
      ----------------------------------
      initialize the input matrix object
      ----------------------------------
*/
      if ( mtx->inc1 == 1 ) {
         DDenseMtx_init(inmtx, myid, -1, incount, ncol, 1, incount) ;
      } else {
         DDenseMtx_init(inmtx, myid, -1, incount, ncol, ncol, 1) ;
      }
      if ( msglvl > 2 ) {
         fprintf(msgFile, "\n\n inmtx initialized to have %d rows",
                 incount) ;
         fflush(msgFile) ;
      }
      source = left ;
      stats[1]++ ;
      stats[3] += sizeof(double)*DV_size(&inmtx->wrkDV) ;
   } else {
      source = MPI_PROC_NULL ;
   }
/*
   -----------------
   do a send/receive
   -----------------
*/
   inbuffersize = outbuffersize = 0 ;
   inbuffer     = outbuffer     = NULL ;
   if ( outmtx != NULL ) {
      outbuffersize = DV_size(&outmtx->wrkDV) ;
      outbuffer     = DV_entries(&outmtx->wrkDV) ;
   }
   if ( inmtx != NULL ) {
      inbuffersize = DV_size(&inmtx->wrkDV) ;
      inbuffer     = DV_entries(&inmtx->wrkDV) ;
   }
   MPI_Sendrecv((void*) outbuffer, outbuffersize, MPI_DOUBLE, 
                destination, tag, (void*) inbuffer, inbuffersize, 
                MPI_DOUBLE, source, tag, comm, &status) ;
   if ( msglvl > 3 ) {
      MPI_Get_count(&status, MPI_DOUBLE, &ndouble) ;
      fprintf(msgFile, 
            "\n\n message received, source %d, tag %d, double count %d",
            status.MPI_SOURCE, status.MPI_TAG, ndouble) ;
      fprintf(msgFile, "\n MPI_ERROR = %d", status.MPI_ERROR) ;
      fflush(msgFile) ;
   }
   if ( source != MPI_PROC_NULL ) {
/*
      -------------------------------------
      initialize the object from its buffer
      -------------------------------------
*/
      DDenseMtx_initFromBuffer(inmtx) ;
      if ( msglvl > 3 ) {
         fprintf(msgFile,
                 "\n DDenseMtx object intialized from its buffer") ;
         fflush(msgFile) ;
      }
      if ( msglvl > 4 ) {
         DDenseMtx_writeForHumanEye(inmtx, msgFile) ;
         fflush(msgFile) ;
      }
   }
   if ( incount > 0 ) {
      if ( nkeep + incount > nowned ) {
         fprintf(msgFile, "\n fatal error in DDenseMtx_splitByRows()"
              "\n nkeep = %d, nrecv = %d, nowned = %d",
              nkeep, nrecv, nowned) ;
         exit(-1) ;
      }
      for ( irow = 0 ; irow < incount ; irow++, nkeep++ ) {
         DDenseMtx_copyRow(keepmtx, nkeep, inmtx, irow) ;
      }
   }
}
/*
   -------------------------
   sort the rows and columns
   -------------------------
*/
DDenseMtx_sort(keepmtx) ;
/*
   ------------------------------------------------------
   check that the matrix contains only the rows it should
   ------------------------------------------------------
*/
nrow   = keepmtx->nrow ;
rowind = keepmtx->rowind ;
for ( ii = 0 ; ii < nrow ; ii++ ) {
   irow = rowind[ii] ;
   if ( irow < 0 || irow >= neqns ) {
      fprintf(stderr, 
            "\n process %d : local row %d, global row %d, neqns = %d\n",
            myid, ii, irow, neqns) ;
      exit(-1) ;
   }
   if ( rowmap[irow] != myid ) {
      fprintf(stderr, 
            "\n process %d : local row %d, global row %d, map = %d\n",
            myid, ii, irow, rowmap[irow]) ;
      exit(-1) ;
   }
}
/*
   ------------------------
   free the working storage
   ------------------------
*/
DDenseMtx_free(outmtx) ;
DDenseMtx_free(inmtx) ;
IVfree(rowsToSend) ;

return(keepmtx) ; }

/*--------------------------------------------------------------------*/
/*
   -----------------------------------------------------
   copy rows owned by destination from the source matrix
   to the destination matrix

   created -- 97nov08, cca
   -----------------------------------------------------
*/
static void
extractRows (
   DDenseMtx   *sourcemtx,
   DDenseMtx   *destmtx,
   int         rowmap[],
   int         destination,
   int         msglvl,
   FILE        *msgFile
) {
int   idest, ii, irow, isource, jproc, ndest, nsource ;
int   *rowind ;
/*
   -------------------------------------
   copy all rows to keep from the source 
   matrix into the destination matrix
   -------------------------------------
*/
nsource = sourcemtx->nrow ;
rowind  = sourcemtx->rowind ;
ndest   = destmtx->nrow ;
idest   = isource = 0 ;
for ( ii = 0 ; ii < nsource ; ii++ ) {
   irow  = rowind[ii] ;
   jproc = rowmap[irow] ;
   if ( msglvl > 3 ) {
      fprintf(msgFile, "\n row %d, row id %d, owner %d",
              ii, irow, jproc) ;
      fflush(msgFile) ;
   }
   if ( jproc == destination ) {
/*
      -----------------------------------------------------------
      row irow belongs to this process, move row into new storage
      -----------------------------------------------------------
*/
      if ( idest == ndest ) {
         fprintf(msgFile, 
           "\n fatal error in copyRows()"
           "\n ii = %d, idest = %d, ndest = %d", ii, idest, ndest) ;
         exit(-1) ;
      }
      DDenseMtx_copyRow(destmtx, idest, sourcemtx, ii) ;
      idest++ ;
      if ( msglvl > 3 ) {
         fprintf(msgFile, "\n    row kept, idest = %d", idest) ;
         fflush(msgFile) ;
      }
   } else {
/*
      --------------------------------------------------------
      row irow does not belong to this process, slide row down
      --------------------------------------------------------
*/
      if ( ii != isource ) {
         DDenseMtx_copyRow(sourcemtx, isource, sourcemtx, ii) ;
      }
      isource++ ;
      if ( msglvl > 3 ) {
         fprintf(msgFile, 
                 "\n    row slid, isource = %d", isource) ;
         fflush(msgFile) ;
      }
   }
}
/*
   ---------------------------------------
   reset the number of rows in the objects
   ---------------------------------------
*/
sourcemtx->nrow = isource ;
sourcemtx->nrow = isource ;
if ( msglvl > 1 ) {
   fprintf(msgFile, 
           "\n\n %d rows remaining in mtx object", sourcemtx->nrow) ;
   fflush(msgFile) ;
}
return ; }

/*--------------------------------------------------------------------*/
