/*
  de_Deadlock.c
  deadlock detector
  finucane@myr.com (David Finucane)
*/

#include <stdio.h>
#include <stdlib.h>
#include "insist.h"
#include "mt_Host.h"
#include "mt_Cloud.h"
#include "de_Deadlock.h"

/*
  a port is part of a cycle if you can travel back to it
  through untested ports.
*/


de_Edge::de_Edge ()
{
  from = to = 0;
}

de_Port*de_Edge::getTo ()
{
  insist (this);
  return to;
  exception: return 0;
}

de_Port*de_Edge::getFrom ()
{
  insist (this);
  return from;
  exception: return 0;
}

int de_Edge::connected ()
{
  insist (this);
  return from != 0;
  exception: return 0;
}

int de_Edge::connect (de_Port*from, de_Port*to)
{
  insist (this);

  this->from = from;
  this->to = to;

  /*
  if (from && to)
    printFormat ("connecting %s:%d (%d) to %s:%d (%d)", 
		 from->getNode ()->getName (), from->getIndex (), from->getDirection (), 
		 to->getNode ()->getName (), to->getIndex (), to->getDirection ());
  */
  return 1;
  exception: return 0;
}

de_Port::de_Port ()
{
  node = 0;
  index = -1;
  direction = IN;
  mark = UNCHECKED;
}

mt_Node*de_Port::getNode ()
{
  insist (this);
  return node;
  exception: return 0;
}

int de_Port::getIndex ()
{
  insist (this);
  return index;
  exception: return 0;
}

int de_Port::getDirection ()
{
  insist (this);
  return direction;
  exception: return 0;
}

int de_Port::getMark ()
{
  insist (this);
  return mark;
  exception: return 0;
}

void de_Port::setMark (int mark)
{
  insist (this);
  this->mark = mark;
  exception:;
}

int de_Port::initialize (mt_Node*node, int index, int direction)
{
  insist (this && node && index >= 0 && index < node->getMaxNodes ());
  this->node = node;
  this->index = index;
  this->direction = direction;
  
  mark = CHECKED;
  
  for (int i = 0; i < mt_Switch::NUM_PORTS; i++)
    edges [i].connect (0, 0);
  
  exception: return 0;
}

int de_Port::connect (de_Port*to)
{
  insist (this && to && to->index >= 0 && to->index <= mt_Switch::NUM_PORTS);
  
  mark = UNCHECKED;
  
  if (to->node == node)
    edges [to->index].connect (this, to);
  else
    edges [0].connect (this, to);
  
  return 1;
  exception: return 0;
}

int de_Port::getEdges (de_Edge*edges[], int maxEdges)
{
  int numEdges = 0;
  
  insist (this && edges);
  insist (maxEdges >= NUM_EDGES);
  
  for (int i = 0; i < NUM_EDGES && i < maxEdges; i++)
    if (this->edges [i].connected ())
      edges [numEdges++] = &this->edges [i];

  return numEdges;
  exception: return 0;
}

de_Port*de_Switch::getInPort (int p)
{
  insist (this && p >= 0 && p < mt_Switch::NUM_PORTS);
  return &inPorts [p];
  exception: return 0;
}

de_Port*de_Switch::getOutPort (int p)
{
  insist (this && p >= 0 && p < mt_Switch::NUM_PORTS);
  return &outPorts [p];
  exception: return 0;
}

int de_Switch::initialize ()
{
  insist (this);
  
  for (int i = 0; i < mt_Switch::NUM_PORTS; i++)
  {
    inPorts [i].initialize (this, i, de_Port::IN);
    outPorts [i].initialize (this, i, de_Port::OUT);
  }
  exception: return 0; 
}

de_Switch::de_Switch (char*name, char*type) : mt_Switch (name, type)
{
  initialize ();
}

int de_Switch::_connect (int mode, int fromPort, de_Switch*toSwitch, int toPort)
{
  insist (this && toSwitch);
  insist (mode == de_Port::WITHIN || mode == de_Port::BETWEEN);
  
  return  mode == de_Port::WITHIN ? 
    inPorts [fromPort].connect (&outPorts [toPort]) :
    outPorts [fromPort].connect (&toSwitch->inPorts [toPort]);
  
  exception: return 0;
}

int de_Deadlock::growPorts ()
{
  insist (this);
  
  if (maxPorts <= 1)
    maxPorts = 100;
  
  if (ports)
    free (ports);
    
  maxPorts *= 2;
  
  ports = (de_Port**) malloc (sizeof (de_Port*) * maxPorts);
  insistp (ports, ("couldn't alloc ports"));

  return maxPorts;
  
  exception: 
  maxPorts = 0;
  return 0;
}

int de_Deadlock::addToPorts (int i, de_Port*p)
{
  insist (this);
  insist (i >= 0);
  insist (p);
  
  if (i >= maxPorts)
    growPorts ();
  insist (i < maxPorts);
  ports [i] = p;

  return 1;
  exception: return 0;  
}

de_Deadlock::de_Deadlock ()
{
  maxPorts = 0;
  numPorts = 0;
  ports = 0;
}

de_Deadlock::~de_Deadlock ()
{
  if (ports) free (ports);
}

mt_Node*de_Deadlock::newNode (int nodeType, char*name, char*type)
{
  insist (this);
  switch (nodeType)
  {
    case mt_Node::HOST:
      return new mt_Host (name, type);
    case mt_Node::SWITCH:
      return new de_Switch (name, type);
    case mt_Node::CLOUD:
      return new mt_Cloud (name, type);
    default:
      insist (0);
  }
  exception: return 0;
}

int de_Deadlock::erase ()
{
  insist (this);
  int numSwitches;
  numSwitches = getNumSwitches ();
  numPorts = 0;
  
  for (int i = 0; i < numSwitches; i++)
    ((de_Switch*)getSwitch (i))->initialize ();
  return 1;
  exception: return 0;
}

int de_Deadlock::step (mt_Node*node, int port)
{
  insist (node);
  insist (port >= 0 && port < node->getMaxNodes ());
  
  int lastPort;
  lastPort = this->port;
  mt_Node*lastNode;
  lastNode = this->node;
  
  this->port = port;
  this->node = node;
  
  if (!lastNode || !lastNode->isSwitch () || !node->isSwitch ())
    return 1;
  
  return ((de_Switch*)lastNode)->_connect (de_Port::WITHIN, lastPort, (de_Switch*) lastNode, node->getOpposite (port)) &&
    ((de_Switch*)lastNode)->_connect (de_Port::BETWEEN, node->getOpposite (port), (de_Switch*) node, port);

  exception: return 0;
}

int de_Deadlock::connect (mt_Node*from, mt_Node*to,  mt_Route*r)
{
  insist (this);
  insist (r);
  
  mt_Node*n;
  int in;

  node = 0;
  
  //  printFormat ("following route from %s to %s", from->getName (), to->getName ());
  
  from->follow (r, &n, &in, this);
  
  insist (n == to);
  return 1;
  
  exception: return 0;
}

int de_Deadlock::connect (mt_Calculator*calculator)
{
  int numHosts = getNumHosts ();
  insist (calculator);
  
  for (int i = 0; i < numHosts; i++)
  {
    mt_Node*from = getHost (i);
    insist (from);
    
    for (int j = 0; j < numHosts; j++)
    {
      mt_Node*to = getHost (j);
      insist (to);
      
      int numRoutes = calculator->getNumRoutes (i, j);
      insist (numRoutes >= 0);
      
      for (int k = 0; k < numRoutes; k++)
      {
	mt_Route r;
	if (!calculator->getRoute (i, j, k, &r) || !connect (from, to, &r))
	  return 0;
      }
    }
  }
  return 1;
  exception: return 0;
}

de_Port*de_Deadlock::getUncheckedPort ()
{
  insist (this);
  
  int numSwitches;
  numSwitches = getNumSwitches ();
  
  for (int i = 0; i < numSwitches; i++)
  {
    de_Switch*s = (de_Switch*) getSwitch (i);
    
    int maxNodes = s->getMaxNodes ();
    
    for (int j = 0; j < maxNodes; j++)
    {
      if (s->getInPort (j)->getMark () == de_Port::UNCHECKED)
	return s->getInPort (j);

      if (s->getOutPort (j)->getMark () == de_Port::UNCHECKED)
	return s->getOutPort (j);      
    }
  }
  exception: return 0;
}

int de_Deadlock::unsee ()
{
  insist (this);
  
  int numSwitches;
  numSwitches = getNumSwitches ();
  
  for (int i = 0; i < numSwitches; i++)
  {
    de_Switch*s = (de_Switch*) getSwitch (i);
    
    int maxNodes = s->getMaxNodes ();
    
    for (int j = 0; j < maxNodes; j++)
    {
      if (s->getInPort (j)->getMark () == de_Port::SEEN)
	s->getInPort (j)->setMark (de_Port::UNCHECKED);

      if (s->getOutPort (j)->getMark () == de_Port::SEEN)
	s->getOutPort (j)->setMark (de_Port::UNCHECKED);
    }
  }
  return 1;
  exception: return 0;
}

int de_Deadlock::check (int depth, de_Port*p)
{
  de_Edge*edges [de_Port::NUM_EDGES];
  int numEdges;

  insist (this && p);
  //printFormat ("checking port %s:%d (%d)", p->getNode ()->getName (), p->getIndex (), p->getDirection ());

  addToPorts (depth, p);
  
  if (p->getMark () != de_Port::UNCHECKED)
    return p == start ? depth + 1 : 0;
  
  p->setMark (de_Port::SEEN);

  numEdges = p->getEdges (edges, de_Port::NUM_EDGES);

  for (int i = 0; i < numEdges; i++)
  {
    insist (edges [i] && edges [i]->getFrom () == p && edges [i]->getTo ());
    
    int d = check (depth + 1, edges [i]->getTo ());
    if (d)
      return d;
  }

  exception: return 0;  
}

int de_Deadlock::print ()
{
  insist (this);
  insist (numPorts <= maxPorts);
  
  for (int i = 0; i < numPorts; i++)
  {
    insist (ports [i]);  
    printFormat ("%s:%d", ports [i]->getNode ()->getName (), ports [i]->getIndex ());
  }
  return numPorts;
  exception: return 0;
}


int de_Deadlock::deadlockable (mt_Calculator*calculator)
{  
  insist (this);
  insist (calculator);
  
  if (!erase () || !connect (calculator))
    return 0;

  return deadlockable ();
  exception: return 0;
}


int de_Deadlock::deadlockable (mt_Route*r, int from, int to)
{  
  insist (this);
  insist (r);
  insist (to >= 0 && to < getNumHosts ());
  insist (from >= 0 && from < getNumHosts ());
  
  if (!erase () || !connect (getHost (from), getHost (to), r))
    return 0;

  return deadlockable ();
  exception: return 0;
}

int de_Deadlock::deadlockable ()
{
  de_Edge*edges [de_Port::NUM_EDGES];
  int numEdges;
  
  insist (this);

  while ((start = getUncheckedPort ()))
  {
    //printFormat ("starting port %s:%d (%d)", start->getNode ()->getName (), start->getIndex (), start->getDirection ());
    
    numEdges = start->getEdges (edges, de_Port::NUM_EDGES);

    start->setMark (de_Port::SEEN);

    addToPorts (0, start);
    
    for (int i = 0; i < numEdges; i++)
    {
      insist (edges [i] && edges [i]->getFrom () == start && edges [i]->getTo ());
      if ((numPorts = check (1, edges [i]->getTo ())))
	return 1;
    }
    unsee ();
    start->setMark (de_Port::CHECKED);
  }
  
  return 0;
  
  exception: return 0;
}
