/*  permuteFactorMPI.c  */

#include "../spoolesMPI.h"

/*--------------------------------------------------------------------*/
typedef struct _Msg   Msg ;
struct _Msg {
   int           info[3] ; /* type, frontid, nbytes */
   int           *base   ;
   Msg           *next   ;
   MPI_Request   req     ;
} ;
/*--------------------------------------------------------------------*/
static Msg * Msg_new ( void ) ;
static void Msg_setDefaultFields ( Msg *msg ) ;
static void Msg_clearData ( Msg *msg ) ;
static void Msg_free ( Msg *msg ) ;
static Msg * wakeup ( DFrontMtx *frontmtx, int J, int myid,
   int owners[], int firsttag, int stats[],
   MPI_Comm comm, int msglvl, FILE *msgFile ) ;
static Msg * checkMessages ( DFrontMtx *frontmtx, int J, Msg *firstrecv,
   int map[], int stats[], MPI_Comm comm, int msglvl, FILE *msgFile ) ;
static Msg * sendMessages ( DFrontMtx *frontmtx, int J, int myid,
   int owners[], Msg **pfirstsent, int firsttag, int stats[],
   MPI_Comm comm, int msglvl, FILE *msgFile ) ;
static Msg * checkSentMessages ( Msg *firstsent, 
   int msglvl, FILE *msgFile ) ;
static void loadNearestOwnedDescendents ( Tree *tree, int seed,
   int owners[], int myid, Ideq *dequeue, int msglvl, FILE *msgFile ) ;
/*--------------------------------------------------------------------*/
/*
   --------------------------------------------------------
   purpose ---

     (1) permute the columns of U_{J,bnd{J}} so they are in
         ascending order w.r.t. to global column numbering
     (2) permute the rows of L_{bnd{J},J} so they are in
         ascending order w.r.t. to global row numbering

   firsttag -- first tag for messages, will use tag values
      in [firsttag, firsttag + 2*nfront]
   stats[] -- statistics vector
      stats[0] -- # of sends
      stats[1] -- # of bytes sent
      stats[2] -- # of receives
      stats[3] -- # of bytes received

   created -- 97nov22, cca
   --------------------------------------------------------
*/
void
DFrontMtx_MPI_permuteFactor (
   DFrontMtx   *frontmtx,
   IV          *frontOwnersIV,
   int         firsttag,
   int         stats[],
   int         msglvl,
   FILE        *msgFile,
   MPI_Comm    comm
) {
char   *status ;
Ideq   *dequeue ;
int    flag, I, J, K, myid, ncolK, neqns, nfront, npath, nproc, 
       nrowK, rc, tag_bound ;
int    *colindK, *fch, *map, *frontOwners, *par, *rowindK, *sib ;
Msg    *firstsent ;
Msg    **p_msg ;
Tree   *tree ;

MPI_Comm_rank(comm, &myid) ;
MPI_Comm_size(comm, &nproc) ;
frontOwners = IV_entries(frontOwnersIV) ;
par    = ETree_par(frontmtx->frontETree) ;
fch    = ETree_fch(frontmtx->frontETree) ;
sib    = ETree_sib(frontmtx->frontETree) ;
tree   = ETree_tree(frontmtx->frontETree) ;
nfront = frontmtx->nfront ;
neqns  = frontmtx->neqns  ;
map    = IVinit(neqns, -1) ;
MPI_Attr_get(MPI_COMM_WORLD, MPI_TAG_UB, &tag_bound, &flag) ;
if ( firsttag < 0 || firsttag + nfront > tag_bound ) {
   fprintf(stderr, "\n fatal error in DFrontMtx_MPI_permuteFactor()"
           "\n tag range is [%d,%d], tag_bound = %d",
           firsttag, firsttag + nfront, tag_bound) ;
   exit(-1) ;
}
/*
   -------------------------------------
   create and fill the willUpdate vector
   -------------------------------------
*/
status = CVinit(nfront, 'F') ;
for ( J = 0 ; J < nfront ; J++ ) {
   if ( frontOwners[J] == myid ) {
      for ( K = J ; K != -1 && status[K] == 'F' ; K = par[K] ) {
         status[K] = 'W' ;
      }
   }
}
/*
   -------------------------------------
   count the number of paths in the tree
   -------------------------------------
*/
npath = 0 ;
for ( J = Tree_postOTfirst(tree) ;
      J != -1 ;
      J = Tree_postOTnext(tree, J) ) {
   if ( frontOwners[J] == myid ) {
      for ( I = fch[J] ; I != -1 ; I = sib[I] ) {
         if ( status[I] == 'W' ) {
            break ;
         }
      }
      if ( I == -1 ) {
         npath++ ;
      }
   }
}
/*
   ----------------------------------------------
   set up the dequeue for the pre-order traversal
   ----------------------------------------------
*/
dequeue = Ideq_new() ;
Ideq_resize(dequeue, npath) ;
loadNearestOwnedDescendents(tree, -1, frontOwners, myid, dequeue, 
                            msglvl, msgFile) ;
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n\n initial status vector") ;
   CVfprintf(msgFile, nfront, status) ;
   fflush(msgFile) ;
}
/*
   ----------------------
   set up working storage
   ----------------------
*/
ALLOCATE(p_msg, struct _Msg *, nfront) ;
for ( J = 0 ; J < nfront ; J++ ) {
   p_msg[J] = NULL ;
}
/*
   -------------------------------------------
   execute the post-order traversal and 
   overwrite global indices with local indices
   -------------------------------------------
*/
firstsent = NULL ;
while ( (J = Ideq_removeFromHead(dequeue)) != -1 ) {
   if ( msglvl > 1 ) {
      fprintf(msgFile, "\n\n ### checking out %d, status %c",
              J, status[J]) ;
      fflush(msgFile) ;
   }
   if ( status[J] == 'W' ) {
/*
      ----------------------------------
      wake up front J, post any receives
      ----------------------------------
*/
      p_msg[J] = wakeup(frontmtx, J, myid, frontOwners, 
                        firsttag, stats, comm, msglvl, msgFile) ;
      status[J] = 'A' ;
   }
   if ( status[J] == 'A' ) {
      if ( (K = par[J]) != -1 && frontOwners[K] == myid ) {
/*
         ------------------------------------------
         parent is owned, no message to be received
         ------------------------------------------
*/
         DFrontMtx_columnIndices(frontmtx, K, &ncolK, &colindK) ;
         DFrontMtx_permuteColumnsOfU(frontmtx, J, K, ncolK, colindK,
                                     map, msglvl, msgFile) ;
         if (  frontmtx->pivotingflag == 1
            && frontmtx->symmetryflag == 2 ) {
            DFrontMtx_rowIndices(frontmtx, K, &nrowK, &rowindK) ;
            DFrontMtx_permuteRowsOfL(frontmtx, J, K, nrowK, rowindK,
                                     map, msglvl, msgFile) ;
         }
         status[J] = 'F' ;
      } else {
         p_msg[J] = checkMessages(frontmtx, J, p_msg[J], map,
                                  stats, comm, msglvl, msgFile) ;
         if ( p_msg[J] == NULL ) {
            status[J] = 'F' ;
         }
      }
   }
   if ( status[J] == 'F' ) {
      sendMessages(frontmtx, J, myid, frontOwners, &firstsent, 
                   firsttag, stats, comm, msglvl, msgFile) ;
/*
      --------------------------------------------
      place nearest owned descendents on the queue
      --------------------------------------------
*/
      if ( msglvl > 1 ) {
         fprintf(msgFile, 
                 "\n\n loading nearest owned descendents of %d", J) ;
         fflush(msgFile) ;
      }
      loadNearestOwnedDescendents(tree, J, frontOwners, myid,
                                  dequeue, msglvl, msgFile) ;
   } else {
/*
      -------------------------------------------
      front J is not done, place on tail of queue
      -------------------------------------------
*/
      if ( msglvl > 1 ) {
         fprintf(msgFile, "\n placing %d on tail of dequeue", J) ;
         fflush(msgFile) ;
      }
      rc = Ideq_insertAtTail(dequeue, J) ;
      if ( rc == -1 ) {
         fprintf(stderr, "\n fatal error, not enough room in dequeue") ;
         fprintf(stdout, "\n fatal error, not enough room in dequeue") ;
         exit(-1) ;
      }
   }
/*
   -------------------------
   check for completed sends
   -------------------------
*/
   firstsent = checkSentMessages(firstsent, msglvl, msgFile) ;
}
/*
   --------------------------------
   all done, clean up sent messages
   --------------------------------
*/
while ( firstsent != NULL ) {
   firstsent = checkSentMessages(firstsent, msglvl, msgFile) ;
}
return ; }

/*--------------------------------------------------------------------*/
/*
   -----------------------
   constructor
 
   created -- 97nov15, cca
   -----------------------
*/
static Msg * 
Msg_new (
   void 
) {
Msg   *msg ;
ALLOCATE(msg, struct _Msg, 1) ;
Msg_setDefaultFields(msg) ;

return(msg) ; }

/*--------------------------------------------------------------------*/
/*  
   ----------------------------------------------
   set the fields of the object to default values

   created -- 97nov22, cca
   ----------------------------------------------
*/
static void
Msg_setDefaultFields (
   Msg   *msg
) {
msg->info[0] =   0  ;
msg->info[1] =  -1  ;
msg->info[2] =   0  ;
msg->base    = NULL ;
msg->next    = NULL ;
msg->req     = NULL ;

return ; }

/*--------------------------------------------------------------------*/
/*
   -----------------------
   clear the data
 
   created -- 97nov22, cca
   -----------------------
*/
static void
Msg_clearData (
   Msg   *msg
) {
if ( msg->base != NULL ) {
   IVfree(msg->base) ;
}
Msg_setDefaultFields(msg) ;
 
return ; }
 
/*--------------------------------------------------------------------*/
/*
   -----------------------
   free the object
 
   created -- 97nov22, cca
   -----------------------
*/
static void
Msg_free (
   Msg   *msg
) {
Msg_clearData(msg) ;
FREE(msg) ;
 
return ; }

/*--------------------------------------------------------------------*/
/*
   ---------------------------------------
   wakeup a front, post sends and receives

   created -- 97nov22, cca
   ---------------------------------------
*/
static Msg *
wakeup ( 
   DFrontMtx   *frontmtx,
   int         J,
   int         myid,
   int         owners[],
   int         firsttag,
   int         stats[],
   MPI_Comm    comm,
   int         msglvl,
   FILE        *msgFile
) {
Msg   *firstrecv, *msg ;
int   K, source, tag ;
int   *par ;

if ( owners[J] != myid ) {
   fprintf(stderr, "\n proc %d : owner [%d] = %d", myid, J, owners[J]) ;
   exit(-1) ;
}
firstrecv = NULL ;
par = ETree_par(frontmtx->frontETree) ;
if ( (K = par[J]) != -1 && (source = owners[K]) != myid ) {
/*
      ------------------------
      post receive from parent
      ------------------------
*/
   msg = Msg_new() ;
   msg->info[0] = 1 ;
   msg->info[1] = K ;
   msg->info[2] = 0 ;
   msg->next    = firstrecv ;
   firstrecv    = msg   ;
   tag          = firsttag + K ;
   stats[2]++ ;
   stats[3] += 3*sizeof(int) ;
   if ( msglvl > 1 ) {
      fprintf(msgFile, 
        "\n    posting Irecv, msg %p, source = %d, K = %d, tag = %d",
        msg, source, K, tag) ;
      fflush(msgFile) ;
   }
   MPI_Irecv((void *)msg->info, 3, MPI_INT, source, tag,
             comm, &msg->req) ;
   if ( msglvl > 1 ) {
      fprintf(msgFile, ", return") ;
      fflush(msgFile) ;
   }
}
return(firstrecv) ; }

/*--------------------------------------------------------------------*/
/*
   ------------------------------
   check for messages for a front

   created -- 97nov22, cca
   ------------------------------
*/
static Msg *
checkMessages (
   DFrontMtx   *frontmtx,
   int         J,
   Msg         *firstrecv,
   int         map[],
   int         stats[],
   MPI_Comm    comm,
   int         msglvl,
   FILE        *msgFile
) {
int          error, flag, K, length, nbytes, ncolK, nrowK, 
             source, tag, type ;
int          *colindK, *ivec, *rowindK ;
MPI_Status   status ;
Msg          *msg, *nextmsg ;

for ( msg = firstrecv, firstrecv = NULL ; 
      msg != NULL ; 
      msg = nextmsg ) {
/*
   ------------------------
   set link to next message
   ------------------------
*/
   nextmsg   = msg->next ;
   msg->next = NULL ;
/*
   --------------------------------------------
   test to see if the message has been received
   --------------------------------------------
*/
   type = msg->info[0] ;
   if ( msglvl > 1 ) {
      fprintf(msgFile,
              "\n    checking message %p : type %d, frontid %d",
              msg, type, msg->info[1]) ;
      fflush(msgFile) ;
   }
   MPI_Test(&msg->req, &flag, &status) ;
   if ( msglvl > 1 ) {
      fprintf(msgFile, ", flag %d", flag) ;
      fflush(msgFile) ;
   }
   if ( flag != 1 ) {
/*
      -----------------------------------------
      message not received, keep it on the list
      -----------------------------------------
*/
      msg->next = firstrecv ;
      firstrecv = msg ;
   } else {
/*
      ------------------------------------------------
      message has been received, increment statistics,
      extract source, tag, and # of bytes
      ------------------------------------------------
*/
      K      = msg->info[1]      ;
      nbytes = msg->info[2]      ;
      source = status.MPI_SOURCE ;
      tag    = status.MPI_TAG    ;
      error  = status.MPI_ERROR  ;
      if ( msglvl > 1 ) {
         fprintf(msgFile,
                 "\n    message received, source %d, tag %d, error %d"
                 "\n    info = [ %d %d %d ] ",
                 source, tag, error, type, K, nbytes) ;
         fflush(msgFile) ;
      }
      switch ( type ) {
      case 1 :
/*
         --------------------
         notification message
         --------------------
*/
         stats[2]++ ;
         stats[3] += 3*sizeof(int) ;
         msg->info[0] = 2 ;
         length = nbytes / sizeof(int) ;
         msg->base = IVinit(length, -1) ;
         msg->next = nextmsg ;
         nextmsg   = msg ;
         tag += frontmtx->nfront ;
         if ( msglvl > 1 ) {
            fprintf(msgFile,
      "\n    posting Irecv, msg %p, type %d, length = %d, K %d, tag %d",
      msg, 2, length, K, tag) ;
            fflush(msgFile) ;
         }
         MPI_Irecv((void *)msg->base, length, MPI_INT, source, tag,
                   comm, &msg->req) ;
         if ( msglvl > 1 ) {
            fprintf(msgFile, ", return") ;
            fflush(msgFile) ;
         }
         break ;
      case 2 :
/*
         -------------------------
         message with indices of K
         -------------------------
*/
         stats[2]++ ;
         stats[3] += nbytes ;
         if ( msglvl > 1 ) {
            fprintf(msgFile, "\n received message") ;
            IVfprintf(msgFile, nbytes/sizeof(int), msg->base) ;
            fflush(msgFile) ;
         }
         ivec    = msg->base ;
         ncolK   = ivec[0] ;
         colindK = ivec + 1 ;
         if (  frontmtx->pivotingflag == 1
            && frontmtx->symmetryflag == 2 ) {
            nrowK = ivec[ncolK + 1] ;
            rowindK = ivec + 2 + ncolK ;
         } else {
            nrowK = 0 ;
         }
         if ( msglvl > 1 ) {
            fprintf(msgFile, "\n ncolK = %d, nrowK = %d", ncolK, nrowK);
            fflush(msgFile) ;
         }
         DFrontMtx_permuteColumnsOfU(frontmtx, J, K, ncolK, colindK,
                                     map, msglvl, msgFile) ;
         if ( nrowK > 0 ) {
            DFrontMtx_permuteRowsOfL(frontmtx, J, K, nrowK, rowindK,
                                     map, msglvl, msgFile) ;
         }
         Msg_free(msg) ;
         break ;
      default :
         break ;
      }
   }
}
return(firstrecv) ; }
   
/*--------------------------------------------------------------------*/
/*
   -------------------------
   send messages to children

   created -- 97nov22, cca
   -------------------------
*/
static Msg *
sendMessages ( 
   DFrontMtx   *frontmtx,
   int         J,
   int         myid,
   int         owners[],
   Msg         **pfirstsent,
   int         firsttag,
   int         stats[],
   MPI_Comm    comm,
   int         msglvl,
   FILE        *msgFile
) {
Msg   *firstrecv, *firstsent, *msg ;
int   destination, I, length, nbytes, ncolJ, nrowJ, tag ;
int   *colindJ, *fch, *ivec, *rowindJ, *sib ;

if ( owners[J] != myid ) {
   fprintf(stderr, "\n proc %d : owner [%d] = %d", myid, J, owners[J]) ;
   exit(-1) ;
}
firstsent = *pfirstsent ;
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n on entrance, pfirstsent = %p, *pfirstsent = %p",
           pfirstsent, *pfirstsent) ;
   fflush(msgFile) ;
}
firstrecv = NULL ;
fch = ETree_fch(frontmtx->frontETree) ;
sib = ETree_sib(frontmtx->frontETree) ;
for ( I = fch[J] ; I != -1 ; I = sib[I] ) {
   if ( (destination = owners[I]) != myid ) {
/*
      ------------------------
      get the message together
      ------------------------
*/
      DFrontMtx_columnIndices(frontmtx, J, &ncolJ, &colindJ) ;
      length = ncolJ + 1 ;
      if (  frontmtx->pivotingflag == 1 
         && frontmtx->symmetryflag == 2 ) {
         DFrontMtx_rowIndices(frontmtx, J, &nrowJ, &rowindJ) ;
         length += nrowJ + 1 ;
      } else {
         nrowJ = 0 ;
      }
      nbytes  = length*sizeof(int) ;
      ivec    = IVinit(length, -1) ;
      ivec[0] = ncolJ ;
      IVcopy(ncolJ, ivec + 1, colindJ) ;
      if ( nrowJ > 0 ) {
         ivec[ncolJ+1] = nrowJ ;
         IVcopy(nrowJ, ivec + 2 + ncolJ, rowindJ) ;
      }
/*
      --------------------------
      post the notification send
      --------------------------
*/
      msg = Msg_new() ;
      msg->info[0] =   1  ;
      msg->info[1] =   J  ;
      msg->info[2] = nbytes ;
      msg->next    = firstsent ;
      firstsent    = msg   ;
      tag          = firsttag + J ;
      stats[0]++ ;
      stats[1] += 3*sizeof(int) ;
      if ( msglvl > 1 ) {
         fprintf(msgFile, 
"\n    posting Isend, msg %p, dest = %d, J = %d, nbytes = %d, tag = %d",
msg, destination, J, nbytes, tag) ;
         fflush(msgFile) ;
      }
      MPI_Isend((void *)msg->info, 3, MPI_INT, destination, tag,
                comm, &msg->req) ;
      if ( msglvl > 1 ) {
         fprintf(msgFile, ", return") ;
         fflush(msgFile) ;
      }
/*
      ---------------------
      post the indices send
      ---------------------
*/
      msg = Msg_new() ;
      msg->info[0] =   2  ;
      msg->info[1] =   J  ;
      msg->info[2] = nbytes ;
      msg->base    = ivec ;
      msg->next    = firstsent ;
      firstsent    = msg   ;
      tag          = firsttag + frontmtx->nfront + J ;
      stats[0]++ ;
      stats[1] += nbytes ;
      if ( msglvl > 1 ) {
         fprintf(msgFile, 
"\n    posting Isend, msg %p, dest = %d, J = %d, nbytes = %d, tag = %d",
msg, destination, J, nbytes, tag) ;
         fprintf(msgFile, "\n    msg->base message") ;
         IVfprintf(msgFile, length, msg->base) ;
         fflush(msgFile) ;
      }
      MPI_Isend((void *)msg->base, length, MPI_INT, destination, tag,
                comm, &msg->req) ;
      if ( msglvl > 1 ) {
         fprintf(msgFile, ", return") ;
         fflush(msgFile) ;
      }
   }
}
*pfirstsent = firstsent ;
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n on exit, pfirstsent = %p, *pfirstsent = %p",
           pfirstsent, *pfirstsent) ;
   fflush(msgFile) ;
}

return(firstrecv) ; }

/*--------------------------------------------------------------------*/
/*
   -----------------------
   check sent messages

   created -- 97nov22, cca
   -----------------------
*/
static Msg *
checkSentMessages (
   Msg    *firstsent,
   int    msglvl,
   FILE   *msgFile
) {
int          flag, frontid, nbytes, type ;
MPI_Status   status ;
Msg          *msg, *nextmsg ;

for ( msg = firstsent, firstsent = NULL ;
      msg != NULL ; 
      msg = nextmsg ) {
   nextmsg   = msg->next ;
   msg->next = NULL ;
   type      = msg->info[0] ;
   frontid   = msg->info[1] ;
   nbytes    = msg->info[2] ;
   if ( msglvl > 1 ) {
      fprintf(msgFile,
              "\n checking sent msg %p : type %d, front %d, nbytes %d",
              msg, type, frontid, nbytes) ;
      fflush(msgFile) ;
   }
   MPI_Test(&msg->req, &flag, &status) ;
   if ( msglvl > 1 ) {
      fprintf(msgFile, ", flag %d", flag) ;
      fflush(msgFile) ;
   }
   if ( flag == 1 ) {
      if ( msglvl > 1 ) {
         fprintf(msgFile, ", tag %d", status.MPI_TAG) ;
         fflush(msgFile) ;
      }
      Msg_free(msg) ;
   } else {
      msg->next = firstsent ;
      firstsent = msg   ;
   }
}
if ( msglvl > 1 ) {
   fprintf(msgFile, "\n\n head of sent messages = %p", firstsent) ;
   fflush(msgFile) ;
}
return(firstsent) ; }

/*--------------------------------------------------------------------*/
/*
   ---------------------------------------------
   load the dequeue with the closest descendents
   of seed that are owned by thread myid
 
   created -- 97jun27, cca
   ---------------------------------------------
*/
static void
loadNearestOwnedDescendents (
   Tree   *tree,
   int    seed,
   int    owners[],
   int    myid,
   Ideq   *dequeue,
   int    msglvl,
   FILE   *msgFile
) {
int    I ;
int    *fch = tree->fch ;
int    *par = tree->par ;
int    *sib = tree->sib ;

if ( seed != -1 ) {
   I = fch[seed] ;
} else {
   I = tree->root ;
}
while ( I != -1 ) {
   if ( owners[I] == myid ) {
      if ( msglvl > 1 ) {
         fprintf(msgFile, "\n loading descendent %d onto queue", I) ;
         fflush(msgFile) ;
      }
      Ideq_insertAtHead(dequeue, I) ;
      while ( sib[I] == -1 && par[I] != seed ) {
         I = par[I] ;
      }
      I = sib[I] ;
   } else {
      if ( fch[I] != -1 ) {
         I = fch[I] ;
      } else {
         while ( sib[I] == -1 && par[I] != seed ) {
            I = par[I] ;
         }
         I = sib[I] ;
      }
   }
}
return ; }
 
/*--------------------------------------------------------------------*/
