/*
  sc_Calculator.c
  simple calculator
  finucane@myri.com (David Finucane)
*/

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "insist.h"

#include "mt_Queue.h"
#include "sc_Calculator.h"
#include "mt_Cloud.h"

sc_Calculator::sc_Calculator () : mt_Calculator ()
{
  numRoots = 0;
  roots = 0;
  weight = 1;
  prune = 1;
  margin = 2;
  useNumbers = 0;
  saveRoutes = 0;
  routeTable = 0;
  useTable = 0;
  disjoin = 0;
  index = -1;
  root = 0;
  spread = 0;
}

sc_Calculator::~sc_Calculator ()
{
  if (roots)
    delete [] roots;
}

void sc_Calculator::cleanup ()
{
  mt_Calculator::cleanup ();
  if (routeTable)
  {
    delete routeTable;
    routeTable = 0;
  }
}

int sc_Calculator::initialize (mt_Graph*graph, mt_Node*root)
{
  this->root = 0;
  
  if (!mt_Calculator::initialize (graph, root))
    return 0;

  if (!useNumbers)
  {
    int r;
    
    if (numRoots && roots)
      r = renumber (numRoots, roots);
    else if (root)
      r = renumber (1, &root);
    else r = autoNumber ();

    insist (r);
  }
  
  if (saveRoutes)
    return makeRouteTable ();

  return 1;
  exception: return 0;
}

int sc_Calculator::makeRouteTable ()
{
  insist (this);
  insist (!routeTable);

  useTable = 0;
  
  routeTable = new mt_RouteTable ();
  insist (routeTable);
  return useTable = routeTable->fromCalculator (this, this);
  
  exception: return 0;
}

int sc_Calculator::initialize (char *mapFile)
{
  root = 0;
  
  if (!mt_Calculator::initialize (mapFile))
    return 0;
  
  if (numRoots && roots && !renumber (numRoots, roots))
    return 0;

 
  mt_Node*firstHost;
  firstHost = getHost (0);
  insist (firstHost);
  
  /*prevent lack of node numbering*/
  if (!numRoots && !useNumbers && !renumber (1, &firstHost))
    return 0;
  
  if (saveRoutes)
    return makeRouteTable ();
    
  return 1;
  exception: return 0;
}

mt_Node*sc_Calculator::newNode (int nodeType, char*name, char*type)
{
  mt_Route r;
  
  insist (this);

  switch (nodeType)
  {
    case mt_Node::HOST:
      return new sc_Host (name, type, maxRoutes);
    case mt_Node::SWITCH:
      return new sc_Switch (name, type);
    case mt_Node::CLOUD:
      return new mt_Cloud (name, type);
    default:
      insist (0);
  }
  exception: return 0;
}

int sc_Calculator::getMaxRoutes ()
{
  insist (this);

  return maxRoutes;

  exception: return 0;
}

int sc_Calculator::getNumRoutes (int from, int to)
{
  mt_Route r;
  
  insist (this);
  insist (from >= 0 && from < getNumHosts ());
  insist (to >= 0 && from < getNumHosts ());

  return maxRoutes;
  
  //return getRoute (from, to, 0, &r) ? ((sc_Host*) getHost (to))->getNumRoutes () : 0;
  
  exception: return 0;
}

int sc_Calculator::getNumRoutes (int from)
{
  return mt_Calculator::getNumRoutes (from);
}

void sc_Calculator::removeRoute (mt_Node*from, mt_Node*to, int weight)
{
  removeRoute (from, to, to->getRoute (0), weight);
}

void sc_Calculator::removeRoute (mt_Node*from, mt_Node*to, mt_Route*r, int weight)
{
  insist (this);
  insist (from);
  followRoute (from, to, r, -weight);
  exception: return;
}

void sc_Calculator::addRoute (mt_Node*from, mt_Node*to, mt_Route*r, int weight)
{
  insist (this);
  insist (from && to && r);
  followRoute (from, to, r, +weight);
  exception: return;
}

int sc_Calculator::isBetterRoute (mt_Node*from, mt_Node*to, mt_Route*r, int margin)
{
  insist (this);
  insist (from && to && r);
  
  if (r->getLength () - to->getRoute (0)->getLength() > margin)
    return 0;
    
  int oldLoad;
  oldLoad = followRoute (from, to, to->getRoute (0));
  int newLoad;
  newLoad = followRoute (from, to, r);
     
  return newLoad < oldLoad;
  exception: return 0;
}
 

int sc_Calculator::followRoute (mt_Node*from, mt_Node*to, mt_Route*r, int d)
{
  insist (this);
  int i;

  insist (from && to && from->isHost () && to->isHost ());
  
  insist (r);
  
  int count;
  count = r->getLength();
  char*hops;
  hops = r->getHops();
  mt_Node*n;
  n = from->getNode (0);
  
  int c;
  c = 0;
  int in;
  in = from->getOpposite (0);

  for (i = 0; i < count; i++)
  {
    insist (n);

    if (n->isSwitch ())
    {
      sc_Switch*s = (sc_Switch*) n;
      
      int p = in + hops[i];
      int o = s->getLoad (p);

      c += o;
      s->setLoad (p, o + d);

      in = s->getOpposite (p);
      n = s->getNode (p);
    }
  }
  insist (n == to);

  return c;
  exception: return 0;
}

void sc_Calculator::clearMarks ()
{
  insist (this);

  int numNodes;
  numNodes = getNumNodes ();
  
  for (int i = 0; i < numNodes; i++)
    (getNode (i))->setMark (mt_Node::_NONE);

  exception: return;
}


void sc_Calculator::clearLoads ()
{
  insist (this);

  int numSwitches;
  numSwitches = getNumSwitches ();
  
  for (int i = 0; i < numSwitches; i++)
    ((sc_Switch*) getSwitch (i))->clearLoads ();
  
  exception: return;
}

void sc_Calculator::reverseNegativeLoads ()
{
  insist (this);

  int numSwitches;
  numSwitches = getNumSwitches ();  
  
  for (int i = 0; i < numSwitches; i++)
    ((sc_Switch*) getSwitch (i))->reverseNegativeLoads ();
  
  exception: return;
}

void sc_Calculator::clearRoutes ()
{
  insist (this);

  int numHosts;
  numHosts = getNumHosts ();
  
  for (int i = 0; i < numHosts; i++)
    ((sc_Host*) getHost (i))->clearRoutes ();
  
  exception: return;
}

int sc_Calculator::getRoute (int from, int to, int routeIndex, mt_Route*route)  
{
  insist (this);
  insist (from >= 0 && from < getNumHosts ());
  insist (to >= 0 && to < getNumHosts ());
  insist (routeIndex >= 0 && route);

  if (useTable)
  {
    insist (routeTable);
    return routeTable->getRoute (from, to, routeIndex, route);
  }
  
  if (!root)
    clearLoads ();

  sc_Host*h;
  h = (sc_Host*) getHost (from);
  insist (h);

  if (saveRoutes)
  {
    if (root != h || index != routeIndex)
    {
      clearRoutes ();
  
      if (disjoin)
      {
	weight = -1;
	if (index != routeIndex)
	  reverseNegativeLoads ();
      }
      index = routeIndex;
      
      if (!computeRoutes (root = h) || (spread && !spreadRoutes (root)))
	return 0;
    }
    routeIndex = 0;
  }
  else
  {
    if (root != h && (!computeManyRoutes (root = h) || (spread && !spreadRoutes (root))))
    {
      root = 0;
      return 0;
    }
  }
  
  h = (sc_Host*) getHost (to);
  insist (h);
  
  if (routeIndex >= h->getNumRoutes ())
    return 0;
  
  mt_Route*r;
  r = h->getRoute (routeIndex);
  insist (r);
  
  *route = *r;
  
  return 1;
  exception: return 0;
}

int sc_Calculator::computeManyRoutes (sc_Host*from)
{
  insist (this && from);

  clearRoutes ();
  
  for (int j = 0; j < maxRoutes; j++)
  {  
    if (!computeRoutes (from))
      return 0;
  }
  
  return 1;
  exception: return 0;
}

int sc_Calculator::spreadRoutes (sc_Host*from)
{
  insist (this);
  
  for (int i = 0; i < getNumHosts (); i++)
  {
    sc_Host*to = (sc_Host*)getHost (i);
    insist (to);
    
    for (int j = 0; j < to->getNumRoutes (); j++)
    {
      mt_Route r;
      if (!getSpreadRoute (from->getTypeIndex (), i, j, &r) || !to->replaceRoute (j, &r))
	return 0;
    }
  }
  return 1;
  exception: return 0;
}


int sc_Calculator::computeRoutes (sc_Host*from)
{
  mt_Queue queue;
  
  insist (this);
  insist (from);
  
  insist (from->isHost ());

  mt_Node*first;
  first = from->getNode (0);
  if (!first) return 1;

  if (first->isHost ())
  {
    mt_Route loopback;
    from->setRoute (0, mt_Route ());
    first->setRoute (0, mt_Route ());
    return from == first ? 1 : 2;
  }
  
  clearMarks ();

  first->setMark (mt_Node::UP);
  first->setRoute (0, mt_Route ());

  int in;
  in = from->getOpposite (0);
  insist (in >= 0 && in < sc_Switch::NUM_PORTS);

  first->setIn (in);
  queue.put (first);

  int hostsFound;
  hostsFound = 0;
  
  mt_Node*m;
  
  //printFormat ("computing routes from %s", from->getName ());
  //printFormat ("weight = %d, margin = %d, shortest-path = %d", weight, margin, !prune);
  
  while ((m = (mt_Node*)queue.get ()))
  {
    sc_Switch*s;  

    //printFormat ("exploring node %s", m->getName ());

    insist (!m->isHost ());
    
    in = m->getIn ();
    int p;
    if (m->getNodeType () == mt_Node::CLOUD)
    {
      int numNodes = m->getNumNodes ();
      mt_Route*r = m->getRoute (0);
      
      insist (numNodes > 0);
      
      for (int i = 0; i < numNodes; i++)
      {
	mt_Node*c = m->getNode (i);
	insist (c);
	
	if (c->isHost ())
	{
	  if (c == from)
	    c->setRoute (0, mt_Route (r, 0));
	  else
	    c->setRoute (0, *r);
	  hostsFound++;
	}
#if 0	
	else if (!c->isMarked ())
	{
	  insist (c->isSwitch ());
	  sc_Switch*s = (sc_Switch*) c;
	  s->setIn (m->getOpposite (i));
	  s->setMark (c->number < m->number ? mt_Node::UP : mt_Node::DOWN);
	  queue.put (s);
	}
#endif
      }
      continue;
    }
    s = (sc_Switch*) m;
    
    insist (s->getNode (in) == root ||
	    ((s->getNode (in))->isMarked ()));
 
    int ports [sc_Switch::NUM_PORTS];
    int numPorts = s->getPorts (ports);

    for (int i = 0; i < numPorts; i++)
    {
      p = ports [i];
      mt_Node*n =  s->getNode (p);
      insist (n);

      //printFormat ("found node %s", n->getName ());

      if (n->isHost ())
      {
	if (n->isMarked ())
	  continue;
	
	insist (n == root || p != in);
	n->setRoute (0, mt_Route (s->getRoute (0), p - in));
	n->setMark (mt_Node::UP);
	addRoute (root, (sc_Host*)n, n->getRoute(0), weight);
	hostsFound++;

	//printFormat ("saved route to %s as %s", n->getName (), n->getRoute (0)->toString ());
      }
      else
      {
	int mark = n->number < s->number ? mt_Node::UP : mt_Node::DOWN;
	if (prune && (s->getMark () == mt_Node::DOWN && mark == mt_Node::UP))
	  continue;

	if (!n->isMarked ())
	{
	  insist (p != in);

	  n->setMark (mark);
	  n->setRoute (0, mt_Route (s->getRoute (0), p - in));
	  n->setIn (s->getOpposite (p));

	  //printFormat ("saved route to %s as %s", n->getName (), n->getRoute (0)->toString ());

	  queue.put (n);
	}
	else if (n->isSwitch ())
	{
	  if (p == in)
	    continue;
	  if (weight)
	  {
	    int enter = s->getOpposite (p);
	    mt_Route rt (s->getRoute (0), p - in);

	    for (int j = 0; j <= sc_Switch::NUM_PORTS; j++)
	    {
	      int q = enter + j;
	      if (q < 0 || q >= sc_Switch::NUM_PORTS)
		continue;

	      mt_Node*h =  n->getNode (q);
	      if (!h || !h->isHost ())
		continue;
	      
	      if (!h->isMarked ())
		continue;

	      mt_Route r (&rt, j);

	      if (isBetterRoute (root, h, &r, margin))
	      {
		removeRoute (root, h, weight);
		addRoute (root, h, &r, weight);
		
		h->setRoute (0, r);
		//printFormat ("saved better route to %s as %s", h->getName (), r.toString ());
	      }
	    }
	  }
	}
      }
    }
  }
  
  insist (getNumNodes (mt_Node::CLOUD) || hostsFound == getNumHosts ());
  
  return 1;
  exception: return 0;
}

void sc_Calculator::usage ()
{
  printFormat
    (
     "route-args:\n"
     "[-weight <weight>]\n"
     "[-margin <margin>]\n"
     "[-roots <numRoots> <node> <node> ...]\n"
     "[-shortest-path]\n"
     "[-use-numbers]\n"
     "[-save-routes]\n"
     "[-max-routes <number>]\n"
     "[-spread]\n"
     "[-disjoin]\n"
     );
}

int sc_Calculator::parseArgs (mt_Args*args)
{
  weight = 1;
  prune = 1;
  margin = 0;
  useNumbers = 0;
  saveRoutes = 0;
  maxRoutes = 1;
  disjoin = 0;
  spread = 0;
  
  insist (args);

  int argc;  
  char**argv;
  argv = args->getArgs (mt_Args::CALCULATOR, &argc);
  insist (argv);
  
  for (int i = 0; i < argc; i++)
  {
    insist (argv[i]);
    
    if (!strcmp (argv[i], "-shortest-path"))
      prune = 0;
    else if (!strcmp (argv[i], "-use-numbers"))
      useNumbers = 1;
    else if (!strcmp (argv[i], "-save-routes"))
      saveRoutes = 1;
    else if (!strcmp (argv[i], "-spread"))
      spread = 1;
    else if (!strcmp (argv[i], "-disjoin"))
      disjoin = 1;
    else if (!strcmp (argv [i], "-weight"))
    {
      if (++i == argc)
      {
	printFormat ("sc_Calculator: missing weight");
	return 0;
      }
      weight = atoi (argv [i]);
    }
    else if (!strcmp (argv [i], "-max-routes"))
    {
      if (++i == argc)
      {
	printFormat ("sc_Calculator: missing max-routes");
	return 0;
      }
      maxRoutes = atoi (argv [i]);
      if (maxRoutes < 1)
      {
	printFormat ("sc_Calculator: %s is a bad max-routes", argv [i]);
	return 0;
      }
    }
    else if (!strcmp (argv [i], "-margin"))
    {
      if (++i == argc)
      {
	printFormat ("sc_Calculator: missing margin");
	return 0;
      }
      margin = atoi (argv [i]);
    }
    else if (!strcmp (argv[i], "-roots"))
    {
      if (++i == argc)
      {
	printFormat ("sc_Calculator: missing root count");
	return 0;
      }
      numRoots = atoi (argv[i]);

      if (numRoots <= 0)
      {
	printFormat ("sc_Calculator: bad root count");
	numRoots = 0;
	return 0;
      }
      
      roots = new char* [numRoots];
      insistp (roots, ("sc_Calculator::parseArgs: alloc failed"));
	
      for (int j = 0; j < numRoots; j++)
      {
	if (++i == argc)
	{
	  printFormat ("sc_Calculator: missing root");
	  return 0;
	}
	roots[j] = argv [i];
      }
    }
    else 
    {
      printFormat ("sc_Calculator: bad option \"%s\"", argv[i]);
      return 0;
    }
  }
  return 1;
  exception: return 0;
}
