/*
 *             Automatically Tuned Linear Algebra Software v3.3.0Dev
 **************** THIS IS AN UNSUPPORTED DEVELOPER RELEASE *****************
 *                    (C) Copyright 1999 R. Clint Whaley                     
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *   1. Redistributions of source code must retain the above copyright
 *      notice, this list of conditions and the following disclaimer.
 *   2. Redistributions in binary form must reproduce the above copyright
 *      notice, this list of conditions, and the following disclaimer in the
 *      documentation and/or other materials provided with the distribution.
 *   3. The name of the University of Tennessee, the ATLAS group,
 *      or the names of its contributers may not be used to endorse
 *      or promote products derived from this software without specific
 *      written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE. 
 *
 */

#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
#include <assert.h>
#include "atlas_fopen.h"

#define ATL_MVIsMM(iflag_) ( ((iflag_) | 8) == (iflag_) )
#define ATL_MVIsAxpy(iflag_) ( ((iflag_) | 16) == (iflag_) )
#define ATL_MVIsDot(iflag_) ( !ATL_MVIsMM(iflag_) && !ATL_MVIsAxpy(iflag_) )
#define ATL_MVNoBlock(iflag_) ( ((iflag_) | 32) == (iflag_) )

double GetAvg(int n, double tolerance, double *mflop)
{
   int i, j;
   double t0, tavg;
/*
 * Sort results, largest first
 */
   for (i=0; i != n; i++)
   {
      for (j=i+1; j < n; j++)
      {
         if (mflop[i] < mflop[j])
         {
            t0 = mflop[i];
            mflop[i] = mflop[j];
            mflop[j] = t0;
         }
      }
   }
/*
 * Throw out result if it is outside tolerance; rerun if two mflop not within
 * tolerance;  this code assumes n == 3
 */
   if (tolerance*mflop[1] < mflop[0])  /* too big a range in results */
   {
      if (tolerance*mflop[2] < mflop[1]) return(-1.0);
      tavg = (mflop[1] + mflop[2]) / 2.0;
   }
   else if (tolerance*mflop[2] < mflop[0]) tavg = (mflop[0] + mflop[1]) / 2.0;
   else tavg = (mflop[0] + mflop[1] + mflop[2]) / 3.0;

   return(tavg);
}

int GetL1CacheSize()
{
   FILE *L1f;
   int L1Size;

   if (!FileExists("res/L1CacheSize"))
   {
      assert(system("make res/L1CacheSize\n") == 0);
   }
   L1f = fopen("res/L1CacheSize", "r");
   assert(L1f != NULL);
   fscanf(L1f, "%d", &L1Size);
   fclose(L1f);
   fprintf(stderr, "\n      Read in L1 Cache size as = %dKB.\n",L1Size);
   return(L1Size);
}

int mvtstcase(char pre, char TA, char *mvnam)
{
   char ln[256];
   sprintf(ln, "make %cmvtstcase%c mvrout=%s\n", pre, TA, mvnam);
   return(system(ln));
}

double mvcase(char pre, char *mvnam, char TA, int flag, int mu, int nu, 
              int cas, double l1mul)
{
   char nTA;
   char ln[128], fnam[64];
   const int imul = l1mul*100.0;
   double mfs[3], mf;
   FILE *fp;

   if (TA == 'n' || TA == 'N') nTA = 'T';
   else nTA = 'N';

   if (ATL_MVNoBlock(flag)) sprintf(fnam, "res/%cgemv%c_%d_0", pre, TA, cas);
   else sprintf(fnam, "res/%cgemv%c_%d_%d", pre, TA, cas, imul);
   if (!FileExists(fnam))
   {
      sprintf(ln, 
"make %cmvcase ta=%c nta=%c mvrout=%s cas=%d xu=%d yu=%d l1mul=%f iflag=%d gmvout=\"-o %s\"\n", 
              pre, TA, nTA, mvnam, cas, nu, mu, imul*0.01, flag, fnam);
      fprintf(stderr, "%s", ln);
      if (system(ln)) return(-1.0);  /* won't compile here */
   }
   fp = fopen(fnam, "r");
   assert(fp);
   assert(fscanf(fp, " %lf %lf %lf", mfs, mfs+1, mfs+2) == 3);
   fclose(fp);
   mf = GetAvg(3, 1.20, mfs);
   if (mf == -1.0)
   {
      fprintf(stderr, 
"\n\n%s : VARIATION EXCEEDS TOLERENCE, RERUN WITH HIGHER REPS.\n\n", fnam);
      sprintf(ln, "rm -f %s\n", fnam);
      system(ln);
      exit(-1);
   }
   return(mf);
}

double FindL1Mul(char pre, int cas, char *mvnam, char TA, int flag,
                 int mu, int nu)
{
   double low = .5, high = 1.0;
   double mflow, mfhigh;
   int ilow, ihigh;

   if (ATL_MVNoBlock(flag)) flag -= 32;  /* always actually block these times */
   do
   {
      ilow = (low  * 100.0);
      ihigh = (high * 100.0);
      mflow = mvcase(pre, mvnam, TA, flag, mu, nu, cas, low);
      mfhigh = mvcase(pre, mvnam, TA, flag, mu, nu, cas, high);
      fprintf(stdout, "      %.2f%% %.2fMFLOP  ---  %.2f%% %.2fMFLOP\n",
              low*100.0, mflow, high*100.0, mfhigh);
      if (mflow < 1.005*mfhigh) low += 0.5*(high-low);
      else high -= 0.5 * (high-low);
   }
   while (ihigh-ilow);
   fprintf(stdout, "\n\nBEST %% of L1 cache: %.2f\n", low*100.0);
   return(low);
}

int ConfirmBlock(char pre, char *mvnam, char TA, int flag, int mu, int nu,
                 int cas, double l1mul)
{
   int bflag;
   double mfblock, mfnoblock;

   if ( ATL_MVNoBlock(flag) )
   {
      bflag = flag - 32;
      mfblock   = mvcase(pre, mvnam, TA, bflag, mu, nu, cas, l1mul);
      mfnoblock = mvcase(pre, mvnam, TA,  flag, mu, nu, cas, l1mul);
      fprintf(stdout, "\nWith blocking=%lf, without=%lf\n\n", 
              mfblock, mfnoblock);
      if (mfblock >= mfnoblock) return(bflag);
   }
   return(flag);
}
void GetCases(FILE *fp, int *N, char ***fnams, char ***auths, int **flags, 
              int **mus, int **nus)
{
   int i, n;
   int *mu, *nu, *flag;
   char **fnam, **auth;
   char ln[256];

   assert(fgets(ln, 128, fp));
   assert(sscanf(ln, " %d", &n) == 1);
   assert(n < 100 && n > 0);
   fnam = malloc(n * sizeof(char*));
   auth = malloc(n * sizeof(char*));
   assert(fnam && auth);
   for (i=0; i < n; i++)
   {
      assert(fnam[i] = malloc(64*sizeof(char)));
      assert(auth[i] = malloc(64*sizeof(char)));
   }
   mu = malloc(n * sizeof(int));
   nu = malloc(n * sizeof(int));
   flag = malloc(n * sizeof(int));
   assert(mu && nu && flag);
   for (i=0; i < n; i++)
   {
      assert(fgets(ln, 256, fp));
      assert(sscanf(ln, " %d %d %d %s \"%[^\"]", 
                    flag+i, mu+i, nu+i, fnam[i], auth[i]) == 5);
      assert(mu[i] >= 0 && nu[i] >= 0 && fnam[i][0] != '\0');
   }
   
   *N = n;
   *fnams = fnam;
   *auths = auth;
   *flags = flag;
   *mus = mu;
   *nus = nu;
}

int RunTransCases(char pre, char TA, int ncases, char **fnams, 
                  int *flags, int *mus, int *nus)
{
   int i, imax=0;
   double mf, mfmax=0.0;

   for (i=0; i < ncases; i++)
   {
      mf = mvcase(pre, fnams[i], TA, flags[i], mus[i], nus[i], i+1, 0.75);
      fprintf(stdout, "%s : %.2f\n", fnams[i], mf);
      if (mf > mfmax)
      {
         if (mvtstcase(pre, TA, fnams[i]) == 0) /* ensure it passes test */
         {
            mfmax = mf;
            imax = i+1;
         }
         else fprintf(stderr, "\n\nROUTINE %s FAILED TESTS!!!\n\n", fnams[i]);
      }
   }
   assert(imax);
   fprintf(stdout, 
           "\nbest %cgemv%c : case %d, mu=%d, nu=%d at %.2f MFLOPS\n\n", 
           pre, TA, imax, mus[imax-1], nus[imax-1], mfmax);
   return(imax-1);
}

void ReadSum(char pre, double *l1mul, char *fnamN, char *authN, int *flagN,
             int *muN, int *nuN, double *mfN, char *fnamT, char *authT,
             int *flagT, int *muT, int *nuT, double *mfT)
{
   char ln[256];
   FILE *fp;

   sprintf(ln, "res/%cMVRES", pre);
   fp = fopen(ln, "r");
   assert(fp);

   assert(fgets(ln, 256, fp));
   assert(sscanf(ln, " %d %d %d %lf %lf %s \"%[^\"]", 
                 flagN, muN, nuN, l1mul, mfN, fnamN, authN) == 7);
   assert(fgets(ln, 256, fp));
   assert(sscanf(ln, " %d %d %d %lf %lf %s \"%[^\"]", 
                 flagT, muT, nuT, l1mul, mfT, fnamT, authT) == 7);
   fclose(fp);
}

void CreateSum(char pre, double l1mul, char *fnamN, char *authN, int flagN,
               int muN, int nuN, double mfN, char *fnamT, char *authT,
               int flagT, int muT, int nuT, double mfT)
{
   char fnam[32];
   FILE *fp;

   sprintf(fnam, "res/%cMVRES", pre);
   fp = fopen(fnam, "w");
   assert(fp);
   fprintf(fp, "%d %d %d %.2f %.2f %s \"%s\"\n", 
           flagN, muN, nuN, l1mul, mfN, fnamN, authN);
   fprintf(fp, "%d %d %d %.2f %.2f %s \"%s\"\n",
           flagT, muT, nuT, l1mul, mfT, fnamT, authT);
   fclose(fp);
}

void mvinstall(char pre, double l1mul, char *fnamN, char *authN, int flagN,
               int muN, int nuN, char *fnamT, char *authT, int flagT,
               int muT, int nuT)
{
   char ln[256];
   double mfN, mfT;

   mfN = mvcase(pre, fnamN, 'N', flagN, muN, nuN, 100, l1mul);
   mfT = mvcase(pre, fnamT, 'T', flagT, muT, nuT, 100, l1mul);
   sprintf(ln, "make xemit_head \n");
   assert(system(ln) == 0);
   sprintf(ln, "./xemit_head -p %c -l %f -N -f %d -y %d -x %d -T -f %d -y %d -x %d\n", pre, l1mul, flagN, muN, nuN, flagT, muT, nuT);
   assert(system(ln) == 0);
   sprintf(ln, "make %cinstall mvNrout=%s mvTrout=%s\n", pre, fnamN, fnamT);
   fprintf(stderr, "%s", ln);
   assert(system(ln) == 0);
   CreateSum(pre, l1mul, fnamN, authN, flagN, muN, nuN, mfN, 
             fnamT, authT, flagT, muT, nuT, mfT);
}

void RunCases(char pre)
{
   char fnam[128];
   char Nfnam[64], Tfnam[64], Nauth[64], Tauth[64];
   char **fnamN, **fnamT, **authN, **authT;
   int i, nNTcases, nTcases, Nbest, Tbest;
   int Nflag, Nmu, Nnu, Tflag, Tmu, Tnu;
   int *flagN, *muN, *nuN, *flagT, *muT, *nuT;
   double l1mul;
   FILE *fp;

/*
 * Read in cases to try
 */
   sprintf(fnam, "../CASES/%ccases.dsc", pre);
   fp = fopen(fnam, "r");
   assert(fp);
   GetCases(fp, &nNTcases, &fnamN, &authN, &flagN, &muN, &nuN);
   GetCases(fp, &nTcases,  &fnamT, &authT, &flagT, &muT, &nuT);
   fclose(fp);
/*
 * Try all cases for each trans case
 */
   Nbest = RunTransCases(pre, 'N', nNTcases, fnamN, flagN, muN, nuN);
   Tbest = RunTransCases(pre, 'T', nTcases, fnamT, flagT, muT, nuT);

   Nflag = flagN[Nbest]; Tflag = flagT[Tbest];
   Nmu = muN[Nbest]; Nnu = nuN[Nbest];
   strcpy(Nfnam, fnamN[Nbest]); strcpy(Nauth, authN[Nbest]);
   Tmu = muT[Tbest]; Tnu = nuT[Tbest];
   strcpy(Tfnam, fnamT[Tbest]); strcpy(Tauth, authT[Tbest]);

   free(flagN); free(flagT); free(muN); free(muT); free(nuN); free(nuT);
   for (i=0; i < nNTcases; i++)
   {
      free(fnamN[i]);
      free(authN[i]);
   }
   free(fnamN);
   for (i=0; i < nTcases; i++)
   {
      free(fnamT[i]);
      free(authT[i]);
   }
   free(fnamT);
   l1mul = FindL1Mul(pre, Tbest+1, Tfnam, 'T', Tflag, Tmu, Tnu);
   Tflag = ConfirmBlock(pre, Tfnam, 'T', Tflag, Tmu, Tnu, Tbest+1, l1mul);
   Nflag = ConfirmBlock(pre, Nfnam, 'N', Nflag, Nmu, Nnu, Nbest+1, l1mul);
   CreateSum(pre, l1mul, Nfnam, Nauth, Nflag, Nmu, Nnu, -1.0, 
             Tfnam, Tauth, Tflag, Tmu, Tnu, -1.0);
}

void GoToTown(pre)
{
   char fnamN[128], authN[64], fnamT[128], authT[64], ln[128];
   int flagN, muN, nuN, flagT, muT, nuT;
   double l1mul, mfN, mfT;

   sprintf(ln, "res/%cMVRES", pre);
   if (!FileExists(ln)) RunCases(pre);
   else  /* if default does not pass tester, rerun cases */
   {
      ReadSum(pre, &l1mul, fnamN, authN, &flagN, &muN, &nuN, &mfN,
              fnamT, authT, &flagT, &muT, &nuT, &mfT);
      if ( (mvtstcase(pre, 'N', fnamN) != 0) ||
           (mvtstcase(pre, 'T', fnamT) != 0) )
      {
         assert(remove(ln) == 0);
         RunCases(pre);
      }
   }
   ReadSum(pre, &l1mul, fnamN, authN, &flagN, &muN, &nuN, &mfN,
           fnamT, authT, &flagT, &muT, &nuT, &mfT);
   mvinstall(pre, l1mul, fnamN, authN, flagN, muN, nuN, fnamT, authT,
             flagT, muT, nuT);
}

void PrintUsage(char *fnam)
{
   fprintf(stderr, "USAGE: %s [-p <s,d,c,z>]\n", fnam);
   exit(-1);
}

void GetFlags(int nargs, char **args, char *pre, char *TA, int *mu, int *nu)
{
   char ctmp;
   int i;
   *pre = 'd';
   *mu = *nu = 0;
   *TA = ' ';
   for (i=1; i < nargs; i++)
   {
      if (args[i][0] != '-') PrintUsage(&args[0][0]);
      switch(args[i][1])
      {
      case 'p':
         ctmp = args[++i][0];
         ctmp = tolower(ctmp);
         if (ctmp == 's' || ctmp == 'd' || ctmp == 'c' || ctmp == 'z')
            *pre = ctmp;
         else PrintUsage(&args[0][0]);
         break;
      case 'm':
         *mu = atoi(args[++i]);
         assert(*mu > 0);
         break;
      case 'n':
         *nu = atoi(args[++i]);
         assert(*nu > 0);
         break;
      case 'A':
         ctmp = args[++i][0];
         ctmp = toupper(ctmp);
         if (ctmp == 'N' || ctmp == 'T') *TA = ctmp;
         else PrintUsage(args[0]);
         break;
      default:
         fprintf(stderr, "Unknown flag : %s\n", args[i]);
         PrintUsage(&args[0][0]);
      }
   }
   if (*mu == 0 && *nu) *mu = *nu;
   else if (*nu == 0 && *mu) *nu = *mu;
   if ( (*mu || *nu) && *TA == ' ') *TA = 'N';
}


main(int nargs, char **args)
{
   char pre, TA;
   int mu, nu;
   GetFlags(nargs, args, &pre, &TA, &mu, &nu);
   GoToTown(pre);
   exit(0);
}
