From 6919a28d51c416ff4bb647b03eae2070cf87f039 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Fri, 17 May 2013 17:10:47 -0700 Subject: Construct shell commands as sequences for safety and composability --- ec2/spark_ec2.py | 45 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 11 deletions(-) (limited to 'ec2') diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 932e70db96..75dd0ffa61 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -23,6 +23,7 @@ from __future__ import with_statement import logging import os +import pipes import random import shutil import subprocess @@ -536,18 +537,41 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): dest.write(text) dest.close() # rsync the whole directory over to the master machine - command = (("rsync -rv -e 'ssh -o StrictHostKeyChecking=no -i %s' " + - "'%s/' '%s@%s:/'") % (opts.identity_file, tmp_dir, opts.user, active_master)) - subprocess.check_call(command, shell=True) + command = [ + 'rsync', '-rv', + '-e', stringify_command(ssh_command(opts)), + "%s/" % tmp_dir, + "%s@%s:/" % (opts.user, active_master) + ] + subprocess.check_call(command) # Remove the temp directory we created above shutil.rmtree(tmp_dir) +def stringify_command(parts): + if isinstance(parts, str): + return parts + else: + return ' '.join(map(pipes.quote, parts)) + + +def ssh_args(opts): + parts = ['-o', 'StrictHostKeyChecking=no', '-i', opts.identity_file] + return parts + + +def ssh_command(opts): + return ['ssh'] + ssh_args(opts) + + +def scp_command(opts): + return ['scp', '-q'] + ssh_args(opts) + + # Copy a file to a given host through scp, throwing an exception if scp fails def scp(host, opts, local_file, dest_file): subprocess.check_call( - "scp -q -o StrictHostKeyChecking=no -i %s '%s' '%s@%s:%s'" % - (opts.identity_file, local_file, opts.user, host, dest_file), shell=True) + scp_command(opts) + [local_file, "%s@%s:%s" % (opts.user, host, dest_file)]) # Run a command on a host through ssh, retrying up to two times @@ -557,8 +581,7 @@ def ssh(host, opts, command): while True: try: return subprocess.check_call( - "ssh -t -o StrictHostKeyChecking=no -i %s %s@%s '%s'" % - (opts.identity_file, opts.user, host, command), shell=True) + ssh_command(opts) + ['-t', '%s@%s' % (opts.user, host), stringify_command(command)]) except subprocess.CalledProcessError as e: if (tries > 2): raise e @@ -670,11 +693,11 @@ def main(): conn, opts, cluster_name) master = master_nodes[0].public_dns_name print "Logging into master " + master + "..." - proxy_opt = "" + proxy_opt = [] if opts.proxy_port != None: - proxy_opt = "-D " + opts.proxy_port - subprocess.check_call("ssh -o StrictHostKeyChecking=no -i %s %s %s@%s" % - (opts.identity_file, proxy_opt, opts.user, master), shell=True) + proxy_opt = ['-D', opts.proxy_port] + subprocess.check_call( + ssh_command(opts) + proxy_opt + ['-t', "%s@%s" % (opts.user, master)]) elif action == "get-master": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) -- cgit v1.2.3 From b98572c70ad3932381a55f23f82600d7e435d2eb Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Wed, 3 Jul 2013 16:57:22 -0700 Subject: Generate new SSH key for the cluster, make "--identity-file" optional --- ec2/spark_ec2.py | 58 ++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 37 insertions(+), 21 deletions(-) (limited to 'ec2') diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 75dd0ffa61..0858b126c5 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -104,11 +104,7 @@ def parse_args(): parser.print_help() sys.exit(1) (action, cluster_name) = args - if opts.identity_file == None and action in ['launch', 'login', 'start']: - print >> stderr, ("ERROR: The -i or --identity-file argument is " + - "required for " + action) - sys.exit(1) - + # Boto config check # http://boto.cloudhackers.com/en/latest/boto_config_tut.html home_dir = os.getenv('HOME') @@ -392,10 +388,18 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): master = master_nodes[0].public_dns_name if deploy_ssh_key: - print "Copying SSH key %s to master..." % opts.identity_file - ssh(master, opts, 'mkdir -p ~/.ssh') - scp(master, opts, opts.identity_file, '~/.ssh/id_rsa') - ssh(master, opts, 'chmod 600 ~/.ssh/id_rsa') + print "Generating cluster's SSH key on master..." + key_setup = """ + [ -f ~/.ssh/id_rsa ] || + (ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa && + cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys) + """ + ssh(master, opts, key_setup) + dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh']) + print "Transferring cluster's SSH key to slaves..." + for slave in slave_nodes: + print slave.public_dns_name + ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar) modules = ['spark', 'shark', 'ephemeral-hdfs', 'persistent-hdfs', 'mapreduce', 'spark-standalone'] @@ -556,7 +560,9 @@ def stringify_command(parts): def ssh_args(opts): - parts = ['-o', 'StrictHostKeyChecking=no', '-i', opts.identity_file] + parts = ['-o', 'StrictHostKeyChecking=no'] + if opts.identity_file is not None: + parts += ['-i', opts.identity_file] return parts @@ -564,16 +570,6 @@ def ssh_command(opts): return ['ssh'] + ssh_args(opts) -def scp_command(opts): - return ['scp', '-q'] + ssh_args(opts) - - -# Copy a file to a given host through scp, throwing an exception if scp fails -def scp(host, opts, local_file, dest_file): - subprocess.check_call( - scp_command(opts) + [local_file, "%s@%s:%s" % (opts.user, host, dest_file)]) - - # Run a command on a host through ssh, retrying up to two times # and then throwing an exception if ssh continues to fail. def ssh(host, opts, command): @@ -585,13 +581,33 @@ def ssh(host, opts, command): except subprocess.CalledProcessError as e: if (tries > 2): raise e - print "Couldn't connect to host {0}, waiting 30 seconds".format(e) + print "Error connecting to host, sleeping 30: {0}".format(e) time.sleep(30) tries = tries + 1 +def ssh_read(host, opts, command): + return subprocess.check_output( + ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)]) +def ssh_write(host, opts, command, input): + tries = 0 + while True: + proc = subprocess.Popen( + ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)], + stdin=subprocess.PIPE) + proc.stdin.write(input) + proc.stdin.close() + if proc.wait() == 0: + break + elif (tries > 2): + raise RuntimeError("ssh_write error %s" % proc.returncode) + else: + print "Error connecting to host, sleeping 30" + time.sleep(30) + tries = tries + 1 + # Gets a list of zones to launch instances in def get_zones(conn, opts): -- cgit v1.2.3 From e86d1d4a52147fe52feeda74ca3558f6bc109285 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Wed, 11 Sep 2013 14:59:42 -0700 Subject: Clarify error messages on SSH failure --- ec2/spark_ec2.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) (limited to 'ec2') diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 0858b126c5..f4babba9b9 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -37,6 +37,9 @@ import boto from boto.ec2.blockdevicemapping import BlockDeviceMapping, EBSBlockDeviceType from boto import ec2 +class UsageError(Exception): + pass + # A URL prefix from which to fetch AMI information AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list" @@ -580,8 +583,12 @@ def ssh(host, opts, command): ssh_command(opts) + ['-t', '%s@%s' % (opts.user, host), stringify_command(command)]) except subprocess.CalledProcessError as e: if (tries > 2): - raise e - print "Error connecting to host, sleeping 30: {0}".format(e) + # If this was an ssh failure, provide the user with hints. + if e.returncode == 255: + raise UsageError("Failed to SSH to remote host {0}.\nPlease check that you have provided the correct --identity-file and --key-pair parameters and try again.".format(host)) + else: + raise e + print >> stderr, "Error executing remote command, retrying after 30 seconds: {0}".format(e) time.sleep(30) tries = tries + 1 @@ -599,12 +606,13 @@ def ssh_write(host, opts, command, input): stdin=subprocess.PIPE) proc.stdin.write(input) proc.stdin.close() - if proc.wait() == 0: + status = proc.wait() + if status == 0: break elif (tries > 2): - raise RuntimeError("ssh_write error %s" % proc.returncode) + raise RuntimeError("ssh_write failed with error %s" % proc.returncode) else: - print "Error connecting to host, sleeping 30" + print >> stderr, "Error {0} while executing remote command, retrying after 30 seconds".format(status) time.sleep(30) tries = tries + 1 @@ -626,7 +634,7 @@ def get_partition(total, num_partitions, current_partitions): return num_slaves_this_zone -def main(): +def real_main(): (opts, action, cluster_name) = parse_args() try: conn = ec2.connect_to_region(opts.region) @@ -755,6 +763,13 @@ def main(): sys.exit(1) +def main(): + try: + real_main() + except UsageError, e: + print >> stderr, "\nError:\n", e + + if __name__ == "__main__": logging.basicConfig() main() -- cgit v1.2.3