/*********************************************************************/
/*        LINEAR TREE for Supervised Learning                        */
/*        Versao 1.0 (10/12/1997)                                    */
/*        Developed by: Joao Gama                                    */
/*                LIACC - Uni.do Porto                               */
/*                jgama@ncc.up.pt                                    */
/*-------------------------------------------------------------------*/
/*   FILE:discrim.c     Release 1.0    3/10/97                       */  
/*********************************************************************/
#include <stdio.h>
#include <stdlib.h>
#include <values.h>
#include <math.h>
#include "Ci_instances.h"
#include "discrim.h"
#include "utils.h"
#include "externs.i"

#define HALF -0.5
#define TRESH_LIM 1.0e-6
#define SQR(a)          ((a)*(a))
#define SIGN(a,b)	((b) >= 0.0 ? fabs(a) : -fabs(a))
#define MAX(a,b)	((a) > (b)  ? (a) : (b))
#define ABS(x)          ((x) > 0.0 ? (x) : -1.0 * (x))
#define MIN(a,b)	((a) > (b)  ? (b) : (a))

static double TRESH = 1.0e-6; 
static int *CLASS_USED = NULL;
static int NR_CLASS;

static int *CONT_ATT = NULL;
static int NR_CONT_ATT = 0;
static int NR_NON_CONT = 0;
#define POSITION(D, att) (att > NrAttrs(D) ? att - NR_NON_CONT : CONT_ATT[att])
/************************************************/
/*    Prototipos para Funcoes Locais            */
/************************************************/
static void free_covar(double ***mcov, int nr_cl, int nr_at);
static void projection(DomainInfo *domain, AttrVal *instance, double *mcoef, int att, int nr_att);
static double ***init_matriz_cov(int nr_cl, int nr_att);
static double ***compute_covars(CiDs *ds, double ***STATIS, double *class_freq, long int low, long int high, int nr_att, int nr_cl);
static double **coeficientes(DomainInfo *domain, double ***STATIS, double **MINV, double *class_freq, int nratt, int nrcla);
static double **matmul(DomainInfo *domain, double ***STATIS, double **MINV, int nratt, int nrcla);
static double *inner(DomainInfo *domain, double ***STATIS, double **beta, int nratt, int nrcla, double *class_freq);
static void corrige_beta(double **beta, int nratt, int nrcla);

static double **svd_inv(double **a, int m, int n);
static void svbksb(double **u, double *w, double **v, int m, int n, double *b, double *x);
int svdcmp(double **a, int m, int n, double *w, double **v);
static double pythag(double a, double b);
/* **********************************************/
/*           Funcoes Publicas                   */
/************************************************/
int Nr_Att_Non_Cont()
{
  return NR_NON_CONT;
}

double **discriminant\
(CiDs *ds, long int low,long int high, double ***STATIS, double *class_freq, int nr_att, int nr_cl, int *hidden)
{
  register int i;
  double **coef = NULL, **MINV = NULL, ***MCOV = NULL;

  if (!CLASS_USED) CLASS_USED = ivector(1, Ci_NrClasses(ds->domain));
  for(NR_CLASS = 0, i = 1; i <= Ci_NrClasses(ds->domain); i++)
    CLASS_USED[i] = (class_freq[i] > KH * nr_att) ? ++NR_CLASS : 0;

  if (!CONT_ATT) {
    CONT_ATT = ivector(1, NrAttrs(ds->domain));
    for(i = 1; i <= NrAttrs(ds->domain); i++)
      switch(CiTypeAttr(ds->domain, i)) {
      case integer:
      case ordered:
      case continuous:
	++NR_CONT_ATT;
	CONT_ATT[i] = NR_CONT_ATT; 
	break;
      case nominal:
	++NR_NON_CONT;
	break;
      }
  }

  *hidden = NR_CLASS - 1;
  if (NR_CLASS > 1) {
    MCOV = compute_covars(ds, STATIS, class_freq, low, high, nr_att, nr_cl);
    if ((MINV = svd_inv(MCOV[0], nr_att - NR_NON_CONT, nr_att - NR_NON_CONT)) != NULL) {
      coef = coeficientes(ds->domain, STATIS, MINV, class_freq, nr_att - NR_NON_CONT, nr_cl);
      free_dmatrix(MINV, 1, nr_att - NR_NON_CONT, 1, nr_att - NR_NON_CONT);
    }
    else
      *hidden = 0;
    free_covar(MCOV, NR_CLASS, nr_att- NR_NON_CONT);
  }
  return coef;
}

void apply_discriminant\
(CiDs *ds, double **coeficientes, long int Low, long int High, int nr_att)
{
  register int hidden;
  long int k;

  hidden = NR_CLASS - 1;
  for(k = Low; k <= High; k++) 
    project_example(ds->domain, Ci_Example(ds, k), coeficientes, nr_att, hidden);
}

void project_example\
(DomainInfo *domain, CiExample *exemplo, double **coeficientes, int nr_att, int hidden)
{
  register int i;
  double probmax = MINFLOAT, sumprob = 0.0;
  AttrVal *instance;
  
  if (Ci_ReBuildInstance(exemplo, 1 + hidden + nr_att)) {
    instance = exemplo->instance;
    for(i = 1; i <= hidden; i++) 
      projection(domain, instance, coeficientes[i], nr_att + i, nr_att);
    
    if (hidden > 1) {
      for(i = 1; i <= hidden; i++)
	if (probmax < CValAttEx(instance, i + nr_att))
	  probmax = CValAttEx(instance, i + nr_att);
      
      for(i = 1; i <= hidden; i++) {
	CValAttEx(instance, i + nr_att) -= probmax;
	if (CValAttEx(instance, i + nr_att) < -25.0) 
	  CValAttEx(instance, i + nr_att) = -25.0;
	CValAttEx(instance, i + nr_att) = exp(CValAttEx(instance, i + nr_att));
	
	sumprob += CValAttEx(instance, i + nr_att);
      }
      for(i = 1; i <= hidden; i++) 
	CValAttEx(instance, i + nr_att) /= sumprob;
    }
  }
}
/************************************************/
/*           Funcoes Privadas                   */
/************************************************/
static void projection\
(DomainInfo *domain, AttrVal *instance, double *mcoef, int att, int nr_att)
{
  register int i = 2, j;

  CValAttEx(instance, att) = mcoef[1];
  for(j = 1; j <= nr_att; j++){
    if (NormalVal(instance[j])) {
      switch(CiTypeAttr(domain, j)) {
      case integer:
      case ordered:
	CValAttEx(instance, att) += (DValAttEx(instance, j) * mcoef[i++]);
	break;
      case continuous:
	CValAttEx(instance, att) += (CValAttEx(instance, j) * mcoef[i++]);
	break;
      }
    }
  }
}
/****************************************/
/*   Covariancias FUNCTIONS             */
/****************************************/
static double ***init_matriz_cov(int nr_cl, int nr_att)
{
  register int i;
  double ***tz = NULL;

  tz = (double ***) malloc((nr_cl + 1) * sizeof(double **));
  for(i = 0; i <= nr_cl; i++) 
    tz[i] = dmatrix(1, nr_att, 1, nr_att);
  return tz;
}

static void free_covar(double ***mcov, int nr_cl, int nr_att)
{
  register int i;

  for(i = 0; i <= nr_cl; i++) 
    free_dmatrix(mcov[i], 1, nr_att, 1, nr_att);

  free(mcov);
}

static double ***compute_covars\
(CiDs *ds, double ***STATIS, double *class_freq, long int low, long int high, int nr_att, int nr_cl)
{
  register int att, atti, pos, pos1, classe, cl;
  register long int i;
  double x = 0.0, y = 0.0, mediax = 0.0, mediay = 0.0, weight, ***MCOV;
  AttrVal *instance;

  MCOV = init_matriz_cov(NR_CLASS, nr_att - NR_NON_CONT);

  for(i = low; i <= high; i++) {
    instance = Ci_AttVal(ds, i);
    classe = Ci_Classe(Ci_Example(ds, i));
    weight = Ci_Weight(Ci_Example(ds, i));
    if ((cl = CLASS_USED[classe]) > 0) {
      for(att = 1; att <= nr_att; att++) {
	if ((pos = POSITION(ds->domain, att)) > 0) {
	  if (!NormalVal(instance[att])) {
	    if (CiTypeAttr(ds->domain, att) == continuous) {
	      x =  STATIS[att][classe][1];
	      mediax = STATIS[att][0][1];
	    }
	    else {
	      x = STATIS[att][classe][1+NValsAttr(ds->domain, att)];
	      mediax = STATIS[att][0][1+NValsAttr(ds->domain, att)];
	    }
	  }
	  else {
	    switch (CiTypeAttr(ds->domain, att)) {
	    case continuous:
	      x = CValAttEx(instance, att);
	      mediax = STATIS[att][classe][1];
	      break;
	    case integer:
	    case ordered:
	      x = DValAttEx(instance, att);
	      mediax = STATIS[att][classe][1+NValsAttr(ds->domain, att)];
	      break;
	    }
	  }
	  
	  for(atti = att; atti <= nr_att; atti++) {
	    if ((pos1 = POSITION(ds->domain, atti)) > 0) {
	      if (!NormalVal(instance[atti])) {
		if (CiTypeAttr(ds->domain, atti) == continuous) {
		  y =  STATIS[atti][classe][1];
		  mediay = STATIS[atti][0][1];
		}
		else {
		  y = STATIS[atti][classe][1+NValsAttr(ds->domain, atti)];
		  mediay = STATIS[atti][0][1+NValsAttr(ds->domain, atti)];
		}
	      }
	      else {
		switch (CiTypeAttr(ds->domain, atti)) {
		case continuous:
		  y = CValAttEx(instance, atti);
		  mediay = STATIS[atti][classe][0];
		  break;
		case integer:
		case ordered:
		  y = DValAttEx(instance, atti);
		  mediay = STATIS[atti][classe][1+NValsAttr(ds->domain, atti)];
		  break;
		}	
	      }
	      MCOV[cl][pos][pos1] += (x - mediax) * (y - mediay) * weight;
	    }
	  }
	}
      }
    }
  }

  for(i = 1; i <= Ci_NrClasses(ds->domain); i++) 
    if ((cl = CLASS_USED[i]) > 0) {
      for(att = 1; att <= nr_att - NR_NON_CONT; att++) 
	for(atti = att; atti <= nr_att- NR_NON_CONT; atti++) {
	  MCOV[cl][att][atti] /= (class_freq[i] -1.0);
	  MCOV[cl][atti][att] = MCOV[cl][att][atti];
	  MCOV[0][att][atti] += (MCOV[cl][att][atti] * (class_freq[i] -1.0));
	}
    }
  for(att = 1; att <= nr_att - NR_NON_CONT; att++) 
    for(atti = att; atti <= nr_att- NR_NON_CONT; atti++) {
      MCOV[0][att][atti] /= (class_freq[0] - (double) NR_CLASS);
      MCOV[0][atti][att] = MCOV[0][att][atti];
    }
  return MCOV;
}

static double **coeficientes\
(DomainInfo *domain, double ***STATIS, double **MINV, double *class_freq, int nratt, int nrcla)
{
  int i, j;
  double **beta, *alfa;
  double **coef = dmatrix(1, NR_CLASS - 1, 1, nratt+1);
  
  beta = matmul(domain, STATIS, MINV, nratt, nrcla);
  alfa = inner(domain, STATIS, beta, nratt, nrcla, class_freq);
  /**************************************************************/
  corrige_beta(beta, nratt, nrcla);
  
  for(i = 1; i < NR_CLASS; i++) {
    coef[i][1] = alfa[i];
    for(j = 1; j <= nratt; j++)
      coef[i][j+1] = ABS(beta[i][j]) < TRESH_LIM ? 0.0 : beta[i][j];
  }
  free_dvector(alfa, 1, NR_CLASS);
  free_dmatrix(beta, 1, NR_CLASS, 1, nratt);
  return coef;
}

static double **matmul\
(DomainInfo *domain, double ***STATIS, double **MINV, int nratt, int nrcla)
{
  register int att, att1, cl, virtual_class;
  double temp, **beta;
  
  beta = dmatrix(1, NR_CLASS, 1, nratt);

  for(att = 1; att <= nratt; att++) {
    for(cl = 1; cl <= nrcla; cl++) {
      temp = 0.0;
      virtual_class = CLASS_USED[cl];
      if (virtual_class) {
	for(att1 = 1; att1 <= nratt; att1++) {
	  if (CiTypeAttr(domain, att1) == continuous) {
	    temp += MINV[att][att1] * STATIS[att1][cl][1];
	  }
	  else {
	    temp += MINV[att][att1] * STATIS[att1][cl][1+NValsAttr(domain, att1)];
	  }
	}
	beta[virtual_class][att] = temp;
      }
    }
  }
  return beta;
}

static double *inner\
(DomainInfo *domain, double ***STATIS, double **beta, int nratt, int nrcla, double *class_freq)
{
  register int att, cl, virtual_class;
  double temp, *alfa;
 
  alfa = dvector(1, NR_CLASS);

  for(cl = 1; cl <= nrcla; cl++) {
    virtual_class = CLASS_USED[cl];
    if (virtual_class) {
      temp = 0.0;
      for(att = 1; att <= nratt; att++) {
	if (CiTypeAttr(domain, att) == continuous) {
	  temp += beta[virtual_class][att] * STATIS[att][cl][1];
	}
	else  {
	  temp += beta[virtual_class][att] * STATIS[att][cl][1+NValsAttr(domain, att)];
	}
      }
      alfa[virtual_class] = temp;
    }
  }

  for(cl = 1; cl <= nrcla; cl++) {
    virtual_class = CLASS_USED[cl];
    if (virtual_class) {
      alfa[virtual_class] =  HALF * alfa[virtual_class] + log((double) class_freq[cl] / class_freq[0]);
    }
  }
  for(cl = 1; cl <= nrcla; cl++)    alfa[cl] -= alfa[NR_CLASS];
  return alfa;
}

static void corrige_beta(double **beta, int nratt, int nrcla)
{
  register int i, j, cl;

  for(i = 1; i <= nrcla; i++) {
    cl = CLASS_USED[i];
    if (cl)
      for(j = 1; j <= nratt; j++)
	beta[cl][j] -= beta[NR_CLASS][j];
  }
}

/********************************************************/
/*	Singular Value Decomposition	A = U.W.Vt	*/
/*	Input 	a[1..m][1..n]				*/
/*	Output	a[1..m][1..n]	as U			*/
/*		w[1..n]					*/
/*		v[1..n][1..n]				*/
/********************************************************/
void svd_treshold(double tresh)
{
  TRESH = tresh;
}

int svdcmp(double **a, int m, int n, double *w, double **v)
{
  int flag, i, its, j, jj, k, l, nm;
  double anorm, c, f, g, h, s, scale, x, y, z, *rv1;

  rv1 = dvector(1, n);
  g = scale = anorm = 0.0;
  for(i=1;i<=n;i++) {
    l = i + 1;
    rv1[i]=scale * g;
    g = s = scale = 0.0;
    if (i <= m) {
      for(k=i;k<=m;k++)	scale += fabs(a[k][i]);
      if (scale) {
	for(k=i;k<=m;k++){
	  a[k][i] /= scale;
	  s += SQR(a[k][i]);
	}
	f = a[i][i];
	g = -SIGN(sqrt(s), f);
	h = f * g - s;
	a[i][i] = f- g;
	for(j=l;j<=n;j++) {
	  for(s=0.0, k= i;k<=m;k++)	s += a[k][i]*a[k][j];
	  f = s/h;
	  for(k=i;k<=m;k++)	a[k][j] += f*a[k][i];
	}
	for(k=i;k<=m;k++)	a[k][i] *= scale;
      }
    }
    w[i] = scale * g;
    g = s = scale = 0.0;
    if (i <= m && i != n) {
      for(k=l; k <=n ; k++)	scale += fabs(a[i][k]);
      if (scale) {
	for(k=l; k<=n; k++) {
	  a[i][k] /= scale;
	  s += SQR(a[i][k]);
	}
	f = a[i][l];
	g = -SIGN(sqrt(s), f);
	h = f * g - s;
	a[i][l] = f - g;
	for(k=l;k<=n;k++)	rv1[k]=a[i][k]/h;
	for(j=l;j<=m;j++) {
	  for(s=0.0, k=l; k<=n; k++)	s+=a[j][k]*a[i][k];
	  for(k=l;k<=n;k++)		a[j][k] += s*rv1[k];
	}
	for(k=l; k<=n; k++)	a[i][k] *= scale;
      }
    }
    anorm = MAX(anorm,(fabs(w[i])+fabs(rv1[i])));
  }
  for(i=n;i>=1;i--) {
    if(i < n) {
      if (g) {
	for(j=l; j<=n; j++)	v[j][i]=(a[i][j]/a[i][l])/g;
	for(j=l; j<=n; j++) {
	  for(s=0.0,k=l;k<=n;k++)	s += a[i][k]*v[k][j];
	  for(k=l;k<=n;k++)	v[k][j] += s*v[k][i];
	}
      }
      for(j=l;j<=n;j++)	v[i][j] = v[j][i] = 0.0;
    }
    v[i][i] = 1.0;
    g = rv1[i];
    l = i;
  }
  for(i=MIN(m,n);i>=1;i--) {
    l = i+1;
    g=w[i];
    for(j=l; j<=n; j++)	a[i][j] = 0.0;
    if(g) {
      g = 1.0 / g;
      for(j=l; j<=n; j++) {
	for(s=0.0,k=l; k <= m; k++)	s += a[k][i]*a[k][j];
	f = (s/a[i][i])*g;
	for(k=i; k<=m; k++)	a[k][j] += f*a[k][i];
      }
      for(j=i;j<=m;j++)	a[j][i] *= g;
    }
    else	for(j=i;j<=m;j++)	a[j][i] = 0.0;
    ++a[i][i];
  }
  for(k=n;k>=1;k--) {
    for(its=1; its<=30; its++) {
      flag = 1;
      for(l=k;l>=1;l--) {
	nm = l-1;
	if((double)(fabs(rv1[l])+anorm) == anorm) {
	  flag = 0;
	  break;
	}
	if ((double)(fabs(w[nm])+anorm) == anorm)	break;
      }
      if (flag) {
	c = 0.0;
	s = 1.0;
	for(i=l;i<=k;i++) {
	  f = s * rv1[i];
	  rv1[i] = c * rv1[i];
	  if ((double)(fabs(f)+anorm) == anorm) break;
	  g = w[i];
	  h = pythag(f,g);
	  w[i] = h;
	  h = 1.0 / h;
	  c = g*h;
	  s = -f*h;
	  for(j=1; j<=m; j++) {
	    y=a[j][nm];
	    z=a[j][i];
	    a[j][nm] = y*c+z*s;
	    a[j][i] = z*c-y*s;
	  }
	}
      }
      z=w[k];
      if (l == k) {
	if (z < 0.0) {
	  w[k] = -z;
	  for(j=1;j<=n;j++)	v[j][k] = -v[j][k];
	}
	break;
      }
      if(its == 30) {
	fprintf(stderr,"E)SVD: No Convergence\n");
	return 0;
      }
      x = w[l];
      nm = k-1;
      y = w[nm];
      g = rv1[nm];
      h = rv1[k];
      f = ((y-z)*(y+z)+(g-h)*(g+h))/(2.0*h*y);
      g = pythag(f, 1.0);	
      f = ((x-z)*(x+z)+h*((y/(f+SIGN(g,f)))-h))/x;
      c = s = 1.0;
      for(j=l; j<=nm; j++) {
	i = j+1;
	g = rv1[i];
	y = w[i];
	h = s*g;
	g = c*g;
	z = pythag(f,h);
	rv1[j] = z;
	c = f/z;
	s = h/z;
	f = x*c+g*s;
	g = g*c-x*s;
	h = y*s;
	y *= c;
	for(jj=1; jj<=n; jj++) {
	  x = v[jj][j];
	  z = v[jj][i];
	  v[jj][j] = x*c+z*s;
	  v[jj][i] = z*c-x*s;
	}
	z = pythag(f, h);
	w[j] = z;
	if (z) {
	  z = 1.0/z;
	  c = f*z;
	  s = h*z;
	}
	f = c*g+s*y;
	x = c*y-s*g;
	for(jj = 1; jj<=m; jj++) {
	  y = a[jj][j];
	  z = a[jj][i];
	  a[jj][j] = y*c+z*s;
	  a[jj][i] = z*c-y*s;
	}
      }
      rv1[l] = 0.0;
      rv1[k] = f;
      w[k] = x;
    }
  }
  free_dvector(rv1, 1, n);
  return 1;
}

static void svbksb(double **u, double *w, double **v, int m, int n, double *b, double *x)
{
  int i, j, jj;
  double s, *tmp, wmin, wmax = 0.0;
  
  for(j = 1; j <= n; j++)	if(w[j] > wmax)	wmax = w[j];
  wmin = wmax * TRESH;
  for(j = 1; j <= n; j++)	if(w[j] < wmin)	w[j] = 0.0;
  
  tmp = dvector(1, n);
  for(j = 1; j <= n; j++) {
    s = 0.0;
    if (w[j]) {
      for(i = 1; i <= m; i++)	s += u[i][j] * b[i];
      s /= w[j];
    }
    tmp[j] = s;
  }
  for(j = 1; j <= n; j++) {
    s = 0.0;
    for(jj = 1; jj <= n; jj++)	s += v[j][jj] * tmp[jj];
    x[j] = s;
  }
  free_dvector(tmp, 1, n);
}
/********************************************************/
/*	Inversao de matrizes por SVD	      	        */
/*	Input:	a[1..m][1..n]				*/
/*	Output	y[1..m][1..n]				*/
/********************************************************/
static double **svd_inv(double **a, int m, int n)
{
  double **y = NULL, **v, *w;
  double *col = dvector(1, n);
  double wmax, wmin;
  int i, j;
  
  v = dmatrix(1, n, 1, n);
  w = dvector(1, n);

  if (svdcmp(a, m, n, w, v)) {
    wmax = w[1];
    for(j = 1; j <= n; j++) if (w[j] > wmax) wmax = w[j];
    wmin = wmax * TRESH;
    for(j = 1; j <= n; j++) if (w[j] < wmin) w[j] = 0.0; 
    y = dmatrix(1, m, 1, n);
    for(j = 1; j <= m; j++) {
      for(i = 1; i <= n; i++)	col[i] = 0.0;
      col[j] = 1.0;
      svbksb(a, w, v, m, n, col, col);
      for(i = 1; i <= n; i++)	y[i][j] = col[i];
    }
    free_dvector(w, 1, n);
    free_dmatrix(v, 1, n, 1, n);
  }
  free_dvector(col, 1, n);
  return(y);
}

static double pythag(double a, double b)
{
  double absa, absb;
  absa = fabs(a);
  absb = fabs(b);
  if (absa > absb)      return(absa*sqrt(1.0+SQR(absb/absa)));
  return(absb == 0.0 ? 0.0 : absb*sqrt(1.0+SQR(absa/absb)));
}
