/*
  trouble.c
  send a mapper message along a route and print reply.
  finucane@myri.com (David Finucane)
*/

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "mt_Network.h"
#include "mt_Host.h"
#include "mt_Message.h"

#include "insist.h"

static void printLine (char*s)
{
  printf ("%s\n", s);
  fflush (stdout);
}

class Trouble : public mt_Job
{
  private:
  int size;
  int count;
  int unit;
  int scout;
  int timeout;
  mt_Route route;
  mt_Address address;
  
  public:
  int parseArgs (int argc, char*argv []);
  void usage ();
  Trouble (mt_Node*node, mt_Network*network);
  int sendScouts ();
  int sendProbes ();
  int wait (int phase, int type);

  virtual void dump (FILE*fp);
  virtual int willWait ();
  virtual int start ();
  virtual void receive (int event, char*p, int length);
};


Trouble::Trouble (mt_Node*node, mt_Network*network) : mt_Job (node, network)
{
}

int Trouble::parseArgs (int argc, char*argv [])
{
  if (argc < 6)
    return 0;
  
  unit = atoi (argv [1]);
  
  if (!(count = atoi (argv [2])))
    return 0;

  if (!(size = atoi (argv [3])))
    return 0;

  timeout = atoi (argv [4]);

  scout = !strcmp (argv [5], "scout");  

  route.empty ();
  
  for (int i = 6; i < argc; i++)
    route.append ((char) atoi (argv [i]));
  
  insistp (unit >= 0 && unit < 16, ("bad unit number %s", argv [1]));
  insistp (timeout > 0, ("bad timeout %s", argv [1]));
  insistp (count >= 0, ("bad count %s", argv [2]));
  insistp (size > 0 && size <= 4000, ("bad size %s", argv [3]));
  
  return 1;
  exception: return 0;
}

void Trouble::usage ()
{
  printFormat ("usage: <unit> <count> <size> <timeout> <scout | probe> <hop> <hop> <hop> ...");
}

void Trouble::dump (FILE*fp)
{
}

int Trouble::willWait ()
{
  return 1;
}


int Trouble::start () 
{
  char name [mt_Network::HOSTNAME_LENGTH + 1];
  int type;

  insist (this);
  
  if (!getNetwork ()->open (this, &address, name, &type, unit))
    return 0;  

  if (scout)
    sendScouts ();
  else
    sendProbes ();

  getNetwork ()->close (this);
  exception: return 0;
}

int Trouble::sendScouts ()
{
  insist (this);

  if (size < (int) sizeof (mt_ScoutMessage))
    size = sizeof (mt_ScoutMessage);
  
  for (int i = 0; i < count; i++)
  {
    mt_Route rr (route);
    rr.invert ();
    
    mt_ScoutMessage m (0, i, &rr, &address, 1);
    m.swap ();
    
    printFormat ("sending scout on route %s", route.toString ());

    if (!getNetwork ()->send (this, &route, (char*) &m, size))
      printFormat ("send failed");
    
    wait (i, mt_Message::REPLY);
  }
  return 1;
  exception: return 0;
}

int Trouble::sendProbes ()
{
  insist (this);

  if (size < (int) sizeof (mt_ProbeMessage))
    size = sizeof (mt_ProbeMessage);
  
  for (int i = 0; i < count; i++)
  {
    mt_ProbeMessage m (0, i);
    m.swap ();
    
    printFormat ("sending probe on route %s", route.toString ());

    if (!getNetwork ()->send (this, &route, (char*) &m, size))
      printFormat ("send failed");
    
    wait (i, mt_Message::PROBE);
  }
  return 1;
  exception: return 0;
}

int Trouble::wait (int phase, int type)
{
  insist (this);
  
  char*p;
  int length;
  int e;
      
  getNetwork ()->setTimer (this, timeout);
      
  while ((e = getNetwork ()->wait (&p, &length)))
  {
    if (e == mt_Network::SEND_DONE)
    {
      printFormat ("send done");
    }
    else if (e == mt_Network::TIMEOUT)
    {
      printFormat ("timeout expired");
      return 0;
    }
    else if (e == mt_Network::RECEIVE)
    {
      insist (p);
      insist (length);
      
      mt_Message*m;
      m = (mt_Message*)p;
      
      int mtype;
      int stype;
      int mphase;

      mtype = mt_htons (m->type);
      stype = mt_htons (m->subtype);
      mphase = mt_htonl (m->phase);

      printFormat ("received a message");
      
      switch (stype)
      {
	case mt_Message::REPLY:
	  {
	    mt_ReplyMessage*rm = (mt_ReplyMessage*) m;
	    rm->swap ();
	    mt_Address a;
	    a.fromBytes (rm->address);
	    
	    printFormat ("reply from %s address %s. phase %d", rm->getHostname (), a.toString (), rm->phase);
	    if (rm->phase != phase)
	      printFormat ("bad phase (expected %d)", phase);
	    if (stype != type)
	      printFormat ("bad type %d expected %d", stype, type);
	      
	    getNetwork ()->clearTimer (this);
	    return 1;
	  }
	  break;
	case mt_Message::PROBE:
	  {
	    mt_ProbeMessage*pm = (mt_ProbeMessage*) m;
	    pm->swap ();
	    printFormat ("probe phase %d", pm->phase);
	    if (pm->phase != phase)
	      printFormat ("bad phase (expected %d)", phase);
	    if (stype != type)
	      printFormat ("bad type %d expected %d", stype, type);
	    getNetwork ()->clearTimer (this);
	    return 1;
	  }
	  break;
	default:
	  printFormat ("bad type %d expected %d", stype, type);
	  break;
      }
    }
    else
    {
      insist (0);
    }
  }

  printFormat ("timeout expired");
  exception: return 0;
}

void Trouble::receive (int event, char*p, int length)
{
}

int main (int argc, char*argv [])
{
  mt_Component::initialize (printLine);
  mt_Host host ("fake", "-");
  
  Trouble*b = new Trouble (&host, mt_getNetwork ());

  if (!b->parseArgs (argc, argv))
    b->usage ();
  else
    b->start ();
}
