aboutsummaryrefslogtreecommitdiff
path: root/ec2/spark_ec2.py
diff options
context:
space:
mode:
authorJey Kottalam <jey@cs.berkeley.edu>2013-07-03 16:57:22 -0700
committerJey Kottalam <jey@cs.berkeley.edu>2013-09-06 14:51:47 -0700
commitb98572c70ad3932381a55f23f82600d7e435d2eb (patch)
tree055327eb12aa3b205f7695422243eef33e9f4d58 /ec2/spark_ec2.py
parent6919a28d51c416ff4bb647b03eae2070cf87f039 (diff)
downloadspark-b98572c70ad3932381a55f23f82600d7e435d2eb.tar.gz
spark-b98572c70ad3932381a55f23f82600d7e435d2eb.tar.bz2
spark-b98572c70ad3932381a55f23f82600d7e435d2eb.zip
Generate new SSH key for the cluster, make "--identity-file" optional
Diffstat (limited to 'ec2/spark_ec2.py')
-rwxr-xr-xec2/spark_ec2.py58
1 files changed, 37 insertions, 21 deletions
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 75dd0ffa61..0858b126c5 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -104,11 +104,7 @@ def parse_args():
parser.print_help()
sys.exit(1)
(action, cluster_name) = args
- if opts.identity_file == None and action in ['launch', 'login', 'start']:
- print >> stderr, ("ERROR: The -i or --identity-file argument is " +
- "required for " + action)
- sys.exit(1)
-
+
# Boto config check
# http://boto.cloudhackers.com/en/latest/boto_config_tut.html
home_dir = os.getenv('HOME')
@@ -392,10 +388,18 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
master = master_nodes[0].public_dns_name
if deploy_ssh_key:
- print "Copying SSH key %s to master..." % opts.identity_file
- ssh(master, opts, 'mkdir -p ~/.ssh')
- scp(master, opts, opts.identity_file, '~/.ssh/id_rsa')
- ssh(master, opts, 'chmod 600 ~/.ssh/id_rsa')
+ print "Generating cluster's SSH key on master..."
+ key_setup = """
+ [ -f ~/.ssh/id_rsa ] ||
+ (ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa &&
+ cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys)
+ """
+ ssh(master, opts, key_setup)
+ dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh'])
+ print "Transferring cluster's SSH key to slaves..."
+ for slave in slave_nodes:
+ print slave.public_dns_name
+ ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar)
modules = ['spark', 'shark', 'ephemeral-hdfs', 'persistent-hdfs',
'mapreduce', 'spark-standalone']
@@ -556,7 +560,9 @@ def stringify_command(parts):
def ssh_args(opts):
- parts = ['-o', 'StrictHostKeyChecking=no', '-i', opts.identity_file]
+ parts = ['-o', 'StrictHostKeyChecking=no']
+ if opts.identity_file is not None:
+ parts += ['-i', opts.identity_file]
return parts
@@ -564,16 +570,6 @@ def ssh_command(opts):
return ['ssh'] + ssh_args(opts)
-def scp_command(opts):
- return ['scp', '-q'] + ssh_args(opts)
-
-
-# Copy a file to a given host through scp, throwing an exception if scp fails
-def scp(host, opts, local_file, dest_file):
- subprocess.check_call(
- scp_command(opts) + [local_file, "%s@%s:%s" % (opts.user, host, dest_file)])
-
-
# Run a command on a host through ssh, retrying up to two times
# and then throwing an exception if ssh continues to fail.
def ssh(host, opts, command):
@@ -585,13 +581,33 @@ def ssh(host, opts, command):
except subprocess.CalledProcessError as e:
if (tries > 2):
raise e
- print "Couldn't connect to host {0}, waiting 30 seconds".format(e)
+ print "Error connecting to host, sleeping 30: {0}".format(e)
time.sleep(30)
tries = tries + 1
+def ssh_read(host, opts, command):
+ return subprocess.check_output(
+ ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)])
+def ssh_write(host, opts, command, input):
+ tries = 0
+ while True:
+ proc = subprocess.Popen(
+ ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)],
+ stdin=subprocess.PIPE)
+ proc.stdin.write(input)
+ proc.stdin.close()
+ if proc.wait() == 0:
+ break
+ elif (tries > 2):
+ raise RuntimeError("ssh_write error %s" % proc.returncode)
+ else:
+ print "Error connecting to host, sleeping 30"
+ time.sleep(30)
+ tries = tries + 1
+
# Gets a list of zones to launch instances in
def get_zones(conn, opts):