/*********************************************************************/
/*        LINEAR TREE for Supervised Learning                        */
/*        Versao 1.0 (10/12/1997)                                    */
/*        Developed by: Joao Gama                                    */
/*                LIACC - Uni.do Porto                               */
/*                jgama@ncc.up.pt                                    */
/*-------------------------------------------------------------------*/
/*   FILE:classify.c                                                 */  
/*********************************************************************/
#include <stdio.h>
#include <stdlib.h>
#include "Ci_instances.h"
#include "BuildTree.h"
#include "Combine.h"
#include "tree.h"
#include "discrim.h"
#include "utils.h"
#include "externs.i"

static double *classify_example(DomainInfo *domain, Tree *tree, CiExample *exemplo);
static double *minimize_cost(double **cost, double *distrib, int nrcla);
static int dminimize(double *class_freq, int nr_class);
/*************************************/
/*      Classify a Test Set          */
/*************************************/
int **classify\
(CiDs *ds, Tree *tree, double **cost, int *errados)
{
  register int class;
  int **mconf = NULL;
  long i, j;
  double *distrib;

  *errados = 0;
  if (Ci_NrExs(ds)) {
    mconf = imatrix(1, Ci_NrClasses(ds->domain), 1, Ci_NrClasses(ds->domain));
    for(i = 1; i <= Ci_NrExs(ds); i++) {
      if ((distrib = classify_example(ds->domain, tree, Ci_Example(ds, i))) != NULL) {
        if (cost != NULL) {
          distrib = minimize_cost(cost, distrib, Ci_NrClasses(ds->domain));
          class = dminimize(distrib, Ci_NrClasses(ds->domain));
        }
        else class = dmajority(distrib, Ci_NrClasses(ds->domain));
        VERBOSE(3) {
	  printf("Example %ld Observed %s Classified %s \t[ ", 
		 i,  LblValId(ds->domain, 1+NrAttrs(ds->domain),Ci_Classe(Ci_Example(ds, i))), 
		 LblValId(ds->domain, 1+NrAttrs(ds->domain), class));  
          for(j = 1; j <= Ci_NrClasses(ds->domain); j++)
            printf("%.3f ", distrib[j]);
          printf("]\n");
        }
        if (class != Ci_Classe(Ci_Example(ds, i))) ++(*errados);
        if (mconf) ++mconf[Ci_Classe(Ci_Example(ds, i))][class];
      }
      else
        fprintf(stderr, "\nExample %ld Observed %d Not Classified", 
		i, Ci_Classe(Ci_Example(ds, i)));
    }
  }
  return mconf;
}

static double *classify_example(DomainInfo *domain, Tree *tree, CiExample *exemplo)
{
  register int j, k, att, val;
  char dummy[128];
  double *d;
  Data *data;
  static double *di = NULL;

  if (!di) di = dvector(1, Ci_NrClasses(domain));
  
  data = DATA(tree);
  att = SPLIT_ATT(data);
  switch(tree->tipo) {
  case leaf:
    return ACUM_CLASS_DIST(data);
    break;
  case split_discrete:
    if (COEFICIENTS(data))
      project_example(domain, exemplo, COEFICIENTS(data), NR_ATT(data), ROWS(data));
    if (NormalVal(exemplo->instance[att])) {
      val = DValAttEx(exemplo->instance, att);
      if (TypeAttr(domain, att) == integer) {
	sprintf(dummy, "%d", val);
	val = IdValLbl(domain, att, dummy);
      }
      d = classify_example(domain, DESCENDENT(tree,val), exemplo);
      if (!d) return ACUM_CLASS_DIST(data);
      return d;
    }
    else {
      for(j = 1; j <= Ci_NrClasses(domain); j++) di[j] = 0.0;
      for(k = 1; k <= NR_DESCENDENTS(tree); k++) {
	if ((d = classify_example(domain, DESCENDENT(tree, k), exemplo)) != NULL)
	  for(j = 1; j <= Ci_NrClasses(domain); j++) di[j] += d[j];
      }
      for(j = 1; j <= Ci_NrClasses(domain); j++) di[j] /= tree->nr_descendents;
      return di;
    }
    break;
  case split_continuous:
  case split_linear:
    if (COEFICIENTS(data))
      project_example(domain, exemplo, COEFICIENTS(data), NR_ATT(data), ROWS(data));
    if (NormalVal(exemplo->instance[att])) {
      if (CValAttEx(exemplo->instance, att) <= SPLIT_VALUE(data))
	d = classify_example(domain, DESCENDENT(tree, 1), exemplo);
      else
	d = classify_example(domain, DESCENDENT(tree, 2), exemplo);

      if (!d) return ACUM_CLASS_DIST(data);;
      return d;
    }
    else {
      d = classify_example(domain, DESCENDENT(tree, 1), exemplo);
      for(j = 1; j <= Ci_NrClasses(domain); j++) di[j] = d ? d[j] : 0.0;
      d = classify_example(domain, DESCENDENT(tree, 2), exemplo);
      for(j = 1; j <= Ci_NrClasses(domain); j++) { 
	di[j] += d ? d[j] : 0.0; 
	di[j] /= 2.0; 
      }
      return di;
    }
    break;
  }
}

static double *minimize_cost(double **cost, double *distrib, int nrcla)
{
  int i, j;
  static double *COSTS = NULL;

  if (COSTS == NULL) COSTS = dvector(1, nrcla); 
  for(i = 1; i <= nrcla; i++) {
    COSTS[i] = 0.0;
    for(j = 1; j <= nrcla; j++) {
      COSTS[i] += distrib[j] * cost[j][i];
    }
  }
  return COSTS;
}

static int dminimize(double *class_freq, int nr_class)
{
  register int i, j = 1;

  for(i = 1; i <= nr_class; i++) 
    if (class_freq[i] < class_freq[j]) j = i;
  
  return j;
}
