aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2012-09-07 17:08:36 -0700
committerDenny <dennybritz@gmail.com>2012-09-10 12:48:58 -0700
commita13780670d8810a9fb52d8cc4e42d3c5155a8d1d (patch)
tree9589113ae4afbedce2c2d99d2886306c79d62312
parentf2ac55840c8b56d5ab677b6b5d37458c7ddc83a9 (diff)
downloadspark-a13780670d8810a9fb52d8cc4e42d3c5155a8d1d.tar.gz
spark-a13780670d8810a9fb52d8cc4e42d3c5155a8d1d.tar.bz2
spark-a13780670d8810a9fb52d8cc4e42d3c5155a8d1d.zip
Added a unit test for local-cluster mode and simplified some of the code involved in that
-rw-r--r--core/src/main/scala/spark/SparkContext.scala8
-rw-r--r--core/src/main/scala/spark/broadcast/HttpBroadcast.scala9
-rw-r--r--core/src/main/scala/spark/deploy/LocalSparkCluster.scala43
-rw-r--r--core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala29
-rw-r--r--core/src/main/scala/spark/deploy/worker/Worker.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala6
-rw-r--r--core/src/test/scala/spark/DistributedSuite.scala68
-rwxr-xr-xrun1
8 files changed, 120 insertions, 48 deletions
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index d7bd832e52..5d0f2950d6 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -67,7 +67,7 @@ class SparkContext(
System.setProperty("spark.master.port", "0")
}
- private val isLocal = (master == "local" || master.startsWith("local["))
+ private val isLocal = (master == "local" || master.startsWith("local\\["))
// Create the Spark execution environment (cache, map output tracker, etc)
val env = SparkEnv.createFromSystemProperties(
@@ -84,7 +84,7 @@ class SparkContext(
// Regular expression for local[N, maxRetries], used in tests with failing tasks
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+),([0-9]+)\]""".r
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
- val SPARK_LOCALCLUSTER_REGEX = """local-cluster\[([0-9]+)\,([0-9]+),([0-9]+)]""".r
+ val LOCAL_CLUSTER_REGEX = """local-cluster\[([0-9]+),([0-9]+),([0-9]+)]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """(spark://.*)""".r
@@ -104,13 +104,13 @@ class SparkContext(
scheduler.initialize(backend)
scheduler
- case SPARK_LOCALCLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerlave) =>
+ case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerlave) =>
val scheduler = new ClusterScheduler(this)
val localCluster = new LocalSparkCluster(numSlaves.toInt, coresPerSlave.toInt, memoryPerlave.toInt)
val sparkUrl = localCluster.start()
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, frameworkName)
scheduler.initialize(backend)
- backend.shutdownHook = (backend: SparkDeploySchedulerBackend) => {
+ backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
localCluster.stop()
}
scheduler
diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
index 03986ea756..eacf237508 100644
--- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
@@ -76,9 +76,12 @@ private object HttpBroadcast extends Logging {
}
def stop() {
- if (server != null) {
- server.stop()
- server = null
+ synchronized {
+ if (server != null) {
+ server.stop()
+ server = null
+ initialized = false
+ }
}
}
diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
index da74df4dcf..1591bfdeb6 100644
--- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
@@ -9,10 +9,8 @@ import spark.{Logging, Utils}
import scala.collection.mutable.ArrayBuffer
-class LocalSparkCluster(numSlaves : Int, coresPerSlave : Int,
- memoryPerSlave : Int) extends Logging {
+class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging {
- val threadPool = Utils.newDaemonFixedThreadPool(numSlaves + 1)
val localIpAddress = Utils.localIpAddress
var masterActor : ActorRef = _
@@ -24,35 +22,25 @@ class LocalSparkCluster(numSlaves : Int, coresPerSlave : Int,
val slaveActors = ArrayBuffer[ActorRef]()
def start() : String = {
-
logInfo("Starting a local Spark cluster with " + numSlaves + " slaves.")
/* Start the Master */
- val (masterActorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0)
+ val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0)
+ masterActorSystem = actorSystem
masterUrl = "spark://" + localIpAddress + ":" + masterPort
- threadPool.execute(new Runnable {
- def run() {
- val actor = masterActorSystem.actorOf(
- Props(new Master(localIpAddress, masterPort, 8080)), name = "Master")
- masterActor = actor
- masterActorSystem.awaitTermination()
- }
- })
+ val actor = masterActorSystem.actorOf(
+ Props(new Master(localIpAddress, masterPort, 0)), name = "Master")
+ masterActor = actor
/* Start the Slaves */
- (1 to numSlaves).foreach { slaveNum =>
+ for (slaveNum <- 1 to numSlaves) {
val (actorSystem, boundPort) =
AkkaUtils.createActorSystem("sparkWorker" + slaveNum, localIpAddress, 0)
slaveActorSystems += actorSystem
- threadPool.execute(new Runnable {
- def run() {
- val actor = actorSystem.actorOf(
- Props(new Worker(localIpAddress, boundPort, 8080 + slaveNum, coresPerSlave, memoryPerSlave, masterUrl)),
- name = "Worker")
- slaveActors += actor
- actorSystem.awaitTermination()
- }
- })
+ val actor = actorSystem.actorOf(
+ Props(new Worker(localIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)),
+ name = "Worker")
+ slaveActors += actor
}
return masterUrl
@@ -60,9 +48,10 @@ class LocalSparkCluster(numSlaves : Int, coresPerSlave : Int,
def stop() {
logInfo("Shutting down local Spark cluster.")
- masterActorSystem.shutdown()
+ // Stop the slaves before the master so they don't get upset that it disconnected
slaveActorSystems.foreach(_.shutdown())
+ slaveActorSystems.foreach(_.awaitTermination())
+ masterActorSystem.shutdown()
+ masterActorSystem.awaitTermination()
}
-
-
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
index 393f4a3ee6..1740a42a7e 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
@@ -29,6 +29,7 @@ class ExecutorRunner(
val fullId = jobId + "/" + execId
var workerThread: Thread = null
var process: Process = null
+ var shutdownHook: Thread = null
def start() {
workerThread = new Thread("ExecutorRunner for " + fullId) {
@@ -37,17 +38,16 @@ class ExecutorRunner(
workerThread.start()
// Shutdown hook that kills actors on shutdown.
- Runtime.getRuntime.addShutdownHook(
- new Thread() {
- override def run() {
- if(process != null) {
- logInfo("Shutdown Hook killing process.")
- process.destroy()
- process.waitFor()
- }
+ shutdownHook = new Thread() {
+ override def run() {
+ if (process != null) {
+ logInfo("Shutdown hook killing child process.")
+ process.destroy()
+ process.waitFor()
}
- })
-
+ }
+ }
+ Runtime.getRuntime.addShutdownHook(shutdownHook)
}
/** Stop this executor runner, including killing the process it launched */
@@ -58,8 +58,10 @@ class ExecutorRunner(
if (process != null) {
logInfo("Killing process!")
process.destroy()
+ process.waitFor()
}
worker ! ExecutorStateChanged(jobId, execId, ExecutorState.KILLED, None)
+ Runtime.getRuntime.removeShutdownHook(shutdownHook)
}
}
@@ -114,7 +116,12 @@ class ExecutorRunner(
val out = new FileOutputStream(file)
new Thread("redirect output to " + file) {
override def run() {
- Utils.copyStream(in, out, true)
+ try {
+ Utils.copyStream(in, out, true)
+ } catch {
+ case e: IOException =>
+ logInfo("Redirection to " + file + " closed: " + e.getMessage)
+ }
}
}.start()
}
diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala
index 0a80463c0b..175464d40d 100644
--- a/core/src/main/scala/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/spark/deploy/worker/Worker.scala
@@ -153,6 +153,10 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, mas
def generateWorkerId(): String = {
"worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port)
}
+
+ override def postStop() {
+ executors.values.foreach(_.kill())
+ }
}
object Worker {
diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index ec3ff38d5c..9093a329a3 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -16,7 +16,7 @@ class SparkDeploySchedulerBackend(
var client: Client = null
var stopping = false
- var shutdownHook : (SparkDeploySchedulerBackend) => Unit = _
+ var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _
val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt
@@ -62,8 +62,8 @@ class SparkDeploySchedulerBackend(
stopping = true;
super.stop()
client.stop()
- if (shutdownHook != null) {
- shutdownHook(this)
+ if (shutdownCallback != null) {
+ shutdownCallback(this)
}
}
diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala
new file mode 100644
index 0000000000..b7b8a79327
--- /dev/null
+++ b/core/src/test/scala/spark/DistributedSuite.scala
@@ -0,0 +1,68 @@
+package spark
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.prop.Checkers
+import org.scalacheck.Arbitrary._
+import org.scalacheck.Gen
+import org.scalacheck.Prop._
+
+import com.google.common.io.Files
+
+import scala.collection.mutable.ArrayBuffer
+
+import SparkContext._
+
+class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
+
+ val clusterUrl = "local-cluster[2,1,512]"
+
+ var sc: SparkContext = _
+
+ after {
+ if (sc != null) {
+ sc.stop()
+ sc = null
+ }
+ }
+
+ test("simple groupByKey") {
+ sc = new SparkContext(clusterUrl, "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 5)
+ val groups = pairs.groupByKey(5).collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("accumulators") {
+ sc = new SparkContext(clusterUrl, "test")
+ val accum = sc.accumulator(0)
+ sc.parallelize(1 to 10, 10).foreach(x => accum += x)
+ assert(accum.value === 55)
+ }
+
+ test("broadcast variables") {
+ sc = new SparkContext(clusterUrl, "test")
+ val array = new Array[Int](100)
+ val bv = sc.broadcast(array)
+ array(2) = 3 // Change the array -- this should not be seen on workers
+ val rdd = sc.parallelize(1 to 10, 10)
+ val sum = rdd.map(x => bv.value.sum).reduce(_ + _)
+ assert(sum === 0)
+ }
+
+ test("repeatedly failing task") {
+ sc = new SparkContext(clusterUrl, "test")
+ val accum = sc.accumulator(0)
+ val thrown = intercept[SparkException] {
+ sc.parallelize(1 to 10, 10).foreach(x => println(x / 0))
+ }
+ assert(thrown.getClass === classOf[SparkException])
+ assert(thrown.getMessage.contains("more than 4 times"))
+ }
+}
+
diff --git a/run b/run
index 8f7256b4e5..2946a04d3f 100755
--- a/run
+++ b/run
@@ -52,6 +52,7 @@ CLASSPATH="$SPARK_CLASSPATH"
CLASSPATH+=":$MESOS_CLASSPATH"
CLASSPATH+=":$FWDIR/conf"
CLASSPATH+=":$CORE_DIR/target/scala-$SCALA_VERSION/classes"
+CLASSPATH+=":$CORE_DIR/target/scala-$SCALA_VERSION/test-classes"
CLASSPATH+=":$CORE_DIR/src/main/resources"
CLASSPATH+=":$REPL_DIR/target/scala-$SCALA_VERSION/classes"
CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes"