/*  test_nd3Upd.c  */

#include "../DA2.h"
#include "../../timings.h"

/*--------------------------------------------------------------------*/

void
main ( int argc, char *argv[] )
/*
   ----------------------------------------
   test the dense update routine DA2_nd3Upd

   A -= B * D * C^T

   created -- 96nov18, cca
   ----------------------------------------
*/
{
DA2      *aDA2, *bDA2, *cDA2, *dDA2, *eDA2 ;
double   ops, t1, t2, value ;
DV       *tmpDV ;
FILE     *msgFile ;
int      ierr, ii, jj, majorA, majorB, majorC, majorD, 
         msglvl, nrowA, ncolA, ncolD, seed ;

if ( argc != 11 ) {
   fprintf(stdout, 
"\n\n usage : %s msglvl msgFile majorA majorB majorC majorD"
"\n         nrowA ncolA ncolD seed distribution "
"\n    msglvl  -- message level"
"\n    msgFile -- message file"
"\n    majorA  -- 0 if A is row major, 1 if column major"
"\n    majorB  -- 0 if B is row major, 1 if column major"
"\n    majorC  -- 0 if C is row major, 1 if column major"
"\n    majorD  -- 0 if D is row major, 1 if column major"
"\n    nrowA   -- # of rows in A"
"\n    ncolA   -- # of columns in A"
"\n    ncolD   -- # of rows and columns in D"
"\n    seed    -- random number seed"
"\n", argv[0]) ;
   return ;
}
if ( (msglvl = atoi(argv[1])) < 0 ) {
   fprintf(stderr, "\n message level must be positive\n") ;
   exit(-1) ;
}
if ( strcmp(argv[2], "stdout") == 0 ) {
   msgFile = stdout ;
} else if ( (msgFile = fopen(argv[2], "a")) == NULL ) {
   fprintf(stderr, "\n unable to open file %s\n", argv[2]) ;
   return ;
}
majorA = atoi(argv[3]) ;
majorB = atoi(argv[4]) ;
majorC = atoi(argv[5]) ;
majorD = atoi(argv[6]) ;
nrowA  = atoi(argv[7]) ;
ncolA  = atoi(argv[8]) ;
ncolD  = atoi(argv[9]) ;
if (  nrowA <= 0 || ncolA <= 0 || ncolD <= 0 ) {
   fprintf(stderr, "\n invalid matrix dimensions"
           "\n nrowA = %d, ncolA = %d, ncolD = %d\n",
           nrowA, ncolA, ncolD) ;
   exit(-1) ;
}
if ( majorA < 0 || majorA > 1
  || majorB < 0 || majorB > 1
  || majorC < 0 || majorC > 1
  || majorD < 0 || majorD > 1 ) {
   fprintf(stderr, "\n invalid major values"
           "\n majorA = %d, majorB = %d, majorC = %d, majorD = %d"
           "\n 0 --> row major, 1 --> column major\n",
           majorA, majorB, majorC, majorD) ;
   exit(-1) ;
}
seed = atoi(argv[10]) ;
fprintf(msgFile, "\n nrowA = %d, ncolA = %d, ncolD = %d",
        nrowA, ncolA, ncolD) ;
fprintf(msgFile, 
        "\n majorA = %d, majorB = %d, majorC = %d, majorD = %d",
        majorA, majorB, majorC, majorD) ;
fflush(msgFile) ;
/*
   -----------------------------
   initialize the matrix objects
   -----------------------------
*/
MARKTIME(t1) ;
aDA2 = DA2_new() ;
bDA2 = DA2_new() ;
cDA2 = DA2_new() ;
dDA2 = DA2_new() ;
eDA2 = DA2_new() ;
if ( majorA == 0 ) {
   DA2_init(aDA2, nrowA, ncolA, ncolA, 1, NULL) ;
   DA2_init(eDA2, nrowA, ncolA, ncolA, 1, NULL) ;
} else {
   DA2_init(aDA2, nrowA, ncolA, 1, nrowA, NULL) ;
   DA2_init(eDA2, nrowA, ncolA, 1, nrowA, NULL) ;
}
if ( majorB == 0 ) {
   DA2_init(bDA2, nrowA, ncolD, ncolD, 1, NULL) ;
} else {
   DA2_init(bDA2, nrowA, ncolD, 1, nrowA, NULL) ;
}
if ( majorC == 0 ) {
   DA2_init(cDA2, ncolA, ncolD, ncolD, 1, NULL) ;
} else {
   DA2_init(cDA2, ncolA, ncolD, 1, ncolA, NULL) ;
}
if ( majorD == 0 ) {
   DA2_init(dDA2, ncolD, ncolD, ncolD, 1, NULL) ;
} else {
   DA2_init(dDA2, ncolD, ncolD, 1, ncolD, NULL) ;
}
MARKTIME(t2) ;
fprintf(msgFile, "\n CPU : %.3f to initialize matrix objects",
        t2 - t1) ;
MARKTIME(t1) ;
DA2_fillRandomUniform(aDA2, 0, 1, seed) ;
DA2_fillRandomUniform(bDA2, 0, 1, seed) ;
DA2_fillRandomUniform(cDA2, 0, 1, seed) ;
DA2_fillRandomUniform(dDA2, 0, 1, seed) ;
DA2_copy(eDA2, aDA2) ;
MARKTIME(t2) ;
fprintf(msgFile, 
        "\n CPU : %.3f to fill matrix with random numbers", t2 - t1) ;
if ( msglvl > 3 ) {
   fprintf(msgFile, "\n matrix A") ;
   DA2_writeForHumanEye(aDA2, msgFile) ;
   fprintf(msgFile, "\n matrix B") ;
   DA2_writeForHumanEye(bDA2, msgFile) ;
   fprintf(msgFile, "\n matrix C") ;
   DA2_writeForHumanEye(cDA2, msgFile) ;
   fprintf(msgFile, "\n matrix D") ;
   DA2_writeForHumanEye(dDA2, msgFile) ;
}
/*
   ----------------------------------
   compute the matrix vector multiply
   using the simplest kernel
   ----------------------------------
*/
ops = 2*nrowA*ncolD*(ncolD + ncolA) ;
MARKTIME(t1) ;
{
DA2      *tDA2 ;
double   *col, *row ;
DV       *colDV, *rowDV ;

rowDV = DV_new() ;
DV_init(rowDV, ncolD, NULL) ;
row = DV_entries(rowDV) ;
colDV = DV_new() ;
DV_init(colDV, ncolD, NULL) ;
col = DV_entries(colDV) ;
tDA2 = DA2_new() ;
DA2_init(tDA2, nrowA, ncolD, 1, nrowA, NULL) ;
for ( jj = 0 ; jj < ncolD ; jj++ ) {
   DA2_extractColumnDV(dDA2, colDV, jj) ;
   for ( ii = 0 ; ii < nrowA ; ii++ ) {
      DA2_extractRowDV(bDA2, rowDV, ii) ;
      value = DVdot(ncolD, row, col) ;
      DA2_setEntry(tDA2, ii, jj, value) ;
   }
}
for ( jj = 0 ; jj < ncolA ; jj++ ) {
   DA2_extractRowDV(cDA2, colDV, jj) ;
   for ( ii = 0 ; ii < nrowA ; ii++ ) {
      DA2_extractRowDV(tDA2, rowDV, ii) ;
      value = - DVdot(ncolD, row, col) ;
      DA2_addEntry(eDA2, ii, jj, value) ;
   }
}
DA2_free(tDA2) ;
DV_free(colDV) ;
DV_free(rowDV) ;
}
MARKTIME(t2) ;
fprintf(msgFile, 
"\n CPU : %.3f compute simplest mvm, %.0f ops, %.3f megaflops", 
    t2 - t1, ops, 1.0e-6*ops/(t2 - t1)) ;
if ( msglvl > 3 ) {
   fprintf(msgFile, "\n eDA2") ;
   DA2_writeForHumanEye(eDA2, msgFile) ;
}
/*
   ----------------------------------------
   compute the update with the other kernel
   ----------------------------------------
*/
tmpDV = DV_new() ;
MARKTIME(t1) ;
DA2_nd3Upd(aDA2, bDA2, dDA2, cDA2, tmpDV) ;
MARKTIME(t2) ;
DA2_sub(eDA2, aDA2) ;
fprintf(msgFile, 
        "\n CPU : %.3f DA2_nd3Upd(), %.3f megaflops, %12.4e error", 
        t2 - t1, 1.0e-6*ops/(t2 - t1), DA2_frobNorm(eDA2)) ;
if ( msglvl > 3 ) {
   fprintf(msgFile, "\n after the update") ;
   fprintf(msgFile, "\n matrix aDA2") ;
   DA2_writeForHumanEye(aDA2, msgFile) ;
}

fprintf(msgFile, "\n") ;

return ; }

/*--------------------------------------------------------------------*/
