#include <stdio.h>
#include "cmd.h"

int pvmrsh_debug = 0;
static int num_children = 0;
static void (*child_exit_hook)();
static void (*child_stdout_hook)();
static void (*child_stderr_hook)();

static unsigned char *server_status;
#define STATUS_UP	01	/* server is up, no procs running */
#define STATUS_RUNNING	02	/* server is up, procs are running */
#define STATUS_DIED	03	/* server is no longer running */
static int num_servers = 0;

#define CLNT_ERR(x) (SRC_CLNT | (x))

int
pvm_error_code (x)
int x;
{
    if (x < 0) {
	x = -x;
	return ((x & 0xff) | SRC_CLNT_PVM);
    }
    return x;
}


/***********************************************************************
 *                                                                     *
 *                               pvm glue                              *
 *                                                                     *
 ***********************************************************************/

static int
getint (ip)
int *ip;
{
    return getnint (ip, 1);
}


static int
getlongint (lp)
long *lp;
{
    return getnlong (lp, 1);
}

static int
putint (i)
{
    return putnint (&i, 1);
}

static int
putlongint (l)
long l;
{
    return putnlong (&l, 1);
}

/*
 * wait on a reply from a command.  while waiting, process any messages
 * that arrive.  if dowait is true, return as soon as we get any
 * "process complete" messages.  otherwise, return as soon as we get
 * any "command reply" messages.
 *
 * return value
 * 0 on success, -1 if wait completed, otherwise a universal error code
 */

struct pvmwait {
    char proc[33];
    int inum;
    int status;
};

static int
get_reply (wait)
struct pvmwait *wait;
{
    char proc[33];
    int inum;
    int msgtype;
    char buf[MSGBUFSIZE];
    int nbytes;
    int status;

    while (1) {
	rcv (0);
	rcvinfo (0, 0, proc, &inum);
	if (getint (&msgtype) < 0)
	    continue;

	if ((msgtype & MSG_REPLY_MASK) == MSG_SIGNAL) {
	    switch (msgtype) {
	    case MSG_STDOUT:
		if (getint (&nbytes) < 0 || getbytes (buf, nbytes) < 0) {
		    fprintf (stderr, "bad MSG_STDOUT from %d: end of buffer\n",
			     inum);
		    break;
		}
		if (pvmrsh_debug)
		    fprintf (stderr, "<<< MSG_STDOUT (%d, \"%.*s\")",
			     nbytes, nbytes, buf);
		if (child_stdout_hook)
		    (*child_stdout_hook)(proc, inum, buf, nbytes);
		break;
	    case MSG_STDERR:
		if (getint (&nbytes) < 0 || getbytes (buf, nbytes) < 0) {
		    fprintf (stderr, "bad MSG_STDERR from %d: end of buffer\n",
			     inum);
		    break;
		}
		if (pvmrsh_debug)
		    fprintf (stderr, "<<< MSG_STDERR (%d, \"%.*s\")",
			     nbytes, nbytes, buf);
		if (child_stderr_hook)
		    (*child_stderr_hook)(proc, inum, buf, nbytes);
		break;	   
	    case MSG_SRVR_EXIT:
		if (pvmrsh_debug) {
		    fprintf (stderr, "<<< MSG_SRVR_EXIT [%d] (%d)\n",
			     inum, status);
		}
		if (inum < num_servers) {
		    /*
		     * if server dies, any children it has also are
		     * assumed to be dead, so decrement the child count.
		     */
		    if (server_status[inum] == STATUS_RUNNING) {
			/* XXX should call child_exit_hook here? */
			--num_children;
		    }
		    server_status[inum] = STATUS_DIED;
		}
		if (wait) {
		    strncpy (wait->proc, proc, 33);
		    wait->inum = inum;
		    wait->status = 0;
		    return -1;
		}
		break;
	    case MSG_PROC_EXIT:
		if (getint (&status) < 0) {
		    fprintf (stderr,
			     "bad MSG_PROC_EXIT from %d: end of buffer\n",
			     inum);
		    break;
		}
		if (pvmrsh_debug) {
		    fprintf (stderr, "<<< MSG_PROC_EXIT [%d] (%d)\n",
			     inum, status);
		}
		if (inum < num_servers) {
		    /*
		     * sanity check.  don't decrement count of running
		     * children if we've done it already.  This might
		     * happen if we send a quit msg while a child is
		     * running, for instance.
		     */
		    if (server_status[inum] == STATUS_RUNNING)
			--num_children;
		    server_status[inum] = STATUS_UP;
		}
		if (child_exit_hook)
		    (*child_exit_hook) (proc, inum, status);
		if (wait) {
		    strncpy (wait->proc, proc, 33);
		    wait->inum = inum;
		    wait->status = status;
		    return -1;
		}
		break;
	    }
	    continue;
	}
	else {
	    if ((msgtype & MSG_TYPE_MASK) == MSG_OK_REPLY) {
		if (pvmrsh_debug) {
		    char msgbuf[1024];
		    if (msgtype == REPL_OK) {
			*msgbuf = '\0';
			fprintf (stderr, "<<< MSG_OK_REPLY [%d] (%d)\n",
				 inum, msgtype & 0xff);
		    }
		    else {
			getstring (msgbuf);
			fprintf (stderr, "<<< MSG_OK_REPLY [%d] (%d=%s)\n",
				 inum, msgtype & 0xff, msgbuf);
		    }
		}
		return msgtype == REPL_OK ? 0 : msgtype; /* success */
	    }
	    else {
		if (pvmrsh_debug)
		    fprintf (stderr, "<<< MSG %08x [%d]\n", msgtype);
		return msgtype;	/* error */
	    }
	}
    }
}


/*
 * send a command to the remote server using pvm
 * returns: 0 on success
 * otherwise, a universal error code
 */

static int
send_cmd (instno)
int instno;
{
    if (instno >= num_servers)
	return CLNT_ERR (ERR_BADINST);
    return pvm_error_code (snd ("pvmrshd", instno, 0));
}


/***********************************************************************
 *                                                                     *
 *                          callable routines                          *
 *                                                                     *
 ***********************************************************************/

/*
 * start up a pvmrshd server on the named machine
 * if result == 0, call succeeded, and the pvm instance number to be used
 *                 in future calls to cmd_xxxx () routines is placed in
 *                 *instno.
 * otherwise, the universal error code is returned
 */

int
cmd_start (machine, instno)
char *machine;
int *instno;
{
    int x;

    if ((x = initiateM ("pvmrshd", machine)) < 0) {
	/* couldn't start server */
	return pvm_error_code (x);
    }
    if (x >= num_servers) {
	int new_num_servers;
	int i;

#define max(a,b) ((a) > (b) ? (a) : (b))

	if (num_servers == 0) {
	    new_num_servers = max (100, x);
	    server_status = (unsigned char *) malloc (new_num_servers);
	}
	else {
	    new_num_servers = max (num_servers * 2, x);
	    server_status = (unsigned char *)
		realloc (server_status, new_num_servers);
	}
	for (i = num_servers ; i < new_num_servers ; ++i)
	    server_status[i] = '\0';
	num_servers = new_num_servers;
    }
    server_status[x] = STATUS_UP;
    *instno = x;
    return 0;
}


/*
 * create a file on the remote machine
 * returns: 0 if successful
 *          otherwise, a universal error code
 */

int
cmd_create (instno, filename, filesize, modtime)
int instno;
char *filename;
long filesize;
long modtime;
{
    int x;

    if (pvmrsh_debug)
	fprintf (stderr, ">>> CMD_CREATE (%d, %s, %d, %ld)\n",
		 instno, filename, filesize, modtime);
    if ((x = initsend ()) < 0 ||
	(x = putint (CMD_CREATE)) < 0 ||
	(x = putstring (filename)) < 0 ||
	(x = putlongint (filesize)) < 0 ||
	(x = putlongint (modtime)) < 0)
	return pvm_error_code (x);
    if ((x = send_cmd (instno)) != 0)
	return x;
    return (get_reply (0));
}

/*
 * open a file on the remote machine for reading
 * returns: 0 if successful
 *          otherwise, a universal error code
 */

int
cmd_open (instno, filename)
int instno;
char *filename;
{
    int x;

    if (pvmrsh_debug)
	fprintf (stderr, ">>> CMD_OPEN (%d, %s)\n", instno, filename);
    if ((x = initsend ()) < 0 || (x = putint (CMD_OPEN)) < 0 ||
	(x = putstring (filename)) < 0)
	return pvm_error_code (x);
    if ((x = send_cmd (instno)) != 0)
	return x;
    return (get_reply (0));
}

/*
 * stat a file on the remote machine, and return its size and modtime
 * returns: 0 if successful,
 *          otherwise, a universal error code
 */

int
cmd_stat (instno, filename, filesize, modtime)
int instno;
char *filename;
long *filesize;
unsigned long *modtime;
{
    int x;
    
    if (pvmrsh_debug)
	fprintf (stderr, ">>> CMD_STAT (%d, %s)\n", instno, filename);
    if ((x = initsend ()) < 0 ||
	(x = putint (CMD_STAT)) < 0 ||
	(x = putstring (filename)) < 0)
	return pvm_error_code (x);
    if ((x = send_cmd (instno)) != 0)
	return x;
    if ((x = get_reply (0)) != REPL_STAT)
	return x;
    if ((x = getlongint (filesize)) < 0 || (x = getlongint (modtime)) < 0) {
	fprintf (stderr, "bad REPL_STAT from %d: end of buffer\n", instno);
	return pvm_error_code (x);
    }
    if (pvmrsh_debug)
	fprintf (stderr, "<<< REPL_STAT (%d, %ld)\n", *filesize, *modtime);
    return 0;
}

/*
 * write nbytes bytes to the currently open file on the remote machine
 * returns a universal status code
 */

int
cmd_write (instno, buf, nbytes)
int instno;
char *buf;
int nbytes;
{
    int x;

    if (pvmrsh_debug)
	fprintf (stderr, ">>> CMD_WRITE (%d, %d)\n", instno, nbytes);
    if ((x = initsend ()) < 0 || (x = putint (CMD_WRITE)) < 0 ||
	(x = putint (nbytes) < 0) || (x = putbytes (buf, nbytes)) < 0)
	return pvm_error_code (x);
    if ((x = send_cmd (instno)) != 0)
	return x;
    return (get_reply (0));
}

/*
 * read up to nbytes bytes from the currently open file on the remote
 * machine.
 *
 * returns: if return value > 0, the number of bytes read
 *          otherwise, a universal status code
 */

int
cmd_read (instno, buf, nbytes)
int instno;
char *buf;
int *nbytes;
{
    int x;

    if (pvmrsh_debug)
	fprintf (stderr, ">>> CMD_READ (%d, %d)\n", instno, nbytes);
    if ((x = initsend ()) < 0 || (x = putint (CMD_READ)) < 0 ||
	(x = putint (*nbytes)) < 0)
	return pvm_error_code (x);
    if ((x = send_cmd (instno)) < 0)
	return x;
    if ((x = get_reply (0)) != REPL_READ)
	return x;
    if ((x = getint (nbytes)) < 0)
	return pvm_error_code (x);
    if (*nbytes > 0) {
	if ((x = getbytes (buf, *nbytes)) < 0)
	    return pvm_error_code (x);
	if (pvmrsh_debug)
	    fprintf (stderr, "<<< REPL_READ (%d bytes)\n", *nbytes);
	return 0;
    }
    return 0;
}

/*
 * start up a command on the indicated host, passing argument vector
 *
 * returns: 0 on success, universal status code on failure.
 *
 * XXX needs to get a reply from server, in case its fork/exec fails
 * due to resource limitations or whatever.
 */

int
cmd_spawn (instno, argc, argv)
int instno;
int argc;
char **argv;
{
    int i, x;

    if (pvmrsh_debug) {
	fprintf (stderr, ">>> CMD_SPAWN (%d, %d, argv[]= { ", instno, argc);
	for (i = 0; i < argc; ++i)
	    fprintf (stderr, "\"%s\"%s", argv[i],
		     i == argc - 1 ? "" : ", ");
	fprintf (stderr, "} );\n");
    }
    if (instno >= num_servers)
	return CLNT_ERR (ERR_BADINST);
    if ((x = initsend ()) < 0 || (x = putint (CMD_SPAWN)) < 0 ||
	(x = putint (argc)) < 0)
	return pvm_error_code (x);
    for (i = 0; i < argc; ++i)
	if ((x = putstring (argv[i])) < 0)
	    return pvm_error_code (x);
    if ((x = send_cmd (instno)) != 0)
	return x;
    server_status[instno] = STATUS_RUNNING;
    num_children++;
    return 0;
}

/*
 * tell a remote server to go away
 * returns 0 on success, status code on failure.
 */

int
cmd_quit (instno)
int instno;
{
    int x;

    if (pvmrsh_debug)
	fprintf (stderr, ">>> CMD_QUIT (%d)\n", instno);
    if (instno >= num_servers)
	return CLNT_ERR (ERR_BADINST);
    if ((x = initsend ()) < 0 || (x = putint (CMD_QUIT)) < 0)
	return pvm_error_code (x);
    if ((x = send_cmd (instno)) != 0)
	return x;
    if (server_status[instno] == STATUS_RUNNING)
	num_children--;
    server_status[instno] = 0;
    return 0;
}

/*
 * wait for a process to finish
 * return the status code, and fill in inst number of the process
 * that died.
 */

int
cmd_wait (instno)
int *instno;
{
    struct pvmwait w;
    while (num_children > 0)
	if (get_reply (&w) == -1) {
	    *instno = w.inum;
	    return w.status;
	}
    return -1;			/* no more children */
}

void
cmd_exit_hook (proc)
void (*proc)();
{
    child_exit_hook = proc;
}

void
cmd_stdout_hook (proc)
void (*proc)();
{
    child_stdout_hook = proc;
}

void
cmd_stderr_hook (proc)
void (*proc)();
{
    child_stderr_hook = proc;
}
