#!/usr/bin/env python
# ssh-installkeys -- prepare a remote account for ssh access
"""This script tries to export ssh public keys to a specified site.  It will
walk you through generating key pairs if it doesn't find any to export.
It handles all the fiddly details like making sure local and remote
permissions are correct, and tells you what it's doing if it has to change
anything.

Options:
-h        Print usage summary and exit
-c        Check local and remote configuration only, don't try to fix problems.
-d        Delete remote .ssh configuration.
-v        Report all commands sent to host and output received.
-p <port> Connect on a particular port in place of 22

Author: Eric S. Raymond <esr@thyrsus.com>"""

# $Id: ssh-installkeys,v 1.18 2004/06/21 12:53:48 esr Exp $

import os, sys, getopt, string, getpass, tempfile

###############################################################################
#
# pexpect 0.94 is inlined here because it's not in the Python library yet.
# Some comments and trivial unused methods have been removed.
#

"""
Pexpect is a Python module for spawning child applications;
controlling them; and responding to expected patterns in their output.
Pexpect can be used for automating interactive applications such as
ssh, ftp, passwd, telnet, etc. It can be used to a automate setup scripts
for duplicating software package installations on different servers.
It can be used for automated software testing. Pexpect is in the spirit of
Don Libes' Expect, but Pexpect is pure Python. Other Expect-like
modules for Python require TCL and Expect or require C extensions to
be compiled. Pexpect does not use C, Expect, or TCL extensions. It
should work on any platform that supports the standard Python pty
module. The Pexpect interface focuses on ease of use so that simple
tasks are easy.

Pexpect is Open Source, free, and all that stuff.
License: Python Software Foundation License
         http://www.opensource.org/licenses/PythonSoftFoundation.html

Noah Spurrier
"""
import os, sys
import select
# import signal
import traceback
import re
import struct
import resource
from types import *
try:
    import pty
    import tty
    import termios
    import fcntl
except ImportError, e:
    raise ImportError, str(e) + """
A critical module was not found. Probably this OS does not support it.
Currently pexpect is intended for UNIX operating systems (including OS-X)."""

__version__ = '0.94'
__revision__ = '$Revision: 1.18 $'
__all__ = ['ExceptionPexpect', 'EOF', 'TIMEOUT', 'spawn',
    '__version__', '__revision__']

# Exception classes used by this module.
class ExceptionPexpect(Exception):
    """Base class for all exceptions raised by this module."""
    def __init__(self, value):
        self.value = value
    def __str__(self):
        return `self.value`
class EOF(ExceptionPexpect):
    """Raised when EOF is read from a child."""
class TIMEOUT(ExceptionPexpect):
    """Raised when a read time exceeds the timeout."""
##class MAXBUFFER(ExceptionPexpect):
##    """Raised when a scan buffer fills before matching an expected pattern."""

class spawn:
    """This is the main class interface for Pexpect. Use this class to
    start and control child applications.
    """

    def __init__(self, command, args=None, timeout=30):
        """This is the constructor. The command parameter is a string
        that includes the path and any arguments to the command. For example:
            p = pexpect.spawn ('/usr/bin/ftp')
            p = pexpect.spawn ('/usr/bin/ssh some@host.com')
            p = pexpect.spawn ('/bin/ls -latr /tmp')
        You may also construct it with a list of arguments like so:
            p = pexpect.spawn ('/usr/bin/ftp', [])
            p = pexpect.spawn ('/usr/bin/ssh', ['some@host.com'])
            p = pexpect.spawn ('/bin/ls', ['-latr', '/tmp'])
        After this the child application will be created and
        will be ready for action. See expect() and send()/sendline().
        """

        self.STDIN_FILENO = sys.stdin.fileno()
        self.STDOUT_FILENO = sys.stdout.fileno()
        self.STDERR_FILENO = sys.stderr.fileno()

        ### IMPLEMENT THIS FEATURE!!!
        #self.maxbuffersize = 10000
        # anything before maxsearchsize point is preserved, but not searched.
        #self.maxsearchsize = 1000

        self.timeout = timeout
        self.child_fd = -1
        self.pid = None
        self.before = None
        self.after = None
        self.match = None

        if args is None:
            self.args = split_command_line(command)
            self.command = self.args[0]
        else:
            args.insert (0, command)
            self.args = args
            self.command = command

        self.__spawn()

    def __del__(self):
        """This makes sure that no system resources are left open.
        Python only garbage collects Python objects. Since OS file descriptors
        are not Python objects, so they must be handled manually.
        """
	self.close()

    def __spawn(self):
        """This starts the given command in a child process. This does
        all the fork/exec type of stuff for a pty. This is called by
        __init__. The args parameter is a list, command is a string.
        """
        # The pid and child_fd of this object get set by this method.
        # Note that it is difficult for this method to fail.
        # You cannot detect if the child process cannot start.
        # So the only way you can tell if the child process started
        # or not is to try to read from the file descriptor. If you get
        # EOF immediately then it means that the child is already dead.
        # That may not necessarily be bad, because you may spawn a child
        # that performs some operator, creates no stdout output, and then dies.
        # It is a fuzzy edge case. Any child process that you are likely to
        # want to interact with Pexpect would probably not fall into this
        # category.
        # FYI, This essentially does a fork/exec operation.

        assert self.pid == None, 'The pid member is not None.'
        assert self.command != None, 'The command member is None.'

        if which(self.command) == None:
            raise ExceptionPexpect ('The command was not found or was not executable: %s.' % self.command)

        try:
            self.pid, self.child_fd = pty.fork()
        except OSError, e:
            raise ExceptionPexpect('Pexpect: pty.fork() failed: ' + str(e))

        if self.pid == 0: # Child
            # Do not allow child to inherit open file descriptors from parent.
            max_fd = resource.getrlimit(resource.RLIMIT_NOFILE)[0]
            for i in range (3, max_fd):
                try:
                    os.close (i)
                except OSError:
                    pass

            os.execvp(self.command, self.args)
            raise ExceptionPexpect ('Reached an unexpected state in __spawn().')

        # Parent

    def fileno (self):
        """This returns the file descriptor of the pty for the child."""
        return self.child_fd

    def close (self):
        """This is experimental.
	This closes the file descriptor of the child application.
	It makes no attempt to actually kill the child or wait for its status.
	"""
	if self.child_fd != -1:
	    os.close (self.child_fd)
	    self.child_fd = -1

    def set_echo (self, on):
        """This sets the terminal echo-mode on or off."""
        new = termios.tcgetattr(self.child_fd)
        if on:
            new[3] = new[3] | termios.ECHO # lflags
        else:
            new[3] = new[3] & ~termios.ECHO # lflags
        termios.tcsetattr(self.child_fd, termios.TCSADRAIN, new)

    def compile_pattern_list(self, pattern):
        """This compiles a pattern-string or a list of pattern-strings.
        This is used by expect() when calling expect_list().
        Thus expect() is nothing more than::
             cpl = self.compile_pattern_list(pl)
             return self.expect_list(clp, timeout)

        If you are using expect() within a loop it may be more
        efficient to compile the patterns first and call expect_list():
             cpl = self.compile_pattern_list(my_pattern)
             while some_condition:
                ...
                i = self.expect_list(clp, timeout)
                ...
        """
        if pattern is EOF:
            compiled_pattern_list = [EOF]
        elif type(pattern) is StringType:
            compiled_pattern_list = [re.compile(pattern)]
        elif type(pattern) is ListType:
            compiled_pattern_list = []
            for x in pattern:
                if x is EOF:
                    compiled_pattern_list.append(EOF)
                else:
                    compiled_pattern_list.append( re.compile(x) )
        else:
            raise TypeError, 'Pattern argument must be a string or a list of strings.'

        return compiled_pattern_list
 
    def expect(self, pattern, timeout = None):
        """This seeks through the stream looking for the given
        pattern. The 'pattern' can be a string or a list of strings.
        The strings are regular expressions (see module 're'). This
        returns the index into the pattern list or raises an exception
        on error.

        After a match is found the instance attributes
        'before', 'after' and 'match' will be set.
        You can see all the data read before the match in 'before'.
        You can see the data that was matched in 'after'.
        The re.MatchObject used in the re match will be in 'match'.
        If an error occured then 'before' will be set to all the
        data read so far and 'after' and 'match' will be None.

        Note: A list entry may be EOF instead of a string.
        This will catch EOF exceptions and return the index
        of the EOF entry instead of raising the EOF exception.
        The attributes 'after' and 'match' will be None.
        This allows you to write code like this:
                index = p.expect (['good', 'bad', pexpect.EOF])
                if index == 0:
                        do_something()
                elif index == 1:
                        do_something_else()
                elif index == 2:
                        do_some_other_thing()
        instead of code like this:
                try:
                        index = p.expect (['good', 'bad'])
                        if index == 0:
                                do_something()
                        elif index == 1:
                                do_something_else()
                except EOF:
                        do_some_other_thing()
        These two forms are equivalent. It all depends on what you want.
        You can also just expect the EOF if you are waiting for all output
        of a child to finish. For example:
                p = pexpect.spawn('/bin/ls')
                p.expect (pexpect.EOF)
                print p.before
        """
        compiled_pattern_list = self.compile_pattern_list(pattern)
        return self.expect_list(compiled_pattern_list, timeout)

    def expect_exact (self, pattern_list, local_timeout = None):
        """This is similar to expect() except that it takes
        list of regular strings instead of compiled regular expressions.
        The idea is that this should be much faster. It could also be
        useful when you don't want to have to worry about escaping
        regular expression characters that you want to match.
        You may also pass just a string without a list and the single
        string will be converted to a list.
        """
        ### This is dumb. It shares most of the code with expect_list.
        ### The only different is the comparison method and that
        ### self.match is always None after calling this.

        if type(pattern_list) is StringType:
            pattern_list = [pattern_list]

        try:
            incoming = ''
            while 1: # Keep reading until exception or return.
                c = self.read(1, local_timeout)
                incoming = incoming + c

                # Sequence through the list of patterns and look for a match.
                index = -1
                for str_target in pattern_list:
                    index = index + 1
                    match_index = incoming.find (str_target)
                    if match_index >= 0:
                        self.before = incoming [ : match_index]
                        self.after = incoming [match_index : ]
                        self.match = None
                        return index
        except EOF, e:
            if EOF in pattern_list:
                self.before = incoming
                self.after = EOF
                return pattern_list.index(EOF)
            else:
                raise
        except Exception, e:
            self.before = incoming
            self.after = None
            self.match = None
            raise
            
    def expect_list(self, pattern_list, local_timeout = None):
        """This is called by expect(). This takes a list of compiled
        regular expressions. This returns the index into the pattern_list
        that matched the child's output.

        """

        if local_timeout is None: 
            local_timeout = self.timeout
        
        try:
            incoming = ''
            while 1: # Keep reading until exception or return.
                c = self.read(1, local_timeout)
                incoming = incoming + c

                # Sequence through the list of patterns and look for a match.
                index = -1
                for cre in pattern_list:
                    index = index + 1
                    if cre is EOF: 
                        continue # The EOF pattern is not a regular expression.
                    match = cre.search(incoming)
                    if match is not None:
                        self.before = incoming[ : match.start()]
                        self.after = incoming[match.start() : ]
                        self.match = match
                        return index
        except EOF, e:
            if EOF in pattern_list:
                self.before = incoming
                self.after = EOF
                return pattern_list.index(EOF)
            else:
                raise
        except Exception, e:
            self.before = incoming
            self.after = None
            self.match = None
            raise
            
    def read(self, n, timeout = None):
        """This reads up to n characters from the child application.
        It includes a timeout. If the read does not complete within the
        timeout period then a TIMEOUT exception is raised.
        If the end of file is read then an EOF exception will be raised.

        Notice that if this method is called with timeout=None 
        then it actually may block.

        This is a non-blocking wrapper around os.read().
        It uses select.select() to supply a timeout. 

        """
        # Note that some systems like Solaris don't seem to ever give
        # an EOF when the child dies. In fact, you can still try to read
        # from the child_fd -- it will block forever or until TIMEOUT.
        # For this case, I test isalive() before doing any reading.
        # If isalive() is false, then I pretend that this is the same as EOF.
        if not self.isalive():
            r, w, e = select.select([self.child_fd], [], [], 0)
            if not r:
                self.flag_eof = 1
                raise EOF ('End Of File (EOF) in read(). Braindead platform.')
        
        r, w, e = select.select([self.child_fd], [], [], timeout)
        if not r:
            raise TIMEOUT('Timeout exceeded in read().')

        if self.child_fd in r:
            try:
                s = os.read(self.child_fd, n)
            except OSError, e:
                self.flag_eof = 1
                raise EOF('End Of File (EOF) in read(). Exception style platform.')
            if s == '':
                self.flag_eof = 1
                raise EOF('End Of File (EOF) in read(). Empty string style platform.')
            
            return s

        raise ExceptionPexpect('Reached an unexpected state in read().')

    def send(self, text):
        """This sends a string to the child process.
        This return the number of bytes actually written.
        """
        return os.write(self.child_fd, text)

    def sendline(self, text=''):
        """This is like send(), but it adds a line separator.
        This return the number of bytes actually written.
        """
        n = self.send(text)
        return n + self.send(os.linesep)

    def send_eof(self):
        """This sends an EOF to the child.

        More precisely: this sends a character which causes the pending
        child buffer to be sent to the waiting user program without
        waiting for end-of-line. If it is the first character of the
        line, the read() in the user program returns 0, which
        signifies end-of-file.

        This means: to make this work as expected a send_eof() has to be
        called at the begining of a line. A newline character is not
        send here to avoid problems.
        """
        ### Hmmm... how do I send an EOF?
        ###C  if ((m = write(pty, *buf, p - *buf)) < 0)
        ###C      return (errno == EWOULDBLOCK) ? n : -1;

        fd = sys.stdin.fileno()
        old = termios.tcgetattr(fd) # remember current state
        new = termios.tcgetattr(fd)
        new[3] = new[3] | termios.ICANON        # lflags
        # use try/finally to ensure state gets restored
        try:
            # EOF is recognized when ICANON is set, thus ensure it is:
            termios.tcsetattr(fd, termios.TCSADRAIN, new)
            os.write(self.child_fd, termios.CEOF)
        finally:
            termios.tcsetattr(fd, termios.TCSADRAIN, old) # restore state
        
    def isalive(self):
        """This tests if the child process is running or not.
        This returns 1 if the child process appears to be running or 0 if not.
        I can't figure out how to check if a process is alive on Solaris.
        This is what is wrong. Apparently the call:
            os.waitpid(self.pid, os.WNOHANG)
        does not work. It always returns with (0,0) whether the process
        is running or defunct. This works:
            os.waitpid(self.pid, 0)
        but it blocks if the process IS alive, so it's useless for me.
        I don't want to use signals. Signals on UNIX suck and they
        mess up Python pipes (setting SIGCHLD to SIGIGNORE).
        """
        try:
            pid, status = os.waitpid(self.pid, os.WNOHANG)
        except OSError, e:
            return 0
### This does not work on Solaris. Grrr!!!
#        if pid == self.pid and status == 0:
        if status == 0:
            return 1

#        if pid == 0 and status == 0: # This happens on Solaris...
#            # This means (I think) that the Solaris process is "defunct"
#            # and a second wait must be called without the WNOHANG.
#            # This is dangerous because if I am wrong then this could block.
#            pid, status = os.waitpid (self.pid, 0)

        # I didn't OR this together because I want hooks for debugging.
        if os.WIFEXITED (status):
            return 0
        elif os.WIFSTOPPED (status):
            return 0
        elif os.WIFSIGNALED (status):
            return 0
        else:
            return 1

    def kill(self, sig):
        """This sends the given signal to the child application.
        In keeping with UNIX tradition it has a misleading name.
        It does not necessarily kill the child unless
        you send the right signal.
        """
        # Same as os.kill, but the pid is given for you.
        os.kill(self.pid, sig)

def which (filename):
    """This takes a given filename and tries to find it in the
    environment path and check if it is executable.
    """

    # Special case where filename already contains a path.
    if os.path.dirname(filename) != '':
        if os.access (filename, os.X_OK):
            return filename

    if not os.environ.has_key('PATH') or os.environ['PATH'] == '':
        p = os.defpath
    else:
        p = os.environ['PATH']

    pathlist = p.split (os.pathsep)

    for path in pathlist:
        f = os.path.join(path, filename)
        if os.access(f, os.X_OK):
            return f
    return None

def split_command_line(command_line):
    """This splits a command line into a list of arguments.
    It splits arguments on spaces, but handles
    embedded quotes, doublequotes, and escaped characters.
    I couldn't do this with a regular expression, so
    I wrote a little state machine to parse the command line.
    """
    arg_list = []
    arg = ''
    state_quote = 0
    state_doublequote = 0
    state_esc = 0
    for c in command_line:
        if c == '\\': # Escape the next character
            state_esc = 1
        elif c == r"'": # Handle single quote
            if state_esc:
                state_esc = 0
            elif not state_quote:
                state_quote = 1
            else:
                state_quote = 0
        elif c == r'"': # Handle double quote
            if state_esc:
                state_esc = 0
            elif not state_doublequote:
                state_doublequote = 1
            else:
                state_doublequote = 0

        # Add arg to arg_list unless in some other state.
        elif c == ' 'and not state_quote and not state_doublequote and not state_esc:
            arg_list.append(arg)
            arg = ''
        else:
            arg = arg + c
            if c != '\\'and state_esc: # escape mode lasts for one character.
                state_esc = 0

    # Handle last argument.        
    if arg != '':
        arg_list.append(arg)
    return arg_list

#
#
###############################################################################

class ssh_session:
    "Session with extra state including the password to be used."
    def __init__(self, user, port, host, password=None, verbose=0):
        self.user = user
        self.host = host
        self.verbose = verbose
        self.password = password
        self.port = port
    def __exec(self, command):
        "Execute a command on the remote host.  Return the output."
        child = spawn(command)
        if self.verbose:
            sys.stderr.write("-> " + command + "\n")
        while True:
            seen=child.expect([EOF,'assword:','continue connecting (yes/no)?'])
            if self.verbose:
                sys.stderr.write("<- " + child.before + "\n")
            if seen == 0:
                break
            elif seen == 1:
                if self.verbose:
                    sys.stderr.write("Saw password prompt...\n")
                if not self.password:
                    self.password = getpass.getpass('Remote password: ')
                child.sendline(self.password)
            elif seen == 2:
                if self.verbose:
                    sys.stderr.write("Saw 'continue connecting' prompt...\n")
                child.sendline("yes")
        return child.before
    def ssh(self, command):
        return self.__exec("ssh -p %s -l %s %s \"%s\""%(session.port, self.user,self.host,command))
    def scp(self, src, dst):
        return self.__exec("scp -P %s %s %s@%s:%s" % (session.port, src, session.user, session.host, dst))
    def exists(self, file):
        "Retrieve file permissions of specified remote file."
        seen = self.ssh("/bin/ls -ld %s" % file)
        if  self.verbose:
            sys.stderr.write("Existence check on %s gave %s\n" % (file,`seen`))
        if string.find(seen, "No such file") > -1:
            return None			# File doesn't exist
        else:
            return seen.split()[0]	# Return permission field of listing.

def die(msg):
    print msg
    print "Goodbye."
    raise SystemExit, 1

def fixperms(session, file, checkonly=None):
    "Force perms of remote file to user read on, group/other write off."
    perms = session.exists(file)
    ok = (perms[1] == 'r' and perms[5] == '-' and perms[8] == '-')
    print "Remote %s permissions are %s which is %s." \
          % (file, perms, ("wrong", "OK")[ok])
    if not ok:
        print "User read should be on, Group and Other write permissions should be off."
        if not checkonly:
            print "Correcting permissions..."
            session.ssh("/bin/chmod u+r,g-w,o-w %s" % file)
    return ok

def forcefile(session, file, command, checkonly=None):
    "Create file with command if it does not already exist."
    perms = session.exists(file)
    if perms is None:
        print "Remote %s doesn't exist." % file
        if check:
            die("Bailing out...")
        else:
            print "Creating remote %s" % file
            session.ssh(command)
            perms = session.exists(file)
    else:
        print file, "already exists."
    return perms

if __name__ == '__main__': 
    # Process options
    (options, arguments) = getopt.getopt(sys.argv[1:], "cdhv:p:")
    check = 0
    verbose = 0
    delete = 0
    port = 22

    for (switch, val) in options:
	if (switch == '-c'):
            check = 1
        elif (switch == '-d'):
            delete = 1
	elif (switch == '-p'):
	    port = val;
	elif (switch == '-h'):
	    print __doc__
            raise SystemExit, 0
	elif (switch == '-v'):
	    verbose = 1

    if not arguments:
        print "No remote host specified."
        print __doc__
        raise SystemExit, 0
    else:
        host = arguments[0]
        if "@" in host:
            (login, host) = string.split(host, "@")
        else:
            login = os.environ.get("LOGNAME") or os.environ.get("USER")

    # Verify that key pairs are present on the local side 
    print "Checking your local configuration..."
    try:
        os.chdir(os.getenv("HOME"))
    except:
        die("Can't find your home directory!")
    if not os.access(".", os.R_OK):
        die("Your home directory is not readable by you. That's weird.")
        die("This probably means $HOME is incorrect.")
    if not os.path.exists(".ssh"):
        print "You have no .ssh directory."
        if check:
            print "Creating your .ssh directory."
            os.mkdir(".ssh")
        else:
            print "Creation of .ssh suppressed."
    if not os.access(".ssh", os.R_OK):
        die("Your .ssh directory is not readable by you. That's weird.")
        die("This probabably means $HOME is incorrect.")
    try:
        os.chdir(".ssh")
    except:
        die("Can't chdir into your .ssh directory!")
    public_keys  = filter(lambda f: os.access(f, os.R_OK),
                          ("id_dsa.pub", "id_rsa.pub", "identity.pub"))
    if not public_keys:
        if check:
            print "I don't see any key pairs.  You need to generate some."
            die("Try re-running this script without the check-only option.")
        public_keys = []
        for (keytype, file) in (("rsa1", "identity"), ("dsa", "id_dsa")):
            print "About to generate a %s key. You will be prompted for a passphrase," % keytype
            print "Pressing enter at that prompt will create a key with no passphrase" 
            os.system("ssh-keygen -t %s -f %s" % (keytype, file))
            public_keys.append(file + ".pub")

    print "I see the following public keys:", string.join(public_keys,", ")
    for file in public_keys:
        private_key = file[:-4]
        legend = private_key + " corresponding to " + file + " " 
        if os.access(private_key, os.R_OK):
            print legend + "exists and is readable."
        else:
            die(legend + "does not exist or is unreadable.")

    print "Local configuration looks correct."

    if delete:
        os.system('ssh -l %s %s "rm -fr .ssh"' % (login, host))
        print "Keys for %s on %s have been deleted." % (login, host)
        raise SystemExit, 0

    # Update or check
    try:
         keyfile = tempfile.mktemp()
         os.system("cat " + string.join(public_keys, " ") + " >" + keyfile)
         if verbose:
             print "Public keys consolidated in %s" % keyfile

         # Get session to remote host.
         print "Starting session to %s..." % host
         session = ssh_session(login, port, host, verbose=verbose)

         # Permissions checks
         fixperms(session, "~", check)
         forcefile(session, "~/.ssh", "mkdir ~/.ssh", check)
         fixperms(session, "~/.ssh", check)

         # OK, install keys remotely if they don't exist
         if not check and not session.exists("~/.ssh/authorized_keys"):
             session.scp(keyfile, "~/.ssh/authorized_keys")
    except SystemExit:
        pass
    # It's all over now...
    os.remove(keyfile)
    if check:
        print "Check is complete."
    else:
        print "Keys for %s on %s have been installed." % (login, host)

# End
