extern int _MSG_;		/* debug messages */

#include <stdio.h>
#include <assert.h>
#include <math.h>
#include "pblas0.h"
#include "import.h"
#include "decide.h"

#define CUBE(x) ((x)*(x)*(x))

/*
 * Assume flist is already sorted in decreasing order of size
 */

static double TaskPack ( int P, int len, double * time );
static double TaskTime ( int N );
static double DataParallelTime ( int N, int P );

#ifdef Add_
#define dlamch dlamch_
#endif

extern double dlamch();

int DecideSwitch ( int P, int fsize, Task * flist )
/* 
 * Returns 0 if a data parallel task is on the list;
 * a nonzero return means that it is more efficient to process
 * the remaining tasks in task parallel mode.
 */
{
  int i, minIndex;
  double minTotal;
  double * prefix = MALLOC((1+fsize) * sizeof(double));
  double * suffix = MALLOC((1+fsize) * sizeof(double));
  double * total =  MALLOC((1+fsize) * sizeof(double));

/*
static int nt = 0;
return !((++nt)%4);
*/

  for ( prefix[0] = 0.0, i = 1; i <= fsize; i++ ) {
    const int N = flist[i-1].key;
    prefix[i] = prefix[i-1] + DataParallelTime(N,P);
    suffix[i-1] = TaskTime(N);
  }
  suffix[fsize] = 0.0;
  for ( i = 0; i <= fsize; i++ )
    suffix[i] = TaskPack(P, fsize+1-i, suffix+i);

  for ( i = 0; i <= fsize; i++ )
    total[i] = prefix[i] + suffix[i];

  for ( minIndex=0, minTotal=dlamch("Overflow"), i=0; i <= fsize; i++ ) {
    if ( total[i] < minTotal ) {
      minTotal = total[i];
      minIndex = i;
    }
  }

  FREE(total), FREE(suffix), FREE(prefix);
  return minIndex == 0;
}


static double TaskTime(int N)
/*
 * Models the time for a serial solve (LAPACK's DGEEV/DGEES)
 */
{
  double dn = N;
  return 7.359e-7 * dn * dn * dn  +  6.23e-3;
}


static double DataParallelTime(int N, int P)
/*
 * Models the time to split the matrix and solve each peice (half).
 */
{
  const double sigma = 6e4;
  const double iters = 15.0;
  const double dn = N, dp = P;
  const double fN = iters * ( 8e-8 * dn * dn * dn + 2.45e-1 );
  return fN * ( 1.0 / dp  + sigma / dn / dn );
}


static double TaskPack(int P, int len, double *time)
{
  int i;
  double maxLoad;
  double * load = MALLOC(P * sizeof(double));

  for ( i = 0; i < P; i++ )
    load[i] = 0.0;

  for ( i = 0; i < len; i++ ) {
    int tryProc, minProc=0;
    double minLoad=dlamch("Overflow");

    for ( tryProc=0; tryProc < P; tryProc++ )
      if ( load[tryProc] < minLoad ) {
	minLoad = load[tryProc];
	minProc = tryProc;
      }
      
    load[minProc] += time[i];
  }

  for ( i = 0, maxLoad = 0.0; i < P; i++ )
    if ( maxLoad < load[i] )
      maxLoad = load[i];

  FREE(load);
  return maxLoad;
}
