#! /usr/bin/python2
#
#    ssh-multiadd - add multiple ssh keys, maybe some with the same passphrase 
#    Copyright (C) 2001-2002  Matthew Mueller <donut@azstarnet.com>
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program; if not, write to the Free Software
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

import os,errno,pty,sys,getopt,re,commands

def pathfind(pl, path=os.environ.get('PATH',os.defpath).split(os.pathsep), notfound=None, pathadd=[]):
	if type(pl)==type(''): pl = (pl,)
	for a in [os.path.join(d,p) for p in pl for d in path+pathadd]:
		if os.path.exists(a):
			return a
	# if we don't find it, return with no explicit path so that we get a useful
	# "file not found" error rather than some "bad operand for +" if we
	# returned (and tried to use) None
	return notfound is None and pl[0] or notfound

# default configuration:
keys = ('identity', 'id_dsa')
ssh = pathfind('ssh')
sshadd = pathfind('ssh-add')
sshaskpass = os.environ.get('SSH_ASKPASS', pathfind(('ssh-askpass','x11-ssh-askpass','ssh-askpass2','ssh-askpass1','gnome-ssh-askpass'),pathadd=[os.path.join(os.sep,'usr','lib','ssh')]))
sshdir = pathfind(('.ssh','.ssh2'), [os.path.expanduser('~')], '')
useaskpass = 'auto'
verbose = 0
listidentities = 0
sshgetfingerprint = -1

# user configuration:
conffile = os.path.join(os.path.expanduser('~'),'.ssh-multiadd.rc.py')
if os.path.exists(conffile):
	execfile(conffile)

def p(s,nl=1,outf=sys.stdout):
	outf.write(s)
	if nl: outf.write('\n')
def pverbose(s,nl=1):
	if verbose>0: p(s,nl)
def pinfo(s,nl=1):
	if verbose>=0: p(s,nl)
def perror(s,nl=1):
	if verbose>=-1: p(s,nl,outf=sys.stderr)


def ptyopen(blah,args):
	"kinda like popen..."
	pverbose('ptyopen: '+blah+' '+str(args))
	pid,fd=pty.fork()
	if pid==0:
		os.execv(blah,[blah]+args)
	else:
		return fd
def writen(fd, data):
	"Write all the data to a descriptor."
	while data != '':
		n = os.write(fd, data)
		data = data[n:]

def exitstatus_str(st):
	if os.WIFSTOPPED(st):
		return 'stopped: sig %i'%os.WSTOPSIG(st)
	if os.WIFSIGNALED(st):
		return 'killed: sig %i'%os.WTERMSIG(st)
	if os.WIFEXITED(st):
		return 'exit status: %i'%os.WEXITSTATUS(st)
	return '(??: %i)'%st

def getoutput(cmd):
	pverbose('running %r...'%cmd)
	st,p=commands.getstatusoutput(cmd)
	if not os.WIFEXITED(st):
		perror('%r %s (%s)'%(cmd,exitstatus_str(st),p))
		sys.exit(1)
	return p

def getoutput_checkstatus(cmd):
	pverbose('running %r...'%cmd)
	st,p=commands.getstatusoutput(cmd)
	if st:
		perror('%r %s (%s)'%(cmd,exitstatus_str(st),p))
		sys.exit(1)
	return p

def askpass(prompt=''):
	p=''
	while p=='':
		if useaskpass=='yes' or (useaskpass=='auto' and not sys.stdin.isatty()):
			p=getoutput_checkstatus(sshaskpass + (prompt and ' '+`prompt` or ''))
		else:
			import getpass
			p=getpass.getpass(prompt or 'enter pass: ')
	return p

version='1.3.2'

def printusage(err=0):
	def isdef(c): return c and ' (default)' or ''
	perror('Usage: ssh-multiadd [opts] [keyfiles (%s)]'%', '.join(keys))
	perror(' -a <a>    use ssh-askpass (auto)')
	perror(' -A <p>    ssh-askpass (%s)'%sshaskpass)
	perror(' -s <p>    ssh-add (%s)'%sshadd)
	perror(' -S <p>    ssh (%s)'%ssh)
	perror(' -d <d>    dir to look for keyfiles (%s)'%sshdir)
	perror(' -f        force all keys to be added'+isdef(sshgetfingerprint==None))
	perror(' -l        lists all identities represented by the agent when done'+isdef(listidentities==1))
	perror(' -L        lists differences in represented identities when done'+isdef(listidentities==2))
	perror(" --nolist  disable -l/-L"+isdef(listidentities==0))
	perror(' -h/--help show help')
	perror(' --version show version')
	sys.exit(err)
def printhelp():
	perror('ssh-multiadd v%s - Copyright (C) 2001-2002 Matthew Mueller - GPL license'%version)
	printusage()

def main(argv):
	try:
		optlist, args = getopt.getopt(argv, 'a:A:s:d:S:flLh?', ['help','version','debug','nolist'])
	except getopt.error, a:
		perror("ssh-multiadd: %s"%a)
		printusage(1)

	global useaskpass,sshaskpass,sshadd,sshdir,verbose,listidentities,keys,sshgetfingerprint,ssh
	#prevopt=''
	for o,a in optlist:
		if o=='-a':
			if a in ('yes','no','auto'):
				useaskpass=a
			else:
				perror("invalid -a arg '%s'"%a)
				printusage(1)
		elif o=='-A':
			sshaskpass=a
		elif o=='-s':
			sshadd=a
		elif o=='-d':
			sshdir=a
		elif o=='-S':
			ssh=a
		elif o=='-f':
			sshgetfingerprint=None
		elif o=='-l':
			listidentities=1
		elif o=='-L':
			listidentities=2
		elif o=='--nolist':
			listidentities=0
		elif o=='--debug':
			verbose=1
		elif o=='-h' or o=='-?' or o=='--help':
			printhelp()
		elif o=='--version':
			print version
			sys.exit(0)

	keys = [os.path.join(sshdir,d) for d in args or keys]

	if sshgetfingerprint == -1:
		sver = getoutput_checkstatus(ssh + ' -V').splitlines()[0]
		if sver.find('OpenSSH')>=0:
			sdir = os.path.split(ssh)[0]
			sshgetfingerprint = (os.path.join(sdir,'ssh-keygen -l -f "%s"'), '(\S+\s+\S+)')
		elif sver.find('SSH Version 1')>=0:
			sshgetfingerprint = (pathfind('cat')+' "%s.pub"', '(\S+\s+\S+\s+\S+)')
		elif sver.find('SSH Version 2')>=0:
			sshgetfingerprint = (pathfind('basename')+' "%s"', '(.+)') #kludge since sshv2's output is minimal
		else:
			sshgetfingerprint = None
		pverbose('ssh ver %s, sshgetfingerprint=%s'%(sver,sshgetfingerprint))

	if listidentities==2 or sshgetfingerprint:
		pre_identities=getoutput(sshadd + ' -l').split(os.linesep)
		pverbose('pre_identities=%s'%(pre_identities))
	if sshgetfingerprint:
		for k in keys[:]:
			fp = getoutput_checkstatus(sshgetfingerprint[0]%k)
			pverbose('%s = %s'%(k,fp))
			x = re.match(sshgetfingerprint[1],fp)
			kid = x.group(1)
			for pid in pre_identities:
				if pid.find(kid)>=0:
					pinfo('key %s already loaded'%(k))
					#pinfo('key %s(%s) already loaded (%s)'%(k,kid,pid))
					keys.remove(k)
					break
		if not keys:
			pinfo('all keys loaded already, exiting')
			return

#we have to use a pty since ssh-add will always call ssh_askpass if its stdin is not a tty.
	fd=ptyopen(sshadd, keys)
	s=''
	passphrases=[]
	readded=re.compile(r"added: (.*)$",re.M|re.I)#commercial ssh2 doesn't say this at all. oh well.
	rebadpass=re.compile(r"Bad passphrase",re.I)
	#commercial ssh don't say enter passphrase for ..., so get it from the need
	reneed=re.compile(r"Need passphrase for (.*)",re.I) 
	reenterpass=re.compile(r"Enter passphrase(?: for (.*?)):? ",re.I)
	while 1:
		try:
			r=os.read(fd,1024)
		except OSError, err:
			if err[0]==errno.EAGAIN or err[0]==errno.EINTR:
				continue #dunno if these can happen here, the docs don't say.. be safe. :)
			if err[0]==errno.EIO:
				break #seems to return oserror on eof :)
			raise
		if not len(r):
			#break
			perror('hm, I wonder why we are here')
			continue
		pverbose('read: {'+r+'}')
		s += r
		while 1: #mulitple added messages coulde be in the same read if no passphrase used..
			x=readded.search(s)
			if not x:
				break
			pinfo('added: %s'%x.group(1))
			#pverbose('s before: {'+s+'}')
			s=s[:x.start(0)]+s[x.end(0):] #cut the added part out so we don't see it again
			#pverbose('s after: {'+s+'}')
			
		x=reneed.search(s)
		if x:
			curkey=x.group(1)
		x=reenterpass.search(s)
		if x:
			curpass=0
			if x.group(1):
				curkey=x.group(1)
			curaskextra=''
			askedyet=0
		else:
			x=rebadpass.search(s)
			if not x:
				continue
			if askedyet and not curaskextra: #if we haven't asked yet for this key, any bad pass messages will just be from trying previous passphrases.
				curaskextra='Bad Pass. '
		if curpass>=len(passphrases):
			passphrases.append(askpass('%sEnter passphrase for %s: '%(curaskextra,curkey)))
			askedyet=1
		writen(fd,passphrases[curpass]+'\n')
		curpass+=1
		pverbose('<pass>')
		s=''
	os.close(fd)
	w=os.wait()
	if w[1]:
		perror('%s[%i] %s'%(sshadd,w[0],exitstatus_str(w[1])))
		x=re.search('(.*)$',s) #. does not match \n without re.DOTALL
		if x:
			perror('last line of output was: %s'%x.group(1)) #we only want to print the last line in case a passphrase was in that output somewhere... and just because its cleaner :)
		sys.exit(1)
	pverbose('wait: '+str(w))

	if listidentities==2:
		excl=re.compile('agent has (no|[0-9]+) ',re.I)
		post_identities=getoutput(sshadd + ' -l').split(os.linesep)
		added=[x for x in post_identities if x not in pre_identities and not excl.search(x)]
		delled=[x for x in pre_identities if x not in post_identities and not excl.search(x)]
		if added or delled:
			for x in added: print '+'+x
			for x in delled: print '-'+x
		else:
			pverbose("no change in agent's identities")
	elif listidentities:
		os.system(sshadd + ' -l')

if __name__=='__main__':
	try:
		main(sys.argv[1:])
	except KeyboardInterrupt:
		# in debug mode print the traceback, otherwise exit nicely on ^C
		if verbose>0: raise
