/*
  dj_Calculator.c
  route calculator that uses Dijkstra's shortest path algorithm
  dmazzoni@myri.com (Dominic Mazzoni)
*/

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "insist.h"

#include "mt_Queue.h"
#include "dj_Calculator.h"

// for array-based priority queue
#define parent(x) (((x)-1)/2)
#define left_child(x) (2*(x)+1)
#define right_child(x) (2*(x)+2)

dj_Calculator::dj_Calculator()
{
  numRoots = 0;
  roots = 0;
  routes = 0;
  queue = 0;
  useEdgeNumbers = 0;
  autoNumberFlag = 0;
  split = 1;
}

dj_Calculator::~dj_Calculator()
{
  cleanup();
}

void dj_Calculator::cleanup()
{
  int numHosts = getNumHosts();

  if (routes)
  {
    for (int i = 0; i < numHosts; i++)
      if (routes[i])
      {
	for (int j = 0; j < numHosts; j++)
	  if (routes[i][j])
	  {
	    for(int r = 0; r < split; r++)
	      if (routes [i] [j] [r])
		delete routes[i] [j] [r];
	    delete [] routes [i] [j];
	  }
	delete [] routes[i];
      }
    delete[] routes;
  }
  routes = 0;

  if (roots) delete roots;

  mt_Calculator::cleanup ();
  
  roots = 0;
}

int dj_Calculator::initialize (mt_Graph*graph, mt_Node*root)
{
  return mt_Calculator::initialize (graph, root) && renumber (root) && calculateRoutes ();
}

int dj_Calculator::initialize (char *mapFile)
{
  return mt_Calculator::initialize (mapFile) && renumber (0)  && calculateRoutes ();
}

int dj_Calculator::checkNodeNumbers ()
{
  insist (this);
  insist (!useEdgeNumbers);

  int numSwitches;
  numSwitches = getNumSwitches();

  for (int i = 0; i < numSwitches; i++)
  {
    mt_Node*s  = getSwitch (i);
    insist (s);

    int j;
    for (j = 0; j < i && getSwitch (j)->number != s->number; j++);

    if (j != i || s->number < 0)
      return 0;
  }
  exception: return 1;
}

int dj_Calculator::renumber (mt_Node*root)
{
  insist (this);
  insist (!root || root->isSwitch ());
  
  
  if (useEdgeNumbers)
    return 1;
  else if (autoNumberFlag)
    return autoNumber ();
  else if (numRoots && roots)
    return mt_Calculator::renumber (numRoots, roots);
  else if (root)
    return mt_Calculator::renumber (1, &root);
  else if (!checkNodeNumbers ())
    return autoNumber ();
  return 1;

  exception: return 0;
}

int dj_Calculator::followRoute(dj_Host *src, mt_Route *route)
{
  insist(src);
  insist(route);

  dj_Switch *s;

  s = (dj_Switch *)src->getNode(0);
  insist(s);
    
  int len;
  len = route->getLength();
  char *hops;
  hops = route->getHops();
  int inPort;
  inPort = src->getOpposite(0);

  int h;
  for(h=0; h<len-1; h++)
  {
    insist(s);
    int thePort;
    thePort = inPort + hops[h];
    s->load[thePort] ++;
    if (s->subnet[thePort]<subnet)
      s->subnet[thePort] = subnet;      
    if (s->subnet[thePort]>subnet)
    {
      printFormat("%d disjoint subnetworks could not be found.",split);
      return 0;
    }
    inPort = s->getOpposite(thePort);
    s = (dj_Switch *)s->getNode(thePort);
  }
    
  return 1;
  exception: return 0;
}

int dj_Calculator::calculateRoutes()
{
  int s,d,r;
  int numHosts = getNumHosts();
  int numSwitches = getNumSwitches();
  
  routes = new mt_Route***[numHosts];
  if (!routes) return 0;
  for(s=0; s<numHosts; s++)
    routes[s] = 0;
  for(s=0; s<numHosts; s++)
  {
    routes[s] = new mt_Route**[numHosts];
    if (!routes[s]) return 0;
    for(d=0; d<numHosts; d++)
      routes[s][d] = 0;
    for(d=0; d<numHosts; d++)
    {
      routes[s][d] = new mt_Route*[split];
      if (!routes[s][d]) return 0;
      for(r=0; r<split; r++)
	routes[s][d][r] = 0;
    }
  }

  for(s=0; s<numSwitches; s++)
  {
    dj_Switch *sw = (dj_Switch *)getSwitch(s);

    for(int p=0; p<dj_Switch::NUM_PORTS; p++)
    {
      sw->load[p] = 0;
      sw->subnet[p] = 0;
    }
  }

  for(subnet=split-1; subnet>=0; subnet--)
    for(s=0; s<numHosts; s++)
    {
      for(d=0; d<numHosts; d++)
	((dj_Host *)getHost(d))->route = 0;
      if (!calculateRoutesFrom(s))
	return 0;
      for(d=0; d<numHosts; d++)
      {
	mt_Route *theRoute = ((dj_Host *)getHost(d))->route;
	if (!theRoute)
	  return 0;
	routes[s][d][subnet] = theRoute;
	
	if (!followRoute((dj_Host *)getHost(s),theRoute))
	  return 0;
      }
    }

  return 1;
}

void dj_Calculator::swap(int index1, int index2)
{
  insist(index1>=0 && index1<queueSize && index2>=0 && index2<queueSize);

  dj_QElement *temp;

  temp = queue[index1];
  queue[index1] = queue[index2];
  queue[index2] = temp;

  queue[index1]->qIndex = index1;
  queue[index2]->qIndex = index2;

  exception: return;
}

void dj_Calculator::trickle_up(int index)
{
  insist(index>=0 && index<queueSize);

  if (index==0)
    return;

  if (*queue[index] > *queue[parent(index)])
    return;

  swap(index,parent(index));
  trickle_up(parent(index));

  exception: return;
}

void dj_Calculator::trickle_down(int index)
{
  insist(index>=0 && index<queueSize);

  if (left_child(index) >= queueSize)
    return;

  if (right_child(index) >= queueSize ||
       *queue[left_child(index)] < *queue[right_child(index)])
  {
    if (*queue[index] > *queue[left_child(index)])
    {
      swap(index,left_child(index));
      trickle_down(left_child(index));
    }
  }
  else
  {
    if (*queue[index] > *queue[right_child(index)])
    {
      swap(index,right_child(index));
      trickle_down(right_child(index));
    }
  }

  exception: return;
}

dj_QElement *dj_Calculator::extractMin()
{
  dj_QElement *min;

  if (queueSize==0)
    return 0;

  min = queue[0];

  if (min->dist >= INFINITY)
    return 0;

  swap(0,queueSize-1);
  queueSize--;
  if (queueSize)
    trickle_down(0);
  min->qIndex = -1;

  return min;
}

dj_QElement *dj_Calculator::getNBor(dj_QElement *from, int portNum)
{
  dj_Switch *s = from->sw;

  mt_Node *n = s->getNode(portNum);
  if (!n || !n->isSwitch())
    return 0;

  dj_Switch *nbor = (dj_Switch *)n;
  dj_QElement *found;

  if (useEdgeNumbers)
  {
    if (from->type > s->edgeNumber[portNum])
      return 0;
    int oppositePort = s->getOpposite(portNum);
    found = &nbor->qElement[oppositePort];
  }
  else
  {
    int pointsUp = (nbor->number < s->number);

    if (from->type == 1 && pointsUp)
      return 0;
    if (pointsUp)
      found = &nbor->qElement[0];
    else
      found = &nbor->qElement[1];
  }

  return found;
}

void dj_Calculator::discover(dj_QElement *e)
{
  // add some insists here

  dj_Switch *s = e->sw;

  for(int p=0; p<dj_Switch::NUM_PORTS; p++)
  {
    dj_QElement *found;

    found = getNBor(e,p);
    if (!found)
      continue;

    dj_QElement temp;

    temp = *found;
    temp.hops = e->hops+1;
    temp.dist = e->dist + s->load[p];
    temp.sameSubnet = (s->subnet[p] == subnet);
    temp.time = discoverCount;
    if (useEdgeNumbers)
      temp.type = s->edgeNumber[p];

    if (*found < temp)
      continue;

    found->hops = e->hops+1;
    found->dist = e->dist + s->load[p];
    found->sameSubnet = (s->subnet[p] == subnet);
    found->time = discoverCount++;
    if (useEdgeNumbers)
      found->type = s->edgeNumber[p];

    found->inPort = s->getOpposite(p);
    found->parent = e;

    if (found->qIndex>=0)
      trickle_up(found->qIndex);
  }
}

int dj_Calculator::calculateRoutesFrom(int srcHost)
{
  int numHosts = getNumHosts();

  dj_Switch *src = (dj_Switch *)getHost(srcHost)->getNode(0);
  if (!src) return 0;

  if (!initQueue(srcHost))
    return 0;

  // Run Dijkstra's algorithm - O(n log n)
  discover(&src->qElement[0]);
  while(queueSize)
  {
    dj_QElement *min = extractMin();
    if (!min)
      break;
    discover(min);
  }

  // give routes to each host
  for(int hi=0; hi<numHosts; hi++)
  {
    dj_Switch *s = (dj_Switch *)getHost(hi)->getNode(0);
    dj_QElement *e;

    // find min of all qElements

    e = &s->qElement[0];

    int numElements;
    if (useEdgeNumbers)
      numElements = dj_Switch::NUM_PORTS;
    else
      numElements = 2;

    for(int m=1; m<numElements; m++)
      if (s->qElement[m] < *e)
	e = &s->qElement[m];

    if (e->dist >= INFINITY)
    {
      printFormat("Error: Couldn't route from host %s to switch %s",
		  getHost(srcHost)->getName(),
		  s->getName());
      return 0;
    }

    int hops = e->hops+1;

    char hop[mt_Route::MAX_ROUTE];
    hop[hops-1] = getHost(hi)->getOpposite(0) - e->inPort;

    int x;
    x = hops-2;
    while(x>=0)
    {
      insist(e->parent);
      dj_Switch *es = e->sw;
      //      e->parent->sw->load[es->getOpposite(e->inPort)] += 5;
      hop[x--] = es->getOpposite(e->inPort) - e->parent->inPort;
      e = e->parent;
    }
    ((dj_Host *)getHost(hi))->route = new mt_Route(hop,hops);
  }

  deleteQueue();

  return 1;
  exception: return 0;
}

int dj_Calculator::initQueue(int srcHost)
{
  int numSwitches = getNumSwitches();

  dj_Switch *src = (dj_Switch *)getHost(srcHost)->getNode(0);
  if (!src) return 0;

  if (useEdgeNumbers)
    queue = new dj_QElement*[dj_Switch::NUM_PORTS*numSwitches];
  else
    queue = new dj_QElement*[2*numSwitches];
  if (!queue) return 0;

  queueSize = 0;
  discoverCount = 0;

  for(int i=0; i<numSwitches; i++)
  {
    dj_Switch *sw = (dj_Switch *)getSwitch(i);

    int numElements;

    if (useEdgeNumbers)
      numElements = dj_Switch::NUM_PORTS;
    else
      numElements = 2;

    dj_QElement *it;

    for(int e=0; e<numElements; e++)
    {
      it = &sw->qElement[e];

      it->sw = sw;
      it->parent = 0;
      it->sameSubnet = 0;
      if (useEdgeNumbers)
	it->type = -1;
      else
	it->type = e;

      if (sw == src)
      {
	it->dist = 0;
	it->inPort = getHost(srcHost)->getOpposite(0);
	it->qIndex = -1;
	it->hops = 0;
	it->time = discoverCount++;
      }
      else
      {
	it->dist = INFINITY;
	it->hops = INFINITY;
	it->time = -1;
	it->qIndex = queueSize;
	queue[queueSize++] = it;
      }
    }
  }
  return 1;
}

void dj_Calculator::deleteQueue()
{
  delete[] queue;
  queue = 0;
}


mt_Node*dj_Calculator::newNode (int nodeType, char*name, char*type)
{
  insist (this);
  switch (nodeType)
  {
    case mt_Node::HOST:
      return new dj_Host (name, type);
    case mt_Node::SWITCH:
      return new dj_Switch (name, type);
    default:
      insist (0);
  }
  exception: return 0;
}

int dj_Calculator::getNumRoutes (int from, int to)
{
  insist (this);
  insist (from >= 0 && from < getNumHosts ());
  insist (to >= 0 && to < getNumHosts ());

  return split;

  exception: return 0;
}

int dj_Calculator::getMaxRoutes ()
{
  insist (this);

  return split;

  exception: return 0;
}

int dj_Calculator::getRoute (int from, int to, int routeIndex, mt_Route*route)
{
  insist (this);
  insist (from >= 0 && from < getNumHosts ());
  insist (to >=0 && to < getNumHosts ());
  insist (route);
  insist (routes);
  insist (routes[from]);
  insist (routes[from][to]);
  insist (0 <= routeIndex && routeIndex < split);
  insist (routes[from][to][routeIndex]);

  *route = *routes[from][to][routeIndex];
  
  return 1;
  exception: return 0;
}

void dj_Calculator::usage ()
{
  printFormat
    (
     "route-args:\n"
     "[-use-node-numbers]\n"
     "[-bfs-node-number <numRoots> <node> <node> ...]\n"
     "[-bottom-up-node-number]\n"
     "[-use-edge-numbers]\n"
     "[-split <numDisjointNetworks>]\n"
     );
}

int dj_Calculator::parseArgs (mt_Args*args)
{
  insist (args);
  
  int argc;  
  char**argv;
  argv = args->getArgs (mt_Args::CALCULATOR, &argc);
  
  for (int i = 0; i < argc; i++)
  {
    insist (argv[i]);

    if (!strcmp (argv[i], "-use-node-numbers"))
    {
      useEdgeNumbers = 0;
    }
    else if (!strcmp (argv[i], "-use-edge-numbers"))
    {
      useEdgeNumbers = 1;
    }
    else if (!strcmp (argv[i], "-bottom-up-node-number"))
    {
      autoNumberFlag = 1;
      useEdgeNumbers = 0;
    }
    else if (!strcmp (argv[i], "-split"))
    {
      if (++i == argc)
      {
	printFormat ("dj_Calculator: missing split count");
	return 0;
      }
      split = atoi (argv[i]);
      if (split<1 || split>8)
      {
	printFormat ("dj_Calculator: bad split count");
	split = 0;
	return 0;
      }
    }
    else if (!strcmp (argv[i], "-bfs-node-number"))
    {
      useEdgeNumbers = 0;
      if (++i == argc)
      {
	printFormat ("dj_Calculator: missing root count");
	return 0;
      }
      numRoots = atoi (argv[i]);

      if (numRoots <= 0)
      {
	printFormat ("dj_Calculator: bad root count");
	numRoots = 0;
	return 0;
      }
      
      roots = new char* [numRoots];
      insistp (roots, ("dj_Calculator::parseArgs: alloc failed"));
	
      for (int j = 0; j < numRoots; j++)
      {
	if (++i == argc)
	{
	  printFormat ("dj_Calculator: missing root");
	  return 0;
	}
	roots[j] = argv [i];
      }
      return 1;
    }
    else 
    {
      printFormat ("dj_Calculator: bad option \"%s\"", argv[i]);
      return 0;
    }
  }
    
  return 1;
  exception: return 0;
}
