/*
 *   slush		 SSL remote shell
 *   Copyright (c) 1999 Damien Miller <damien@ibs.com.au>
 *				 All Rights Reserved
 *
 *   Based on stunnel-2.1:
 *   Copyright (c) 1998 Michal Trojnara <mtrojnar@ddc.daewoo.com.pl>
 *				 All Rights Reserved
 *   Author:	   Michal Trojnara  <mtrojnar@ddc.daewoo.com.pl>
 *   SSL support:  Adam Hernik	  <adas@infocentrum.com>
 *				 Pawel Krawczyk   <kravietz@ceti.com.pl>
 *
 *   This program is free software; you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation; either version 2 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program; if not, write to the Free Software
 *   Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

#define BUFFSIZE 8192

#include "config.h"

#include <stdio.h>
#include <errno.h>
#include <signal.h>
#include <string.h>
#include <stdlib.h>
#include <getopt.h>
#include <termios.h>
#include <pwd.h>
#include <netdb.h>
#include <sys/ioctl.h>
#include <sys/stat.h>
#include <sys/time.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>

#ifdef HAVE_UNISTD_H
#include <unistd.h>
#endif
#ifdef HAVE_SYS_SELECT_H
#include <sys/select.h>
#endif

#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/bio.h>
#include <openssl/pem.h>

#include "common.h"

void usage(void);
void restore_terminal(void);
void raw_tty(void);
void save_terminal(void);
void transfer_loop(char *hostname, short port, SSL_CTX *ctx);
void send_control_strings(SSL *ssl);
void parse_user_host_port(int argc, char **argv, int opt_offset, char **user, char **host, int *port);
static char *find_certificate(char *certfile);

static struct termios	orig;
static char 				*username = NULL;
static char 				*hostname = NULL;
#ifdef HAVE_GETOPT_LONG
static struct option long_options[] =
{
	{"login",			1, NULL, 'l'},
	{"certificate",	1, NULL, 'C'},
	{"help",				0, NULL, 'h'},
	{"version",			0, NULL, 'V'},
	{NULL, 0, NULL, 0}
};
#endif /* HAVE_GETOPT_LONG */

int main(int argc, char **argv)
{
	SSL_CTX			*ctx;
	char				error[256];
	char				certfile[256];
	char				*certpath;
	int				c;

	int				port = DEFAULT_SLUSH_PORT;
	struct servent *s;	

	/* Use port defined in /etc/services in preference to default */
	s = getservbyname("slush", "tcp");
	if (s != NULL)
		port = htons(s->s_port);

	certfile[0] = '\0';
	
	/* Parse command-line options */
	while(1)
	{
#ifdef HAVE_GETOPT_LONG
		c = getopt_long (argc, argv, "C:l:hV", long_options, NULL);
#else /* HAVE_GETOPT_LONG */
		c = getopt(argc, argv, "C:l:hV");
#endif /* HAVE_GETOPT_LONG */

		if (c == -1)
			break;

		switch(c)
		{
			case 'C':
				strncpy(certfile, optarg, sizeof(certfile) - 1);
				certfile[sizeof(certfile) - 1] = '\0';
				break;
			case 'l':
				username = strdup(optarg);
				break;
			case 'h':
				usage();
				exit(0);
			case 'v':
				fprintf(stderr, "slush %s", VERSION);
				exit(0);
			default:
				fprintf(stderr, "Invalid commandline options.\n");
				usage();
				exit(1);
		}
	}
	
	/* Make sure that we have been passed at least one, and at */
	/* most two more arguments */
	if ((optind != (argc - 1)) && (optind != (argc - 2)))
	{
		fprintf(stderr, "Invalid commandline options.\n");
		usage();
		exit(1);
	}

	parse_user_host_port(argc, argv, optind, &username, &hostname, &port);
	
	init_ssl();
	
	certpath = find_certificate(certfile);
	ctx = setup_ssl_context(certpath, certpath, 0, CLIENT, error, sizeof(error));
	if (ctx == NULL)
	{
		fprintf(stderr, "%s\n", error);
		exit(1);
	}

	transfer_loop(hostname, port, ctx);

	/* close SSL */
	SSL_CTX_free(ctx);

	return(0);
}

void usage(void)
{
	fprintf(stderr, "slush [OPTIONS] [user@]host [port]\n\n");
	fprintf(stderr, "Options:\n");
	fprintf(stderr, "  --login, -l [account]    Specify account to login to.\n");
	fprintf(stderr, "  --certificate, -C [file] Use this certificate/key instead of default.\n");
	fprintf(stderr, "  --help                   Display this text.\n");
	fprintf(stderr, "  --version                Display program version.\n");
}

void transfer_loop(char *hostname, short port, SSL_CTX *ctx)
{
	int				net_fd;
	SSL 				*ssl;
	char 				buffer[BUFFSIZE];
	fd_set			in;
	fd_set			out;
	int				max_fd;
	int				ssl_fd;
	int				bytes_read;
	const	int		on = 1;
	char				error[256];

	/* Open socket connection to remote host */
	net_fd = connect_to_remote(hostname, port, error, sizeof(error));
	if (net_fd == -1)
	{
		fprintf(stderr, "%s\n", error);
		exit(1);
	}

	/* Turn off Nagle algorithm to maximise interactive performance */
	if (setsockopt(net_fd, IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == -1)
	{
		perror("setsockopt");
		exit(3);
	}

	/* Open SSL on the connection */
	ssl = SSL_new(ctx);
	
	SSL_set_fd(ssl, net_fd);
	
	if(SSL_connect(ssl) <= 0)
	{
		fprintf(stderr, "SSL_connect:");
		ERR_print_errors_fp(stderr);
		exit(3);
	}

	fprintf(stderr, "Connected to %s using %s (%s)\n",
		hostname, SSL_get_version(ssl), SSL_get_cipher(ssl));

	send_control_strings(ssl);

	ssl_fd = SSL_get_fd(ssl);

	save_terminal();
	raw_tty();
	
	FD_ZERO(&in);
	FD_SET(ssl_fd, &in);
	FD_SET(STDIN_FILENO, &in);
	
	max_fd = (ssl_fd > STDIN_FILENO?ssl_fd:STDIN_FILENO) + 1;
	
	while(1)
	{
		out = in;
		
		if(select(max_fd, &out, NULL, NULL, NULL) < 0)
		{
			restore_terminal();
			perror("select");
			exit(3);
		}
		
		if (FD_ISSET(STDIN_FILENO, &out))
		{
			bytes_read = read(STDIN_FILENO, buffer, BUFFSIZE);
			if (bytes_read == -1)
			{
				restore_terminal();
				perror("read");
				exit(3);
			}
			
			if (bytes_read == 0)
				break;
				
			if (SSL_write(ssl, buffer, bytes_read) != bytes_read)
			{
				restore_terminal();
				fprintf(stderr, "SSL_write:");
				ERR_print_errors_fp(stderr);
				exit(3);
			}
		}
		
		if (FD_ISSET(ssl_fd, &out))
		{
			bytes_read = SSL_read(ssl, buffer, BUFFSIZE);
			if (bytes_read < -1)
			{
				restore_terminal();
				fprintf(stderr, "SSL_read:");
				ERR_print_errors_fp(stderr);
				exit(3);
			}
			
			if (bytes_read == 0)
				break;
				
			if (write(STDOUT_FILENO, buffer, bytes_read) != bytes_read)
			{
				restore_terminal();
				perror("write");
				exit(3);
			}
		}
	}
	
	restore_terminal();
	fprintf(stderr, "\nConnection closed.\n");
	
	SSL_free(ssl);
}

void save_terminal(void)
{
	/* Fetch a copy of the terminal flags */
	if (tcgetattr(STDIN_FILENO, &orig) < 0)
	{
		perror("tcgetattr");
		exit(2);
	}
}

void raw_tty(void)
{
	struct termios t;
	
	/* Fetch a copy of the terminal flags */
	if (tcgetattr(STDIN_FILENO, &t) < 0)
	{
		perror("tcgetattr");
		exit(2);
	}

	/* Make raw */
	cfmakeraw(&t);

	/* Set the terminal flags */
	if (tcsetattr(STDIN_FILENO, TCSANOW, &t) < 0)
	{
		perror("tcsetattr");
		exit(2);
	}
}

void restore_terminal(void)
{
	/* Restore the terminal flags */
	if (tcsetattr(STDIN_FILENO, TCSANOW, &orig) < 0)
	{
		perror("tcsetattr");
		exit(2);
	}
}

void send_control_strings(SSL *ssl)
{
	char	buffer[BUFFSIZE];
	int	offset;
	int	c;
	char	*p;
	struct winsize w;
	struct passwd *pw;
	char myname[256];

	offset = 0;
	
	/* Copy terminal type to control buffer */
	p = getenv("TERM");
	if (p != NULL)
	{
		c = snprintf(buffer + offset , sizeof(buffer) - offset, "TERM=%s\n", p);
		if (c == -1)
		{
			fprintf(stderr, "Control strings too long.");
			exit(3);
		}		
		offset += c;
	}
	
	/* Copy X display type to control buffer */
	p = getenv("DISPLAY");
	if (p != NULL)
	{
		if (p[0] == ':')
		{
			if (gethostname(myname, sizeof(myname)) == -1)
			{
				fprintf(stderr, "Could not get local host name: %s", strerror(errno));
				exit(3);
			}		
			c = snprintf(buffer + offset , sizeof(buffer) - offset, "DISPLAY=%s%s\n", myname, p);
		} else
		{
			c = snprintf(buffer + offset , sizeof(buffer) - offset, "DISPLAY=%s\n", p);
		}
		if (c == -1)
		{
			fprintf(stderr, "Control strings too long.");
			exit(3);
		}		
		offset += c;
	}
	
	/* Copy terminal window size to control buffer */
	if (ioctl(STDIN_FILENO, TIOCGWINSZ, &w) == -1)
	{
		fprintf(stderr, "Couldn't read terminal window size: %s", strerror(errno));
		exit(3);
	}
	c = snprintf(buffer + offset , sizeof(buffer) - offset, 
					 "WINSIZE:%hi %hi %hi %hi\n", w.ws_row, w.ws_col, 
					 w.ws_xpixel, w.ws_ypixel);
	if (c == -1)
	{
		fprintf(stderr, "Control strings too long.");
		exit(3);
	}		
	offset += c;

	/* Write username to buffer */
	/* If username not specified, then use current user's name */
	if ((username == NULL) || (username[0] == '\0'))
	{
		pw = getpwuid(getuid());
		if (pw == NULL)
		{
			fprintf(stderr, "Cannot determine username: %s", strerror(errno));
			exit(3);
		}
		c = snprintf(buffer + offset , sizeof(buffer) - offset, 
						 "USER:%s\n", pw->pw_name);
	} else
	{
		c = snprintf(buffer + offset , sizeof(buffer) - offset, 
						 "USER:%s\n", username);
	}
	if (c == -1)
	{
		fprintf(stderr, "Control strings too long.");
		exit(3);
	}		
	offset += c;
	
	/* Write 'end' control word to buffer */
	c = snprintf(buffer + offset , sizeof(buffer) - offset, "END\n");
	if (c == -1)
	{
		fprintf(stderr, "Control strings too long.");
		exit(3);
	}		
	offset += c;

	/* Send buffer */
/*	fprintf(stderr, "Sending:\n%s", buffer); */
	if (SSL_write(ssl, buffer, offset + 1) != (offset + 1))
	{
		restore_terminal();
		fprintf(stderr, "SSL_write:");
		ERR_print_errors_fp(stderr);
		exit(3);
	}
}
	
void parse_user_host_port(int argc, char **argv, int opt_offset, char **user, char **host, int *port)
{
	char				*p;
	static char hostname_l[256];
	static char username_l[32];

	/* Test for username@host syntax */
	p = strchr(argv[optind], '@');
	if (p != NULL)
	{
		/* If present; null terminate string, copy username and */
		/* advance pointer to first char of hostname */
		*p = '\0';
		strncpy(username_l, argv[optind], sizeof(username_l) - 1);
		username_l[sizeof(username_l) - 1] = '\0';
		p++;
	} else
	{
		/* If not present, set pointer to first char of hostname */
		p = argv[optind];
	}

	/* Grab hostname:port */
	strncpy(hostname_l, p, sizeof(hostname_l) - 1);
	hostname_l[sizeof(hostname_l) - 1] = '\0';

	/* Get port (if specified) */
	/* First try "hostname:port" syntax */
	p = strchr(hostname_l, ':');
	if (p != NULL)
	{
		*p = '\0';
		*port = atoi(p + 1);
	} else
	{
		/* If there is another argument, it should be a port number */
		if (optind == (argc - 2))
			*port = atoi(argv[optind + 1]);
	}
	
	if (*port == 0)
	{
		fprintf(stderr, "Invalid port\n");
		usage();
		exit(1);
	}
	
	*user = username_l;
	*host = hostname_l;
}

static char *find_certificate(char *certfile)
{
	static char	certpath[256];
	struct stat	st;
	char			*name;
	struct passwd *pw;
	
	pw = getpwuid(getuid());
	if (pw == NULL)
	{
		fprintf(stderr, "Couln't find password entry: %s\n", strerror(errno));
		exit(2);
	}

	/* If certificate file not specified, use default */
	if (certfile[0] == '\0')
		strcpy(certfile, DEFAULT_CLIENT_CERT);
	
	/* If the user has specified a filename only, prepend default cert path */
	name = strrchr(certfile, '/');
	if (name != NULL)
		strncpy(certpath, certfile, sizeof(certpath) - 1);
	else
		snprintf(certpath, sizeof(certpath), "%s/%s/%s", pw->pw_dir, 
					USER_DIRECTORY, certfile);
	
	certpath[sizeof(certpath) - 1] = '\0';
	
	/* Check existance and perms on certificate file */
	if (stat(certpath, &st) == -1)
	{
		if (errno != ENOENT)
		{
			fprintf(stderr, "Couln't stat cert file %s: %s\n", certpath, strerror(errno));
			exit(2);
		} else
		{
			return(NULL);
		}
	} else
	{
		if(st.st_mode & 7)
			fprintf(stderr, "WARNING: Wrong permissions on %s\n", certpath);

		fprintf(stderr, "Using certificate %s\n", certpath);
		return(certpath);
	}
}

/* End of slush.c */
