/*
  sw_Sender.c
  thing to send messages into a switch
  finucane@myri.com (David Finucane)
*/

#include <string.h>

#include "sw_Sender.h"
#include "sw_TestPort.h"
#include "insist.h"

int sw_Sender::totalSent;
int sw_Sender::totalReceived;

sw_Sender::sw_Sender (mt_Job*job, mt_Network*network, mt_Graph*graph, int receiveWait, int sendWait)
{
  this->job = job;
  this->graph = graph;
  this->receiveWait = receiveWait;
  this->sendWait = sendWait;
  this->network = network;
  numReceived = numSent = numSendDones = 0;
  memset (counts, 0, sizeof (int) * sw_TestPort::NUM_PORTS);
}


int sw_Sender::startThread ()
{
  insist (this);
  
  setHandler (mt_Message::PROBE, mt_fp (sw_Sender::handleProbe));
  
  return 1;
  exception: return 0;
}


int sw_Sender::setHandler (int type, sw_SenderHandler handler)
{
  insist (this);  
  insist (type >= 0 && type < mt_Message::NUM_TYPES);
  handlers [type] = handler;
  
  return 1;
  exception: return 0;
}

int sw_Sender::clearHandler (int type)
{
  insist (this);
  insist (type >= 0 && type < mt_Message::NUM_TYPES);
  handlers [type] = 0;

  return 1;
  exception: return 0;  
}

int sw_Sender::wait (int timeout)
{
  insist (this);
  
  char*p;
  int length;
  
  network->setTimer (job, timeout);
  return wait (mt_Network::TIMEOUT, 0, &p, &length);

  exception: return 0;
}

int sw_Sender::wait (int event, int type, char**p, int*length)
{
  insist (this);
  insist (p && length);

  int e;
  
  while ((e = network->wait (p, length)))
  {
    if (e == mt_Network::SEND_DONE)
      numSendDones += network->getSentDoneLength ();
    
    if (e == mt_Network::RECEIVE)
    {
      insist (*p);
      insist (*length);
      
      mt_Message*m;
      m = (mt_Message*)*p;
      
      int mtype;
      int stype;
      int mphase;
      
      mtype = mt_htons (m->type);
      stype = mt_htons (m->subtype);
      mphase = mt_htonl (m->phase);

      if (mtype != mt_Message::GM_TYPE || stype < 0 || stype >= mt_Message::NUM_TYPES)
      {
	printFormat ("bad type %d", mtype);
	network->freePacket (*p);
	continue;
      }
      
      if (handlers [stype])
	(this->*handlers [stype]) (*p, *length);

      if (e == event && type == stype)
	return 1;
 
      network->freePacket (*p);
    }
    else if (e == event)
      return 1;
    else if (e == mt_Network::TIMEOUT)
      return 0;
  }

  if (network->isAborted ())
  {
    printFormat ("network aborted. is it a deadlock?");
    return 0;
  }
  
  insist (0);
  exception: return 0;
}

void sw_Sender::handleReply (char*p, int length)
{
  insist (this);
  insist (p);  
  
  if ((unsigned) length < sizeof (mt_ReplyMessage))
  {
    printFormat ("bad length");
    return;
  }
  
  mt_ReplyMessage*m;
  m = (mt_ReplyMessage*) p;
  m->swap ();

  insist (m->type == mt_Message::GM_TYPE && m->subtype == mt_Message::REPLY);
  numReceived++;
  totalReceived++;
  
  exception: return;
}

void sw_Sender::handleProbe (char*p, int length)
{
  insist (this);
  insist (p);

  if ((unsigned) length < sizeof (mt_ProbeMessage))
  {
    printFormat ("bad length");
    return;
  }
  
  mt_ProbeMessage*m;
  m = (mt_ProbeMessage*) p;
  m->swap ();

  insist (m->type == mt_Message::GM_TYPE && m->subtype == mt_Message::PROBE);
  numReceived++;
  totalReceived++;
  
  exception: return;
}

int sw_Sender::sendProbe (mt_Route*route, mt_ProbeMessage*m, int length)
{
  insist (this);
  
  memcpy (buffer, m, sizeof (mt_ProbeMessage));
  return send (route, buffer, length);
  
  exception: return 0;
}

int sw_Sender::send (mt_Route*route, char*p, int length)
{
  insist (this);
  insist (p);
  insist (length > 0);  

  if (length < (int) sizeof (mt_ProbeMessage))
    length = sizeof (mt_ProbeMessage);
  
  if (length % 4) length += 4 - length % 4;
  if (length > network->getReceiveMtu ())
    length = network->getReceiveMtu ();  

  char*pp;
  int ll;  

  if (receiveWait)
  {  
    int numReceiveBuffers;
    numReceiveBuffers = network->getNumReceiveBuffers ();

    insist (numSent >= numReceived);
    if (numSent - numReceived >= numReceiveBuffers && !wait (TIMEOUT, mt_Network::RECEIVE, 1))
      return 0;
  
    insist (numSent - numReceived < numReceiveBuffers);
  }
  
  //printFormat ("sending to route %s", route->toString ());
  
  while (!network->send (job, route, p, length))
  {
    insist (numSent > numReceived);
    
    network->setTimer (job, TIMEOUT);
  
    if (!wait (mt_Network::SEND_DONE, 0, &pp, &ll))
    { 
      printFormat ("send failed (dropped send done event after blocking on send. numSendDones %d)", numSendDones);
      return 0;
    }
    network->clearTimer (job);
  }

  numSent++;
  totalSent++;  
  
  if (sendWait)
  { 
    network->setTimer (job, TIMEOUT);
  
    if (!wait (mt_Network::SEND_DONE, 0,&pp, &ll))
    { 
      printFormat ("send failed (dropped send done event. numSendDones %d)", numSendDones);
      return 0;
    }
    network->clearTimer (job);
  }

  if (receiveWait && sendWait && numSent != numReceived)
  {  
    insist (numSent == numReceived + 1);
    
    network->setTimer (job, TIMEOUT);
    
    if (!wait (mt_Network::RECEIVE, mt_Message::PROBE, &pp, &ll))
    {
      printFormat ("receive failed (dropped receive event)");
      return 0;
    }
    else network->freePacket (pp);
    network->clearTimer (job);
    insist (!receiveWait || numSent == numReceived);
  }

  return 1;
  exception: return 0;
}

int sw_Sender::getSwitchRoute (mt_Route*route, mt_Node*_switch, int port)
{
  insist (this);
  insist (route && _switch && port >= 0 && port < _switch->getMaxNodes ());  

  mt_Node*n;
  n = _switch->getNode (port);
    
  if (!n)
    return 0;
  
  if (n == job->getNode ())
  {
    mt_Route r1;
    mt_Route r2 (&r1, 0);
    *route = r2;
  }
  else if (n->isHost ())
    return 0;
  else if (n == _switch)
  {
    mt_Route r1 (_switch->getRoute (0), port - _switch->getIn ());
    r1.invert ();
    mt_Route r2 (_switch->getRoute (0), port - _switch->getIn (), &r1);
    *route = r2;
  }
  else if (port == _switch->getIn ())
  {
    mt_Route r1 (_switch->getRoute (0), port - _switch->getIn ());
    r1.invert ();
    mt_Route r2 (_switch->getRoute (0), &r1);
    *route = r2;
  }  
  else
  {
    mt_Route r1 (_switch->getRoute (0), port - _switch->getIn ());
    r1.invert ();
    mt_Route r2 (_switch->getRoute (0), port - _switch->getIn ());
    mt_Route r3 (&r2, 0, &r1);
    *route = r3;
  }
  
  mt_Node*fn;
  int fin;
  
  insist (job->getNode ()->follow (route, &fn, &fin) && fn == job->getNode ());
  return 1;
  exception: return 0;
}

int sw_Sender::getSnakeyRoute (mt_Route*route, mt_Node*_switch)
{
  insist (this);
  insist (route && _switch);

  int maxNodes;
  maxNodes = _switch->getMaxNodes ();
  
  int in;
  in = _switch->getIn ();
  insist (in >= 0 && in < maxNodes);

  *route = *_switch->getRoute (0);

  int last, first;

  last = in;
  
  for (int i = in; i < maxNodes; i++)
  {
    if (_switch->getNode (i) && _switch->getNode (i)->isSwitch ())
    {
      route->append (1);
      last = i;
    }
  }

  for (first = 0; first < in && !_switch->getNode (first) ; first++);
  
  route->append (first - last);

  for (int i = first; i < in; i++)
  {
    if (_switch->getNode (i) && _switch->getNode (i)->isSwitch ())
      route->append (1);
  }
  return 1;
  exception: return 0;
}



int sw_Sender::getPortRoute (mt_Route*route, mt_Node*_switch, int port)
{
  insist (this);
  insist (route && _switch);

  if (port < 0 || port >= sw_TestPort::NUM_PORTS)
  {
    mt_Route r1 (_switch->getRoute (0), port - _switch->getIn ());
    *route = r1;
    return 1;
  }

  mt_Node*n;
  n = _switch->getNode (port);
    
  if (port == _switch->getIn ())
  {
    *route = *_switch->getRoute (0);
  }
  else if (!n || n->isSwitch ())
  {
    mt_Route r1 (_switch->getRoute (0), port - _switch->getIn ());
    *route = r1;
  }
  else if (n == job->getNode ())
  {
    mt_Route r1;
    *route = r1;
  }
  else
    return 0;

  return 1;
  exception: return 0;
}


int sw_Sender::getDeadlockRoute (mt_Route*route, mt_Node*_switch, int port)
{
  insist (this);
  insist (route && _switch && port >= 0 && port < _switch->getMaxNodes ());  

  mt_Node*n;
  n = _switch->getNode (port);
    
  if (!n || !n->isSwitch () || port == _switch->getIn ())
    return 0;
  else
  {
    mt_Route r1 (_switch->getRoute (0), port - _switch->getIn (), 0, 0);
    *route = r1;
  }
  return 1;
  exception: return 0;
}

int sw_Sender::getNumSent ()
{
  insist (this);
  return numSent;
  exception: return 0;
}

int sw_Sender::getNumReceived ()
{
  insist (this);
  return numReceived;
  exception: return 0;
}

int sw_Sender::getTotalSent ()
{
  return totalSent;
}

int sw_Sender::getTotalReceived ()
{
  return totalReceived;
}

int sw_Sender::flushSends ()
{
  insist (this);
  
  char*pp;
  int ll;
  
  network->setTimer (job, TIMEOUT);
  while (wait (mt_Network::SEND_DONE, 0, &pp, &ll));

  return 1;  
  exception: return 0;
}

int sw_Sender::waitForReceives ()
{
  return wait (TIMEOUT, mt_Network::RECEIVE, numSent - numReceived);
}

int sw_Sender::waitForSends ()
{
  return wait (TIMEOUT, mt_Network::SEND_DONE, numSent - numSendDones);
}

int sw_Sender::wait (int timeout, int type, int count)
{
  int ll;
  char*pp;
  
  insist (this);
  insist (timeout >= 0);
  
  network->setTimer (job, timeout);
  insist (numSent >= numReceived);
  //  insist (numSent >= numSendDones);
  
  for (int i = 0; i < count; i++)
  {
    //insist (numSent > numReceived);
    
    if (type == mt_Network::SEND_DONE && numSent == numSendDones)
      break;
    
    if (!wait (type, mt_Message::PROBE, &pp, &ll))
    {
      printFormat ("receive failed (dropped %s event)", type == mt_Network::RECEIVE ? "receive" : "send done");
      return 0;
    }
    else network->freePacket (pp);
  }
  
  network->clearTimer (job);
  return 1;
  exception: return 0;
}

sw_Sender::~sw_Sender ()
{
  wait (TIMEOUT, mt_Network::SEND_DONE, numSent - numSendDones);
}

  
int sw_Sender::getCounts (int*counts)
{
  insist (this);
  insist (counts);
  memcpy (counts, this->counts, sizeof (int) * sw_TestPort::NUM_PORTS);
  return 1;
  exception: return 0;
}
