/*
 *             Automatically Tuned Linear Algebra Software v3.3.6
 **************** THIS IS AN UNSUPPORTED DEVELOPER RELEASE *****************
 *                    (C) Copyright 1999 R. Clint Whaley
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *   1. Redistributions of source code must retain the above copyright
 *      notice, this list of conditions and the following disclaimer.
 *   2. Redistributions in binary form must reproduce the above copyright
 *      notice, this list of conditions, and the following disclaimer in the
 *      documentation and/or other materials provided with the distribution.
 *   3. The name of the University of Tennessee, the ATLAS group,
 *      or the names of its contributers may not be used to endorse
 *      or promote products derived from this software without specific
 *      written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 *
 */

#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
#include <assert.h>
#include "atlas_fopen.h"

#define ATL_NoBlock(iflag_) ( ((iflag_) | 32) == (iflag_) )

double GetAvg(int n, double tolerance, double *mflop)
{
   int i, j;
   double t0, tavg;
/*
 * Sort results, largest first
 */
   for (i=0; i != n; i++)
   {
      for (j=i+1; j < n; j++)
      {
         if (mflop[i] < mflop[j])
         {
            t0 = mflop[i];
            mflop[i] = mflop[j];
            mflop[j] = t0;
         }
      }
   }
/*
 * Throw out result if it is outside tolerance; rerun if two mflop not within
 * tolerance;  this code assumes n == 3
 */
   if (tolerance*mflop[1] < mflop[0])  /* too big a range in results */
   {
      if (tolerance*mflop[2] < mflop[1]) return(-1.0);
      tavg = (mflop[1] + mflop[2]) / 2.0;
   }
   else if (tolerance*mflop[2] < mflop[0]) tavg = (mflop[0] + mflop[1]) / 2.0;
   else tavg = (mflop[0] + mflop[1] + mflop[2]) / 3.0;

   return(tavg);
}

int GetL1CacheSize()
{
   FILE *L1f;
   int L1Size;

   if (!FileExists("res/L1CacheSize"))
   {
      assert(system("make res/L1CacheSize\n") == 0);
   }
   L1f = fopen("res/L1CacheSize", "r");
   assert(L1f != NULL);
   fscanf(L1f, "%d", &L1Size);
   fclose(L1f);
   fprintf(stderr, "\n      Read in L1 Cache size as = %dKB.\n",L1Size);
   return(L1Size);
}


void emit_r1head(char pre, double l1mul, int flag, int mu, int nu)
{
   char ln[256];
   sprintf(ln, "./xemit_r1h -p %c -l %f -f %d -x %d -y %d\n",
           pre, l1mul, flag, mu, nu);
   assert(system(ln) == 0);
}

int r1tstcase(char pre, char *r1nam)
{
   char ln[256];

   sprintf(ln, "make %cr1tstcase r1rout=%s\n", pre, r1nam);
   return(system(ln));
}

double r1case(char pre, char *r1nam, int flag, int mu, int nu, int cas,
              double l1mul)
{
   char fnam[128], ln[128];
   double mf, mfs[3];
   FILE *fp;

   if (ATL_NoBlock(flag)) sprintf(fnam, "res/%cger1_%d_0", pre, cas);
   else sprintf(fnam, "res/%cger1_%d_%d", pre, cas, (int)(100.0*l1mul));
   if (!FileExists(fnam))
   {
      sprintf(ln,
              "make %cr1case r1rout=%s cas=%d l1mul=%f iflag=%d xu=%d yu=%d\n",
              pre, r1nam, cas, l1mul, flag, mu, nu);
      fprintf(stderr, "%s", ln);
      if (system(ln)) return(-1.0);  /* won't compile here */
   }
   fp = fopen(fnam, "r");
   assert(fp);
   assert(fscanf(fp, " %lf %lf %lf", mfs, mfs+1, mfs+2) == 3);
   fclose(fp);
   mf = GetAvg(3, 1.20, mfs);
   if (mf == -1.0)
   {
      fprintf(stderr,
"\n\n%s : VARIATION EXCEEDS TOLERENCE, RERUN WITH HIGHER REPS.\n\n", fnam);
      sprintf(ln, "rm -f %s\n", fnam);
      system(ln);
      exit(-1);
   }
   return(mf);
}

void CreateSum(char pre, double l1mul, char *r1nam, char *auth, int flag,
               int mu, int nu, double mf)
{
   char fnam[64];
   FILE *fp;

   sprintf(fnam, "res/%cR1RES", pre);
   fp = fopen(fnam, "w");
   assert(fp);
   fprintf(fp, "%d %d %d %.2f %.2f %s \"%s\"\n", flag, mu, nu, l1mul, mf,
           r1nam, auth);
   fclose(fp);
}

void ReadSum(char pre, double *l1mul, char *r1nam, char *auth, int *flag,
             int *mu, int *nu, double *mf)
{
   char fnam[64];
   FILE *fp;

   sprintf(fnam, "res/%cR1RES", pre);
   fp = fopen(fnam, "r");
   assert(fp);
   assert(fscanf(fp, " %d %d %d %lf %lf %s \"%[^\"]",
                 flag, mu, nu, l1mul, mf, r1nam, auth) == 7);
   fclose(fp);
}

void r1install(char pre, char *r1nam, char *auth, double l1mul, int flag,
               int mu, int nu)
{
   char ln[128];
   double mf;

   mf = r1case(pre, r1nam, flag, mu, nu, (int)(l1mul*100), l1mul);
   emit_r1head(pre, l1mul, flag, mu, nu);
   sprintf(ln, "make %cinstall r1rout=%s", pre, r1nam);
   fprintf(stderr, "%s", ln);
   assert(system(ln) == 0);
   CreateSum(pre, l1mul, r1nam, auth, flag, mu, nu, mf);
}

double FindL1Mul(char pre, int cas, char *r1nam, int flag, int mu, int nu)
{
   double low = .5, high = 1.0;
   double mflow, mfhigh;
   int ilow, ihigh;

   if (ATL_NoBlock(flag)) flag -= 32;  /* always actually block these times */
   do
   {
      ilow = (low  * 100.0);
      ihigh = (high * 100.0);
      mflow  = r1case(pre, r1nam, flag, mu, nu, cas, low);
      mfhigh = r1case(pre, r1nam, flag, mu, nu, cas, high);
      fprintf(stdout, "      %.2f%% %.2fMFLOP  ---  %.2f%% %.2fMFLOP\n",
              low*100.0, mflow, high*100.0, mfhigh);
      if (mflow < mfhigh) low += 0.5*(high-low);
      else high -= 0.5 * (high-low);
   }
   while (ihigh-ilow);
   fprintf(stdout, "\n\nBEST %% of L1 cache: %.2f\n", low*100.0);
   return(low);
}

int ConfirmBlock(char pre, char *r1nam, int flag, int mu, int nu, int cas,
                 double l1mul)
{
   int bflag;
   double mfblock, mfnoblock;

   if ( ATL_NoBlock(flag) )
   {
      bflag = flag - 32;
      mfblock   = r1case(pre, r1nam, bflag, mu, nu, cas, l1mul);
      mfnoblock = r1case(pre, r1nam,  flag, mu, nu, cas, l1mul);
      fprintf(stdout, "\nWith blocking=%lf, without=%lf\n\n",
              mfblock, mfnoblock);
      if (mfblock >= mfnoblock) return(bflag);
   }
   return(flag);
}

void GetCases(char pre, int *ncases, char ***r1nams, char ***auths,
              int **flags, int **mus, int **nus)
{
   int n, i;
   char fnam[64], ln[128];
   char **r1nam, **auth;
   int *flag, *mu, *nu;
   FILE *fp;

   sprintf(fnam, "../CASES/%ccases.dsc", pre);
   fp = fopen(fnam, "r");
   assert(fp);

   assert(fgets(ln, 128, fp));
   assert(sscanf(ln, " %d", &n) == 1);
   assert(n < 100 && n > 0);

   r1nam = malloc(n * sizeof(char*));
   auth = malloc(n * sizeof(char*));
   assert(r1nam && auth);
   for (i=0; i < n; i++)
   {
      assert(r1nam[i] = malloc(64*sizeof(char)));
      assert(auth[i] = malloc(64*sizeof(char)));
   }
   mu = malloc(n * sizeof(int));
   nu = malloc(n * sizeof(int));
   flag = malloc(n * sizeof(int));
   assert(mu && nu && flag);
   for (i=0; i < n; i++)
   {
      assert(fgets(ln, 128, fp));
      assert(sscanf(ln, " %d %d %d %s \"%[^\"]",
                    flag+i, mu+i, nu+i, r1nam[i], auth[i]) == 5);
      assert(mu[i] >= 0 && nu[i] >= 0 && r1nam[i][0] != '\0');
   }
   fclose(fp);
   *ncases = n;
   *r1nams = r1nam;
   *auths = auth;
   *flags = flag;
   *mus = mu;
   *nus = nu;
}

void RunCases(char pre)
{
   char **r1nams, **auths;
   int ncases, *flags, *mus, *nus;
   int i, imax=0;
   double l1mul, mf, mfmax=0.0;

   GetCases(pre, &ncases, &r1nams, &auths, &flags, &mus, &nus);

   for (i=0; i < ncases; i++)
   {
      mf = r1case(pre, r1nams[i], flags[i], mus[i], nus[i], i+1, .75);
      fprintf(stdout, "%s : %.2f\n", r1nams[i], mf);
      if (mf > mfmax)
      {
         if (r1tstcase(pre, r1nams[i]) == 0) /* make it pass tester */
         {
            mfmax = mf;
            imax = i+1;
         }
      }
   }

   assert(imax);
   imax--;
   l1mul = FindL1Mul(pre, imax+1, r1nams[imax], flags[imax], mus[imax],
                     nus[imax]);
   flags[imax] = ConfirmBlock(pre, r1nams[imax], flags[imax], mus[imax],
                              nus[imax], imax+1, l1mul);
   fprintf(stdout, "\nBEST: %s, case %d, mu=%d, nu=%d; at %.2f MFLOPS\n\n",
           r1nams[imax], imax+1, mus[imax], nus[imax], mfmax);
   CreateSum(pre, l1mul, r1nams[imax], auths[imax], flags[imax],
             mus[imax], nus[imax], mfmax);

   for (i=0; i < ncases; i++)
   {
      free(r1nams[i]);
      free(auths[i]);
   }
   free(r1nams);
   free(auths);
   free(flags);
   free(mus);
   free(nus);
}

void GoToTown(pre)
{
   char r1nam[128], auth[128], ln[128];
   int flag, mu, nu;
   double l1mul, mf;
   FILE *fp;

   sprintf(ln, "res/%cR1RES", pre);
   if (!FileExists(ln))
   {
      RunCases(pre);
      fp = fopen(ln, "r");
      assert(fp);
   }
   else /* if default does not pass tester, rerun cases */
   {
      fp = fopen(ln, "r");
      ReadSum(pre, &l1mul, r1nam, auth, &flag, &mu, &nu, &mf);
      if (r1tstcase(pre, r1nam) != 0)
      {
         fclose(fp);
         assert(remove(ln) == 0);
         RunCases(pre);
         fp = fopen(ln, "r");
         assert(fp);
      }
   }
   fclose(fp);
   ReadSum(pre, &l1mul, r1nam, auth, &flag, &mu, &nu, &mf);
   fprintf(stdout, "\nBEST: %s, mu=%d, nu=%d; at %.2f MFLOPS\n\n",
           r1nam, mu, nu, mf);
   r1install(pre, r1nam, auth, l1mul, flag, mu, nu);
}

void PrintUsage(char *fnam)
{
   fprintf(stderr, "USAGE: %s [-p <d/s/z/c>]\n", fnam);
   exit(-1);
}

void GetFlags(int nargs, char **args, char *pre)
{
   char ctmp;
   int i;
   *pre = 'd';
   for (i=1; i < nargs; i++)
   {
      if (args[i][0] != '-') PrintUsage(&args[0][0]);
      switch(args[i][1])
      {
      case 'p':
         ctmp = args[++i][0];
         ctmp = tolower(ctmp);
         if (ctmp == 's' || ctmp == 'd' || ctmp == 'c' || ctmp == 'z')
            *pre = ctmp;
         else PrintUsage(&args[0][0]);
         break;
      default:
         fprintf(stderr, "Unknown flag : %s\n", args[i]);
         PrintUsage(&args[0][0]);
      }
   }
}

main(int nargs, char **args)
{
   char pre;
   GetFlags(nargs, args, &pre);
   GoToTown(pre);
   exit(0);
}
