aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala31
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala1
-rw-r--r--core/src/main/scala/spark/Utils.scala8
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala15
-rw-r--r--core/src/test/scala/spark/MapOutputTrackerSuite.scala55
-rwxr-xr-xec2/spark_ec2.py126
6 files changed, 185 insertions, 51 deletions
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 6f80f6ac90..50c4183c0e 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -2,6 +2,10 @@ package spark
import java.io._
import java.util.concurrent.ConcurrentHashMap
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
import akka.actor._
import akka.dispatch._
@@ -11,16 +15,13 @@ import akka.util.Duration
import akka.util.Timeout
import akka.util.duration._
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-
-import scheduler.MapStatus
+import spark.scheduler.MapStatus
import spark.storage.BlockManagerId
-import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String)
- extends MapOutputTrackerMessage
+ extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
@@ -88,14 +89,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
}
mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
}
-
+
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
var array = mapStatuses.get(shuffleId)
array.synchronized {
array(mapId) = status
}
}
-
+
def registerMapOutputs(
shuffleId: Int,
statuses: Array[MapStatus],
@@ -110,7 +111,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
var array = mapStatuses.get(shuffleId)
if (array != null) {
array.synchronized {
- if (array(mapId).address == bmAddress) {
+ if (array(mapId) != null && array(mapId).address == bmAddress) {
array(mapId) = null
}
}
@@ -119,10 +120,10 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
}
-
+
// Remembers which map output locations are currently being fetched on a worker
val fetching = new HashSet[Int]
-
+
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId)
@@ -149,7 +150,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
val host = System.getProperty("spark.hostname", Utils.localHostName)
val fetchedBytes = askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]]
val fetchedStatuses = deserializeStatuses(fetchedBytes)
-
+
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
fetching.synchronized {
@@ -258,8 +259,10 @@ private[spark] object MapOutputTracker {
* sizes up to 35 GB with at most 10% error.
*/
def compressSize(size: Long): Byte = {
- if (size <= 1L) {
+ if (size == 0) {
0
+ } else if (size <= 1L) {
+ 1
} else {
math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte
}
@@ -270,7 +273,7 @@ private[spark] object MapOutputTracker {
*/
def decompressSize(compressedSize: Byte): Long = {
if (compressedSize == 0) {
- 1
+ 0
} else {
math.pow(LOG_BASE, (compressedSize & 0xFF)).toLong
}
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 4c6ec6cc6e..9f2b0c42c7 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -68,7 +68,6 @@ object SparkEnv extends Logging {
isMaster: Boolean,
isLocal: Boolean
) : SparkEnv = {
-
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port)
// Bit of a hack: If this is the master and our port was 0 (meaning bind to any free port),
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 567c4b1475..c8799e6de3 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -199,7 +199,13 @@ private object Utils extends Logging {
/**
* Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4).
*/
- def localIpAddress(): String = InetAddress.getLocalHost.getHostAddress
+ def localIpAddress(): String = {
+ val defaultIpOverride = System.getenv("SPARK_LOCAL_IP")
+ if (defaultIpOverride != null)
+ defaultIpOverride
+ else
+ InetAddress.getLocalHost.getHostAddress
+ }
private var customHostname: Option[String] = None
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index dfdb22024e..cb29a6b8b4 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -43,6 +43,21 @@ private[spark] class Executor extends Logging {
urlClassLoader = createClassLoader()
Thread.currentThread.setContextClassLoader(urlClassLoader)
+ // Make any thread terminations due to uncaught exceptions kill the entire
+ // executor process to avoid surprising stalls.
+ Thread.setDefaultUncaughtExceptionHandler(
+ new Thread.UncaughtExceptionHandler {
+ override def uncaughtException(thread: Thread, exception: Throwable) {
+ try {
+ logError("Uncaught exception in thread " + thread, exception)
+ System.exit(1)
+ } catch {
+ case t: Throwable => System.exit(2)
+ }
+ }
+ }
+ )
+
// Initialize Spark environment (using system properties read above)
env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false)
SparkEnv.set(env)
diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index 4e9717d871..5b4b198960 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -2,10 +2,14 @@ package spark
import org.scalatest.FunSuite
+import akka.actor._
+import spark.scheduler.MapStatus
+import spark.storage.BlockManagerId
+
class MapOutputTrackerSuite extends FunSuite {
test("compressSize") {
assert(MapOutputTracker.compressSize(0L) === 0)
- assert(MapOutputTracker.compressSize(1L) === 0)
+ assert(MapOutputTracker.compressSize(1L) === 1)
assert(MapOutputTracker.compressSize(2L) === 8)
assert(MapOutputTracker.compressSize(10L) === 25)
assert((MapOutputTracker.compressSize(1000000L) & 0xFF) === 145)
@@ -15,11 +19,58 @@ class MapOutputTrackerSuite extends FunSuite {
}
test("decompressSize") {
- assert(MapOutputTracker.decompressSize(0) === 1)
+ assert(MapOutputTracker.decompressSize(0) === 0)
for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) {
val size2 = MapOutputTracker.decompressSize(MapOutputTracker.compressSize(size))
assert(size2 >= 0.99 * size && size2 <= 1.11 * size,
"size " + size + " decompressed to " + size2 + ", which is out of range")
}
}
+
+ test("master start and stop") {
+ val actorSystem = ActorSystem("test")
+ val tracker = new MapOutputTracker(actorSystem, true)
+ tracker.stop()
+ }
+
+ test("master register and fetch") {
+ val actorSystem = ActorSystem("test")
+ val tracker = new MapOutputTracker(actorSystem, true)
+ tracker.registerShuffle(10, 2)
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val compressedSize10000 = MapOutputTracker.compressSize(10000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
+ tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000),
+ Array(compressedSize1000, compressedSize10000)))
+ tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000),
+ Array(compressedSize10000, compressedSize1000)))
+ val statuses = tracker.getServerStatuses(10, 0)
+ assert(statuses.toSeq === Seq((new BlockManagerId("hostA", 1000), size1000),
+ (new BlockManagerId("hostB", 1000), size10000)))
+ tracker.stop()
+ }
+
+ test("master register and unregister and fetch") {
+ val actorSystem = ActorSystem("test")
+ val tracker = new MapOutputTracker(actorSystem, true)
+ tracker.registerShuffle(10, 2)
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val compressedSize10000 = MapOutputTracker.compressSize(10000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
+ tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000),
+ Array(compressedSize1000, compressedSize1000, compressedSize1000)))
+ tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000),
+ Array(compressedSize10000, compressedSize1000, compressedSize1000)))
+
+ // As if we had two simulatenous fetch failures
+ tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
+ tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
+
+ // The remaining reduce task might try to grab the output dispite the shuffle failure;
+ // this should cause it to fail, and the scheduler will ignore the failure due to the
+ // stage already being aborted.
+ intercept[Exception] { tracker.getServerStatuses(10, 1) }
+ }
}
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 2ca4d8020d..2ab11dbd34 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -19,7 +19,6 @@
from __future__ import with_statement
-import boto
import logging
import os
import random
@@ -32,7 +31,7 @@ import urllib2
from optparse import OptionParser
from sys import stderr
from boto.ec2.blockdevicemapping import BlockDeviceMapping, EBSBlockDeviceType
-
+from boto import ec2
# A static URL from which to figure out the latest Mesos EC2 AMI
LATEST_AMI_URL = "https://s3.amazonaws.com/mesos-images/ids/latest-spark-0.6"
@@ -61,7 +60,9 @@ def parse_args():
parser.add_option("-r", "--region", default="us-east-1",
help="EC2 region zone to launch instances in")
parser.add_option("-z", "--zone", default="",
- help="Availability zone to launch instances in")
+ help="Availability zone to launch instances in, or 'all' to spread " +
+ "slaves across multiple (an additional $0.01/Gb for bandwidth" +
+ "between zones applies)")
parser.add_option("-a", "--ami", default="latest",
help="Amazon Machine Image ID to use, or 'latest' to use latest " +
"available AMI (default: latest)")
@@ -97,14 +98,20 @@ def parse_args():
if opts.cluster_type not in ["mesos", "standalone"] and action == "launch":
print >> stderr, ("ERROR: Invalid cluster type: " + opts.cluster_type)
sys.exit(1)
- if os.getenv('AWS_ACCESS_KEY_ID') == None:
- print >> stderr, ("ERROR: The environment variable AWS_ACCESS_KEY_ID " +
- "must be set")
- sys.exit(1)
- if os.getenv('AWS_SECRET_ACCESS_KEY') == None:
- print >> stderr, ("ERROR: The environment variable AWS_SECRET_ACCESS_KEY " +
- "must be set")
- sys.exit(1)
+
+ # Boto config check
+ # http://boto.cloudhackers.com/en/latest/boto_config_tut.html
+ home_dir = os.getenv('HOME')
+ if home_dir == None or not os.path.isfile(home_dir + '/.boto'):
+ if not os.path.isfile('/etc/boto.cfg'):
+ if os.getenv('AWS_ACCESS_KEY_ID') == None:
+ print >> stderr, ("ERROR: The environment variable AWS_ACCESS_KEY_ID " +
+ "must be set")
+ sys.exit(1)
+ if os.getenv('AWS_SECRET_ACCESS_KEY') == None:
+ print >> stderr, ("ERROR: The environment variable AWS_SECRET_ACCESS_KEY " +
+ "must be set")
+ sys.exit(1)
return (opts, action, cluster_name)
@@ -217,17 +224,25 @@ def launch_cluster(conn, opts, cluster_name):
# Launch spot instances with the requested price
print ("Requesting %d slaves as spot instances with price $%.3f" %
(opts.slaves, opts.spot_price))
- slave_reqs = conn.request_spot_instances(
- price = opts.spot_price,
- image_id = opts.ami,
- launch_group = "launch-group-%s" % cluster_name,
- placement = opts.zone,
- count = opts.slaves,
- key_name = opts.key_pair,
- security_groups = [slave_group],
- instance_type = opts.instance_type,
- block_device_map = block_map)
- my_req_ids = [req.id for req in slave_reqs]
+ zones = get_zones(conn, opts)
+ num_zones = len(zones)
+ i = 0
+ my_req_ids = []
+ for zone in zones:
+ num_slaves_this_zone = get_partition(opts.slaves, num_zones, i)
+ slave_reqs = conn.request_spot_instances(
+ price = opts.spot_price,
+ image_id = opts.ami,
+ launch_group = "launch-group-%s" % cluster_name,
+ placement = zone,
+ count = num_slaves_this_zone,
+ key_name = opts.key_pair,
+ security_groups = [slave_group],
+ instance_type = opts.instance_type,
+ block_device_map = block_map)
+ my_req_ids += [req.id for req in slave_reqs]
+ i += 1
+
print "Waiting for spot instances to be granted..."
try:
while True:
@@ -262,20 +277,30 @@ def launch_cluster(conn, opts, cluster_name):
sys.exit(0)
else:
# Launch non-spot instances
- slave_res = image.run(key_name = opts.key_pair,
- security_groups = [slave_group],
- instance_type = opts.instance_type,
- placement = opts.zone,
- min_count = opts.slaves,
- max_count = opts.slaves,
- block_device_map = block_map)
- slave_nodes = slave_res.instances
- print "Launched slaves, regid = " + slave_res.id
+ zones = get_zones(conn, opts)
+ num_zones = len(zones)
+ i = 0
+ slave_nodes = []
+ for zone in zones:
+ num_slaves_this_zone = get_partition(opts.slaves, num_zones, i)
+ slave_res = image.run(key_name = opts.key_pair,
+ security_groups = [slave_group],
+ instance_type = opts.instance_type,
+ placement = zone,
+ min_count = num_slaves_this_zone,
+ max_count = num_slaves_this_zone,
+ block_device_map = block_map)
+ slave_nodes += slave_res.instances
+ print "Launched %d slaves in %s, regid = %s" % (num_slaves_this_zone,
+ zone, slave_res.id)
+ i += 1
# Launch masters
master_type = opts.master_instance_type
if master_type == "":
master_type = opts.instance_type
+ if opts.zone == 'all':
+ opts.zone = random.choice(conn.get_all_zones()).name
master_res = image.run(key_name = opts.key_pair,
security_groups = [master_group],
instance_type = master_type,
@@ -284,7 +309,7 @@ def launch_cluster(conn, opts, cluster_name):
max_count = 1,
block_device_map = block_map)
master_nodes = master_res.instances
- print "Launched master, regid = " + master_res.id
+ print "Launched master in %s, regid = %s" % (zone, master_res.id)
zoo_nodes = []
@@ -474,9 +499,30 @@ def ssh(host, opts, command):
(opts.identity_file, opts.user, host, command), shell=True)
+# Gets a list of zones to launch instances in
+def get_zones(conn, opts):
+ if opts.zone == 'all':
+ zones = [z.name for z in conn.get_all_zones()]
+ else:
+ zones = [opts.zone]
+ return zones
+
+
+# Gets the number of items in a partition
+def get_partition(total, num_partitions, current_partitions):
+ num_slaves_this_zone = total / num_partitions
+ if (total % num_partitions) - current_partitions > 0:
+ num_slaves_this_zone += 1
+ return num_slaves_this_zone
+
+
def main():
(opts, action, cluster_name) = parse_args()
- conn = boto.ec2.connect_to_region(opts.region)
+ try:
+ conn = ec2.connect_to_region(opts.region)
+ except Exception as e:
+ print >> stderr, (e)
+ sys.exit(1)
# Select an AZ at random if it was not specified.
if opts.zone == "":
@@ -509,6 +555,20 @@ def main():
print "Terminating zoo..."
for inst in zoo_nodes:
inst.terminate()
+ # Delete security groups as well
+ group_names = [cluster_name + "-master", cluster_name + "-slaves", cluster_name + "-zoo"]
+ groups = conn.get_all_security_groups()
+ for group in groups:
+ if group.name in group_names:
+ print "Deleting security group " + group.name
+ # Delete individual rules before deleting group to remove dependencies
+ for rule in group.rules:
+ for grant in rule.grants:
+ group.revoke(ip_protocol=rule.ip_protocol,
+ from_port=rule.from_port,
+ to_port=rule.to_port,
+ src_group=grant)
+ conn.delete_security_group(group.name)
elif action == "login":
(master_nodes, slave_nodes, zoo_nodes) = get_existing_cluster(