/*  solveDA2.c  */

#include "../DFrontMtx.h"
#include "../../timings.h"

#define MYDEBUG  0
#define CAUTIOUS 0
/*--------------------------------------------------------------------*/
/*
   ---------------------------------------
   solve one of three

   symmetryflag = 0 
      (U^T + I) * D * (I + U) sol = rhs
      D has 1x1 and/or 2x2 pivots
   symmetryflag = 2 
      (L + I) * D * (I + U) sol = rhs
      D is diagonal
   symmetryflag = 3 
      (U^T + D) * (D + U) sol = rhs
      D is diagonal
   cpus     -- vector of cpu time breakdowns
      cpus[0] -- initialize rhs matrices
      cpus[1] -- load rhs matrices with rhs 
      cpus[2] -- assemble from children and parent
      cpus[3] -- solve and update
      cpus[4] -- store entries
      cpus[5] -- link and free matrices
      cpus[6] -- miscellaneous
      cpus[7] -- total time

   created -- 97jun18, cca
   ---------------------------------------
*/
void
DFrontMtx_solveDA2 (
   DFrontMtx          *frontmtx,
   DA2                *solDA2,
   DA2                *rhsDA2,
   DDenseMtxManager   *manager, 
   double             cpus[] 
) {
DDenseMtx       *firstI, *mtx, *mtxBJ, *mtxJ, *mtxK ;
DDenseMtxList   *matrixList ;
double          t0, t1, t2, t3 ;
int             I, J, K, nbytesNeeded, ncolJ, nDJ, 
                neqns, nfront, nrhs, nrowJ ;
int             *fch, *par, *rowindJ, *sib ;
Tree            *tree ;
/*
   ---------------
   check the input
   ---------------
*/
MARKTIME(t0) ;
if ( frontmtx == NULL || solDA2 == NULL || rhsDA2 == NULL 
   || manager == NULL ) {
   fprintf(stderr, "\n fatal error in DFrontMtx_Solve(%p,%p,%p,%p)"
           "\n bad input\n", 
           frontmtx, solDA2, rhsDA2, manager) ;
   exit(-1) ;
}
nfront       = frontmtx->nfront ;
neqns        = DA2_nrow(rhsDA2) ;
nrhs         = DA2_ncol(rhsDA2) ;
tree         = frontmtx->frontETree->tree ;
par          = tree->par ;
fch          = tree->fch ;
sib          = tree->sib ;
matrixList = DDenseMtxList_new() ;
DDenseMtxList_init(matrixList, nfront, NULL, 0, NULL) ;
DVzero(8, cpus) ;
#if MYDEBUG > 0
fprintf(stdout, "\n\n rhsDA2") ;
DVfprintf(stdout, neqns*nrhs, DA2_entries(rhsDA2)) ;
fflush(stdout) ;
#endif
/*
   ----------------------------------------------
   forward solve: 
   loop over the fronts in a post-order traversal
   ----------------------------------------------
*/
for ( J = Tree_postOTfirst(tree) ;
      J != -1 ;
      J = Tree_postOTnext(tree, J) ) {
   K   = par[J] ;
   nDJ = DFrontMtx_frontSize(frontmtx, J) ;
#if MYDEBUG > 0
   fprintf(stdout, 
           "\n\n ### forward solve with front %d, nDJ = %d, parent %d", 
           J, nDJ, K) ;
   fflush(stdout) ;
#endif
/*
   -------------------------------------
   initialize the mtxJ and mtxBJ objects
   -------------------------------------
*/
   MARKTIME(t1) ;
   DFrontMtx_forwInit(frontmtx, J, nrhs, manager, &mtxJ, &mtxBJ) ;
   MARKTIME(t2) ;
   cpus[0] += t2 - t1 ;
#if MYDEBUG > 0
   if ( mtxJ != NULL ) {
      fprintf(stdout, "\n after initialization, mtxJ") ;
      DDenseMtx_writeForHumanEye(mtxJ, stdout) ;
      fflush(stdout) ;
   }
   if ( mtxBJ != NULL ) {
      fprintf(stdout, "\n after initialization, mtxBJ") ;
      DDenseMtx_writeForHumanEye(mtxBJ, stdout) ;
      fflush(stdout) ;
   }
#endif
/*
   --------------------------------
   load the right hand side entries
   --------------------------------
*/
   if ( mtxJ != NULL ) {
      MARKTIME(t1) ;
      DFrontMtx_forwLoadRHS(mtxJ, rhsDA2) ;
      MARKTIME(t2) ;
      cpus[1] += t2 - t1 ;
#if MYDEBUG > 0
      if ( mtxJ != NULL ) {
         fprintf(stdout, "\n after loading rhs, mtxJ") ;
         DDenseMtx_writeForHumanEye(mtxJ, stdout) ;
         fflush(stdout) ;
      }
#endif
   }
/*
   ------------------------------------------------------
   assemble the right hand side updates from the children
   ------------------------------------------------------
*/
   if ( (firstI = DDenseMtxList_getList(matrixList, J)) != NULL ) {
      MARKTIME(t1) ;
      DFrontMtx_forwLoadFromChildren(mtxJ, mtxBJ, firstI) ;
#if MYDEBUG > 0
      if ( mtxJ != NULL ) {
         fprintf(stdout, "\n\n mtxJ after assembling child updates") ;
         DDenseMtx_writeForHumanEye(mtxJ, stdout) ;
         fflush(stdout) ;
      }
      if ( mtxBJ != NULL ) {
         fprintf(stdout, "\n\n mtxBJ after assembling child updates") ;
         DDenseMtx_writeForHumanEye(mtxBJ, stdout) ;
         fflush(stdout) ;
      }
#endif
      DDenseMtxManager_releaseListOfObjects(manager, firstI) ;
      MARKTIME(t2) ;
      cpus[2] += t2 - t1 ;
   }
   if ( mtxJ != NULL ) {
/*
      -----------------------
      do the solve and update
      -----------------------
*/
      MARKTIME(t1) ;
      DFrontMtx_forwSolveAndUpdate(frontmtx, mtxJ, mtxBJ) ;
      MARKTIME(t2) ;
      cpus[3] += t2 - t1 ;
#if MYDEBUG > 0
      fprintf(stdout, "\n\n AFTER forward solve, mtxJ") ;
      DDenseMtx_writeForHumanEye(mtxJ, stdout) ;
      if ( mtxBJ != NULL ) {
         fprintf(stdout, "\n\n AFTER forward solve, mtxBJ") ;
         DDenseMtx_writeForHumanEye(mtxBJ, stdout) ;
      }
      fflush(stdout) ;
#endif
/*
   --------------------------------------------------------
   store the solution entries, release the internal entries 
   --------------------------------------------------------
*/
      MARKTIME(t1) ;
      DFrontMtx_forwStore(rhsDA2, mtxJ) ;
      MARKTIME(t2) ;
      cpus[4] += t2 - t1 ;
      MARKTIME(t1) ;
      DDenseMtxManager_releaseObject(manager, mtxJ) ;
      MARKTIME(t2) ;
      cpus[5] += t2 - t1 ;
   }
   if ( mtxBJ != NULL ) {
/*
   -------------------------------------------------
   link the update object in the list for the parent
   -------------------------------------------------
*/
      MARKTIME(t1) ;
      DDenseMtxList_addObjectToList(matrixList, mtxBJ, par[J]) ;
      MARKTIME(t2) ;
      cpus[5] += t2 - t1 ;
   }
}
#if MYDEBUG > 0
fprintf(stdout, "\n\n after forward solve, maxabs(rhs) = %12.4e",
        DA2_maxabs(rhsDA2)) ;
fflush(stdout) ;
#endif
#if MYDEBUG > 0
fprintf(stdout, "\n\n after forward solve, rhs ") ;
/*
DVfprintf(stdout, neqns*nrhs, DA2_entries(rhsDA2)) ;
*/
DA2_writeForHumanEye(rhsDA2, stdout) ;
fflush(stdout) ;
#endif
/*
   ------------------------------------------------------
   backward solve: (I + U) sol = rhs or (D + U) sol = rhs
   loop over the fronts in a pre-order traversal
   ------------------------------------------------------
*/
for ( J = Tree_preOTfirst(tree) ;
      J != -1 ;
      J = Tree_preOTnext(tree, J) ) {
   K = par[J] ;
#if MYDEBUG > 0
   fprintf(stdout, "\n\n ### backward solve with front %d, parent %d", 
           J, K) ;
#endif
/*
   ------------------------------
   initialize the solution object
   ------------------------------
*/
   MARKTIME(t1) ;
   mtxJ = DFrontMtx_backInit(frontmtx, J, nrhs, manager) ;
   MARKTIME(t2) ;
   cpus[0] += t2 - t1 ;
#if MYDEBUG > 0
   fprintf(stdout, "\n after initialization") ;
   DDenseMtx_writeForHumanEye(mtxJ, stdout) ;
   fflush(stdout) ;
#endif
   nDJ = DFrontMtx_frontSize(frontmtx, J) ;
/*
   --------------------------------
   load the right hand side entries
   --------------------------------
*/
   MARKTIME(t1) ;
   DFrontMtx_rowIndices(frontmtx, J, &nrowJ, &rowindJ) ;
   DFrontMtx_backLoadSolution(mtxJ, nDJ, rowindJ, rhsDA2) ;
   MARKTIME(t2) ;
   cpus[1] += t2 - t1 ;
#if MYDEBUG > 0
   fprintf(stdout, "\n after loading rhs") ;
   DDenseMtx_writeForHumanEye(mtxJ, stdout) ;
   fflush(stdout) ;
#endif
   if ( K != -1 ) {
/*
      --------------------------------------------------------
      load entries from the parent and release parent's object
      --------------------------------------------------------
*/
      MARKTIME(t1) ;
      mtxK = DDenseMtxList_getList(matrixList, J) ;
      DFrontMtx_backLoadFromParent(mtxJ, nDJ, mtxK) ;
      DDenseMtxManager_releaseObject(manager, mtxK) ;
      MARKTIME(t2) ;
      cpus[2] += t2 - t1 ;
#if MYDEBUG > 0
      fprintf(stdout, "\n after parent's solution gathered") ;
      DDenseMtx_writeForHumanEye(mtxJ, stdout) ;
      fflush(stdout) ;
#endif
   }
   if ( nDJ > 0 ) {
/*
      -------------------------------------
      perform the update and backward solve
      -------------------------------------
*/
      MARKTIME(t1) ;
      DFrontMtx_backSolveAndUpdate(frontmtx, mtxJ) ;
      MARKTIME(t2) ;
      cpus[3] += t2 - t1 ;
#if MYDEBUG > 0
      fprintf(stdout, "\n after update and solve") ;
      DDenseMtx_writeForHumanEye(mtxJ, stdout) ;
      fflush(stdout) ;
#endif
/*
      --------------------------
      store the solution entries
      --------------------------
*/
      MARKTIME(t1) ;
      DFrontMtx_backStore(mtxJ, nDJ, solDA2) ;
      MARKTIME(t2) ;
      cpus[4] += t2 - t1 ;
   }
   MARKTIME(t1) ;
   if ( (I = fch[J]) != -1 ) {
/*
      ----------------------------------------------------------------
      place a copy of the solution matrix into each list for the child
      ----------------------------------------------------------------
*/
      ncolJ = mtxJ->nrow ;
      for ( I = sib[I] ; I != -1 ; I = sib[I] ) {
         nbytesNeeded = DDenseMtx_nbytesNeeded(ncolJ, nrhs) ;
         mtx = DDenseMtxManager_newObjectOfSizeNbytes(manager,
                                                      nbytesNeeded) ;
         DDenseMtx_init(mtx, J, -1, ncolJ, nrhs, 1, ncolJ) ;
         DVcopy(ncolJ*nrhs, mtx->entries, mtxJ->entries) ;
         DDenseMtxList_addObjectToList(matrixList, mtx, I) ;
      }
      DDenseMtxList_addObjectToList(matrixList, mtxJ, fch[J]) ;
   } else {
/*
      ---------------------------
      no children, release matrix
      ---------------------------
*/
      DDenseMtxManager_releaseObject(manager, mtxJ) ;
   }
   MARKTIME(t2) ;
   cpus[5] += t2 - t1 ;
#if MYDEBUG > 0
   fprintf(stdout, "\n solution stored and linked") ;
   fflush(stdout) ;
#endif
}
#if MYDEBUG > 0
fprintf(stdout, "\n\n after backward solve, maxabs(sol) = %12.4e",
        DA2_maxabs(solDA2)) ;
fflush(stdout) ;
#endif
#if MYDEBUG > 0
fprintf(stdout, "\n\n after backward solve, sol ") ;
/*
DVfprintf(stdout, neqns*nrhs, DA2_entries(solDA2)) ;
*/
DA2_writeForHumanEye(solDA2, stdout) ;
fflush(stdout) ;
#endif
/*
   ------------------------
   free the working storage
   ------------------------
*/
DDenseMtxList_free(matrixList) ;
MARKTIME(t3) ;
cpus[7] = t3 - t0 ;
cpus[6] = cpus[7] -
         (cpus[0] + cpus[1] + cpus[2] + cpus[3] + cpus[4] + cpus[5]) ;

return ; }

/*--------------------------------------------------------------------*/
