/* vim:tw=78:ts=8:sw=4:set ft=c:  */
/*
    Copyright (C) 2006-2009 Ben Kibbey <bjk@luxsci.net>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02110-1301  USA
*/
#include <stdlib.h>
#include <unistd.h>
#include <fcntl.h>
#include <errno.h>
#include <err.h>
#include <pwd.h>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <ctype.h>

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#ifdef WITH_LIBPTH
#include <pth.h>
#endif

#include "misc.h"
#include "ssh.h"

static gpg_error_t ssh_connect_finalize(pwm_t *pwm);

static void ssh_deinit(pwmd_tcp_conn_t *conn)
{
    if (!conn)
	return;

    if (conn->channel) {
        libssh2_channel_close(conn->channel);
        libssh2_channel_free(conn->channel);
    }

    if (conn->session) {
	libssh2_session_disconnect(conn->session, N_("libpwmd saying bye!"));
	libssh2_session_free(conn->session);
    }

    conn->session = NULL;
    conn->channel = NULL;
    _free_ssh_conn(conn);
}

static int read_hook(assuan_context_t ctx, assuan_fd_t fd, void *data,
	size_t len, ssize_t *ret)
{
    pwm_t *pwm = assuan_get_pointer(ctx);

    if (!pwm || !pwm->tcp_conn)
#ifdef WITH_LIBPTH
	*ret = pth_recv((int)fd, data, len, 0);
#else
	*ret = recv((int)fd, data, len, 0);
#endif
    else
	*ret = libssh2_channel_read(pwm->tcp_conn->channel, data, len);

    return *ret >= 0 ? 1 : 0;
}

static int write_hook(assuan_context_t ctx, assuan_fd_t fd, const void *data,
	size_t len, ssize_t *ret)
{
    pwm_t *pwm = assuan_get_pointer(ctx);

    if (!pwm || !pwm->tcp_conn)
#ifdef WITH_LIBPTH
	*ret = pth_send((int)fd, data, len, 0);
#else
	*ret = send((int)fd, data, len, 0);
#endif
    else
	*ret = libssh2_channel_write(pwm->tcp_conn->channel, data, len);

    return *ret >= 0 ? 1 : 0;
}

void _free_ssh_conn(pwmd_tcp_conn_t *conn)
{
    if (!conn)
	return;

    if (conn->username) {
	pwmd_free(conn->username);
	conn->username = NULL;
    }

    if (conn->known_hosts) {
	pwmd_free(conn->known_hosts);
	conn->known_hosts = NULL;
    }

    if (conn->identity) {
	pwmd_free(conn->identity);
	conn->identity = NULL;
    }

    if (conn->identity_pub) {
	pwmd_free(conn->identity_pub);
	conn->identity_pub = NULL;
    }

    if (conn->host) {
	pwmd_free(conn->host);
	conn->host = NULL;
    }

    if (conn->hostkey) {
	pwmd_free(conn->hostkey);
	conn->hostkey = NULL;
    }

    if (conn->chan) {
	ares_destroy(conn->chan);
	conn->chan = NULL;
    }

    if (conn->he) {
	ares_free_hostent(conn->he);
	conn->he = NULL;
    }

    if (!conn->session && conn->fd >= 0) {
	close(conn->fd);
	conn->fd = -1;
    }

    if (conn->session)
	ssh_deinit(conn);
    else
	pwmd_free(conn);
}

/* Only called from libassuan after the BYE command. */
static void ssh_assuan_deinit(assuan_context_t ctx)
{
    pwm_t *pwm = assuan_get_pointer(ctx);

    if (pwm->tcp_conn) {
	pwm->tcp_conn->fd = -1;
	ssh_deinit(pwm->tcp_conn);
	pwm->tcp_conn = NULL;
    }
}

/*
 * Sets common options from both pwmd_ssh_connect() and
 * pwmd_ssh_connect_async().
 */
static gpg_error_t init_tcp_conn(pwmd_tcp_conn_t **dst, const char *host,
	int port, const char *identity, const char *user,
	const char *known_hosts, int get)
{
    pwmd_tcp_conn_t *conn;
    gpg_error_t rc = 0;
    char *pwbuf = NULL;

    if (get) {
	if (!host || !*host)
	    return GPG_ERR_INV_ARG;
    }
    else {
	if (!host || !*host || !identity || !*identity || !known_hosts ||
		!*known_hosts)
	    return GPG_ERR_INV_ARG;
    }

    conn = pwmd_calloc(1, sizeof(pwmd_tcp_conn_t));

    if (!conn)
	return gpg_error_from_errno(ENOMEM);

    conn->port = port == -1 ? 22 : port;
    conn->host = pwmd_strdup(host);

    if (!conn->host) {
	rc = gpg_error_from_errno(ENOMEM);
	goto fail;
    }

    if (!get) {
	struct passwd pw;

	pwbuf = _getpwuid(&pw);

	if (!pwbuf) {
	    rc = gpg_error_from_errno(errno);
	    goto fail;
	}

	conn->username = pwmd_strdup(user ? user : pw.pw_name);

	if (!conn->username) {
	    rc = gpg_error_from_errno(ENOMEM);
	    goto fail;
	}

	conn->identity = _expand_homedir((char *)identity, &pw);

	if (!conn->identity) {
	    rc = gpg_error_from_errno(ENOMEM);
	    goto fail;
	}

	conn->identity_pub = pwmd_strdup_printf("%s.pub", conn->identity);

	if (!conn->identity_pub) {
	    rc = gpg_error_from_errno(ENOMEM);
	    goto fail;
	}

	conn->known_hosts = _expand_homedir((char *)known_hosts, &pw);

	if (!conn->known_hosts) {
	    rc = gpg_error_from_errno(ENOMEM);
	    goto fail;
	}

	pwmd_free(pwbuf);
    }

    *dst = conn;
    return 0;

fail:
    if (pwbuf)
	pwmd_free(pwbuf);

    _free_ssh_conn(conn);
    return rc;
}

static gpg_error_t do_connect(pwm_t *pwm, int prot, void *addr)
{
    struct sockaddr_in their_addr;

    pwm->tcp_conn->fd = socket(prot, SOCK_STREAM, 0);

    if (pwm->tcp_conn->fd == -1)
	return gpg_error_from_syserror();

    if (pwm->tcp_conn->async)
	fcntl(pwm->tcp_conn->fd, F_SETFL, O_NONBLOCK);

    pwm->cmd = ASYNC_CMD_CONNECT;
    their_addr.sin_family = prot;
    their_addr.sin_port = htons(pwm->tcp_conn->port);
    their_addr.sin_addr = *((struct in_addr *)addr);
    memset(their_addr.sin_zero, '\0', sizeof their_addr.sin_zero);

#ifdef WITH_LIBPTH
    if (pth_connect(pwm->tcp_conn->fd, (struct sockaddr *)&their_addr,
		sizeof(their_addr)) == -1)
#else
    if (connect(pwm->tcp_conn->fd, (struct sockaddr *)&their_addr,
		sizeof(their_addr)) == -1)
#endif
	return gpg_error_from_syserror();

    return 0;
}

static gpg_error_t ares_error_to_pwmd(int status)
{
    if (status != ARES_SUCCESS)
	warnx("%s", ares_strerror(status));

    switch (status) {
	case ARES_ENODATA:
	case ARES_EFORMERR:
	case ARES_ENOTFOUND:
	    return GPG_ERR_UNKNOWN_HOST;
	case ARES_ESERVFAIL:
	     return GPG_ERR_EHOSTDOWN;
	case ARES_ETIMEOUT:
	    return GPG_ERR_TIMEOUT;
	case ARES_ENOMEM:
	    return gpg_error_from_errno(ENOMEM);
	case ARES_ECONNREFUSED:
	    return GPG_ERR_ECONNREFUSED;
	default:
	    /* FIXME ??? */
	    return GPG_ERR_EHOSTUNREACH;
    }

    return ARES_SUCCESS;
}

static void dns_resolve_cb(void *arg, int status, int timeouts,
	unsigned char *abuf, int alen)
{
    pwm_t *pwm = arg;
    int rc;
    struct hostent *he;

    if (status == ARES_EDESTRUCTION)
	return;

    if (status != ARES_SUCCESS) {
	pwm->tcp_conn->rc = ares_error_to_pwmd(status);
	return;
    }

    /* Check for an IPv6 address first. */
    if (pwm->prot == PWMD_IP_ANY || pwm->prot == PWMD_IPV6)
	rc = ares_parse_aaaa_reply(abuf, alen, &he, NULL, NULL);
    else
	rc = ares_parse_a_reply(abuf, alen, &he, NULL, NULL);

    if (rc != ARES_SUCCESS) {
	if (pwm->prot != PWMD_IP_ANY || rc != ARES_ENODATA) {
	    pwm->tcp_conn->rc = ares_error_to_pwmd(status);
	    return;
	}

	rc = ares_parse_a_reply(abuf, alen, &he, NULL, NULL);

	if (rc != ARES_SUCCESS) {
	    pwm->tcp_conn->rc = ares_error_to_pwmd(status);
	    return;
	}
    }

    pwm->tcp_conn->he = he;
    pwm->tcp_conn->rc = do_connect(pwm, he->h_addrtype, he->h_addr);
}

gpg_error_t _do_pwmd_ssh_connect_async(pwm_t *pwm, const char *host,
	int port, const char *identity, const char *user,
	const char *known_hosts, pwmd_async_cmd_t which)
{
    pwmd_tcp_conn_t *conn;
    gpg_error_t rc;

    if (!pwm)
	return GPG_ERR_INV_ARG;

    if (pwm->cmd != ASYNC_CMD_NONE)
	return GPG_ERR_ASS_NESTED_COMMANDS;

    rc = init_tcp_conn(&conn, host, port, identity, user, known_hosts,
	    which == ASYNC_CMD_HOSTKEY ? 1 : 0);

    if (rc)
	return rc;

    conn->async = 1;
    pwm->tcp_conn = conn;
    pwm->tcp_conn->cmd = which;
    pwm->cmd = ASYNC_CMD_DNS;
    pwm->state = ASYNC_PROCESS;
    ares_init(&pwm->tcp_conn->chan);
    ares_query(pwm->tcp_conn->chan, pwm->tcp_conn->host, ns_c_any, ns_t_any,
	    dns_resolve_cb, pwm);
    return 0;
}

static void *ssh_malloc(size_t size, void **data)
{
    return pwmd_malloc(size);
}

static void ssh_free(void *ptr, void **data)
{
    pwmd_free(ptr);
}

static void *ssh_realloc(void *ptr, size_t size, void **data)
{
    return pwmd_realloc(ptr, size);
}

static int verify_host_key(pwm_t *pwm)
{
    FILE *fp = fopen(pwm->tcp_conn->known_hosts, "r");
    char *buf, *p;

    if (!fp)
	return 1;

    buf = pwmd_malloc(LINE_MAX);

    if (!buf)
	goto fail;

    while ((p = fgets(buf, LINE_MAX, fp))) {
	if (*p == '#' || isspace(*p))
	    continue;

	if (p[strlen(p)-1] == '\n')
	    p[strlen(p)-1] = 0;

	if (!strcmp(buf, pwm->tcp_conn->hostkey))
	    goto done;
    }

fail:
    if (buf)
	pwmd_free(buf);

    fclose(fp);
    return 1;

done:
    pwmd_free(buf);
    fclose(fp);
    return 0;
}

gpg_error_t _setup_ssh_auth(pwm_t *pwm)
{
    int n;
    
    pwm->tcp_conn->state = SSH_AUTH;
    n = libssh2_userauth_publickey_fromfile(pwm->tcp_conn->session,
		pwm->tcp_conn->username, pwm->tcp_conn->identity_pub,
		pwm->tcp_conn->identity, NULL);
    
    if (n == LIBSSH2_ERROR_EAGAIN)
	return GPG_ERR_EAGAIN;
    else if (n) {
	_free_ssh_conn(pwm->tcp_conn);
	pwm->tcp_conn = NULL;
	return GPG_ERR_BAD_SECKEY;
    }

    return _setup_ssh_channel(pwm);
}

gpg_error_t _setup_ssh_authlist(pwm_t *pwm)
{
    char *userauth;
    int n;

    pwm->tcp_conn->state = SSH_AUTHLIST;
    userauth = libssh2_userauth_list(pwm->tcp_conn->session,
	    pwm->tcp_conn->username, strlen(pwm->tcp_conn->username));
    n = libssh2_session_last_errno(pwm->tcp_conn->session);

    if (!userauth && n == LIBSSH2_ERROR_EAGAIN)
	return GPG_ERR_EAGAIN;

    if (!userauth || !strstr(userauth, "publickey")) {
	_free_ssh_conn(pwm->tcp_conn);
	pwm->tcp_conn = NULL;
	return GPG_ERR_BAD_PIN_METHOD;
    }

    return _setup_ssh_auth(pwm);
}

static gpg_error_t verify_hostkey(pwm_t *pwm)
{
    const char *fp = libssh2_hostkey_hash(pwm->tcp_conn->session,
	    LIBSSH2_HOSTKEY_HASH_SHA1);

    pwm->tcp_conn->hostkey = _to_hex(fp, 20);

    if (!pwm->tcp_conn->hostkey)
	return gpg_error_from_errno(ENOMEM);

    if (pwm->tcp_conn->cmd == ASYNC_CMD_HOSTKEY) {
	pwm->result = pwmd_strdup(pwm->tcp_conn->hostkey);

	if (!pwm->result)
	    return gpg_error_from_errno(ENOMEM);

	return 0;
    }

    if (!fp || verify_host_key(pwm))
	return GPG_ERR_BAD_CERT;

    return _setup_ssh_authlist(pwm);
}

gpg_error_t _setup_ssh_channel(pwm_t *pwm)
{
    int n;
    gpg_error_t rc = 0;

    pwm->tcp_conn->state = SSH_CHANNEL;
    libssh2_session_set_blocking(pwm->tcp_conn->session, 1);
    pwm->tcp_conn->channel =
	libssh2_channel_open_session(pwm->tcp_conn->session);
    n = libssh2_session_last_errno(pwm->tcp_conn->session);

    if (!pwm->tcp_conn->channel && n == LIBSSH2_ERROR_EAGAIN)
	return GPG_ERR_EAGAIN;

    if (!pwm->tcp_conn->channel) {
	rc = GPG_ERR_ASSUAN_SERVER_FAULT;
	_free_ssh_conn(pwm->tcp_conn);
	pwm->tcp_conn = NULL;
	return rc;
    }

    return _setup_ssh_shell(pwm);
}

gpg_error_t _setup_ssh_shell(pwm_t *pwm)
{
    int n;
    gpg_error_t rc;
    
    pwm->tcp_conn->state = SSH_SHELL;
    n = libssh2_channel_shell(pwm->tcp_conn->channel);

    if (n == LIBSSH2_ERROR_EAGAIN)
	return GPG_ERR_EAGAIN;
    else if (n) {
	rc = GPG_ERR_ASSUAN_SERVER_FAULT;
	_free_ssh_conn(pwm->tcp_conn);
	pwm->tcp_conn = NULL;
	return rc;
    }

    return ssh_connect_finalize(pwm);
}

static gpg_error_t ssh_connect_finalize(pwm_t *pwm)
{
    gpg_error_t rc;
    assuan_context_t ctx;
    struct assuan_io_hooks io_hooks = {read_hook, write_hook};

    assuan_set_io_hooks(&io_hooks);
    rc = assuan_socket_connect_fd(&ctx, pwm->tcp_conn->fd, 0, pwm);

    if (rc)
	goto fail;

    assuan_set_finish_handler(ctx, ssh_assuan_deinit);
    pwm->ctx = ctx;
    rc = _connect_finalize(pwm);

    if (rc)
	goto fail;

    return 0;

fail:
    _free_ssh_conn(pwm->tcp_conn);
    pwm->tcp_conn = NULL;
    return gpg_err_code(rc);
}

gpg_error_t _setup_ssh_init(pwm_t *pwm)
{
    int n;
    
    pwm->tcp_conn->state = SSH_INIT;
    n = libssh2_session_startup(pwm->tcp_conn->session, pwm->tcp_conn->fd);

    if (n == LIBSSH2_ERROR_EAGAIN)
	return GPG_ERR_EAGAIN;
    else if (n) {
	_free_ssh_conn(pwm->tcp_conn);
	pwm->tcp_conn = NULL;
	return GPG_ERR_ASSUAN_SERVER_FAULT;
    }

    return verify_hostkey(pwm);
}

gpg_error_t _setup_ssh_session(pwm_t *pwm)
{
    gpg_error_t rc;

    pwm->tcp_conn->session = libssh2_session_init_ex(ssh_malloc, ssh_free,
	    ssh_realloc, NULL);

    if (!pwm->tcp_conn->session) {
	rc = gpg_error_from_errno(ENOMEM);
	goto fail;
    }

    libssh2_session_set_blocking(pwm->tcp_conn->session, !pwm->tcp_conn->async);
    return _setup_ssh_init(pwm);

fail:
    _free_ssh_conn(pwm->tcp_conn);
    pwm->tcp_conn = NULL;
    return gpg_err_code(rc);
}

gpg_error_t _do_pwmd_ssh_connect(pwm_t *pwm, const char *host, int port,
	const char *identity, const char *user, const char *known_hosts, int get)
{
    pwmd_tcp_conn_t *conn;
    gpg_error_t rc;

    if (!pwm)
	return GPG_ERR_INV_ARG;

    if (pwm->cmd != ASYNC_CMD_NONE)
	return GPG_ERR_INV_STATE;

    rc = init_tcp_conn(&conn, host, port, identity, user, known_hosts, get);

    if (rc)
	return rc;

    pwm->tcp_conn = conn;
    pwm->tcp_conn->cmd = get ? ASYNC_CMD_HOSTKEY : ASYNC_CMD_NONE;
    pwm->cmd = ASYNC_CMD_DNS;
    ares_init(&pwm->tcp_conn->chan);
    ares_query(pwm->tcp_conn->chan, pwm->tcp_conn->host, ns_c_any, ns_t_any,
	    dns_resolve_cb, pwm);

    /* dns_resolve_cb() may have already been called. */
    if (pwm->tcp_conn->rc) {
	rc = pwm->tcp_conn->rc;
	goto fail;
    }

    /*
     * Fake a blocking DNS lookup. libcares does a better job than
     * getaddrinfo().
     */
    do {
	fd_set rfds, wfds;
	int n;
	struct timeval tv;

	FD_ZERO(&rfds);
	FD_ZERO(&wfds);
	n = ares_fds(pwm->tcp_conn->chan, &rfds, &wfds);
	ares_timeout(pwm->tcp_conn->chan, NULL, &tv);
#ifdef WITH_LIBPTH
	n = pth_select(n, &rfds, &wfds, NULL, &tv);
#else
	n = select(n, &rfds, &wfds, NULL, &tv);
#endif

	if (n == -1) {
	    rc = gpg_error_from_syserror();
	    goto fail;
	}
	else if (n == 0) {
	    rc = GPG_ERR_TIMEOUT;
	    goto fail;
	}

	ares_process(pwm->tcp_conn->chan, &rfds, &wfds);

	if (pwm->tcp_conn->rc)
	    break;
    } while (pwm->cmd == ASYNC_CMD_DNS);

    if (pwm->tcp_conn->rc) {
	rc = pwm->tcp_conn->rc;
	goto fail;
    }

    rc = _setup_ssh_session(pwm);
    pwm->cmd = ASYNC_CMD_NONE;

    if (pwm->tcp_conn)
	pwm->tcp_conn->cmd = ASYNC_CMD_NONE;

fail:
    return rc;
}

/*
 * ssh://[username@]hostname[:port],identity,known_hosts
 *
 * Any missing parameters are checked for in init_tcp_conn().
 */
gpg_error_t _parse_ssh_url(char *str, char **host, int *port, char **user,
	char **identity, char **known_hosts)
{
    char *p;
    char *t;
    int len;

    *host = *user = *identity = *known_hosts = NULL;
    *port = -1;
    p = strrchr(str, '@');

    if (p) {
	len = strlen(str)-strlen(p)+1;
	*user = pwmd_malloc(len);

	if (!*user)
	    return gpg_error_from_errno(ENOMEM);

	snprintf(*user, len, "%s", str);
	p++;
    }
    else
	p = str;

    t = strchr(p, ':');

    if (t) {
	len = strlen(p)-strlen(t)+1;
	*host = pwmd_malloc(len);

	if (!*host)
	    return gpg_error_from_errno(ENOMEM);

	snprintf(*host, len, "%s", p);
	t++;
	*port = atoi(t);

	while (*t && isdigit(*t))
	    t++;

	p = t;
    }

    t = strchr(p, ',');

    if (t) {
	char *t2;

	if (!*host) {
	    len = strlen(p)-strlen(t)+1;
	    *host = pwmd_malloc(len);

	    if (!*host)
		return gpg_error_from_errno(ENOMEM);

	    snprintf(*host, len, "%s", p);
	}

	t++;
	t2 = strchr(t, ',');

	if (t2)
	    len = strlen(t)-strlen(t2)+1;
	else
	    len = strlen(t)+1;

	*identity = pwmd_malloc(len);

	if (!*identity)
	    return gpg_error_from_errno(ENOMEM);

	snprintf(*identity, len, "%s", t);

	if (t2) {
	    t2++;
	    t += len+1;
	    len = strlen(t2)+1;
	    *known_hosts = pwmd_malloc(len);

	    if (!*known_hosts)
		return gpg_error_from_errno(ENOMEM);

	    snprintf(*known_hosts, len, "%s", t2);
	}
    }
    else {
	if (!*host) {
	    len = strlen(p)+1;
	    *host = pwmd_malloc(len);

	    if (!*host)
		return gpg_error_from_errno(ENOMEM);

	    snprintf(*host, len, "%s", p);
	}
    }

    return 0;
}

void _ssh_disconnect(pwm_t *pwm)
{
    ssh_deinit(pwm->tcp_conn);
    pwm->tcp_conn = NULL;
}
