/* server_gssapi.c
 *
 * GSSAPI authentication method
 */

/* lsh, an implementation of the ssh protocol
 *
 * Copyright (C) 2003 Simon Josefsson
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

#if HAVE_CONFIG_H
#include "config.h"
#endif

#include "charset.h"
#include "format.h"
#include "parse.h"
#include "ssh.h"
#include "server_userauth.h"
#include "werror.h"
#include "xalloc.h"

#if WITH_GSS

#if WITH_GSS_K5
#include <krb5.h> /* for guserok */
#ifdef HAVE_GSSAPI_H
#include <gssapi.h>
#endif
#ifdef HAVE_GSSAPI_GSSAPI_H
#include <gssapi/gssapi.h>
#endif
#else /* !WITH_GSS_K5 */
#include <gss.h>
#endif /* WITH_GSS_K5 */

#include "server_gssapi.c.x"

static struct lsh_string *
get_status_1 (OM_uint32 code, int type)
{
  OM_uint32 maj_stat, min_stat;
  struct lsh_string *str = NULL;
  gss_buffer_desc msg;
  OM_uint32 msg_ctx = 0;

  do
    {
      maj_stat = gss_display_status (&min_stat, code, type, GSS_C_NULL_OID,
				     &msg_ctx, &msg);
      if (!GSS_ERROR(maj_stat))
	{
	  if (str)
	    str = ssh_format("%flS\n%ls", str, msg.length, (char *) msg.value);
	  else
	    str = ssh_format("%ls", msg.length, (char *) msg.value);
	}
      gss_release_buffer (&min_stat, &msg);
    }
  while (msg_ctx);

  return str;
}

static struct lsh_string *
get_status (OM_uint32 maj_stat, OM_uint32 min_stat)
{
  return ssh_format ("%flS\n%flS", get_status_1 (maj_stat, GSS_C_GSS_CODE),
		     get_status_1 (min_stat, GSS_C_MECH_CODE));
}

static struct packet_handler *
make_gssapi_token_handler(struct gssapi_server_instance *gssapi);

/* GABA:
   (class
     (name userauth_gssapi)
     (super userauth)
     (vars
       (db object user_db)))
*/

/* GABA:
   (class
     (name gssapi_server_instance)
     (super userauth_gssapi)
     (vars
       (user object lsh_string)
       (cont object command_continuation)
       (db object user_db)
       (e object exception_handler)
       (cred . gss_cred_id_t)
       (ctx . gss_ctx_id_t)
       (mech . gss_OID_desc)
       (client . gss_name_t)))
*/


static void
do_gc_gssapi (struct gssapi_server_instance *gssapi)
{
  OM_uint32 maj_stat, min_stat;

  if (gssapi->ctx)
    {
      maj_stat = gss_delete_sec_context (&min_stat, &gssapi->ctx,
					 GSS_C_NO_BUFFER);
      if (GSS_ERROR(maj_stat))
	verbose("GSSAPI error deleting security context: %fS\n",
		get_status(maj_stat, min_stat));
    }
  if (gssapi->cred)
    {
      maj_stat = gss_release_cred (&min_stat, &gssapi->cred);
      if (GSS_ERROR(maj_stat))
	verbose("GSSAPI error deleting security context: %fS\n",
		get_status(maj_stat, min_stat));
    }

  if (gssapi->client)
    {
      maj_stat = gss_release_name(&min_stat, &gssapi->client);
      if (GSS_ERROR(maj_stat))
	verbose("GSSAPI error releasing client name: %fS\n",
		get_status(maj_stat, min_stat));
    }
}

static void
do_authenticate(struct userauth *s,
		struct ssh_connection *connection,
		struct lsh_string *username,
		uint32_t service UNUSED,
		struct simple_buffer *args,
		struct command_continuation *c,
		struct exception_handler *e)
{
  CAST(userauth_gssapi, self, s);
  NEW(gssapi_server_instance, gssapi);
  int number_of_mechanisms;
  OM_uint32 maj_stat, min_stat;
  static const struct exception gssapi_acquire_cred
    = STATIC_EXCEPTION(EXC_USERAUTH,
		       "Cannot acquire credential for any mechanism.");

  gssapi->user = username;
  gssapi->db = self->db;
  gssapi->cont = c;
  gssapi->e = e;
  gssapi->ctx = GSS_C_NO_CONTEXT;
  gssapi->cred = GSS_C_NO_CREDENTIAL;

  if (parse_uint32(args, &number_of_mechanisms))
    {
      int i;

      verbose("Client requests %i GSS mechanism(s).\n", number_of_mechanisms);

      for (i = 0; i < number_of_mechanisms; i++)
	{
	  gss_OID_set desired_mechs = GSS_C_NO_OID_SET;
	  gss_OID_desc tmp;
	  OM_uint32 junk;

	  if (!parse_string(args, &gssapi->mech.length,
			    (const uint8_t**)&gssapi->mech.elements))
	    goto fail;

	  if (gssapi->mech.length < 2 ||
	      ((uint8_t*)gssapi->mech.elements)[0] != 0x06)
	    goto fail;

	  {
	    uint8_t oidlen = (((uint8_t*)gssapi->mech.elements)[1] & 128) ?
	      ((uint8_t*)gssapi->mech.elements)[1] : 1;
	    if (gssapi->mech.length <= 1 + oidlen)
	      goto fail;
	    /* XXX for oids longer than 255 char, we don't check if
	       asn.1 der length is correct. */
	    if (oidlen == 1 &&
		((uint8_t*)gssapi->mech.elements)[1] !=
		gssapi->mech.length - 2)
	      goto fail;
	    tmp.length = gssapi->mech.length - 1 - oidlen;
	    tmp.elements = (char*)gssapi->mech.elements + 1 + oidlen;
	  }

	  verbose("Acquiring GSS credentials for GSS mechanism %S (%i).\n",
		  ssh_format("%lxs", tmp.length, (uint8_t*)tmp.elements), i);

	  maj_stat =  gss_create_empty_oid_set (&min_stat, &desired_mechs);
	  if (GSS_ERROR(maj_stat))
	    {
	      verbose("GSSAPI error creating OID set: %fS\n",
		      get_status(maj_stat, min_stat));
	      continue;
	    }

	  maj_stat = gss_add_oid_set_member (&min_stat, &tmp, &desired_mechs);
	  if (GSS_ERROR(maj_stat))
	    {
	      verbose("GSSAPI error adding OID to set: %fS\n",
		      get_status(maj_stat, min_stat));
	      gss_release_oid_set (&junk, &desired_mechs);
	      continue;
	    }

	  maj_stat = gss_acquire_cred (&min_stat, GSS_C_NO_NAME, 0,
				       desired_mechs, GSS_C_ACCEPT,
				       &gssapi->cred, NULL, NULL);
	  gss_release_oid_set (&junk, &desired_mechs);
	  if (GSS_ERROR(maj_stat))
	    {
	      verbose("GSSAPI error acquiring credential: %fS\n",
		      get_status(maj_stat, min_stat));
	      continue;
	    }

	  verbose("Ready to continue with mechanism %S (%i).\n",
		  ssh_format("%lxs", gssapi->mech.length,
			     (uint8_t*)gssapi->mech.elements), i);

	  if (connection->dispatch[SSH_MSG_USERAUTH_GSSAPI_TOKEN] !=
	      &connection_unimplemented_handler)
	    {
	      verbose("Cleaning up attempted GSS authentication.\n");
	      CAST(gssapi_token_handler, token_handler,
		   connection->dispatch[SSH_MSG_USERAUTH_GSSAPI_TOKEN]);
	      do_gc_gssapi(token_handler->gssapi);
	      KILL(token_handler);
	    }

	  if (connection->dispatch[SSH_MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE]
	      != &connection_unimplemented_handler)
	    {
	      verbose("Cleaning up finished GSS authentication.\n");
	      CAST(gssapi_finish_handler, finish_handler,
		   connection->dispatch
		   [SSH_MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE]);
	      do_gc_gssapi(finish_handler->gssapi);
	      KILL(finish_handler);
	    }
	  connection->dispatch[SSH_MSG_USERAUTH_GSSAPI_TOKEN] =
	    make_gssapi_token_handler (gssapi);
	  EXCEPTION_RAISE(e, make_userauth_special_exception
			  (ssh_format("%c%s", SSH_MSG_USERAUTH_GSSAPI_RESPONSE,
				      gssapi->mech.length,
				      (uint8_t*)gssapi->mech.elements),
			   NULL));
	  return;
	}
    }

  EXCEPTION_RAISE(e, &gssapi_acquire_cred);
  return;

 fail:
  PROTOCOL_ERROR(e, "Invalid gssapi USERAUTH message.");
}

struct userauth *
make_userauth_gssapi(struct user_db *db)
{
  NEW(userauth_gssapi, self);
  self->super.authenticate = do_authenticate;
  self->db = db;

  return &self->super;
}

/* GABA:
   (class
     (name gssapi_finish_handler)
     (super packet_handler)
     (vars
       (gssapi object gssapi_server_instance)))
*/

static int
gss_userok (gss_buffer_t client_name, const char *name)
{
#if WITH_GSS_K5
  int rc = -1;
  krb5_principal p;
  krb5_context kcontext;

  krb5_init_context (&kcontext);

  if (krb5_parse_name (kcontext, client_name->value, &p) != 0)
    return -1;
  if (krb5_kuserok (kcontext, p, name))
    rc = 0;
  else
    rc = 1;
  krb5_free_principal (kcontext, p);
  return rc;
#else
  return (strlen(name) == client_name->length &&
	  memcmp(name, client_name->value, client_name->length) == 0) ? 0 : 1;
#endif
}

static void
do_handle_gssapi_finish(struct packet_handler *s,
			struct ssh_connection *connection,
			struct lsh_string *packet UNUSED)
{
  CAST(gssapi_finish_handler, self, s);
  OM_uint32 maj_stat, min_stat;
  gss_buffer_desc client_name;

  verbose("Finishing GSS.\n");

  connection->dispatch[SSH_MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE] =
    &connection_unimplemented_handler;

  /* Serialize handling of userauth requests */
  connection_lock(connection);

  /* Find authenticated user. */
  maj_stat = gss_display_name (&min_stat, self->gssapi->client,
			       &client_name, NULL);
  if (GSS_ERROR(maj_stat))
    {
      static const struct exception gssapi_display_name
	= STATIC_EXCEPTION(EXC_USERAUTH,
			   "Cannot extract username from GSS.");

      verbose("GSSAPI error getting client name: %fS\n",
	      get_status(maj_stat, min_stat));

      EXCEPTION_RAISE(self->gssapi->e, &gssapi_display_name);
      goto done;
    }

  /* Check authorization: is GSS user authorized to log on as user? */
  if (gss_userok (&client_name, lsh_get_cstring(self->gssapi->user)))
    {
      static const struct exception gssapi_not_auth
	= STATIC_EXCEPTION(EXC_USERAUTH,
			   "GSS user not authorized to log on.");

      verbose("GSS user %s not authorized to log on as %s.\n",
	      client_name.length, (char*)client_name.value,
	      self->gssapi->user);

      EXCEPTION_RAISE(self->gssapi->e, &gssapi_not_auth);
      goto done;
    }

  connection->user = USER_LOOKUP(self->gssapi->db, self->gssapi->user, 1);
  if (connection->user)
    verbose("GSS user %s authorized to log on as %S.\n", client_name.length,
	    (char*)client_name.value, self->gssapi->user);
  else
    verbose("GSS user %s requested unknown user %S.\n", client_name.length,
	    (char*)client_name.value, self->gssapi->user);

  maj_stat = gss_release_buffer(&min_stat, &client_name);
  if (GSS_ERROR(maj_stat))
    verbose("GSSAPI error releasing client name: %fS\n",
	    get_status(maj_stat, min_stat));

  if (!connection->user)
    {
      static const struct exception no_such_user
	= STATIC_EXCEPTION(EXC_USERAUTH, "No such user");
      EXCEPTION_RAISE(self->gssapi->e, &no_such_user);
      goto done;
    }

  COMMAND_RETURN(self->gssapi->cont, connection->user);

 done:
  do_gc_gssapi (self->gssapi);
}

static struct packet_handler *
make_gssapi_finish_handler(struct gssapi_server_instance *gssapi)
{
  NEW(gssapi_finish_handler, self);
  self->super.handler = do_handle_gssapi_finish;
  self->gssapi = gssapi;

  return &self->super;
}

/* GABA:
   (class
     (name gssapi_token_handler)
     (super packet_handler)
     (vars
       (gssapi object gssapi_server_instance)))
*/

static void
do_handle_gssapi_token(struct packet_handler *s,
		       struct ssh_connection *connection,
		       struct lsh_string *packet)
{
  CAST(gssapi_token_handler, self, s);
  OM_uint32 maj_stat, min_stat;
  OM_uint32 retflags;
  gss_buffer_desc inbuf, outbuf;

  verbose("Received GSS token.\n");

  /* Serialize handling of userauth requests */
  connection_lock(connection);

  /* 1 for the SSH message code and 4 for the uint32_t length */
#define GSS_TOKEN_OFFSET (1 + 4)
  inbuf.value = packet->data + GSS_TOKEN_OFFSET;
  inbuf.length = packet->length - GSS_TOKEN_OFFSET;
  maj_stat = gss_accept_sec_context (&min_stat,
				     &self->gssapi->ctx,
				     self->gssapi->cred,
				     &inbuf,
				     GSS_C_NO_CHANNEL_BINDINGS,
				     &self->gssapi->client,
				     NULL,
				     &outbuf,
				     &retflags,
				     NULL,
				     NULL);
  if (maj_stat != GSS_S_COMPLETE && maj_stat != GSS_S_CONTINUE_NEEDED)
    {
      static const struct exception gssapi_accept_sec_context
	= STATIC_EXCEPTION(EXC_USERAUTH, "Authentication failed.");

      verbose("GSSAPI error accept_sec_context: %fS\n",
	      get_status(maj_stat, min_stat));

      connection->dispatch[SSH_MSG_USERAUTH_GSSAPI_TOKEN] =
	&connection_unimplemented_handler;

      EXCEPTION_RAISE(self->gssapi->e, make_userauth_special_exception
		      (ssh_format("%c%li%li%S%S",
				  SSH_MSG_USERAUTH_GSSAPI_ERROR,
				  maj_stat, min_stat,
				  get_status(maj_stat, min_stat),
				  /* XXX GSS is multilingual, but how
				     do we get the RFC 1766 tag it uses? */
				  ssh_format("en")),
		       NULL));
      /* XXX ERRTOK */
      /* Serialize handling of userauth requests */
      connection_lock(connection);
      EXCEPTION_RAISE(self->gssapi->e, &gssapi_accept_sec_context);
      do_gc_gssapi (self->gssapi);
      return;
    }

  if (maj_stat == GSS_S_COMPLETE)
    {
      verbose("Preparing to finish GSS authentication.\n");
      connection->dispatch[SSH_MSG_USERAUTH_GSSAPI_TOKEN] =
	&connection_unimplemented_handler;
      connection->dispatch[SSH_MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE] =
	make_gssapi_finish_handler(self->gssapi);
    }

  verbose("Sending GSS token.\n");
  EXCEPTION_RAISE(self->gssapi->e, make_userauth_special_exception
		  (ssh_format("%c%s", SSH_MSG_USERAUTH_GSSAPI_TOKEN,
			      outbuf.length, outbuf.value),
		   NULL));
}

static struct packet_handler *
make_gssapi_token_handler(struct gssapi_server_instance *gssapi)
{
  NEW(gssapi_token_handler, self);
  self->super.handler = do_handle_gssapi_token;
  self->gssapi = gssapi;

  return &self->super;
}
#endif /* WITH_GSS */
