/*  setLocalIndices.c  */

#include "../spoolesMPI.h"

/*--------------------------------------------------------------------*/
typedef struct _Msg   Msg ;
struct _Msg {
   int           info[3] ; /* type, frontid, nbytes */
   int           *base   ;
   Msg           *next   ;
   MPI_Request   req     ;
} ;
/*--------------------------------------------------------------------*/
static Msg * Msg_new ( void ) ;
static void Msg_setDefaultFields ( Msg *msg ) ;
static void Msg_clearData ( Msg *msg ) ;
static void Msg_free ( Msg *msg ) ;
static Msg * wakeup ( DFrontMtx *frontmtx, int J, int myid,
   int owners[], Msg **pfirstsent, int firsttag, int stats[],
   MPI_Comm comm, int msglvl, FILE *msgFile ) ;
static Msg * checkMessages ( DFrontMtx *frontmtx, int J, Msg *firstrecv,
   int map[], int stats[], MPI_Comm comm, int msglvl, FILE *msgFile ) ;
static Msg * checkSentMessages ( Msg *firstsent, 
   int msglvl, FILE *msgFile ) ;
/*--------------------------------------------------------------------*/
/*
   -------------------------------------------------------
   purpose --- overwrite row and column indices bnd{J}
      with indices that are local w.r.t. J's parent

   firsttag -- first tag for messages, will use tag values
      in [firsttag, firsttag + 2*nfront]
   stats[] -- statistics vector
      stats[0] -- # of sends
      stats[1] -- # of bytes sent
      stats[2] -- # of receives
      stats[3] -- # of bytes received
 
   created -- 97nov22, cca
   -------------------------------------------------------
*/
void
DFrontMtx_MPI_globalToLocalInd (
   DFrontMtx   *frontmtx,
   IV          *frontOwnersIV,
   int         firsttag,
   int         stats[],
   int         msglvl,
   FILE        *msgFile,
   MPI_Comm    comm
) {
char   *status ;
Ideq   *dequeue ;
int    flag, J, K, myid, ncolK, neqns, nfront, nproc, nrowK, 
       ownerJ, tag_bound ;
int    *colindK, *frontOwners, *map, *ndescLeft, *par, *rowindK ;
Msg    *firstsent ;
Msg    **p_msg ;

MPI_Comm_rank(comm, &myid) ;
MPI_Comm_size(comm, &nproc) ;
frontOwners = IV_entries(frontOwnersIV) ;
par = ETree_par(frontmtx->frontETree) ;
nfront = frontmtx->nfront ;
neqns  = frontmtx->neqns  ;
MPI_Attr_get(MPI_COMM_WORLD, MPI_TAG_UB, &tag_bound, &flag) ;
if ( firsttag < 0 || firsttag + nfront > tag_bound ) {
   fprintf(stderr, "\n fatal error in DFrontMtx_MPI_globalToLocalInd()"
           "\n tag range is [%d,%d], tag_bound = %d",
           firsttag, firsttag + nfront, tag_bound) ;
   exit(-1) ;
}
/*
   ------------------------------------------
   compute the local ndescLeft[] vector
   that synchronizes the post-order traversal
   ------------------------------------------
*/
ndescLeft = IVinit(nfront, 0) ;
for ( J = 0 ; J < nfront ; J++ ) {
   ownerJ = frontOwners[J] ;
   for ( K = par[J] ; K != -1 ; K = par[K] ) {
      if ( frontOwners[K] == ownerJ ) {
         ndescLeft[K]++ ;
         break ;
      }
   }
}
if ( msglvl > 2 ) {
   fprintf(msgFile, "\n\n ndescLeft[]") ;
   IVfprintf(msgFile, nfront, ndescLeft) ;
   fflush(msgFile) ;
}
/*
   -----------------------------------------------
   set up the dequeue for the post-order traversal
   -----------------------------------------------
*/
status  = DFrontMtx_status(frontmtx, frontOwnersIV, myid) ;
dequeue = DFrontMtx_setUpDequeue(frontmtx, frontOwnersIV, status, myid);
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n\n initial status vector") ;
   CVfprintf(msgFile, nfront, status) ;
   fflush(msgFile) ;
}
/*
   ----------------------
   set up working storage
   ----------------------
*/
ALLOCATE(p_msg, struct _Msg *, nfront) ;
for ( J = 0 ; J < nfront ; J++ ) {
   p_msg[J] = NULL ;
}
map = IVinit(neqns, -1) ;
/*
   -------------------------------------------
   execute the post-order traversal and 
   overwrite global indices with local indices
   -------------------------------------------
*/
firstsent = NULL ;
while ( (J = Ideq_removeFromHead(dequeue)) != -1 ) {
   if ( msglvl > 1 ) {
      fprintf(msgFile, "\n\n ### checking out %d, status %c",
              J, status[J]) ;
      fflush(msgFile) ;
   }
   if ( status[J] == 'W' ) {
/*
      ----------------------------------
      wake up front J, post any receives
      ----------------------------------
*/
      p_msg[J] = wakeup(frontmtx, J, myid, frontOwners, &firstsent, 
                        firsttag, stats, comm, msglvl, msgFile) ;
      status[J] = 'A' ;
   }
/*
   ---------------------------
   check for received messages
   ---------------------------
*/
   if ( (K = par[J]) != -1 && frontOwners[K] == myid ) {
/*
      ------------------------------------------
      parent is owned, no message to be received
      ------------------------------------------
*/
      DFrontMtx_columnIndices(frontmtx, K, &ncolK, &colindK) ;
      DFrontMtx_setLocalColumnIndices(frontmtx, J, K, ncolK, colindK,
                                      map, msglvl, msgFile) ;
      if (  frontmtx->pivotingflag == 1
           && frontmtx->symmetryflag == 2 ) {
         DFrontMtx_rowIndices(frontmtx, K, &nrowK, &rowindK) ;
         DFrontMtx_setLocalRowIndices(frontmtx, J, K, nrowK, rowindK,
                                      map, msglvl, msgFile) ;
      }
   } else {
      p_msg[J] = checkMessages(frontmtx, J, p_msg[J], map,
                               stats, comm, msglvl, msgFile) ;
   }
   if ( p_msg[J] == NULL ) {
/*
      --------------------------------
      look for the next owned ancestor
      --------------------------------
*/
      K = par[J] ;
      while ( K != -1 && frontOwners[K] != myid ) {
         K = par[K] ;
      }
      if ( K != -1 && --ndescLeft[K] == 0 ) {
/*
         -----------------------------------------------------------
         next owned ancestor exists and all of its owned descendents
         are finished, place K on the head of the dequeue
         -----------------------------------------------------------
*/
         if ( msglvl > 1 ) {
            fprintf(msgFile, "\n placing %d on head of dequeue", K) ;
            fflush(msgFile) ;
         }
         Ideq_insertAtHead(dequeue, K) ;
      }
   } else {
/*
      -------------------------------------------
      front J is not done, place on tail of queue
      -------------------------------------------
*/
      if ( msglvl > 1 ) {
         fprintf(msgFile, "\n placing %d on tail of dequeue", J) ;
         fflush(msgFile) ;
      }
      Ideq_insertAtTail(dequeue, J) ;
   }
/*
   -------------------------
   check for completed sends
   -------------------------
*/
   firstsent = checkSentMessages(firstsent, msglvl, msgFile) ;
}
/*
   --------------------------------
   all done, clean up sent messages
   --------------------------------
*/
while ( firstsent != NULL ) {
   firstsent = checkSentMessages(firstsent, msglvl, msgFile) ;
}
return ; }

/*--------------------------------------------------------------------*/
/*
   -----------------------
   constructor
 
   created -- 97nov15, cca
   -----------------------
*/
static Msg * 
Msg_new (
   void 
) {
Msg   *msg ;
ALLOCATE(msg, struct _Msg, 1) ;
Msg_setDefaultFields(msg) ;

return(msg) ; }

/*--------------------------------------------------------------------*/
/*  
   ----------------------------------------------
   set the fields of the object to default values

   created -- 97nov22, cca
   ----------------------------------------------
*/
static void
Msg_setDefaultFields (
   Msg   *msg
) {
msg->info[0] =   0  ;
msg->info[1] =  -1  ;
msg->info[2] =   0  ;
msg->base    = NULL ;
msg->next    = NULL ;
msg->req     = NULL ;

return ; }

/*--------------------------------------------------------------------*/
/*
   -----------------------
   clear the data
 
   created -- 97nov22, cca
   -----------------------
*/
static void
Msg_clearData (
   Msg   *msg
) {
if ( msg->base != NULL ) {
   IVfree(msg->base) ;
}
Msg_setDefaultFields(msg) ;
 
return ; }
 
/*--------------------------------------------------------------------*/
/*
   -----------------------
   free the object
 
   created -- 97nov22, cca
   -----------------------
*/
static void
Msg_free (
   Msg   *msg
) {
Msg_clearData(msg) ;
FREE(msg) ;
 
return ; }

/*--------------------------------------------------------------------*/
/*
   ---------------------------------------
   wakeup a front, post sends and receives

   created -- 97nov22, cca
   ---------------------------------------
*/
static Msg *
wakeup ( 
   DFrontMtx   *frontmtx,
   int         J,
   int         myid,
   int         owners[],
   Msg         **pfirstsent,
   int         firsttag,
   int         stats[],
   MPI_Comm    comm,
   int         msglvl,
   FILE        *msgFile
) {
Msg   *firstrecv, *firstsent, *msg ;
int   destination, I, K, length, nbytes, ncolJ, nrowJ, source, tag ;
int   *colindJ, *fch, *ivec, *par, *rowindJ, *sib ;

if ( owners[J] != myid ) {
   fprintf(stderr, "\n proc %d : owner [%d] = %d", myid, J, owners[J]) ;
   exit(-1) ;
}
firstsent = *pfirstsent ;
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n on entrance, pfirstsent = %p, *pfirstsent = %p",
           pfirstsent, *pfirstsent) ;
   fflush(msgFile) ;
}
firstrecv = NULL ;
par = ETree_par(frontmtx->frontETree) ;
fch = ETree_fch(frontmtx->frontETree) ;
sib = ETree_sib(frontmtx->frontETree) ;
for ( I = fch[J] ; I != -1 ; I = sib[I] ) {
   if ( (destination = owners[I]) != myid ) {
/*
      ------------------------
      get the message together
      ------------------------
*/
      DFrontMtx_columnIndices(frontmtx, J, &ncolJ, &colindJ) ;
      length = ncolJ + 1 ;
      if (  frontmtx->pivotingflag == 1 
         && frontmtx->symmetryflag == 2 ) {
         DFrontMtx_rowIndices(frontmtx, J, &nrowJ, &rowindJ) ;
         length += nrowJ + 1 ;
      } else {
         nrowJ = 0 ;
      }
      nbytes  = length*sizeof(int) ;
      ivec    = IVinit(length, -1) ;
      ivec[0] = ncolJ ;
      IVcopy(ncolJ, ivec + 1, colindJ) ;
      if ( nrowJ > 0 ) {
         ivec[ncolJ+1] = nrowJ ;
         IVcopy(nrowJ, ivec + 2 + ncolJ, rowindJ) ;
      }
/*
      --------------------------
      post the notification send
      --------------------------
*/
      msg = Msg_new() ;
      msg->info[0] =   1  ;
      msg->info[1] =   J  ;
      msg->info[2] = nbytes ;
      msg->next    = firstsent ;
      firstsent    = msg   ;
      tag          = firsttag + J ;
      stats[0]++ ;
      stats[1] += 3*sizeof(int) ;
      if ( msglvl > 1 ) {
         fprintf(msgFile, 
"\n    posting Isend, msg %p, dest = %d, J = %d, nbytes = %d, tag = %d",
msg, destination, J, nbytes, tag) ;
         fflush(msgFile) ;
      }
      MPI_Isend((void *)msg->info, 3, MPI_INT, destination, tag,
                comm, &msg->req) ;
      if ( msglvl > 1 ) {
         fprintf(msgFile, ", return") ;
         fflush(msgFile) ;
      }
/*
      ---------------------
      post the indices send
      ---------------------
*/
      msg = Msg_new() ;
      msg->info[0] =   2  ;
      msg->info[1] =   J  ;
      msg->info[2] = nbytes ;
      msg->base    = ivec ;
      msg->next    = firstsent ;
      firstsent    = msg   ;
      tag          = firsttag + frontmtx->nfront + J ;
      stats[0]++ ;
      stats[1] += nbytes ;
      if ( msglvl > 1 ) {
         fprintf(msgFile, 
"\n    posting Isend, msg %p, dest = %d, J = %d, nbytes = %d, tag = %d",
msg, destination, J, nbytes, tag) ;
         fflush(msgFile) ;
      }
      MPI_Isend((void *)msg->base, length, MPI_INT, destination, tag,
                comm, &msg->req) ;
      if ( msglvl > 1 ) {
         fprintf(msgFile, ", return") ;
         fflush(msgFile) ;
      }
   }
}
if ( (K = par[J]) != -1 && (source = owners[K]) != myid ) {
/*
      ------------------------
      post receive from parent
      ------------------------
*/
   msg = Msg_new() ;
   msg->info[0] = 1 ;
   msg->info[1] = K ;
   msg->info[2] = 0 ;
   msg->next    = firstrecv ;
   firstrecv    = msg   ;
   tag          = firsttag + K ;
   stats[2]++ ;
   stats[3] += 3*sizeof(int) ;
   if ( msglvl > 1 ) {
      fprintf(msgFile, 
        "\n    posting Irecv, msg %p, source = %d, K = %d, tag = %d",
        msg, source, K, tag) ;
      fflush(msgFile) ;
   }
   MPI_Irecv((void *)msg->info, 3, MPI_INT, source, tag,
             comm, &msg->req) ;
   if ( msglvl > 1 ) {
      fprintf(msgFile, ", return") ;
      fflush(msgFile) ;
   }
}
*pfirstsent = firstsent ;
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n on exit, pfirstsent = %p, *pfirstsent = %p",
           pfirstsent, *pfirstsent) ;
   fflush(msgFile) ;
}

return(firstrecv) ; }

/*--------------------------------------------------------------------*/
/*
   ------------------------------
   check for messages for a front

   created -- 97nov22, cca
   ------------------------------
*/
static Msg *
checkMessages (
   DFrontMtx   *frontmtx,
   int         J,
   Msg         *firstrecv,
   int         map[],
   int         stats[],
   MPI_Comm    comm,
   int         msglvl,
   FILE        *msgFile
) {
int          error, flag, K, length, nbytes, ncolK, nrowK, 
             source, tag, type ;
int          *colindK, *ivec, *rowindK ;
MPI_Status   status ;
Msg          *msg, *nextmsg ;

for ( msg = firstrecv, firstrecv = NULL ; 
      msg != NULL ; 
      msg = nextmsg ) {
/*
   ------------------------
   set link to next message
   ------------------------
*/
   nextmsg   = msg->next ;
   msg->next = NULL ;
/*
   --------------------------------------------
   test to see if the message has been received
   --------------------------------------------
*/
   type = msg->info[0] ;
   if ( msglvl > 1 ) {
      fprintf(msgFile,
              "\n    checking message %p : type %d, frontid %d",
              msg, type, msg->info[1]) ;
      fflush(msgFile) ;
   }
   MPI_Test(&msg->req, &flag, &status) ;
   if ( msglvl > 1 ) {
      fprintf(msgFile, ", flag %d", flag) ;
      fflush(msgFile) ;
   }
   if ( flag != 1 ) {
/*
      -----------------------------------------
      message not received, keep it on the list
      -----------------------------------------
*/
      msg->next = firstrecv ;
      firstrecv = msg ;
   } else {
/*
      ------------------------------------------------
      message has been received, increment statistics,
      extract source, tag, and # of bytes
      ------------------------------------------------
*/
      K      = msg->info[1]      ;
      nbytes = msg->info[2]      ;
      source = status.MPI_SOURCE ;
      tag    = status.MPI_TAG    ;
      error  = status.MPI_ERROR  ;
      if ( msglvl > 1 ) {
         fprintf(msgFile,
                 "\n    message received, source %d, tag %d, error %d"
                 "\n    info = [ %d %d %d ] ",
                 source, tag, error, type, K, nbytes) ;
         fflush(msgFile) ;
      }
      switch ( type ) {
      case 1 :
/*
         --------------------
         notification message
         --------------------
*/
         stats[2]++ ;
         stats[3] += 3*sizeof(int) ;
         msg->info[0] = 2 ;
         length = nbytes / sizeof(int) ;
         msg->base = IVinit(length, -1) ;
         msg->next = nextmsg ;
         nextmsg   = msg ;
         tag += frontmtx->nfront ;
         if ( msglvl > 1 ) {
            fprintf(msgFile,
      "\n    posting Irecv, msg %p, type %d, length = %d, K %d, tag %d",
      msg, 2, length, K, tag) ;
            fflush(msgFile) ;
         }
         MPI_Irecv((void *)msg->base, length, MPI_INT, source, tag,
                   comm, &msg->req) ;
         if ( msglvl > 1 ) {
            fprintf(msgFile, ", return") ;
            fflush(msgFile) ;
         }
         break ;
      case 2 :
/*
         -------------------------
         message with indices of K
         -------------------------
*/
         stats[2]++ ;
         stats[3] += nbytes*sizeof(int) ;
         ivec    = msg->base ;
         ncolK   = ivec[0] ;
         colindK = ivec + 1 ;
         if (  frontmtx->pivotingflag == 1
            && frontmtx->symmetryflag == 2 ) {
            nrowK = ivec[ncolK + 1] ;
            rowindK = ivec + 2 + ncolK ;
         } else {
            nrowK = 0 ;
         }
         DFrontMtx_setLocalColumnIndices(frontmtx, J, K, ncolK, colindK,
                                         map, msglvl, msgFile) ;
         if ( nrowK > 0 ) {
            DFrontMtx_setLocalRowIndices(frontmtx, J, K, nrowK, rowindK,
                                         map, msglvl, msgFile) ;
         }
         Msg_free(msg) ;
         break ;
      default :
         break ;
      }
   }
}
return(firstrecv) ; }
   
/*--------------------------------------------------------------------*/
/*
   -----------------------
   check sent messages

   created -- 97nov22, cca
   -----------------------
*/
static Msg *
checkSentMessages (
   Msg    *firstsent,
   int    msglvl,
   FILE   *msgFile
) {
int          flag, frontid, nbytes, type ;
MPI_Status   status ;
Msg          *msg, *nextmsg ;

for ( msg = firstsent, firstsent = NULL ;
      msg != NULL ; 
      msg = nextmsg ) {
   nextmsg   = msg->next ;
   msg->next = NULL ;
   type      = msg->info[0] ;
   frontid   = msg->info[1] ;
   nbytes    = msg->info[2] ;
   if ( msglvl > 1 ) {
      fprintf(msgFile,
              "\n checking sent msg %p : type %d, front %d, nbytes %d",
              msg, type, frontid, nbytes) ;
      fflush(msgFile) ;
   }
   MPI_Test(&msg->req, &flag, &status) ;
   if ( msglvl > 1 ) {
      fprintf(msgFile, ", flag %d", flag) ;
      fflush(msgFile) ;
   }
   if ( flag == 1 ) {
      if ( msglvl > 1 ) {
         fprintf(msgFile, ", tag %d", status.MPI_TAG) ;
         fflush(msgFile) ;
      }
      Msg_free(msg) ;
   } else {
      msg->next = firstsent ;
      firstsent = msg   ;
   }
}
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n\n head of sent messages = %p", firstsent) ;
   fflush(msgFile) ;
}
return(firstsent) ; }

/*--------------------------------------------------------------------*/
