aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bagel/pom.xml11
-rw-r--r--bagel/src/test/scala/bagel/BagelSuite.scala2
-rwxr-xr-xbin/start-master.sh3
-rw-r--r--core/pom.xml16
-rw-r--r--core/src/main/scala/spark/Accumulators.scala3
-rw-r--r--core/src/main/scala/spark/CacheManager.scala65
-rw-r--r--core/src/main/scala/spark/CacheTracker.scala240
-rw-r--r--core/src/main/scala/spark/DaemonThreadFactory.scala18
-rw-r--r--core/src/main/scala/spark/Dependency.scala10
-rw-r--r--core/src/main/scala/spark/HttpFileServer.scala8
-rw-r--r--core/src/main/scala/spark/HttpServer.scala9
-rw-r--r--core/src/main/scala/spark/KryoSerializer.scala5
-rw-r--r--core/src/main/scala/spark/Logging.scala3
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala16
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala24
-rw-r--r--core/src/main/scala/spark/ParallelCollection.scala15
-rw-r--r--core/src/main/scala/spark/RDD.scala192
-rw-r--r--core/src/main/scala/spark/RDDCheckpointData.scala19
-rw-r--r--core/src/main/scala/spark/SparkContext.scala188
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala57
-rw-r--r--core/src/main/scala/spark/SparkFiles.java25
-rw-r--r--core/src/main/scala/spark/TaskContext.scala3
-rw-r--r--core/src/main/scala/spark/Utils.scala77
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala29
-rw-r--r--core/src/main/scala/spark/api/java/JavaSparkContext.scala29
-rw-r--r--core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java20
-rw-r--r--core/src/main/scala/spark/api/java/StorageLevels.java11
-rw-r--r--core/src/main/scala/spark/api/python/PythonPartitioner.scala13
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala60
-rw-r--r--core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala24
-rw-r--r--core/src/main/scala/spark/broadcast/Broadcast.scala6
-rw-r--r--core/src/main/scala/spark/broadcast/BroadcastFactory.scala4
-rw-r--r--core/src/main/scala/spark/broadcast/HttpBroadcast.scala6
-rw-r--r--core/src/main/scala/spark/broadcast/MultiTracker.scala35
-rw-r--r--core/src/main/scala/spark/broadcast/TreeBroadcast.scala52
-rw-r--r--core/src/main/scala/spark/deploy/DeployMessage.scala4
-rw-r--r--core/src/main/scala/spark/deploy/JobDescription.scala3
-rw-r--r--core/src/main/scala/spark/deploy/LocalSparkCluster.scala63
-rw-r--r--core/src/main/scala/spark/deploy/client/Client.scala13
-rw-r--r--core/src/main/scala/spark/deploy/client/ClientListener.scala4
-rw-r--r--core/src/main/scala/spark/deploy/client/TestClient.scala2
-rw-r--r--core/src/main/scala/spark/deploy/master/JobInfo.scala2
-rw-r--r--core/src/main/scala/spark/deploy/master/Master.scala58
-rw-r--r--core/src/main/scala/spark/deploy/master/MasterWebUI.scala28
-rw-r--r--core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala12
-rw-r--r--core/src/main/scala/spark/deploy/worker/Worker.scala81
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala6
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala38
-rw-r--r--core/src/main/scala/spark/executor/MesosExecutorBackend.scala7
-rw-r--r--core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala53
-rw-r--r--core/src/main/scala/spark/network/Connection.scala15
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala8
-rw-r--r--core/src/main/scala/spark/partial/ApproximateActionListener.scala4
-rw-r--r--core/src/main/scala/spark/rdd/BlockRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/CartesianRDD.scala13
-rw-r--r--core/src/main/scala/spark/rdd/CheckpointRDD.scala61
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala16
-rw-r--r--core/src/main/scala/spark/rdd/CoalescedRDD.scala14
-rw-r--r--core/src/main/scala/spark/rdd/MappedRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/PartitionPruningRDD.scala42
-rw-r--r--core/src/main/scala/spark/rdd/SampledRDD.scala5
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala9
-rw-r--r--core/src/main/scala/spark/rdd/UnionRDD.scala15
-rw-r--r--core/src/main/scala/spark/rdd/ZippedRDD.scala8
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala390
-rw-r--r--core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/JobResult.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/JobWaiter.scala14
-rw-r--r--core/src/main/scala/spark/scheduler/MapStatus.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala49
-rw-r--r--core/src/main/scala/spark/scheduler/Stage.scala14
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala105
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala12
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala44
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala22
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala102
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala7
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala57
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala12
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala32
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala56
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala89
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerId.scala51
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala91
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMasterActor.scala77
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMessages.scala11
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerUI.scala76
-rw-r--r--core/src/main/scala/spark/storage/BlockMessage.scala2
-rw-r--r--core/src/main/scala/spark/storage/StorageLevel.scala64
-rw-r--r--core/src/main/scala/spark/storage/StorageUtils.scala82
-rw-r--r--core/src/main/scala/spark/storage/ThreadingTest.scala9
-rw-r--r--core/src/main/scala/spark/util/AkkaUtils.scala19
-rw-r--r--core/src/main/scala/spark/util/MetadataCleaner.scala32
-rw-r--r--core/src/main/scala/spark/util/TimeStampedHashMap.scala4
-rw-r--r--core/src/main/twirl/spark/common/layout.scala.html (renamed from core/src/main/twirl/spark/deploy/common/layout.scala.html)0
-rw-r--r--core/src/main/twirl/spark/deploy/master/index.scala.html2
-rw-r--r--core/src/main/twirl/spark/deploy/master/job_details.scala.html2
-rw-r--r--core/src/main/twirl/spark/deploy/worker/index.scala.html3
-rw-r--r--core/src/main/twirl/spark/storage/index.scala.html40
-rw-r--r--core/src/main/twirl/spark/storage/rdd.scala.html81
-rw-r--r--core/src/main/twirl/spark/storage/rdd_table.scala.html32
-rw-r--r--core/src/main/twirl/spark/storage/worker_table.scala.html24
-rw-r--r--core/src/test/scala/spark/AccumulatorSuite.scala38
-rw-r--r--core/src/test/scala/spark/BroadcastSuite.scala14
-rw-r--r--core/src/test/scala/spark/CacheTrackerSuite.scala131
-rw-r--r--core/src/test/scala/spark/CheckpointSuite.scala40
-rw-r--r--core/src/test/scala/spark/ClosureCleanerSuite.scala73
-rw-r--r--core/src/test/scala/spark/DistributedSuite.scala92
-rw-r--r--core/src/test/scala/spark/DriverSuite.scala33
-rw-r--r--core/src/test/scala/spark/FailureSuite.scala14
-rw-r--r--core/src/test/scala/spark/FileServerSuite.scala29
-rw-r--r--core/src/test/scala/spark/FileSuite.scala16
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java30
-rw-r--r--core/src/test/scala/spark/LocalSparkContext.scala41
-rw-r--r--core/src/test/scala/spark/MapOutputTrackerSuite.scala75
-rw-r--r--core/src/test/scala/spark/PartitioningSuite.scala15
-rw-r--r--core/src/test/scala/spark/PipedRDDSuite.scala16
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala49
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala14
-rw-r--r--core/src/test/scala/spark/SortingSuite.scala13
-rw-r--r--core/src/test/scala/spark/ThreadingSuite.scala14
-rw-r--r--core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala663
-rw-r--r--core/src/test/scala/spark/scheduler/TaskContextSuite.scala14
-rw-r--r--core/src/test/scala/spark/storage/BlockManagerSuite.scala132
-rw-r--r--docs/configuration.md26
-rw-r--r--docs/java-programming-guide.md3
-rw-r--r--docs/python-programming-guide.md11
-rw-r--r--docs/scala-programming-guide.md3
-rw-r--r--docs/spark-standalone.md43
-rw-r--r--examples/pom.xml28
-rw-r--r--pom.xml40
-rw-r--r--project/SparkBuild.scala15
-rw-r--r--python/epydoc.conf2
-rw-r--r--python/pyspark/__init__.py5
-rw-r--r--python/pyspark/accumulators.py48
-rw-r--r--python/pyspark/broadcast.py9
-rw-r--r--python/pyspark/context.py139
-rw-r--r--python/pyspark/files.py38
-rw-r--r--python/pyspark/rdd.py77
-rw-r--r--python/pyspark/shell.py1
-rw-r--r--python/pyspark/tests.py121
-rw-r--r--python/pyspark/worker.py20
-rwxr-xr-xpython/run-tests6
-rwxr-xr-xpython/test_support/hello.txt1
-rwxr-xr-xpython/test_support/userlibrary.py7
-rw-r--r--repl-bin/pom.xml11
-rw-r--r--repl/pom.xml35
-rw-r--r--repl/src/test/scala/spark/repl/ReplSuite.scala2
-rwxr-xr-xrun8
-rwxr-xr-xsbt/sbt2
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar (renamed from streaming/lib/kafka-0.7.2.jar)bin1358063 -> 1358063 bytes
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md51
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha11
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom9
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md51
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha11
-rw-r--r--streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml12
-rw-r--r--streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md51
-rw-r--r--streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha11
-rw-r--r--streaming/pom.xml144
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala8
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala8
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala4
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala5
-rw-r--r--streaming/src/test/java/spark/streaming/JavaAPISuite.java (renamed from streaming/src/test/java/JavaAPISuite.java)2
-rw-r--r--streaming/src/test/java/spark/streaming/JavaTestUtils.scala (renamed from streaming/src/test/java/JavaTestUtils.scala)0
-rw-r--r--streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala2
-rw-r--r--streaming/src/test/scala/spark/streaming/CheckpointSuite.scala2
-rw-r--r--streaming/src/test/scala/spark/streaming/FailureSuite.scala2
-rw-r--r--streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala2
-rw-r--r--streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala2
177 files changed, 3876 insertions, 2220 deletions
diff --git a/bagel/pom.xml b/bagel/pom.xml
index 5f58347204..a8256a6e8b 100644
--- a/bagel/pom.xml
+++ b/bagel/pom.xml
@@ -45,11 +45,6 @@
<profiles>
<profile>
<id>hadoop1</id>
- <activation>
- <property>
- <name>!hadoopVersion</name>
- </property>
- </activation>
<dependencies>
<dependency>
<groupId>org.spark-project</groupId>
@@ -77,12 +72,6 @@
</profile>
<profile>
<id>hadoop2</id>
- <activation>
- <property>
- <name>hadoopVersion</name>
- <value>2</value>
- </property>
- </activation>
<dependencies>
<dependency>
<groupId>org.spark-project</groupId>
diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala
index ca59f46843..3c2f9c4616 100644
--- a/bagel/src/test/scala/bagel/BagelSuite.scala
+++ b/bagel/src/test/scala/bagel/BagelSuite.scala
@@ -23,7 +23,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter {
sc = null
}
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ System.clearProperty("spark.driver.port")
}
test("halting by voting") {
diff --git a/bin/start-master.sh b/bin/start-master.sh
index a901b1c260..87feb261fe 100755
--- a/bin/start-master.sh
+++ b/bin/start-master.sh
@@ -26,7 +26,8 @@ fi
# Set SPARK_PUBLIC_DNS so the master report the correct webUI address to the slaves
if [ "$SPARK_PUBLIC_DNS" = "" ]; then
# If we appear to be running on EC2, use the public address by default:
- if [[ `hostname` == *ec2.internal ]]; then
+ # NOTE: ec2-metadata is installed on Amazon Linux AMI. Check based on that and hostname
+ if command -v ec2-metadata > /dev/null || [[ `hostname` == *ec2.internal ]]; then
export SPARK_PUBLIC_DNS=`wget -q -O - http://instance-data.ec2.internal/latest/meta-data/public-hostname`
fi
fi
diff --git a/core/pom.xml b/core/pom.xml
index 862d3ec37a..66c62151fe 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -99,6 +99,11 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.easymock</groupId>
+ <artifactId>easymock</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>com.novocode</groupId>
<artifactId>junit-interface</artifactId>
<scope>test</scope>
@@ -163,11 +168,6 @@
<profiles>
<profile>
<id>hadoop1</id>
- <activation>
- <property>
- <name>!hadoopVersion</name>
- </property>
- </activation>
<dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
@@ -220,12 +220,6 @@
</profile>
<profile>
<id>hadoop2</id>
- <activation>
- <property>
- <name>hadoopVersion</name>
- <value>2</value>
- </property>
- </activation>
<dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala
index b644aba5f8..57c6df35be 100644
--- a/core/src/main/scala/spark/Accumulators.scala
+++ b/core/src/main/scala/spark/Accumulators.scala
@@ -25,8 +25,7 @@ class Accumulable[R, T] (
extends Serializable {
val id = Accumulators.newId
- @transient
- private var value_ = initialValue // Current value on master
+ @transient private var value_ = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
var deserialized = false
diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala
new file mode 100644
index 0000000000..711435c333
--- /dev/null
+++ b/core/src/main/scala/spark/CacheManager.scala
@@ -0,0 +1,65 @@
+package spark
+
+import scala.collection.mutable.{ArrayBuffer, HashSet}
+import spark.storage.{BlockManager, StorageLevel}
+
+
+/** Spark class responsible for passing RDDs split contents to the BlockManager and making
+ sure a node doesn't load two copies of an RDD at once.
+ */
+private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
+ private val loading = new HashSet[String]
+
+ /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
+ def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
+ : Iterator[T] = {
+ val key = "rdd_%d_%d".format(rdd.id, split.index)
+ logInfo("Cache key is " + key)
+ blockManager.get(key) match {
+ case Some(cachedValues) =>
+ // Split is in cache, so just return its values
+ logInfo("Found partition in cache!")
+ return cachedValues.asInstanceOf[Iterator[T]]
+
+ case None =>
+ // Mark the split as loading (unless someone else marks it first)
+ loading.synchronized {
+ if (loading.contains(key)) {
+ logInfo("Loading contains " + key + ", waiting...")
+ while (loading.contains(key)) {
+ try {loading.wait()} catch {case _ =>}
+ }
+ logInfo("Loading no longer contains " + key + ", so returning cached result")
+ // See whether someone else has successfully loaded it. The main way this would fail
+ // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
+ // partition but we didn't want to make space for it. However, that case is unlikely
+ // because it's unlikely that two threads would work on the same RDD partition. One
+ // downside of the current code is that threads wait serially if this does happen.
+ blockManager.get(key) match {
+ case Some(values) =>
+ return values.asInstanceOf[Iterator[T]]
+ case None =>
+ logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
+ loading.add(key)
+ }
+ } else {
+ loading.add(key)
+ }
+ }
+ try {
+ // If we got here, we have to load the split
+ val elements = new ArrayBuffer[Any]
+ logInfo("Computing partition " + split)
+ elements ++= rdd.computeOrReadCheckpoint(split, context)
+ // Try to put this block in the blockManager
+ blockManager.put(key, elements, storageLevel, true)
+ return elements.iterator.asInstanceOf[Iterator[T]]
+ } finally {
+ loading.synchronized {
+ loading.remove(key)
+ loading.notifyAll()
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
deleted file mode 100644
index 86ad737583..0000000000
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ /dev/null
@@ -1,240 +0,0 @@
-package spark
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-
-import akka.actor._
-import akka.dispatch._
-import akka.pattern.ask
-import akka.remote._
-import akka.util.Duration
-import akka.util.Timeout
-import akka.util.duration._
-
-import spark.storage.BlockManager
-import spark.storage.StorageLevel
-import util.{TimeStampedHashSet, MetadataCleaner, TimeStampedHashMap}
-
-private[spark] sealed trait CacheTrackerMessage
-
-private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
- extends CacheTrackerMessage
-private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
- extends CacheTrackerMessage
-private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage
-private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
-private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
-private[spark] case object GetCacheStatus extends CacheTrackerMessage
-private[spark] case object GetCacheLocations extends CacheTrackerMessage
-private[spark] case object StopCacheTracker extends CacheTrackerMessage
-
-private[spark] class CacheTrackerActor extends Actor with Logging {
- // TODO: Should probably store (String, CacheType) tuples
- private val locs = new TimeStampedHashMap[Int, Array[List[String]]]
-
- /**
- * A map from the slave's host name to its cache size.
- */
- private val slaveCapacity = new HashMap[String, Long]
- private val slaveUsage = new HashMap[String, Long]
-
- private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.clearOldValues)
-
- private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
- private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
- private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
-
- def receive = {
- case SlaveCacheStarted(host: String, size: Long) =>
- slaveCapacity.put(host, size)
- slaveUsage.put(host, 0)
- sender ! true
-
- case RegisterRDD(rddId: Int, numPartitions: Int) =>
- logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
- locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
- sender ! true
-
- case AddedToCache(rddId, partition, host, size) =>
- slaveUsage.put(host, getCacheUsage(host) + size)
- locs(rddId)(partition) = host :: locs(rddId)(partition)
- sender ! true
-
- case DroppedFromCache(rddId, partition, host, size) =>
- slaveUsage.put(host, getCacheUsage(host) - size)
- // Do a sanity check to make sure usage is greater than 0.
- locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
- sender ! true
-
- case MemoryCacheLost(host) =>
- logInfo("Memory cache lost on " + host)
- for ((id, locations) <- locs) {
- for (i <- 0 until locations.length) {
- locations(i) = locations(i).filterNot(_ == host)
- }
- }
- sender ! true
-
- case GetCacheLocations =>
- logInfo("Asked for current cache locations")
- sender ! locs.map{case (rrdId, array) => (rrdId -> array.clone())}
-
- case GetCacheStatus =>
- val status = slaveCapacity.map { case (host, capacity) =>
- (host, capacity, getCacheUsage(host))
- }.toSeq
- sender ! status
-
- case StopCacheTracker =>
- logInfo("Stopping CacheTrackerActor")
- sender ! true
- metadataCleaner.cancel()
- context.stop(self)
- }
-}
-
-private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
- extends Logging {
-
- // Tracker actor on the master, or remote reference to it on workers
- val ip: String = System.getProperty("spark.master.host", "localhost")
- val port: Int = System.getProperty("spark.master.port", "7077").toInt
- val actorName: String = "CacheTracker"
-
- val timeout = 10.seconds
-
- var trackerActor: ActorRef = if (isMaster) {
- val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName)
- logInfo("Registered CacheTrackerActor actor")
- actor
- } else {
- val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
- actorSystem.actorFor(url)
- }
-
- // TODO: Consider removing this HashSet completely as locs CacheTrackerActor already
- // keeps track of registered RDDs
- val registeredRddIds = new TimeStampedHashSet[Int]
-
- // Remembers which splits are currently being loaded (on worker nodes)
- val loading = new HashSet[String]
-
- val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.clearOldValues)
-
- // Send a message to the trackerActor and get its result within a default timeout, or
- // throw a SparkException if this fails.
- def askTracker(message: Any): Any = {
- try {
- val future = trackerActor.ask(message)(timeout)
- return Await.result(future, timeout)
- } catch {
- case e: Exception =>
- throw new SparkException("Error communicating with CacheTracker", e)
- }
- }
-
- // Send a one-way message to the trackerActor, to which we expect it to reply with true.
- def communicate(message: Any) {
- if (askTracker(message) != true) {
- throw new SparkException("Error reply received from CacheTracker")
- }
- }
-
- // Registers an RDD (on master only)
- def registerRDD(rddId: Int, numPartitions: Int) {
- registeredRddIds.synchronized {
- if (!registeredRddIds.contains(rddId)) {
- logInfo("Registering RDD ID " + rddId + " with cache")
- registeredRddIds += rddId
- communicate(RegisterRDD(rddId, numPartitions))
- }
- }
- }
-
- // For BlockManager.scala only
- def cacheLost(host: String) {
- communicate(MemoryCacheLost(host))
- logInfo("CacheTracker successfully removed entries on " + host)
- }
-
- // Get the usage status of slave caches. Each tuple in the returned sequence
- // is in the form of (host name, capacity, usage).
- def getCacheStatus(): Seq[(String, Long, Long)] = {
- askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]]
- }
-
- // For BlockManager.scala only
- def notifyFromBlockManager(t: AddedToCache) {
- communicate(t)
- }
-
- // Get a snapshot of the currently known locations
- def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
- askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
- }
-
- // Gets or computes an RDD split
- def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
- : Iterator[T] = {
- val key = "rdd_%d_%d".format(rdd.id, split.index)
- logInfo("Cache key is " + key)
- blockManager.get(key) match {
- case Some(cachedValues) =>
- // Split is in cache, so just return its values
- logInfo("Found partition in cache!")
- return cachedValues.asInstanceOf[Iterator[T]]
-
- case None =>
- // Mark the split as loading (unless someone else marks it first)
- loading.synchronized {
- if (loading.contains(key)) {
- logInfo("Loading contains " + key + ", waiting...")
- while (loading.contains(key)) {
- try {loading.wait()} catch {case _ =>}
- }
- logInfo("Loading no longer contains " + key + ", so returning cached result")
- // See whether someone else has successfully loaded it. The main way this would fail
- // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
- // partition but we didn't want to make space for it. However, that case is unlikely
- // because it's unlikely that two threads would work on the same RDD partition. One
- // downside of the current code is that threads wait serially if this does happen.
- blockManager.get(key) match {
- case Some(values) =>
- return values.asInstanceOf[Iterator[T]]
- case None =>
- logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
- loading.add(key)
- }
- } else {
- loading.add(key)
- }
- }
- try {
- // If we got here, we have to load the split
- val elements = new ArrayBuffer[Any]
- logInfo("Computing partition " + split)
- elements ++= rdd.compute(split, context)
- // Try to put this block in the blockManager
- blockManager.put(key, elements, storageLevel, true)
- return elements.iterator.asInstanceOf[Iterator[T]]
- } finally {
- loading.synchronized {
- loading.remove(key)
- loading.notifyAll()
- }
- }
- }
- }
-
- // Called by the Cache to report that an entry has been dropped from it
- def dropEntry(rddId: Int, partition: Int) {
- communicate(DroppedFromCache(rddId, partition, Utils.localHostName()))
- }
-
- def stop() {
- communicate(StopCacheTracker)
- registeredRddIds.clear()
- trackerActor = null
- }
-}
diff --git a/core/src/main/scala/spark/DaemonThreadFactory.scala b/core/src/main/scala/spark/DaemonThreadFactory.scala
deleted file mode 100644
index 56e59adeb7..0000000000
--- a/core/src/main/scala/spark/DaemonThreadFactory.scala
+++ /dev/null
@@ -1,18 +0,0 @@
-package spark
-
-import java.util.concurrent.ThreadFactory
-
-/**
- * A ThreadFactory that creates daemon threads
- */
-private object DaemonThreadFactory extends ThreadFactory {
- override def newThread(r: Runnable): Thread = new DaemonThread(r)
-}
-
-private class DaemonThread(r: Runnable = null) extends Thread {
- override def run() {
- if (r != null) {
- r.run()
- }
- }
-} \ No newline at end of file
diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala
index b85d2732db..5eea907322 100644
--- a/core/src/main/scala/spark/Dependency.scala
+++ b/core/src/main/scala/spark/Dependency.scala
@@ -5,6 +5,7 @@ package spark
*/
abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
+
/**
* Base class for dependencies where each partition of the parent RDD is used by at most one
* partition of the child RDD. Narrow dependencies allow for pipelined execution.
@@ -12,12 +13,13 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
/**
* Get the parent partitions for a child partition.
- * @param outputPartition a partition of the child RDD
+ * @param partitionId a partition of the child RDD
* @return the partitions of the parent RDD that the child partition depends upon
*/
- def getParents(outputPartition: Int): Seq[Int]
+ def getParents(partitionId: Int): Seq[Int]
}
+
/**
* Represents a dependency on the output of a shuffle stage.
* @param shuffleId the shuffle id
@@ -32,6 +34,7 @@ class ShuffleDependency[K, V](
val shuffleId: Int = rdd.context.newShuffleId()
}
+
/**
* Represents a one-to-one dependency between partitions of the parent and child RDDs.
*/
@@ -39,6 +42,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
override def getParents(partitionId: Int) = List(partitionId)
}
+
/**
* Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs.
* @param rdd the parent RDD
@@ -48,7 +52,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
*/
class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int)
extends NarrowDependency[T](rdd) {
-
+
override def getParents(partitionId: Int) = {
if (partitionId >= outStart && partitionId < outStart + length) {
List(partitionId - outStart + inStart)
diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala
index 659d17718f..00901d95e2 100644
--- a/core/src/main/scala/spark/HttpFileServer.scala
+++ b/core/src/main/scala/spark/HttpFileServer.scala
@@ -1,9 +1,7 @@
package spark
-import java.io.{File, PrintWriter}
-import java.net.URL
-import scala.collection.mutable.HashMap
-import org.apache.hadoop.fs.FileUtil
+import java.io.{File}
+import com.google.common.io.Files
private[spark] class HttpFileServer extends Logging {
@@ -40,7 +38,7 @@ private[spark] class HttpFileServer extends Logging {
}
def addFileToDir(file: File, dir: File) : String = {
- Utils.copyFile(file, new File(dir, file.getName))
+ Files.copy(file, new File(dir, file.getName))
return dir + "/" + file.getName
}
diff --git a/core/src/main/scala/spark/HttpServer.scala b/core/src/main/scala/spark/HttpServer.scala
index 0196595ba1..4e0507c080 100644
--- a/core/src/main/scala/spark/HttpServer.scala
+++ b/core/src/main/scala/spark/HttpServer.scala
@@ -4,6 +4,7 @@ import java.io.File
import java.net.InetAddress
import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.server.bio.SocketConnector
import org.eclipse.jetty.server.handler.DefaultHandler
import org.eclipse.jetty.server.handler.HandlerList
import org.eclipse.jetty.server.handler.ResourceHandler
@@ -27,7 +28,13 @@ private[spark] class HttpServer(resourceBase: File) extends Logging {
if (server != null) {
throw new ServerStateException("Server is already started")
} else {
- server = new Server(0)
+ server = new Server()
+ val connector = new SocketConnector
+ connector.setMaxIdleTime(60*1000)
+ connector.setSoLingerTime(-1)
+ connector.setPort(0)
+ server.addConnector(connector)
+
val threadPool = new QueuedThreadPool
threadPool.setDaemon(true)
server.setThreadPool(threadPool)
diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala
index 93d7327324..0bd73e936b 100644
--- a/core/src/main/scala/spark/KryoSerializer.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -206,5 +206,8 @@ class KryoSerializer extends spark.serializer.Serializer with Logging {
kryo
}
- def newInstance(): SerializerInstance = new KryoSerializerInstance(this)
+ def newInstance(): SerializerInstance = {
+ this.kryo.get().setClassLoader(Thread.currentThread().getContextClassLoader)
+ new KryoSerializerInstance(this)
+ }
}
diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala
index 90bae26202..7c1c1bb144 100644
--- a/core/src/main/scala/spark/Logging.scala
+++ b/core/src/main/scala/spark/Logging.scala
@@ -11,8 +11,7 @@ import org.slf4j.LoggerFactory
trait Logging {
// Make the log field transient so that objects with Logging can
// be serialized and used on another machine
- @transient
- private var log_ : Logger = null
+ @transient private var log_ : Logger = null
// Method to get or create the logger for this object
protected def log: Logger = {
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index ac02f3363a..4735207585 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -38,10 +38,7 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
}
}
-private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging {
- val ip: String = System.getProperty("spark.master.host", "localhost")
- val port: Int = System.getProperty("spark.master.port", "7077").toInt
- val actorName: String = "MapOutputTracker"
+private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolean) extends Logging {
val timeout = 10.seconds
@@ -56,11 +53,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
var cacheGeneration = generation
val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
- var trackerActor: ActorRef = if (isMaster) {
+ val actorName: String = "MapOutputTracker"
+ var trackerActor: ActorRef = if (isDriver) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
logInfo("Registered MapOutputTrackerActor actor")
actor
} else {
+ val ip = System.getProperty("spark.driver.host", "localhost")
+ val port = System.getProperty("spark.driver.port", "7077").toInt
val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
actorSystem.actorFor(url)
}
@@ -114,7 +114,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
var array = mapStatuses(shuffleId)
if (array != null) {
array.synchronized {
- if (array(mapId) != null && array(mapId).address == bmAddress) {
+ if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
}
}
@@ -170,7 +170,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
}
}
- def cleanup(cleanupTime: Long) {
+ private def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}
@@ -277,7 +277,7 @@ private[spark] object MapOutputTracker {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing an output location for shuffle " + shuffleId))
} else {
- (status.address, decompressSize(status.compressedSizes(reduceId)))
+ (status.location, decompressSize(status.compressedSizes(reduceId)))
}
}
}
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index abb01c387c..cc3cca2571 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -465,7 +465,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
val res = self.context.runJob(self, process _, Array(index), false)
res(0)
case None =>
- self.filter(_._1 == key).map(_._2).collect
+ self.filter(_._1 == key).map(_._2).collect()
}
}
@@ -493,20 +493,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
path: String,
keyClass: Class[_],
valueClass: Class[_],
- outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) {
- saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, new Configuration)
- }
-
- /**
- * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
- * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
- */
- def saveAsNewAPIHadoopFile(
- path: String,
- keyClass: Class[_],
- valueClass: Class[_],
outputFormatClass: Class[_ <: NewOutputFormat[_, _]],
- conf: Configuration) {
+ conf: Configuration = self.context.hadoopConfiguration) {
val job = new NewAPIHadoopJob(conf)
job.setOutputKeyClass(keyClass)
job.setOutputValueClass(valueClass)
@@ -557,7 +545,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: OutputFormat[_, _]],
- conf: JobConf = new JobConf) {
+ conf: JobConf = new JobConf(self.context.hadoopConfiguration)) {
conf.setOutputKeyClass(keyClass)
conf.setOutputValueClass(valueClass)
// conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug
@@ -602,7 +590,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
var count = 0
while(iter.hasNext) {
- val record = iter.next
+ val record = iter.next()
count += 1
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
}
@@ -661,9 +649,7 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
}
private[spark]
-class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U)
- extends RDD[(K, U)](prev) {
-
+class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev) {
override def getSplits = firstParent[(K, V)].splits
override val partitioner = firstParent[(K, V)].partitioner
override def compute(split: Split, context: TaskContext) =
diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala
index ede933c9e9..10adcd53ec 100644
--- a/core/src/main/scala/spark/ParallelCollection.scala
+++ b/core/src/main/scala/spark/ParallelCollection.scala
@@ -23,32 +23,28 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest](
}
private[spark] class ParallelCollection[T: ClassManifest](
- @transient sc : SparkContext,
+ @transient sc: SparkContext,
@transient data: Seq[T],
numSlices: Int,
- locationPrefs : Map[Int,Seq[String]])
+ locationPrefs: Map[Int,Seq[String]])
extends RDD[T](sc, Nil) {
// TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
// cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
// instead.
// UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.
- @transient
- var splits_ : Array[Split] = {
+ @transient var splits_ : Array[Split] = {
val slices = ParallelCollection.slice(data, numSlices).toArray
slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray
}
- override def getSplits = splits_.asInstanceOf[Array[Split]]
+ override def getSplits = splits_
override def compute(s: Split, context: TaskContext) =
s.asInstanceOf[ParallelCollectionSplit[T]].iterator
override def getPreferredLocations(s: Split): Seq[String] = {
- locationPrefs.get(s.index) match {
- case Some(s) => s
- case _ => Nil
- }
+ locationPrefs.getOrElse(s.index, Nil)
}
override def clearDependencies() {
@@ -56,7 +52,6 @@ private[spark] class ParallelCollection[T: ClassManifest](
}
}
-
private object ParallelCollection {
/**
* Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index e0d2eabb1d..9d6ea782bd 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -1,27 +1,17 @@
package spark
-import java.io.{ObjectOutputStream, IOException, EOFException, ObjectInputStream}
import java.net.URL
import java.util.{Date, Random}
import java.util.{HashMap => JHashMap}
-import java.util.concurrent.atomic.AtomicLong
import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
-import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
-import org.apache.hadoop.io.Writable
-import org.apache.hadoop.mapred.FileOutputCommitter
-import org.apache.hadoop.mapred.HadoopWriter
-import org.apache.hadoop.mapred.JobConf
-import org.apache.hadoop.mapred.OutputCommitter
-import org.apache.hadoop.mapred.OutputFormat
-import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.TextOutputFormat
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
@@ -30,7 +20,6 @@ import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
import spark.partial.GroupedCountEvaluator
import spark.partial.PartialResult
-import spark.rdd.BlockRDD
import spark.rdd.CartesianRDD
import spark.rdd.FilteredRDD
import spark.rdd.FlatMappedRDD
@@ -73,11 +62,11 @@ import SparkContext._
* on RDD internals.
*/
abstract class RDD[T: ClassManifest](
- @transient var sc: SparkContext,
- var dependencies_ : List[Dependency[_]]
+ @transient private var sc: SparkContext,
+ @transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {
-
+ /** Construct an RDD with just a one-to-one dependency on one parent */
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))
@@ -85,14 +74,20 @@ abstract class RDD[T: ClassManifest](
// Methods that should be implemented by subclasses of RDD
// =======================================================================
- /** Function for computing a given partition. */
+ /** Implemented by subclasses to compute a given partition. */
def compute(split: Split, context: TaskContext): Iterator[T]
- /** Set of partitions in this RDD. */
- protected def getSplits(): Array[Split]
+ /**
+ * Implemented by subclasses to return the set of partitions in this RDD. This method will only
+ * be called once, so it is safe to implement a time-consuming computation in it.
+ */
+ protected def getSplits: Array[Split]
- /** How this RDD depends on any parent RDDs. */
- protected def getDependencies(): List[Dependency[_]] = dependencies_
+ /**
+ * Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only
+ * be called once, so it is safe to implement a time-consuming computation in it.
+ */
+ protected def getDependencies: Seq[Dependency[_]] = deps
/** Optionally overridden by subclasses to specify placement preferences. */
protected def getPreferredLocations(split: Split): Seq[String] = Nil
@@ -100,7 +95,6 @@ abstract class RDD[T: ClassManifest](
/** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None
-
// =======================================================================
// Methods and fields available on all RDDs
// =======================================================================
@@ -108,6 +102,15 @@ abstract class RDD[T: ClassManifest](
/** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId()
+ /** A friendly name for this RDD */
+ var name: String = null
+
+ /** Assign a name to this RDD */
+ def setName(_name: String) = {
+ name = _name
+ this
+ }
+
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
@@ -119,6 +122,8 @@ abstract class RDD[T: ClassManifest](
"Cannot change storage level of an RDD after it was already assigned a level")
}
storageLevel = newLevel
+ // Register the RDD with the SparkContext
+ sc.persistentRdds(id) = this
this
}
@@ -131,15 +136,24 @@ abstract class RDD[T: ClassManifest](
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
+ // Our dependencies and splits will be gotten by calling subclass's methods below, and will
+ // be overwritten when we're checkpointed
+ private var dependencies_ : Seq[Dependency[_]] = null
+ @transient private var splits_ : Array[Split] = null
+
+ /** An Option holding our checkpoint RDD, if we are checkpointed */
+ private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD)
+
/**
- * Get the preferred location of a split, taking into account whether the
+ * Get the list of dependencies of this RDD, taking into account whether the
* RDD is checkpointed or not.
*/
- final def preferredLocations(split: Split): Seq[String] = {
- if (isCheckpointed) {
- checkpointData.get.getPreferredLocations(split)
- } else {
- getPreferredLocations(split)
+ final def dependencies: Seq[Dependency[_]] = {
+ checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse {
+ if (dependencies_ == null) {
+ dependencies_ = getDependencies
+ }
+ dependencies_
}
}
@@ -148,22 +162,21 @@ abstract class RDD[T: ClassManifest](
* RDD is checkpointed or not.
*/
final def splits: Array[Split] = {
- if (isCheckpointed) {
- checkpointData.get.getSplits
- } else {
- getSplits
+ checkpointRDD.map(_.splits).getOrElse {
+ if (splits_ == null) {
+ splits_ = getSplits
+ }
+ splits_
}
}
/**
- * Get the list of dependencies of this RDD, taking into account whether the
+ * Get the preferred location of a split, taking into account whether the
* RDD is checkpointed or not.
*/
- final def dependencies: List[Dependency[_]] = {
- if (isCheckpointed) {
- dependencies_
- } else {
- getDependencies
+ final def preferredLocations(split: Split): Seq[String] = {
+ checkpointRDD.map(_.getPreferredLocations(split)).getOrElse {
+ getPreferredLocations(split)
}
}
@@ -173,10 +186,19 @@ abstract class RDD[T: ClassManifest](
* subclasses of RDD.
*/
final def iterator(split: Split, context: TaskContext): Iterator[T] = {
+ if (storageLevel != StorageLevel.NONE) {
+ SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
+ } else {
+ computeOrReadCheckpoint(split, context)
+ }
+ }
+
+ /**
+ * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
+ */
+ private[spark] def computeOrReadCheckpoint(split: Split, context: TaskContext): Iterator[T] = {
if (isCheckpointed) {
- checkpointData.get.iterator(split, context)
- } else if (storageLevel != StorageLevel.NONE) {
- SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel)
+ firstParent[T].iterator(split, context)
} else {
compute(split, context)
}
@@ -363,20 +385,22 @@ abstract class RDD[T: ClassManifest](
val reducePartition: Iterator[T] => Option[T] = iter => {
if (iter.hasNext) {
Some(iter.reduceLeft(cleanF))
- }else {
+ } else {
None
}
}
- val options = sc.runJob(this, reducePartition)
- val results = new ArrayBuffer[T]
- for (opt <- options; elem <- opt) {
- results += elem
- }
- if (results.size == 0) {
- throw new UnsupportedOperationException("empty collection")
- } else {
- return results.reduceLeft(cleanF)
+ var jobResult: Option[T] = None
+ val mergeResult = (index: Int, taskResult: Option[T]) => {
+ if (taskResult != None) {
+ jobResult = jobResult match {
+ case Some(value) => Some(f(value, taskResult.get))
+ case None => taskResult
+ }
+ }
}
+ sc.runJob(this, reducePartition, mergeResult)
+ // Get the final result out of our Option, or throw an exception if the RDD was empty
+ jobResult.getOrElse(throw new UnsupportedOperationException("empty collection"))
}
/**
@@ -386,9 +410,13 @@ abstract class RDD[T: ClassManifest](
* modify t2.
*/
def fold(zeroValue: T)(op: (T, T) => T): T = {
+ // Clone the zero value since we will also be serializing it as part of tasks
+ var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanOp = sc.clean(op)
- val results = sc.runJob(this, (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp))
- return results.fold(zeroValue)(cleanOp)
+ val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp)
+ val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult)
+ sc.runJob(this, foldPartition, mergeResult)
+ jobResult
}
/**
@@ -400,11 +428,14 @@ abstract class RDD[T: ClassManifest](
* allocation.
*/
def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
+ // Clone the zero value since we will also be serializing it as part of tasks
+ var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanSeqOp = sc.clean(seqOp)
val cleanCombOp = sc.clean(combOp)
- val results = sc.runJob(this,
- (iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp))
- return results.fold(zeroValue)(cleanCombOp)
+ val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
+ val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
+ sc.runJob(this, aggregatePartition, mergeResult)
+ jobResult
}
/**
@@ -415,7 +446,7 @@ abstract class RDD[T: ClassManifest](
var result = 0L
while (iter.hasNext) {
result += 1L
- iter.next
+ iter.next()
}
result
}).sum
@@ -430,7 +461,7 @@ abstract class RDD[T: ClassManifest](
var result = 0L
while (iter.hasNext) {
result += 1L
- iter.next
+ iter.next()
}
result
}
@@ -567,15 +598,15 @@ abstract class RDD[T: ClassManifest](
/**
* Return whether this RDD has been checkpointed or not
*/
- def isCheckpointed(): Boolean = {
- if (checkpointData.isDefined) checkpointData.get.isCheckpointed() else false
+ def isCheckpointed: Boolean = {
+ checkpointData.map(_.isCheckpointed).getOrElse(false)
}
/**
* Gets the name of the file to which this RDD was checkpointed
*/
- def getCheckpointFile(): Option[String] = {
- if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None
+ def getCheckpointFile: Option[String] = {
+ checkpointData.flatMap(_.getCheckpointFile)
}
// =======================================================================
@@ -600,31 +631,52 @@ abstract class RDD[T: ClassManifest](
def context = sc
/**
- * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler
+ * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler
* after a job using this RDD has completed (therefore the RDD has been materialized and
* potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs.
*/
- protected[spark] def doCheckpoint() {
- if (checkpointData.isDefined) checkpointData.get.doCheckpoint()
- dependencies.foreach(_.rdd.doCheckpoint())
+ private[spark] def doCheckpoint() {
+ if (checkpointData.isDefined) {
+ checkpointData.get.doCheckpoint()
+ } else {
+ dependencies.foreach(_.rdd.doCheckpoint())
+ }
}
/**
- * Changes the dependencies of this RDD from its original parents to the new RDD
- * (`newRDD`) created from the checkpoint file.
+ * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`)
+ * created from the checkpoint file, and forget its old dependencies and splits.
*/
- protected[spark] def changeDependencies(newRDD: RDD[_]) {
+ private[spark] def markCheckpointed(checkpointRDD: RDD[_]) {
clearDependencies()
- dependencies_ = List(new OneToOneDependency(newRDD))
+ dependencies_ = null
+ splits_ = null
+ deps = null // Forget the constructor argument for dependencies too
}
/**
* Clears the dependencies of this RDD. This method must ensure that all references
* to the original parent RDDs is removed to enable the parent RDDs to be garbage
* collected. Subclasses of RDD may override this method for implementing their own cleaning
- * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea.
+ * logic. See [[spark.rdd.UnionRDD]] for an example.
*/
- protected[spark] def clearDependencies() {
+ protected def clearDependencies() {
dependencies_ = null
}
+
+ /** A description of this RDD and its recursive dependencies for debugging. */
+ def toDebugString(): String = {
+ def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = {
+ Seq(prefix + rdd + " (" + rdd.splits.size + " splits)") ++
+ rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " "))
+ }
+ debugString(this).mkString("\n")
+ }
+
+ override def toString(): String = "%s%s[%d] at %s".format(
+ Option(name).map(_ + " ").getOrElse(""),
+ getClass.getSimpleName,
+ id,
+ origin)
+
}
diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala
index 18df530b7d..a4a4ebaf53 100644
--- a/core/src/main/scala/spark/RDDCheckpointData.scala
+++ b/core/src/main/scala/spark/RDDCheckpointData.scala
@@ -20,7 +20,7 @@ private[spark] object CheckpointState extends Enumeration {
* of the checkpointed RDD.
*/
private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
-extends Logging with Serializable {
+ extends Logging with Serializable {
import CheckpointState._
@@ -31,7 +31,7 @@ extends Logging with Serializable {
@transient var cpFile: Option[String] = None
// The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
- @transient var cpRDD: Option[RDD[T]] = None
+ var cpRDD: Option[RDD[T]] = None
// Mark the RDD for checkpointing
def markForCheckpoint() {
@@ -41,12 +41,12 @@ extends Logging with Serializable {
}
// Is the RDD already checkpointed
- def isCheckpointed(): Boolean = {
+ def isCheckpointed: Boolean = {
RDDCheckpointData.synchronized { cpState == Checkpointed }
}
// Get the file to which this RDD was checkpointed to as an Option
- def getCheckpointFile(): Option[String] = {
+ def getCheckpointFile: Option[String] = {
RDDCheckpointData.synchronized { cpFile }
}
@@ -71,7 +71,7 @@ extends Logging with Serializable {
RDDCheckpointData.synchronized {
cpFile = Some(path)
cpRDD = Some(newRDD)
- rdd.changeDependencies(newRDD)
+ rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and splits
cpState = Checkpointed
RDDCheckpointData.clearTaskCaches()
logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
@@ -79,7 +79,7 @@ extends Logging with Serializable {
}
// Get preferred location of a split after checkpointing
- def getPreferredLocations(split: Split) = {
+ def getPreferredLocations(split: Split): Seq[String] = {
RDDCheckpointData.synchronized {
cpRDD.get.preferredLocations(split)
}
@@ -91,9 +91,10 @@ extends Logging with Serializable {
}
}
- // Get iterator. This is called at the worker nodes.
- def iterator(split: Split, context: TaskContext): Iterator[T] = {
- rdd.firstParent[T].iterator(split, context)
+ def checkpointRDD: Option[RDD[T]] = {
+ RDDCheckpointData.synchronized {
+ cpRDD
+ }
}
}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 7f3259d982..0efc00d5dd 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -1,6 +1,7 @@
package spark
import java.io._
+import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger
import java.net.{URI, URLClassLoader}
import java.lang.ref.WeakReference
@@ -8,6 +9,7 @@ import java.lang.ref.WeakReference
import scala.collection.Map
import scala.collection.generic.Growable
import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.JavaConversions._
import akka.actor.Actor
import akka.actor.Actor._
@@ -42,6 +44,9 @@ import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler}
import spark.scheduler.local.LocalScheduler
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
+import storage.BlockManagerUI
+import util.{MetadataCleaner, TimeStampedHashMap}
+import storage.{StorageStatus, StorageUtils, RDDInfo}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -57,59 +62,55 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend
class SparkContext(
val master: String,
val jobName: String,
- val sparkHome: String,
- val jars: Seq[String],
- environment: Map[String, String])
+ val sparkHome: String = null,
+ val jars: Seq[String] = Nil,
+ environment: Map[String, String] = Map())
extends Logging {
- /**
- * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI
- * @param sparkHome Location where Spark is installed on cluster nodes.
- * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
- * system or HDFS, HTTP, HTTPS, or FTP URLs.
- */
- def this(master: String, jobName: String, sparkHome: String, jars: Seq[String]) =
- this(master, jobName, sparkHome, jars, Map())
-
- /**
- * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI
- */
- def this(master: String, jobName: String) = this(master, jobName, null, Nil, Map())
-
// Ensure logging is initialized before we spawn any threads
initLogging()
- // Set Spark master host and port system properties
- if (System.getProperty("spark.master.host") == null) {
- System.setProperty("spark.master.host", Utils.localIpAddress)
+ // Set Spark driver host and port system properties
+ if (System.getProperty("spark.driver.host") == null) {
+ System.setProperty("spark.driver.host", Utils.localIpAddress)
}
- if (System.getProperty("spark.master.port") == null) {
- System.setProperty("spark.master.port", "0")
+ if (System.getProperty("spark.driver.port") == null) {
+ System.setProperty("spark.driver.port", "0")
}
private val isLocal = (master == "local" || master.startsWith("local["))
// Create the Spark execution environment (cache, map output tracker, etc)
private[spark] val env = SparkEnv.createFromSystemProperties(
- System.getProperty("spark.master.host"),
- System.getProperty("spark.master.port").toInt,
+ "<driver>",
+ System.getProperty("spark.driver.host"),
+ System.getProperty("spark.driver.port").toInt,
true,
isLocal)
SparkEnv.set(env)
+ // Start the BlockManager UI
+ private[spark] val ui = new BlockManagerUI(
+ env.actorSystem, env.blockManager.master.driverActor, this)
+ ui.start()
+
// Used to store a URL for each static file/jar together with the file's local timestamp
private[spark] val addedFiles = HashMap[String, Long]()
private[spark] val addedJars = HashMap[String, Long]()
+ // Keeps track of all persisted RDDs
+ private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]()
+ private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
+
+
// Add each JAR given through the constructor
jars.foreach { addJar(_) }
// Environment variables to pass to our executors
private[spark] val executorEnvs = HashMap[String, String]()
+ // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner
for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS",
- "SPARK_TESTING")) {
+ "SPARK_TESTING")) {
val value = System.getenv(key)
if (value != null) {
executorEnvs(key) = value
@@ -127,6 +128,8 @@ class SparkContext(
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """(spark://.*)""".r
+ //Regular expression for connection to Mesos cluster
+ val MESOS_REGEX = """(mesos://.*)""".r
master match {
case "local" =>
@@ -167,6 +170,9 @@ class SparkContext(
scheduler
case _ =>
+ if (MESOS_REGEX.findFirstIn(master).isEmpty) {
+ logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
+ }
MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
@@ -183,6 +189,26 @@ class SparkContext(
taskScheduler.start()
private var dagScheduler = new DAGScheduler(taskScheduler)
+ dagScheduler.start()
+
+ /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
+ val hadoopConfiguration = {
+ val conf = new Configuration()
+ // Explicitly check for S3 environment variables
+ if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
+ conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
+ conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
+ conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ }
+ // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
+ for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) {
+ conf.set(key.substring("spark.hadoop.".length), System.getProperty(key))
+ }
+ val bufferSize = System.getProperty("spark.buffer.size", "65536")
+ conf.set("io.file.buffer.size", bufferSize)
+ conf
+ }
private[spark] var checkpointDir: Option[String] = None
@@ -238,10 +264,8 @@ class SparkContext(
valueClass: Class[V],
minSplits: Int = defaultMinSplits
) : RDD[(K, V)] = {
- val conf = new JobConf()
+ val conf = new JobConf(hadoopConfiguration)
FileInputFormat.setInputPaths(conf, path)
- val bufferSize = System.getProperty("spark.buffer.size", "65536")
- conf.set("io.file.buffer.size", bufferSize)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
}
@@ -282,8 +306,7 @@ class SparkContext(
path,
fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]],
- vm.erasure.asInstanceOf[Class[V]],
- new Configuration)
+ vm.erasure.asInstanceOf[Class[V]])
}
/**
@@ -295,7 +318,7 @@ class SparkContext(
fClass: Class[F],
kClass: Class[K],
vClass: Class[V],
- conf: Configuration): RDD[(K, V)] = {
+ conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
@@ -307,7 +330,7 @@ class SparkContext(
* and extra configuration options to pass to the input format.
*/
def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]](
- conf: Configuration,
+ conf: Configuration = hadoopConfiguration,
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = {
@@ -390,14 +413,14 @@ class SparkContext(
/**
* Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values
- * to using the `+=` method. Only the master can access the accumulator's `value`.
+ * to using the `+=` method. Only the driver can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
new Accumulator(initialValue, param)
/**
* Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`.
- * Only the master can access the accumuable's `value`.
+ * Only the driver can access the accumuable's `value`.
* @tparam T accumulator type
* @tparam R type that can be added to the accumulator
*/
@@ -422,9 +445,10 @@ class SparkContext(
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
/**
- * Add a file to be downloaded into the working directory of this Spark job on every node.
+ * Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
- * filesystems), or an HTTP, HTTPS or FTP URI.
+ * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
+ * use `SparkFiles.get(path)` to find its download location.
*/
def addFile(path: String) {
val uri = new URI(path)
@@ -437,7 +461,7 @@ class SparkContext(
// Fetch the file locally in case a job is executed locally.
// Jobs that run through LocalScheduler will already fetch the required dependencies,
// but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
- Utils.fetchFile(path, new File("."))
+ Utils.fetchFile(path, new File(SparkFiles.getRootDirectory))
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
@@ -446,13 +470,28 @@ class SparkContext(
* Return a map from the slave to the max memory available for caching and the remaining
* memory available for caching.
*/
- def getSlavesMemoryStatus: Map[String, (Long, Long)] = {
+ def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.ip + ":" + blockManagerId.port, mem)
}
}
/**
+ * Return information about what RDDs are cached, if they are in mem or on disk, how much space
+ * they take, etc.
+ */
+ def getRDDStorageInfo : Array[RDDInfo] = {
+ StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
+ }
+
+ /**
+ * Return information about blocks stored in all of the slaves
+ */
+ def getExecutorStorageStatus : Array[StorageStatus] = {
+ env.blockManager.master.getStorageStatus
+ }
+
+ /**
* Clear the job's list of files added by `addFile` so that they do not get downloaded to
* any new nodes.
*/
@@ -486,6 +525,7 @@ class SparkContext(
/** Shut down the SparkContext. */
def stop() {
if (dagScheduler != null) {
+ metadataCleaner.cancel()
dagScheduler.stop()
dagScheduler = null
taskScheduler = null
@@ -521,27 +561,43 @@ class SparkContext(
}
/**
- * Run a function on a given set of partitions in an RDD and return the results. This is the main
- * entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies
- * whether the scheduler can run the computation on the master rather than shipping it out to the
- * cluster, for short actions like first().
+ * Run a function on a given set of partitions in an RDD and pass the results to the given
+ * handler function. This is the main entry point for all actions in Spark. The allowLocal
+ * flag specifies whether the scheduler can run the computation on the driver rather than
+ * shipping it out to the cluster, for short actions like first().
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
- allowLocal: Boolean
- ): Array[U] = {
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit) {
val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal)
+ val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
}
/**
+ * Run a function on a given set of partitions in an RDD and return the results as an array. The
+ * allowLocal flag specifies whether the scheduler can run the computation on the driver rather
+ * than shipping it out to the cluster, for short actions like first().
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ allowLocal: Boolean
+ ): Array[U] = {
+ val results = new Array[U](partitions.size)
+ runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res)
+ results
+ }
+
+ /**
* Run a job on a given set of partitions of an RDD, but take a function of type
* `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
*/
@@ -569,6 +625,29 @@ class SparkContext(
}
/**
+ * Run a job on all partitions in an RDD and pass the results to a handler function.
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ processPartition: (TaskContext, Iterator[T]) => U,
+ resultHandler: (Int, U) => Unit)
+ {
+ runJob[T, U](rdd, processPartition, 0 until rdd.splits.size, false, resultHandler)
+ }
+
+ /**
+ * Run a job on all partitions in an RDD and pass the results to a handler function.
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ processPartition: Iterator[T] => U,
+ resultHandler: (Int, U) => Unit)
+ {
+ val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter)
+ runJob[T, U](rdd, processFunc, 0 until rdd.splits.size, false, resultHandler)
+ }
+
+ /**
* Run a job that can return approximate results.
*/
def runApproximateJob[T, U, R](
@@ -628,6 +707,11 @@ class SparkContext(
/** Register a new RDD, returning its RDD ID */
private[spark] def newRddId(): Int = nextRddId.getAndIncrement()
+
+ /** Called by MetadataCleaner to clean up the persistentRdds map periodically */
+ private[spark] def cleanup(cleanupTime: Long) {
+ persistentRdds.clearOldValues(cleanupTime)
+ }
}
/**
@@ -646,6 +730,16 @@ object SparkContext {
def zero(initialValue: Int) = 0
}
+ implicit object LongAccumulatorParam extends AccumulatorParam[Long] {
+ def addInPlace(t1: Long, t2: Long) = t1 + t2
+ def zero(initialValue: Long) = 0l
+ }
+
+ implicit object FloatAccumulatorParam extends AccumulatorParam[Float] {
+ def addInPlace(t1: Float, t2: Float) = t1 + t2
+ def zero(initialValue: Float) = 0f
+ }
+
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 41441720a7..d2193ae72b 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -19,27 +19,23 @@ import spark.util.AkkaUtils
* SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
*/
class SparkEnv (
+ val executorId: String,
val actorSystem: ActorSystem,
val serializer: Serializer,
val closureSerializer: Serializer,
- val cacheTracker: CacheTracker,
+ val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
- val httpFileServer: HttpFileServer
+ val httpFileServer: HttpFileServer,
+ val sparkFilesDir: String
) {
- /** No-parameter constructor for unit tests. */
- def this() = {
- this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
- }
-
def stop() {
httpFileServer.stop()
mapOutputTracker.stop()
- cacheTracker.stop()
shuffleFetcher.stop()
broadcastManager.stop()
blockManager.stop()
@@ -63,17 +59,18 @@ object SparkEnv extends Logging {
}
def createFromSystemProperties(
+ executorId: String,
hostname: String,
port: Int,
- isMaster: Boolean,
- isLocal: Boolean
- ) : SparkEnv = {
+ isDriver: Boolean,
+ isLocal: Boolean): SparkEnv = {
+
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port)
- // Bit of a hack: If this is the master and our port was 0 (meaning bind to any free port),
- // figure out which port number Akka actually bound to and set spark.master.port to it.
- if (isMaster && port == 0) {
- System.setProperty("spark.master.port", boundPort.toString)
+ // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port),
+ // figure out which port number Akka actually bound to and set spark.driver.port to it.
+ if (isDriver && port == 0) {
+ System.setProperty("spark.driver.port", boundPort.toString)
}
val classLoader = Thread.currentThread.getContextClassLoader
@@ -87,23 +84,22 @@ object SparkEnv extends Logging {
val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
- val masterIp: String = System.getProperty("spark.master.host", "localhost")
- val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt
+ val driverIp: String = System.getProperty("spark.driver.host", "localhost")
+ val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
val blockManagerMaster = new BlockManagerMaster(
- actorSystem, isMaster, isLocal, masterIp, masterPort)
- val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer)
+ actorSystem, isDriver, isLocal, driverIp, driverPort)
+ val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer)
val connectionManager = blockManager.connectionManager
- val broadcastManager = new BroadcastManager(isMaster)
+ val broadcastManager = new BroadcastManager(isDriver)
val closureSerializer = instantiateClass[Serializer](
"spark.closure.serializer", "spark.JavaSerializer")
- val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager)
- blockManager.cacheTracker = cacheTracker
+ val cacheManager = new CacheManager(blockManager)
- val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster)
+ val mapOutputTracker = new MapOutputTracker(actorSystem, isDriver)
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
@@ -112,6 +108,15 @@ object SparkEnv extends Logging {
httpFileServer.initialize()
System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
+ // Set the sparkFiles directory, used when downloading dependencies. In local mode,
+ // this is a temporary directory; in distributed mode, this is the executor's current working
+ // directory.
+ val sparkFilesDir: String = if (isDriver) {
+ Utils.createTempDir().getAbsolutePath
+ } else {
+ "."
+ }
+
// Warn about deprecated spark.cache.class property
if (System.getProperty("spark.cache.class") != null) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
@@ -119,15 +124,17 @@ object SparkEnv extends Logging {
}
new SparkEnv(
+ executorId,
actorSystem,
serializer,
closureSerializer,
- cacheTracker,
+ cacheManager,
mapOutputTracker,
shuffleFetcher,
broadcastManager,
blockManager,
connectionManager,
- httpFileServer)
+ httpFileServer,
+ sparkFilesDir)
}
}
diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java
new file mode 100644
index 0000000000..566aec622c
--- /dev/null
+++ b/core/src/main/scala/spark/SparkFiles.java
@@ -0,0 +1,25 @@
+package spark;
+
+import java.io.File;
+
+/**
+ * Resolves paths to files added through `SparkContext.addFile()`.
+ */
+public class SparkFiles {
+
+ private SparkFiles() {}
+
+ /**
+ * Get the absolute path of a file added through `SparkContext.addFile()`.
+ */
+ public static String get(String filename) {
+ return new File(getRootDirectory(), filename).getAbsolutePath();
+ }
+
+ /**
+ * Get the root directory that contains files added through `SparkContext.addFile()`.
+ */
+ public static String getRootDirectory() {
+ return SparkEnv.get().sparkFilesDir();
+ }
+}
diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala
index d2746b26b3..eab85f85a2 100644
--- a/core/src/main/scala/spark/TaskContext.scala
+++ b/core/src/main/scala/spark/TaskContext.scala
@@ -5,8 +5,7 @@ import scala.collection.mutable.ArrayBuffer
class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable {
- @transient
- val onCompleteCallbacks = new ArrayBuffer[() => Unit]
+ @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
// Add a callback function to be executed on task completion. An example use
// is for HadoopRDD to register a callback to close the input stream.
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index b3421df27c..28d643abca 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -1,7 +1,7 @@
package spark
import java.io._
-import java.net.{NetworkInterface, InetAddress, URL, URI}
+import java.net._
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.conf.Configuration
@@ -10,6 +10,9 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import scala.io.Source
import com.google.common.io.Files
+import com.google.common.util.concurrent.ThreadFactoryBuilder
+import scala.Some
+import spark.serializer.SerializerInstance
/**
* Various utility methods used by Spark.
@@ -111,20 +114,6 @@ private object Utils extends Logging {
}
}
- /** Copy a file on the local file system */
- def copyFile(source: File, dest: File) {
- val in = new FileInputStream(source)
- val out = new FileOutputStream(dest)
- copyStream(in, out, true)
- }
-
- /** Download a file from a given URL to the local filesystem */
- def downloadFile(url: URL, localPath: String) {
- val in = url.openStream()
- val out = new FileOutputStream(localPath)
- Utils.copyStream(in, out, true)
- }
-
/**
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
@@ -201,7 +190,7 @@ private object Utils extends Logging {
Utils.execute(Seq("tar", "-xf", filename), targetDir)
}
// Make the file executable - That's necessary for scripts
- FileUtil.chmod(filename, "a+x")
+ FileUtil.chmod(targetFile.getAbsolutePath, "a+x")
}
/**
@@ -251,7 +240,8 @@ private object Utils extends Logging {
// Address resolves to something like 127.0.1.1, which happens on Debian; try to find
// a better address using the local network interfaces
for (ni <- NetworkInterface.getNetworkInterfaces) {
- for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && !addr.isLoopbackAddress) {
+ for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress &&
+ !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) {
// We've found an address that looks reasonable!
logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" +
" a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress +
@@ -286,29 +276,14 @@ private object Utils extends Logging {
customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
}
- /**
- * Returns a standard ThreadFactory except all threads are daemons.
- */
- private def newDaemonThreadFactory: ThreadFactory = {
- new ThreadFactory {
- def newThread(r: Runnable): Thread = {
- var t = Executors.defaultThreadFactory.newThread (r)
- t.setDaemon (true)
- return t
- }
- }
- }
+ private[spark] val daemonThreadFactory: ThreadFactory =
+ new ThreadFactoryBuilder().setDaemon(true).build()
/**
* Wrapper over newCachedThreadPool.
*/
- def newDaemonCachedThreadPool(): ThreadPoolExecutor = {
- var threadPool = Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory (newDaemonThreadFactory)
-
- return threadPool
- }
+ def newDaemonCachedThreadPool(): ThreadPoolExecutor =
+ Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
/**
* Return the string to tell how long has passed in seconds. The passing parameter should be in
@@ -321,13 +296,8 @@ private object Utils extends Logging {
/**
* Wrapper over newFixedThreadPool.
*/
- def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
- var threadPool = Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory(newDaemonThreadFactory)
-
- return threadPool
- }
+ def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor =
+ Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
/**
* Delete a file or directory and its contents recursively.
@@ -463,4 +433,25 @@ private object Utils extends Logging {
}
"%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
}
+
+ /**
+ * Try to find a free port to bind to on the local host. This should ideally never be needed,
+ * except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray)
+ * don't let users bind to port 0 and then figure out which free port they actually bound to.
+ * We work around this by binding a ServerSocket and immediately unbinding it. This is *not*
+ * necessarily guaranteed to work, but it's the best we can do.
+ */
+ def findFreePort(): Int = {
+ val socket = new ServerSocket(0)
+ val portBound = socket.getLocalPort
+ socket.close()
+ portBound
+ }
+
+ /**
+ * Clone an object using a Spark serializer.
+ */
+ def clone[T](value: T, serializer: SerializerInstance): T = {
+ serializer.deserialize[T](serializer.serialize(value))
+ }
}
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index 087270e46d..60025b459c 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -12,7 +12,7 @@ import spark.storage.StorageLevel
import com.google.common.base.Optional
-trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
+trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround[T] {
def wrapRDD(rdd: RDD[T]): This
implicit val classManifest: ClassManifest[T]
@@ -82,10 +82,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Return a new RDD by first applying a function to all elements of this
- * RDD, and then flattening the results.
+ * Part of the workaround for SPARK-668; called in PairFlatMapWorkaround.java.
*/
- def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
+ private[spark] def doFlatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
@@ -307,23 +306,20 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
JavaPairRDD.fromRDD(rdd.keyBy(f))
}
-
+
/**
- * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir`
- * (set using setCheckpointDir()) and all references to its parent RDDs will be removed.
- * This is used to truncate very long lineages. In the current implementation, Spark will save
- * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done.
- * Hence, it is strongly recommended to use checkpoint() on RDDs when
- * (i) checkpoint() is called before the any job has been executed on this RDD.
- * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will
- * require recomputation.
+ * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
+ * directory set with SparkContext.setCheckpointDir() and all references to its parent
+ * RDDs will be removed. This function must be called before any job has been
+ * executed on this RDD. It is strongly recommended that this RDD is persisted in
+ * memory, otherwise saving it on a file will require recomputation.
*/
def checkpoint() = rdd.checkpoint()
/**
* Return whether this RDD has been checkpointed or not
*/
- def isCheckpointed(): Boolean = rdd.isCheckpointed()
+ def isCheckpointed: Boolean = rdd.isCheckpointed
/**
* Gets the name of the file to which this RDD was checkpointed
@@ -334,4 +330,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
case _ => Optional.absent()
}
}
+
+ /** A description of this RDD and its recursive dependencies for debugging. */
+ def toDebugString(): String = {
+ rdd.toDebugString()
+ }
}
diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
index fa2f14113d..50b8970cd8 100644
--- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
@@ -323,9 +323,10 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
def getSparkHome(): Option[String] = sc.getSparkHome()
/**
- * Add a file to be downloaded into the working directory of this Spark job on every node.
+ * Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
- * filesystems), or an HTTP, HTTPS or FTP URI.
+ * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
+ * use `SparkFiles.get(path)` to find its download location.
*/
def addFile(path: String) {
sc.addFile(path)
@@ -357,20 +358,28 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
}
/**
- * Set the directory under which RDDs are going to be checkpointed. This method will
- * create this directory and will throw an exception of the path already exists (to avoid
- * overwriting existing files may be overwritten). The directory will be deleted on exit
- * if indicated.
+ * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse.
+ */
+ def hadoopConfiguration(): Configuration = {
+ sc.hadoopConfiguration
+ }
+
+ /**
+ * Set the directory under which RDDs are going to be checkpointed. The directory must
+ * be a HDFS path if running on a cluster. If the directory does not exist, it will
+ * be created. If the directory exists and useExisting is set to true, then the
+ * exisiting directory will be used. Otherwise an exception will be thrown to
+ * prevent accidental overriding of checkpoint files in the existing directory.
*/
def setCheckpointDir(dir: String, useExisting: Boolean) {
sc.setCheckpointDir(dir, useExisting)
}
/**
- * Set the directory under which RDDs are going to be checkpointed. This method will
- * create this directory and will throw an exception of the path already exists (to avoid
- * overwriting existing files may be overwritten). The directory will be deleted on exit
- * if indicated.
+ * Set the directory under which RDDs are going to be checkpointed. The directory must
+ * be a HDFS path if running on a cluster. If the directory does not exist, it will
+ * be created. If the directory exists, an exception will be thrown to prevent accidental
+ * overriding of checkpoint files.
*/
def setCheckpointDir(dir: String) {
sc.setCheckpointDir(dir)
diff --git a/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java
new file mode 100644
index 0000000000..68b6fd6622
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java
@@ -0,0 +1,20 @@
+package spark.api.java;
+
+import spark.api.java.JavaPairRDD;
+import spark.api.java.JavaRDDLike;
+import spark.api.java.function.PairFlatMapFunction;
+
+import java.io.Serializable;
+
+/**
+ * Workaround for SPARK-668.
+ */
+class PairFlatMapWorkaround<T> implements Serializable {
+ /**
+ * Return a new RDD by first applying a function to all elements of this
+ * RDD, and then flattening the results.
+ */
+ public <K, V> JavaPairRDD<K, V> flatMap(PairFlatMapFunction<T, K, V> f) {
+ return ((JavaRDDLike <T, ?>) this).doFlatMap(f);
+ }
+}
diff --git a/core/src/main/scala/spark/api/java/StorageLevels.java b/core/src/main/scala/spark/api/java/StorageLevels.java
index 722af3c06c..5e5845ac3a 100644
--- a/core/src/main/scala/spark/api/java/StorageLevels.java
+++ b/core/src/main/scala/spark/api/java/StorageLevels.java
@@ -17,4 +17,15 @@ public class StorageLevels {
public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2);
public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1);
public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2);
+
+ /**
+ * Create a new StorageLevel object.
+ * @param useDisk saved to disk, if true
+ * @param useMemory saved to memory, if true
+ * @param deserialized saved as deserialized objects, if true
+ * @param replication replication factor
+ */
+ public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) {
+ return StorageLevel.apply(useDisk, useMemory, deserialized, replication);
+ }
}
diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
index 648d9402b0..519e310323 100644
--- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala
+++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
@@ -6,8 +6,17 @@ import java.util.Arrays
/**
* A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
+ *
+ * Stores the unique id() of the Python-side partitioning function so that it is incorporated into
+ * equality comparisons. Correctness requires that the id is a unique identifier for the
+ * lifetime of the job (i.e. that it is not re-used as the id of a different partitioning
+ * function). This can be ensured by using the Python id() function and maintaining a reference
+ * to the Python partitioning function so that its id() is not reused.
*/
-private[spark] class PythonPartitioner(override val numPartitions: Int) extends Partitioner {
+private[spark] class PythonPartitioner(
+ override val numPartitions: Int,
+ val pyPartitionFunctionId: Long)
+ extends Partitioner {
override def getPartition(key: Any): Int = {
if (key == null) {
@@ -32,7 +41,7 @@ private[spark] class PythonPartitioner(override val numPartitions: Int) extends
override def equals(other: Any): Boolean = other match {
case h: PythonPartitioner =>
- h.numPartitions == numPartitions
+ h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId
case _ =>
false
}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 89f7c316dc..ab8351e55e 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -67,6 +67,8 @@ private[spark] class PythonRDD[T: ClassManifest](
val dOut = new DataOutputStream(proc.getOutputStream)
// Split index
dOut.writeInt(split.index)
+ // sparkFilesDir
+ PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut)
// Broadcast variables
dOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
@@ -101,21 +103,27 @@ private[spark] class PythonRDD[T: ClassManifest](
private def read(): Array[Byte] = {
try {
- val length = stream.readInt()
- if (length != -1) {
- val obj = new Array[Byte](length)
- stream.readFully(obj)
- obj
- } else {
- // We've finished the data section of the output, but we can still read some
- // accumulator updates; let's do that, breaking when we get EOFException
- while (true) {
- val len2 = stream.readInt()
- val update = new Array[Byte](len2)
- stream.readFully(update)
- accumulator += Collections.singletonList(update)
- }
- new Array[Byte](0)
+ stream.readInt() match {
+ case length if length > 0 =>
+ val obj = new Array[Byte](length)
+ stream.readFully(obj)
+ obj
+ case -2 =>
+ // Signals that an exception has been thrown in python
+ val exLength = stream.readInt()
+ val obj = new Array[Byte](exLength)
+ stream.readFully(obj)
+ throw new PythonException(new String(obj))
+ case -1 =>
+ // We've finished the data section of the output, but we can still read some
+ // accumulator updates; let's do that, breaking when we get EOFException
+ while (true) {
+ val len2 = stream.readInt()
+ val update = new Array[Byte](len2)
+ stream.readFully(update)
+ accumulator += Collections.singletonList(update)
+ }
+ new Array[Byte](0)
}
} catch {
case eof: EOFException => {
@@ -135,11 +143,12 @@ private[spark] class PythonRDD[T: ClassManifest](
}
}
- override def checkpoint() { }
-
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
}
+/** Thrown for exceptions in user Python code. */
+private class PythonException(msg: String) extends Exception(msg)
+
/**
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
* This is used by PySpark's shuffle operations.
@@ -152,7 +161,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
case Seq(a, b) => (a, b)
case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
}
- override def checkpoint() { }
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
}
@@ -230,6 +238,11 @@ private[spark] object PythonRDD {
}
def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+ import scala.collection.JavaConverters._
+ writeIteratorToPickleFile(items.asScala, filename)
+ }
+
+ def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
for (item <- items) {
writeAsPickle(item, file)
@@ -237,8 +250,10 @@ private[spark] object PythonRDD {
file.close()
}
- def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] =
- rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head
+ def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = {
+ implicit val cm : ClassManifest[T] = rdd.elementClassManifest
+ rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator
+ }
}
private object Pickle {
@@ -252,11 +267,6 @@ private object Pickle {
val APPENDS: Byte = 'e'
}
-private class ExtractValue extends spark.api.java.function.Function[(Array[Byte],
- Array[Byte]), Array[Byte]] {
- override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2
-}
-
private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
}
diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
index 386f505f2a..adcb2d2415 100644
--- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
@@ -31,7 +31,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
@transient var totalBlocks = -1
@transient var hasBlocks = new AtomicInteger(0)
- // Used ONLY by Master to track how many unique blocks have been sent out
+ // Used ONLY by driver to track how many unique blocks have been sent out
@transient var sentBlocks = new AtomicInteger(0)
@transient var listenPortLock = new Object
@@ -42,7 +42,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
@transient var serveMR: ServeMultipleRequests = null
- // Used only in Master
+ // Used only in driver
@transient var guideMR: GuideMultipleRequests = null
// Used only in Workers
@@ -99,14 +99,14 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
}
// Must always come AFTER listenPort is created
- val masterSource =
+ val driverSource =
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
hasBlocksBitVector.synchronized {
- masterSource.hasBlocksBitVector = hasBlocksBitVector
+ driverSource.hasBlocksBitVector = hasBlocksBitVector
}
// In the beginning, this is the only known source to Guide
- listOfSources += masterSource
+ listOfSources += driverSource
// Register with the Tracker
MultiTracker.registerBroadcast(id,
@@ -122,7 +122,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
case None =>
logInfo("Started reading broadcast variable " + id)
- // Initializing everything because Master will only send null/0 values
+ // Initializing everything because driver will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables()
@@ -151,7 +151,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
}
}
- // Initialize variables in the worker node. Master sends everything as 0/null
+ // Initialize variables in the worker node. Driver sends everything as 0/null
private def initializeWorkerVariables() {
arrayOfBlocks = null
hasBlocksBitVector = null
@@ -248,7 +248,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
// Receive source information from Guide
var suitableSources =
oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]]
- logDebug("Received suitableSources from Master " + suitableSources)
+ logDebug("Received suitableSources from Driver " + suitableSources)
addToListOfSources(suitableSources)
@@ -532,7 +532,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
oosSource.writeObject(blockToAskFor)
oosSource.flush()
- // CHANGED: Master might send some other block than the one
+ // CHANGED: Driver might send some other block than the one
// requested to ensure fast spreading of all blocks.
val recvStartTime = System.currentTimeMillis
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
@@ -982,9 +982,9 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
// Receive which block to send
var blockToSend = ois.readObject.asInstanceOf[Int]
- // If it is master AND at least one copy of each block has not been
+ // If it is driver AND at least one copy of each block has not been
// sent out already, MODIFY blockToSend
- if (MultiTracker.isMaster && sentBlocks.get < totalBlocks) {
+ if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) {
blockToSend = sentBlocks.getAndIncrement
}
@@ -1031,7 +1031,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
private[spark] class BitTorrentBroadcastFactory
extends BroadcastFactory {
- def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) }
+ def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new BitTorrentBroadcast[T](value_, isLocal, id)
diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala
index 2ffe7f741d..415bde5d67 100644
--- a/core/src/main/scala/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/spark/broadcast/Broadcast.scala
@@ -15,7 +15,7 @@ abstract class Broadcast[T](private[spark] val id: Long) extends Serializable {
}
private[spark]
-class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable {
+class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable {
private var initialized = false
private var broadcastFactory: BroadcastFactory = null
@@ -33,7 +33,7 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Initialize appropriate BroadcastFactory and BroadcastObject
- broadcastFactory.initialize(isMaster)
+ broadcastFactory.initialize(isDriver)
initialized = true
}
@@ -49,5 +49,5 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl
def newBroadcast[T](value_ : T, isLocal: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
- def isMaster = isMaster_
+ def isDriver = _isDriver
}
diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
index ab6d302827..5c6184c3c7 100644
--- a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
@@ -7,7 +7,7 @@ package spark.broadcast
* entire Spark job.
*/
private[spark] trait BroadcastFactory {
- def initialize(isMaster: Boolean): Unit
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long): Broadcast[T]
+ def initialize(isDriver: Boolean): Unit
+ def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
}
diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
index 8e490e6bad..7e30b8f7d2 100644
--- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
@@ -48,7 +48,7 @@ extends Broadcast[T](id) with Logging with Serializable {
}
private[spark] class HttpBroadcastFactory extends BroadcastFactory {
- def initialize(isMaster: Boolean) { HttpBroadcast.initialize(isMaster) }
+ def initialize(isDriver: Boolean) { HttpBroadcast.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
@@ -69,12 +69,12 @@ private object HttpBroadcast extends Logging {
private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup)
- def initialize(isMaster: Boolean) {
+ def initialize(isDriver: Boolean) {
synchronized {
if (!initialized) {
bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
compress = System.getProperty("spark.broadcast.compress", "true").toBoolean
- if (isMaster) {
+ if (isDriver) {
createServer()
}
serverUri = System.getProperty("spark.httpBroadcast.uri")
diff --git a/core/src/main/scala/spark/broadcast/MultiTracker.scala b/core/src/main/scala/spark/broadcast/MultiTracker.scala
index 5e76dedb94..3fd77af73f 100644
--- a/core/src/main/scala/spark/broadcast/MultiTracker.scala
+++ b/core/src/main/scala/spark/broadcast/MultiTracker.scala
@@ -23,25 +23,24 @@ extends Logging {
var ranGen = new Random
private var initialized = false
- private var isMaster_ = false
+ private var _isDriver = false
private var stopBroadcast = false
private var trackMV: TrackMultipleValues = null
- def initialize(isMaster__ : Boolean) {
+ def initialize(__isDriver: Boolean) {
synchronized {
if (!initialized) {
+ _isDriver = __isDriver
- isMaster_ = isMaster__
-
- if (isMaster) {
+ if (isDriver) {
trackMV = new TrackMultipleValues
trackMV.setDaemon(true)
trackMV.start()
- // Set masterHostAddress to the master's IP address for the slaves to read
- System.setProperty("spark.MultiTracker.MasterHostAddress", Utils.localIpAddress)
+ // Set DriverHostAddress to the driver's IP address for the slaves to read
+ System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress)
}
initialized = true
@@ -54,10 +53,10 @@ extends Logging {
}
// Load common parameters
- private var MasterHostAddress_ = System.getProperty(
- "spark.MultiTracker.MasterHostAddress", "")
- private var MasterTrackerPort_ = System.getProperty(
- "spark.broadcast.masterTrackerPort", "11111").toInt
+ private var DriverHostAddress_ = System.getProperty(
+ "spark.MultiTracker.DriverHostAddress", "")
+ private var DriverTrackerPort_ = System.getProperty(
+ "spark.broadcast.driverTrackerPort", "11111").toInt
private var BlockSize_ = System.getProperty(
"spark.broadcast.blockSize", "4096").toInt * 1024
private var MaxRetryCount_ = System.getProperty(
@@ -91,11 +90,11 @@ extends Logging {
private var EndGameFraction_ = System.getProperty(
"spark.broadcast.endGameFraction", "0.95").toDouble
- def isMaster = isMaster_
+ def isDriver = _isDriver
// Common config params
- def MasterHostAddress = MasterHostAddress_
- def MasterTrackerPort = MasterTrackerPort_
+ def DriverHostAddress = DriverHostAddress_
+ def DriverTrackerPort = DriverTrackerPort_
def BlockSize = BlockSize_
def MaxRetryCount = MaxRetryCount_
@@ -123,7 +122,7 @@ extends Logging {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
- serverSocket = new ServerSocket(MasterTrackerPort)
+ serverSocket = new ServerSocket(DriverTrackerPort)
logInfo("TrackMultipleValues started at " + serverSocket)
try {
@@ -235,7 +234,7 @@ extends Logging {
try {
// Connect to the tracker to find out GuideInfo
clientSocketToTracker =
- new Socket(MultiTracker.MasterHostAddress, MultiTracker.MasterTrackerPort)
+ new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort)
oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush()
@@ -276,7 +275,7 @@ extends Logging {
}
def registerBroadcast(id: Long, gInfo: SourceInfo) {
- val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
+ val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream)
@@ -303,7 +302,7 @@ extends Logging {
}
def unregisterBroadcast(id: Long) {
- val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
+ val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream)
diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
index f573512835..c55c476117 100644
--- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
@@ -98,7 +98,7 @@ extends Broadcast[T](id) with Logging with Serializable {
case None =>
logInfo("Started reading broadcast variable " + id)
- // Initializing everything because Master will only send null/0 values
+ // Initializing everything because Driver will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables()
@@ -157,55 +157,55 @@ extends Broadcast[T](id) with Logging with Serializable {
listenPortLock.synchronized { listenPortLock.wait() }
}
- var clientSocketToMaster: Socket = null
- var oosMaster: ObjectOutputStream = null
- var oisMaster: ObjectInputStream = null
+ var clientSocketToDriver: Socket = null
+ var oosDriver: ObjectOutputStream = null
+ var oisDriver: ObjectInputStream = null
// Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures
var retriesLeft = MultiTracker.MaxRetryCount
do {
- // Connect to Master and send this worker's Information
- clientSocketToMaster = new Socket(MultiTracker.MasterHostAddress, gInfo.listenPort)
- oosMaster = new ObjectOutputStream(clientSocketToMaster.getOutputStream)
- oosMaster.flush()
- oisMaster = new ObjectInputStream(clientSocketToMaster.getInputStream)
+ // Connect to Driver and send this worker's Information
+ clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort)
+ oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream)
+ oosDriver.flush()
+ oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream)
- logDebug("Connected to Master's guiding object")
+ logDebug("Connected to Driver's guiding object")
// Send local source information
- oosMaster.writeObject(SourceInfo(hostAddress, listenPort))
- oosMaster.flush()
+ oosDriver.writeObject(SourceInfo(hostAddress, listenPort))
+ oosDriver.flush()
- // Receive source information from Master
- var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
+ // Receive source information from Driver
+ var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
totalBytes = sourceInfo.totalBytes
- logDebug("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
+ logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission(sourceInfo)
val time = (System.nanoTime - start) / 1e9
- // Updating some statistics in sourceInfo. Master will be using them later
+ // Updating some statistics in sourceInfo. Driver will be using them later
if (!receptionSucceeded) {
sourceInfo.receptionFailed = true
}
- // Send back statistics to the Master
- oosMaster.writeObject(sourceInfo)
+ // Send back statistics to the Driver
+ oosDriver.writeObject(sourceInfo)
- if (oisMaster != null) {
- oisMaster.close()
+ if (oisDriver != null) {
+ oisDriver.close()
}
- if (oosMaster != null) {
- oosMaster.close()
+ if (oosDriver != null) {
+ oosDriver.close()
}
- if (clientSocketToMaster != null) {
- clientSocketToMaster.close()
+ if (clientSocketToDriver != null) {
+ clientSocketToDriver.close()
}
retriesLeft -= 1
@@ -552,7 +552,7 @@ extends Broadcast[T](id) with Logging with Serializable {
}
private def sendObject() {
- // Wait till receiving the SourceInfo from Master
+ // Wait till receiving the SourceInfo from Driver
while (totalBlocks == -1) {
totalBlocksLock.synchronized { totalBlocksLock.wait() }
}
@@ -576,7 +576,7 @@ extends Broadcast[T](id) with Logging with Serializable {
private[spark] class TreeBroadcastFactory
extends BroadcastFactory {
- def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) }
+ def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TreeBroadcast[T](value_, isLocal, id)
diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala
index 457122745b..35f40c6e91 100644
--- a/core/src/main/scala/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/spark/deploy/DeployMessage.scala
@@ -4,7 +4,6 @@ import spark.deploy.ExecutorState.ExecutorState
import spark.deploy.master.{WorkerInfo, JobInfo}
import spark.deploy.worker.ExecutorRunner
import scala.collection.immutable.List
-import scala.collection.mutable.HashMap
private[spark] sealed trait DeployMessage extends Serializable
@@ -42,7 +41,8 @@ private[spark] case class LaunchExecutor(
execId: Int,
jobDesc: JobDescription,
cores: Int,
- memory: Int)
+ memory: Int,
+ sparkHome: String)
extends DeployMessage
diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/JobDescription.scala
index 20879c5f11..7160fc05fc 100644
--- a/core/src/main/scala/spark/deploy/JobDescription.scala
+++ b/core/src/main/scala/spark/deploy/JobDescription.scala
@@ -4,7 +4,8 @@ private[spark] class JobDescription(
val name: String,
val cores: Int,
val memoryPerSlave: Int,
- val command: Command)
+ val command: Command,
+ val sparkHome: String)
extends Serializable {
val user = System.getProperty("user.name", "<unknown>")
diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
index 4211d80596..22319a96ca 100644
--- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
@@ -9,43 +9,32 @@ import spark.{Logging, Utils}
import scala.collection.mutable.ArrayBuffer
+/**
+ * Testing class that creates a Spark standalone process in-cluster (that is, running the
+ * spark.deploy.master.Master and spark.deploy.worker.Workers in the same JVMs). Executors launched
+ * by the Workers still run in separate JVMs. This can be used to test distributed operation and
+ * fault recovery without spinning up a lot of processes.
+ */
private[spark]
-class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging {
+class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
- val localIpAddress = Utils.localIpAddress
+ private val localIpAddress = Utils.localIpAddress
+ private val masterActorSystems = ArrayBuffer[ActorSystem]()
+ private val workerActorSystems = ArrayBuffer[ActorSystem]()
- var masterActor : ActorRef = _
- var masterActorSystem : ActorSystem = _
- var masterPort : Int = _
- var masterUrl : String = _
-
- val slaveActorSystems = ArrayBuffer[ActorSystem]()
- val slaveActors = ArrayBuffer[ActorRef]()
-
- def start() : String = {
- logInfo("Starting a local Spark cluster with " + numSlaves + " slaves.")
+ def start(): String = {
+ logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
- val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0)
- masterActorSystem = actorSystem
- masterUrl = "spark://" + localIpAddress + ":" + masterPort
- val actor = masterActorSystem.actorOf(
- Props(new Master(localIpAddress, masterPort, 0)), name = "Master")
- masterActor = actor
-
- /* Start the Slaves */
- for (slaveNum <- 1 to numSlaves) {
- /* We can pretend to test distributed stuff by giving the slaves distinct hostnames.
- All of 127/8 should be a loopback, we use 127.100.*.* in hopes that it is
- sufficiently distinctive. */
- val slaveIpAddress = "127.100.0." + (slaveNum % 256)
- val (actorSystem, boundPort) =
- AkkaUtils.createActorSystem("sparkWorker" + slaveNum, slaveIpAddress, 0)
- slaveActorSystems += actorSystem
- val actor = actorSystem.actorOf(
- Props(new Worker(slaveIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)),
- name = "Worker")
- slaveActors += actor
+ val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0)
+ masterActorSystems += masterSystem
+ val masterUrl = "spark://" + localIpAddress + ":" + masterPort
+
+ /* Start the Workers */
+ for (workerNum <- 1 to numWorkers) {
+ val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker,
+ memoryPerWorker, masterUrl, null, Some(workerNum))
+ workerActorSystems += workerSystem
}
return masterUrl
@@ -53,10 +42,10 @@ class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int)
def stop() {
logInfo("Shutting down local Spark cluster.")
- // Stop the slaves before the master so they don't get upset that it disconnected
- slaveActorSystems.foreach(_.shutdown())
- slaveActorSystems.foreach(_.awaitTermination())
- masterActorSystem.shutdown()
- masterActorSystem.awaitTermination()
+ // Stop the workers before the master so they don't get upset that it disconnected
+ workerActorSystems.foreach(_.shutdown())
+ workerActorSystems.foreach(_.awaitTermination())
+ masterActorSystems.foreach(_.shutdown())
+ masterActorSystems.foreach(_.awaitTermination())
}
}
diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala
index 90fe9508cd..a63eee1233 100644
--- a/core/src/main/scala/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/spark/deploy/client/Client.scala
@@ -9,6 +9,7 @@ import spark.{SparkException, Logging}
import akka.remote.RemoteClientLifeCycleEvent
import akka.remote.RemoteClientShutdown
import spark.deploy.RegisterJob
+import spark.deploy.master.Master
import akka.remote.RemoteClientDisconnected
import akka.actor.Terminated
import akka.dispatch.Await
@@ -24,26 +25,18 @@ private[spark] class Client(
listener: ClientListener)
extends Logging {
- val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r
-
var actor: ActorRef = null
var jobId: String = null
- if (MASTER_REGEX.unapplySeq(masterUrl) == None) {
- throw new SparkException("Invalid master URL: " + masterUrl)
- }
-
class ClientActor extends Actor with Logging {
var master: ActorRef = null
var masterAddress: Address = null
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
override def preStart() {
- val Seq(masterHost, masterPort) = MASTER_REGEX.unapplySeq(masterUrl).get
- logInfo("Connecting to master spark://" + masterHost + ":" + masterPort)
- val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
+ logInfo("Connecting to master " + masterUrl)
try {
- master = context.actorFor(akkaUrl)
+ master = context.actorFor(Master.toAkkaUrl(masterUrl))
masterAddress = master.path.address
master ! RegisterJob(jobDescription)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala
index da6abcc9c2..7035f4b394 100644
--- a/core/src/main/scala/spark/deploy/client/ClientListener.scala
+++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala
@@ -12,7 +12,7 @@ private[spark] trait ClientListener {
def disconnected(): Unit
- def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int): Unit
+ def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit
- def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit
+ def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
}
diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala
index 57a7e123b7..8764c400e2 100644
--- a/core/src/main/scala/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/spark/deploy/client/TestClient.scala
@@ -25,7 +25,7 @@ private[spark] object TestClient {
val url = args(0)
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0)
val desc = new JobDescription(
- "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()))
+ "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home")
val listener = new TestListener
val client = new Client(actorSystem, url, desc, listener)
client.start()
diff --git a/core/src/main/scala/spark/deploy/master/JobInfo.scala b/core/src/main/scala/spark/deploy/master/JobInfo.scala
index 130b031a2a..a274b21c34 100644
--- a/core/src/main/scala/spark/deploy/master/JobInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/JobInfo.scala
@@ -10,7 +10,7 @@ private[spark] class JobInfo(
val id: String,
val desc: JobDescription,
val submitDate: Date,
- val actor: ActorRef)
+ val driver: ActorRef)
{
var state = JobState.WAITING
var executors = new mutable.HashMap[Int, ExecutorInfo]
diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala
index 6ecebe626a..92e7914b1b 100644
--- a/core/src/main/scala/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/spark/deploy/master/Master.scala
@@ -88,7 +88,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
execOption match {
case Some(exec) => {
exec.state = state
- exec.job.actor ! ExecutorUpdated(execId, state, message, exitStatus)
+ exec.job.driver ! ExecutorUpdated(execId, state, message, exitStatus)
if (ExecutorState.isFinished(state)) {
val jobInfo = idToJob(jobId)
// Remove this executor from the worker and job
@@ -97,14 +97,12 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
exec.worker.removeExecutor(exec)
// Only retry certain number of times so we don't go into an infinite loop.
- if (jobInfo.incrementRetryCount <= JobState.MAX_NUM_RETRY) {
+ if (jobInfo.incrementRetryCount < JobState.MAX_NUM_RETRY) {
schedule()
} else {
- val e = new SparkException("Job %s wth ID %s failed %d times.".format(
+ logError("Job %s with ID %s failed %d times, removing it".format(
jobInfo.desc.name, jobInfo.id, jobInfo.retryCount))
- logError(e.getMessage, e)
- throw e
- //System.exit(1)
+ removeJob(jobInfo)
}
}
}
@@ -173,7 +171,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
for (pos <- 0 until numUsable) {
if (assigned(pos) > 0) {
val exec = job.addExecutor(usableWorkers(pos), assigned(pos))
- launchExecutor(usableWorkers(pos), exec)
+ launchExecutor(usableWorkers(pos), exec, job.desc.sparkHome)
job.state = JobState.RUNNING
}
}
@@ -186,7 +184,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
val coresToUse = math.min(worker.coresFree, job.coresLeft)
if (coresToUse > 0) {
val exec = job.addExecutor(worker, coresToUse)
- launchExecutor(worker, exec)
+ launchExecutor(worker, exec, job.desc.sparkHome)
job.state = JobState.RUNNING
}
}
@@ -195,11 +193,11 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
}
- def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) {
+ def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
- worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory)
- exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
+ worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome)
+ exec.job.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
}
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
@@ -221,19 +219,19 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
actorToWorker -= worker.actor
addressToWorker -= worker.actor.path.address
for (exec <- worker.executors.values) {
- exec.job.actor ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None)
+ exec.job.driver ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None)
exec.job.executors -= exec.id
}
}
- def addJob(desc: JobDescription, actor: ActorRef): JobInfo = {
+ def addJob(desc: JobDescription, driver: ActorRef): JobInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
- val job = new JobInfo(now, newJobId(date), desc, date, actor)
+ val job = new JobInfo(now, newJobId(date), desc, date, driver)
jobs += job
idToJob(job.id) = job
- actorToJob(sender) = job
- addressToJob(sender.path.address) = job
+ actorToJob(driver) = job
+ addressToJob(driver.path.address) = job
return job
}
@@ -242,8 +240,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
logInfo("Removing job " + job.id)
jobs -= job
idToJob -= job.id
- actorToJob -= job.actor
- addressToWorker -= job.actor.path.address
+ actorToJob -= job.driver
+ addressToWorker -= job.driver.path.address
completedJobs += job // Remember it in our history
waitingJobs -= job
for (exec <- job.executors.values) {
@@ -264,11 +262,29 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
private[spark] object Master {
+ private val systemName = "sparkMaster"
+ private val actorName = "Master"
+ private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
+
def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings)
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
- val actor = actorSystem.actorOf(
- Props(new Master(args.ip, boundPort, args.webUiPort)), name = "Master")
+ val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort)
actorSystem.awaitTermination()
}
+
+ /** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
+ def toAkkaUrl(sparkUrl: String): String = {
+ sparkUrl match {
+ case sparkUrlRegex(host, port) =>
+ "akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
+ case _ =>
+ throw new SparkException("Invalid master URL: " + sparkUrl)
+ }
+ }
+
+ def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int) = {
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
+ val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName)
+ (actorSystem, boundPort)
+ }
}
diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
index 458ee2d665..529f72e9da 100644
--- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
+++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
@@ -14,12 +14,15 @@ import cc.spray.typeconversion.SprayJsonSupport._
import spark.deploy._
import spark.deploy.JsonProtocol._
+/**
+ * Web UI server for the standalone master.
+ */
private[spark]
class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives {
val RESOURCE_DIR = "spark/deploy/master/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
- implicit val timeout = Timeout(1 seconds)
+ implicit val timeout = Timeout(10 seconds)
val handler = {
get {
@@ -42,13 +45,9 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) =>
val future = master ? RequestMasterState
val jobInfo = for (masterState <- future.mapTo[MasterState]) yield {
- masterState.activeJobs.find(_.id == jobId) match {
- case Some(job) => job
- case _ => masterState.completedJobs.find(_.id == jobId) match {
- case Some(job) => job
- case _ => null
- }
- }
+ masterState.activeJobs.find(_.id == jobId).getOrElse({
+ masterState.completedJobs.find(_.id == jobId).getOrElse(null)
+ })
}
respondWithMediaType(MediaTypes.`application/json`) { ctx =>
ctx.complete(jobInfo.mapTo[JobInfo])
@@ -58,14 +57,10 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
val future = master ? RequestMasterState
future.map { state =>
val masterState = state.asInstanceOf[MasterState]
-
- masterState.activeJobs.find(_.id == jobId) match {
- case Some(job) => spark.deploy.master.html.job_details.render(job)
- case _ => masterState.completedJobs.find(_.id == jobId) match {
- case Some(job) => spark.deploy.master.html.job_details.render(job)
- case _ => null
- }
- }
+ val job = masterState.activeJobs.find(_.id == jobId).getOrElse({
+ masterState.completedJobs.find(_.id == jobId).getOrElse(null)
+ })
+ spark.deploy.master.html.job_details.render(job)
}
}
}
@@ -76,5 +71,4 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
getFromResourceDirectory(RESOURCE_DIR)
}
}
-
}
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
index beceb55ecd..4ef637090c 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
@@ -65,9 +65,9 @@ private[spark] class ExecutorRunner(
}
}
- /** Replace variables such as {{SLAVEID}} and {{CORES}} in a command argument passed to us */
+ /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */
def substituteVariables(argument: String): String = argument match {
- case "{{SLAVEID}}" => workerId
+ case "{{EXECUTOR_ID}}" => execId.toString
case "{{HOSTNAME}}" => hostname
case "{{CORES}}" => cores.toString
case other => other
@@ -106,11 +106,6 @@ private[spark] class ExecutorRunner(
throw new IOException("Failed to create directory " + executorDir)
}
- // Download the files it depends on into it (disabled for now)
- //for (url <- jobDesc.fileUrls) {
- // fetchFile(url, executorDir)
- //}
-
// Launch the process
val command = buildCommandSeq()
val builder = new ProcessBuilder(command: _*).directory(executorDir)
@@ -118,8 +113,7 @@ private[spark] class ExecutorRunner(
for ((key, value) <- jobDesc.command.environment) {
env.put(key, value)
}
- env.put("SPARK_CORES", cores.toString)
- env.put("SPARK_MEMORY", memory.toString)
+ env.put("SPARK_MEM", memory.toString + "m")
// In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command
env.put("SPARK_LAUNCH_WITH_SCALA", "0")
diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala
index 7c9e588ea2..38547ec4f1 100644
--- a/core/src/main/scala/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/spark/deploy/worker/Worker.scala
@@ -1,19 +1,17 @@
package spark.deploy.worker
import scala.collection.mutable.{ArrayBuffer, HashMap}
-import akka.actor.{ActorRef, Props, Actor}
+import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
import spark.{Logging, Utils}
import spark.util.AkkaUtils
import spark.deploy._
-import akka.remote.RemoteClientLifeCycleEvent
+import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
import java.text.SimpleDateFormat
import java.util.Date
-import akka.remote.RemoteClientShutdown
-import akka.remote.RemoteClientDisconnected
import spark.deploy.RegisterWorker
import spark.deploy.LaunchExecutor
import spark.deploy.RegisterWorkerFailed
-import akka.actor.Terminated
+import spark.deploy.master.Master
import java.io.File
private[spark] class Worker(
@@ -27,7 +25,6 @@ private[spark] class Worker(
extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
- val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r
var master: ActorRef = null
var masterWebUiUrl : String = ""
@@ -48,11 +45,7 @@ private[spark] class Worker(
def memoryFree: Int = memory - memoryUsed
def createWorkDir() {
- workDir = if (workDirPath != null) {
- new File(workDirPath)
- } else {
- new File(sparkHome, "work")
- }
+ workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work"))
try {
if (!workDir.exists() && !workDir.mkdirs()) {
logError("Failed to create work directory " + workDir)
@@ -68,8 +61,7 @@ private[spark] class Worker(
override def preStart() {
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
ip, port, cores, Utils.memoryMegabytesToString(memory)))
- val envVar = System.getenv("SPARK_HOME")
- sparkHome = new File(if (envVar == null) "." else envVar)
+ sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
logInfo("Spark home: " + sparkHome)
createWorkDir()
connectToMaster()
@@ -77,24 +69,15 @@ private[spark] class Worker(
}
def connectToMaster() {
- masterUrl match {
- case MASTER_REGEX(masterHost, masterPort) => {
- logInfo("Connecting to master spark://" + masterHost + ":" + masterPort)
- val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
- try {
- master = context.actorFor(akkaUrl)
- master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
- } catch {
- case e: Exception =>
- logError("Failed to connect to master", e)
- System.exit(1)
- }
- }
-
- case _ =>
- logError("Invalid master URL: " + masterUrl)
+ logInfo("Connecting to master " + masterUrl)
+ try {
+ master = context.actorFor(Master.toAkkaUrl(masterUrl))
+ master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
+ context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.watch(master) // Doesn't work with remote actors, but useful for testing
+ } catch {
+ case e: Exception =>
+ logError("Failed to connect to master", e)
System.exit(1)
}
}
@@ -119,10 +102,10 @@ private[spark] class Worker(
logError("Worker registration failed: " + message)
System.exit(1)
- case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_) =>
+ case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) =>
logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name))
val manager = new ExecutorRunner(
- jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, sparkHome, workDir)
+ jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
executors(jobId + "/" + execId) = manager
manager.start()
coresUsed += cores_
@@ -134,7 +117,9 @@ private[spark] class Worker(
val fullId = jobId + "/" + execId
if (ExecutorState.isFinished(state)) {
val executor = executors(fullId)
- logInfo("Executor " + fullId + " finished with state " + state)
+ logInfo("Executor " + fullId + " finished with state " + state +
+ message.map(" message " + _).getOrElse("") +
+ exitStatus.map(" exitStatus " + _).getOrElse(""))
finishedExecutors(fullId) = executor
executors -= fullId
coresUsed -= executor.cores
@@ -143,9 +128,13 @@ private[spark] class Worker(
case KillExecutor(jobId, execId) =>
val fullId = jobId + "/" + execId
- val executor = executors(fullId)
- logInfo("Asked to kill executor " + fullId)
- executor.kill()
+ executors.get(fullId) match {
+ case Some(executor) =>
+ logInfo("Asked to kill executor " + fullId)
+ executor.kill()
+ case None =>
+ logInfo("Asked to kill unknown executor " + fullId)
+ }
case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
masterDisconnected()
@@ -177,11 +166,19 @@ private[spark] class Worker(
private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
- val actor = actorSystem.actorOf(
- Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory,
- args.master, args.workDir)),
- name = "Worker")
+ val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores,
+ args.memory, args.master, args.workDir)
actorSystem.awaitTermination()
}
+
+ def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int,
+ masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
+ // The LocalSparkCluster runs multiple local sparkWorkerX actor systems
+ val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
+ val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory,
+ masterUrl, workDir)), name = "Worker")
+ (actorSystem, boundPort)
+ }
+
}
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
index f9489d99fc..ef81f072a3 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
@@ -13,12 +13,15 @@ import cc.spray.typeconversion.SprayJsonSupport._
import spark.deploy.{WorkerState, RequestWorkerState}
import spark.deploy.JsonProtocol._
+/**
+ * Web UI server for the standalone worker.
+ */
private[spark]
class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives {
val RESOURCE_DIR = "spark/deploy/worker/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
- implicit val timeout = Timeout(1 seconds)
+ implicit val timeout = Timeout(10 seconds)
val handler = {
get {
@@ -50,5 +53,4 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct
getFromResourceDirectory(RESOURCE_DIR)
}
}
-
}
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index 2552958d27..bd21ba719a 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -30,7 +30,7 @@ private[spark] class Executor extends Logging {
initLogging()
- def initialize(slaveHostname: String, properties: Seq[(String, String)]) {
+ def initialize(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) {
// Make sure the local hostname we report matches the cluster scheduler's name for this host
Utils.setCustomHostname(slaveHostname)
@@ -64,7 +64,7 @@ private[spark] class Executor extends Logging {
)
// Initialize Spark environment (using system properties read above)
- env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false)
+ env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
SparkEnv.set(env)
// Start worker thread pool
@@ -159,22 +159,24 @@ private[spark] class Executor extends Logging {
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
- // Fetch missing dependencies
- for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
- currentFiles(name) = timestamp
- }
- for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
- currentJars(name) = timestamp
- // Add it to our class loader
- val localName = name.split("/").last
- val url = new File(".", localName).toURI.toURL
- if (!urlClassLoader.getURLs.contains(url)) {
- logInfo("Adding " + url + " to class loader")
- urlClassLoader.addURL(url)
+ synchronized {
+ // Fetch missing dependencies
+ for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ currentFiles(name) = timestamp
+ }
+ for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ currentJars(name) = timestamp
+ // Add it to our class loader
+ val localName = name.split("/").last
+ val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
+ if (!urlClassLoader.getURLs.contains(url)) {
+ logInfo("Adding " + url + " to class loader")
+ urlClassLoader.addURL(url)
+ }
}
}
}
diff --git a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala
index eeab3959c6..818d6d1dda 100644
--- a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala
@@ -29,9 +29,14 @@ private[spark] class MesosExecutorBackend(executor: Executor)
executorInfo: ExecutorInfo,
frameworkInfo: FrameworkInfo,
slaveInfo: SlaveInfo) {
+ logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
this.driver = driver
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
- executor.initialize(slaveInfo.getHostname, properties)
+ executor.initialize(
+ executorInfo.getExecutorId.getValue,
+ slaveInfo.getHostname,
+ properties
+ )
}
override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
index a29bf974d2..224c126fdd 100644
--- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
@@ -4,75 +4,72 @@ import java.nio.ByteBuffer
import spark.Logging
import spark.TaskState.TaskState
import spark.util.AkkaUtils
-import akka.actor.{ActorRef, Actor, Props}
+import akka.actor.{ActorRef, Actor, Props, Terminated}
+import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
import java.util.concurrent.{TimeUnit, ThreadPoolExecutor, SynchronousQueue}
-import akka.remote.RemoteClientLifeCycleEvent
import spark.scheduler.cluster._
-import spark.scheduler.cluster.RegisteredSlave
+import spark.scheduler.cluster.RegisteredExecutor
import spark.scheduler.cluster.LaunchTask
-import spark.scheduler.cluster.RegisterSlaveFailed
-import spark.scheduler.cluster.RegisterSlave
-
+import spark.scheduler.cluster.RegisterExecutorFailed
+import spark.scheduler.cluster.RegisterExecutor
private[spark] class StandaloneExecutorBackend(
executor: Executor,
- masterUrl: String,
- slaveId: String,
+ driverUrl: String,
+ executorId: String,
hostname: String,
cores: Int)
extends Actor
with ExecutorBackend
with Logging {
- var master: ActorRef = null
+ var driver: ActorRef = null
override def preStart() {
- try {
- logInfo("Connecting to master: " + masterUrl)
- master = context.actorFor(masterUrl)
- master ! RegisterSlave(slaveId, hostname, cores)
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
- } catch {
- case e: Exception =>
- logError("Failed to connect to master", e)
- System.exit(1)
- }
+ logInfo("Connecting to driver: " + driverUrl)
+ driver = context.actorFor(driverUrl)
+ driver ! RegisterExecutor(executorId, hostname, cores)
+ context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.watch(driver) // Doesn't work with remote actors, but useful for testing
}
override def receive = {
- case RegisteredSlave(sparkProperties) =>
- logInfo("Successfully registered with master")
- executor.initialize(hostname, sparkProperties)
+ case RegisteredExecutor(sparkProperties) =>
+ logInfo("Successfully registered with driver")
+ executor.initialize(executorId, hostname, sparkProperties)
- case RegisterSlaveFailed(message) =>
+ case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
System.exit(1)
case LaunchTask(taskDesc) =>
logInfo("Got assigned task " + taskDesc.taskId)
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
+
+ case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
+ logError("Driver terminated or disconnected! Shutting down.")
+ System.exit(1)
}
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
- master ! StatusUpdate(slaveId, taskId, state, data)
+ driver ! StatusUpdate(executorId, taskId, state, data)
}
}
private[spark] object StandaloneExecutorBackend {
- def run(masterUrl: String, slaveId: String, hostname: String, cores: Int) {
+ def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0)
val actor = actorSystem.actorOf(
- Props(new StandaloneExecutorBackend(new Executor, masterUrl, slaveId, hostname, cores)),
+ Props(new StandaloneExecutorBackend(new Executor, driverUrl, executorId, hostname, cores)),
name = "Executor")
actorSystem.awaitTermination()
}
def main(args: Array[String]) {
if (args.length != 4) {
- System.err.println("Usage: StandaloneExecutorBackend <master> <slaveId> <hostname> <cores>")
+ System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores>")
System.exit(1)
}
run(args(0), args(1), args(2), args(3).toInt)
diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
index c193bf7c8d..cd5b7d57f3 100644
--- a/core/src/main/scala/spark/network/Connection.scala
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -12,7 +12,14 @@ import java.net._
private[spark]
-abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging {
+abstract class Connection(val channel: SocketChannel, val selector: Selector,
+ val remoteConnectionManagerId: ConnectionManagerId) extends Logging {
+ def this(channel_ : SocketChannel, selector_ : Selector) = {
+ this(channel_, selector_,
+ ConnectionManagerId.fromSocketAddress(
+ channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
+ ))
+ }
channel.configureBlocking(false)
channel.socket.setTcpNoDelay(true)
@@ -25,7 +32,6 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
val remoteAddress = getRemoteAddress()
- val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
def key() = channel.keyFor(selector)
@@ -103,8 +109,9 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
}
-private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector)
-extends Connection(SocketChannel.open, selector_) {
+private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
+ remoteId_ : ConnectionManagerId)
+extends Connection(SocketChannel.open, selector_, remoteId_) {
class Outbox(fair: Int = 0) {
val messages = new Queue[Message]()
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index 36c01ad629..c7f226044d 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -52,9 +52,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
val sendMessageRequests = new Queue[(Message, SendingConnection)]
- implicit val futureExecContext = ExecutionContext.fromExecutor(
- Executors.newCachedThreadPool(DaemonThreadFactory))
-
+ implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
+
var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
serverChannel.configureBlocking(false)
@@ -300,7 +299,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
- val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, new SendingConnection(inetSocketAddress, selector))
+ val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId,
+ new SendingConnection(inetSocketAddress, selector, connectionManagerId))
newConnection
}
val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
index 42f46e06ed..24b4909380 100644
--- a/core/src/main/scala/spark/partial/ApproximateActionListener.scala
+++ b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
@@ -32,7 +32,7 @@ private[spark] class ApproximateActionListener[T, U, R](
if (finishedTasks == totalTasks) {
// If we had already returned a PartialResult, set its final value
resultObject.foreach(r => r.setFinalValue(evaluator.currentResult()))
- // Notify any waiting thread that may have called getResult
+ // Notify any waiting thread that may have called awaitResult
this.notifyAll()
}
}
@@ -49,7 +49,7 @@ private[spark] class ApproximateActionListener[T, U, R](
* Waits for up to timeout milliseconds since the listener was created and then returns a
* PartialResult with the result so far. This may be complete if the whole job is done.
*/
- def getResult(): PartialResult[R] = synchronized {
+ def awaitResult(): PartialResult[R] = synchronized {
val finishTime = startTime + timeout
while (true) {
val time = System.currentTimeMillis()
diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index b1095a52b4..2c022f88e0 100644
--- a/core/src/main/scala/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -11,13 +11,11 @@ private[spark]
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
extends RDD[T](sc, Nil) {
- @transient
- var splits_ : Array[Split] = (0 until blockIds.size).map(i => {
+ @transient var splits_ : Array[Split] = (0 until blockIds.size).map(i => {
new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split]
}).toArray
- @transient
- lazy val locations_ = {
+ @transient lazy val locations_ = {
val blockManager = SparkEnv.get.blockManager
/*val locations = blockIds.map(id => blockManager.getLocations(id))*/
val locations = blockManager.getLocations(blockIds)
diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
index 79e7c24e7c..0f9ca06531 100644
--- a/core/src/main/scala/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -1,7 +1,7 @@
package spark.rdd
import java.io.{ObjectOutputStream, IOException}
-import spark.{OneToOneDependency, NarrowDependency, RDD, SparkContext, Split, TaskContext}
+import spark._
private[spark]
@@ -35,8 +35,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
val numSplitsInRdd2 = rdd2.splits.size
- @transient
- var splits_ = {
+ override def getSplits: Array[Split] = {
// create the cross product split
val array = new Array[Split](rdd1.splits.size * rdd2.splits.size)
for (s1 <- rdd1.splits; s2 <- rdd2.splits) {
@@ -46,8 +45,6 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
array
}
- override def getSplits = splits_
-
override def getPreferredLocations(split: Split) = {
val currSplit = split.asInstanceOf[CartesianSplit]
rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
@@ -59,7 +56,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
}
- var deps_ = List(
+ override def getDependencies: Seq[Dependency[_]] = List(
new NarrowDependency(rdd1) {
def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2)
},
@@ -68,11 +65,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
}
)
- override def getDependencies = deps_
-
override def clearDependencies() {
- deps_ = Nil
- splits_ = null
rdd1 = null
rdd2 = null
}
diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
index 6f00f6ac73..96b593ba7c 100644
--- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
@@ -9,23 +9,26 @@ import org.apache.hadoop.fs.Path
import java.io.{File, IOException, EOFException}
import java.text.NumberFormat
-private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split {
- override val index: Int = idx
-}
+private[spark] class CheckpointRDDSplit(val index: Int) extends Split {}
/**
* This RDD represents a RDD checkpoint file (similar to HadoopRDD).
*/
private[spark]
-class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String)
+class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String)
extends RDD[T](sc, Nil) {
- @transient val path = new Path(checkpointPath)
- @transient val fs = path.getFileSystem(new Configuration())
+ @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
@transient val splits_ : Array[Split] = {
- val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted
- splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray
+ val dirContents = fs.listStatus(new Path(checkpointPath))
+ val splitFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
+ val numSplits = splitFiles.size
+ if (!splitFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
+ !splitFiles(numSplits-1).endsWith(CheckpointRDD.splitIdToFile(numSplits-1))) {
+ throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
+ }
+ Array.tabulate(numSplits)(i => new CheckpointRDDSplit(i))
}
checkpointData = Some(new RDDCheckpointData[T](this))
@@ -34,36 +37,34 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String)
override def getSplits = splits_
override def getPreferredLocations(split: Split): Seq[String] = {
- val status = fs.getFileStatus(path)
+ val status = fs.getFileStatus(new Path(checkpointPath))
val locations = fs.getFileBlockLocations(status, 0, status.getLen)
- locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
+ locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
}
override def compute(split: Split, context: TaskContext): Iterator[T] = {
- CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile, context)
+ val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))
+ CheckpointRDD.readFromFile(file, context)
}
override def checkpoint() {
- // Do nothing. Hadoop RDD should not be checkpointed.
+ // Do nothing. CheckpointRDD should not be checkpointed.
}
}
private[spark] object CheckpointRDD extends Logging {
- def splitIdToFileName(splitId: Int): String = {
- val numfmt = NumberFormat.getInstance()
- numfmt.setMinimumIntegerDigits(5)
- numfmt.setGroupingUsed(false)
- "part-" + numfmt.format(splitId)
+ def splitIdToFile(splitId: Int): String = {
+ "part-%05d".format(splitId)
}
- def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) {
+ def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
val outputDir = new Path(path)
val fs = outputDir.getFileSystem(new Configuration())
- val finalOutputName = splitIdToFileName(context.splitId)
+ val finalOutputName = splitIdToFile(ctx.splitId)
val finalOutputPath = new Path(outputDir, finalOutputName)
- val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId)
+ val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
if (fs.exists(tempOutputPath)) {
throw new IOException("Checkpoint failed: temporary path " +
@@ -83,22 +84,22 @@ private[spark] object CheckpointRDD extends Logging {
serializeStream.close()
if (!fs.rename(tempOutputPath, finalOutputPath)) {
- if (!fs.delete(finalOutputPath, true)) {
- throw new IOException("Checkpoint failed: failed to delete earlier output of task "
- + context.attemptId)
- }
- if (!fs.rename(tempOutputPath, finalOutputPath)) {
+ if (!fs.exists(finalOutputPath)) {
+ fs.delete(tempOutputPath, false)
throw new IOException("Checkpoint failed: failed to save output of task: "
- + context.attemptId)
+ + ctx.attemptId + " and final output path does not exist")
+ } else {
+ // Some other copy of this task must've finished before us and renamed it
+ logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
+ fs.delete(tempOutputPath, false)
}
}
}
- def readFromFile[T](path: String, context: TaskContext): Iterator[T] = {
- val inputPath = new Path(path)
- val fs = inputPath.getFileSystem(new Configuration())
+ def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
+ val fs = path.getFileSystem(new Configuration())
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
- val fileInputStream = fs.open(inputPath, bufferSize)
+ val fileInputStream = fs.open(path, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance()
val deserializeStream = serializer.deserializeStream(fileInputStream)
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 1d528be2aa..4893fe8d78 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -45,8 +45,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
val aggr = new CoGroupAggregator
- @transient
- var deps_ = {
+ @transient var deps_ = {
val deps = new ArrayBuffer[Dependency[_]]
for ((rdd, index) <- rdds.zipWithIndex) {
if (rdd.partitioner == Some(part)) {
@@ -63,8 +62,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
override def getDependencies = deps_
- @transient
- var splits_ : Array[Split] = {
+ @transient var splits_ : Array[Split] = {
val array = new Array[Split](part.numPartitions)
for (i <- 0 until array.size) {
array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) =>
@@ -86,6 +84,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit]
val numRdds = split.deps.size
+ // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
val seq = map.get(k)
@@ -106,13 +105,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
}
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
- def mergePair(pair: (K, Seq[Any])) {
- val mySeq = getSeq(pair._1)
- for (v <- pair._2)
- mySeq(depNum) += v
- }
val fetcher = SparkEnv.get.shuffleFetcher
- fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair)
+ for ((k, vs) <- fetcher.fetch[K, Seq[Any]](shuffleId, split.index)) {
+ getSeq(k)(depNum) ++= vs
+ }
}
}
JavaConversions.mapAsScalaMap(map).iterator
diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
index 167755bbba..4c57434b65 100644
--- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -27,11 +27,11 @@ private[spark] case class CoalescedRDDSplit(
* or to avoid having a large number of small tasks when processing a directory with many files.
*/
class CoalescedRDD[T: ClassManifest](
- var prev: RDD[T],
+ @transient var prev: RDD[T],
maxPartitions: Int)
- extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs
+ extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
- @transient var splits_ : Array[Split] = {
+ override def getSplits: Array[Split] = {
val prevSplits = prev.splits
if (prevSplits.length < maxPartitions) {
prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) }
@@ -44,26 +44,20 @@ class CoalescedRDD[T: ClassManifest](
}
}
- override def getSplits = splits_
-
override def compute(split: Split, context: TaskContext): Iterator[T] = {
split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit =>
firstParent[T].iterator(parentSplit, context)
}
}
- var deps_ : List[Dependency[_]] = List(
+ override def getDependencies: Seq[Dependency[_]] = List(
new NarrowDependency(prev) {
def getParents(id: Int): Seq[Int] =
splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices
}
)
- override def getDependencies() = deps_
-
override def clearDependencies() {
- deps_ = Nil
- splits_ = null
prev = null
}
}
diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala
index c6ceb272cd..5466c9c657 100644
--- a/core/src/main/scala/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/MappedRDD.scala
@@ -3,13 +3,11 @@ package spark.rdd
import spark.{RDD, Split, TaskContext}
private[spark]
-class MappedRDD[U: ClassManifest, T: ClassManifest](
- prev: RDD[T],
- f: T => U)
+class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U)
extends RDD[U](prev) {
override def getSplits = firstParent[T].splits
override def compute(split: Split, context: TaskContext) =
firstParent[T].iterator(split, context).map(f)
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index bb22db073c..c3b155fcbd 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -37,11 +37,9 @@ class NewHadoopRDD[K, V](
formatter.format(new Date())
}
- @transient
- private val jobId = new JobID(jobtrackerId, id)
+ @transient private val jobId = new JobID(jobtrackerId, id)
- @transient
- private val splits_ : Array[Split] = {
+ @transient private val splits_ : Array[Split] = {
val inputFormat = inputFormatClass.newInstance
val jobContext = newJobContext(conf, jobId)
val rawSplits = inputFormat.getSplits(jobContext).toArray
diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
new file mode 100644
index 0000000000..a50ce75171
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
@@ -0,0 +1,42 @@
+package spark.rdd
+
+import spark.{NarrowDependency, RDD, SparkEnv, Split, TaskContext}
+
+
+class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split {
+ override val index = idx
+}
+
+
+/**
+ * Represents a dependency between the PartitionPruningRDD and its parent. In this
+ * case, the child RDD contains a subset of partitions of the parents'.
+ */
+class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean)
+ extends NarrowDependency[T](rdd) {
+
+ @transient
+ val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index))
+ .zipWithIndex.map { case(split, idx) => new PartitionPruningRDDSplit(idx, split) : Split }
+
+ override def getParents(partitionId: Int) = List(partitions(partitionId).index)
+}
+
+
+/**
+ * A RDD used to prune RDD partitions/splits so we can avoid launching tasks on
+ * all partitions. An example use case: If we know the RDD is partitioned by range,
+ * and the execution DAG has a filter on the key, we can avoid launching tasks
+ * on partitions that don't have the range covering the key.
+ */
+class PartitionPruningRDD[T: ClassManifest](
+ @transient prev: RDD[T],
+ @transient partitionFilterFunc: Int => Boolean)
+ extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) {
+
+ override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(
+ split.asInstanceOf[PartitionPruningRDDSplit].parentSplit, context)
+
+ override protected def getSplits =
+ getDependencies.head.asInstanceOf[PruneDependency[T]].partitions
+}
diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala
index 1bc9c96112..e24ad23b21 100644
--- a/core/src/main/scala/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/spark/rdd/SampledRDD.scala
@@ -19,13 +19,12 @@ class SampledRDD[T: ClassManifest](
seed: Int)
extends RDD[T](prev) {
- @transient
- var splits_ : Array[Split] = {
+ @transient var splits_ : Array[Split] = {
val rg = new Random(seed)
firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt))
}
- override def getSplits = splits_.asInstanceOf[Array[Split]]
+ override def getSplits = splits_
override def getPreferredLocations(split: Split) =
firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 1b219473e0..d396478673 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -22,17 +22,10 @@ class ShuffledRDD[K, V](
override val partitioner = Some(part)
- @transient
- var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
-
- override def getSplits = splits_
+ override def getSplits = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index)
}
-
- override def clearDependencies() {
- splits_ = null
- }
}
diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala
index 24a085df02..26a2d511f2 100644
--- a/core/src/main/scala/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/spark/rdd/UnionRDD.scala
@@ -26,10 +26,9 @@ private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIn
class UnionRDD[T: ClassManifest](
sc: SparkContext,
@transient var rdds: Seq[RDD[T]])
- extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs
+ extends RDD[T](sc, Nil) { // Nil since we implement getDependencies
- @transient
- var splits_ : Array[Split] = {
+ override def getSplits: Array[Split] = {
val array = new Array[Split](rdds.map(_.splits.size).sum)
var pos = 0
for (rdd <- rdds; split <- rdd.splits) {
@@ -39,20 +38,16 @@ class UnionRDD[T: ClassManifest](
array
}
- override def getSplits = splits_
-
- @transient var deps_ = {
+ override def getDependencies: Seq[Dependency[_]] = {
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for (rdd <- rdds) {
deps += new RangeDependency(rdd, 0, pos, rdd.splits.size)
pos += rdd.splits.size
}
- deps.toList
+ deps
}
- override def getDependencies = deps_
-
override def compute(s: Split, context: TaskContext): Iterator[T] =
s.asInstanceOf[UnionSplit[T]].iterator(context)
@@ -60,8 +55,6 @@ class UnionRDD[T: ClassManifest](
s.asInstanceOf[UnionSplit[T]].preferredLocations()
override def clearDependencies() {
- deps_ = null
- splits_ = null
rdds = null
}
}
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
index 16e6cc0f1b..e5df6d8c72 100644
--- a/core/src/main/scala/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -32,10 +32,7 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)))
with Serializable {
- // TODO: FIX THIS.
-
- @transient
- var splits_ : Array[Split] = {
+ override def getSplits: Array[Split] = {
if (rdd1.splits.size != rdd2.splits.size) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
}
@@ -46,8 +43,6 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
array
}
- override def getSplits = splits_
-
override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = {
val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
@@ -59,7 +54,6 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
}
override def clearDependencies() {
- splits_ = null
rdd1 = null
rdd2 = null
}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 59f2099e91..319eef6978 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -23,7 +23,16 @@ import util.{MetadataCleaner, TimeStampedHashMap}
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/
private[spark]
-class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
+class DAGScheduler(
+ taskSched: TaskScheduler,
+ mapOutputTracker: MapOutputTracker,
+ blockManagerMaster: BlockManagerMaster,
+ env: SparkEnv)
+ extends TaskSchedulerListener with Logging {
+
+ def this(taskSched: TaskScheduler) {
+ this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
+ }
taskSched.setListener(this)
// Called by TaskScheduler to report task completions or failures.
@@ -35,12 +44,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
eventQueue.put(CompletionEvent(task, reason, result, accumUpdates))
}
- // Called by TaskScheduler when a host fails.
- override def hostLost(host: String) {
- eventQueue.put(HostLost(host))
+ // Called by TaskScheduler when an executor fails.
+ override def executorLost(execId: String) {
+ eventQueue.put(ExecutorLost(execId))
}
- // Called by TaskScheduler to cancel an entier TaskSet due to repeated failures.
+ // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
override def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
}
@@ -54,8 +63,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// resubmit failed stages
val POLL_TIMEOUT = 10L
- private val lock = new Object // Used for access to the entire DAGScheduler
-
private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
val nextRunId = new AtomicInteger(0)
@@ -68,12 +75,13 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
var cacheLocs = new HashMap[Int, Array[List[String]]]
- val env = SparkEnv.get
- val cacheTracker = env.cacheTracker
- val mapOutputTracker = env.mapOutputTracker
-
- val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back;
- // that's not going to be a realistic assumption in general
+ // For tracking failed nodes, we use the MapOutputTracker's generation number, which is
+ // sent with every task. When we detect a node failing, we note the current generation number
+ // and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask
+ // results.
+ // TODO: Garbage collect information about failure generations when we know there are no more
+ // stray messages to detect.
+ val failedGeneration = new HashMap[String, Long]
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
val running = new HashSet[Stage] // Stages we are running right now
@@ -87,19 +95,27 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
// Start a thread to run the DAGScheduler event loop
- new Thread("DAGScheduler") {
- setDaemon(true)
- override def run() {
- DAGScheduler.this.run()
- }
- }.start()
+ def start() {
+ new Thread("DAGScheduler") {
+ setDaemon(true)
+ override def run() {
+ DAGScheduler.this.run()
+ }
+ }.start()
+ }
- def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
+ private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
+ if (!cacheLocs.contains(rdd.id)) {
+ val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
+ cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
+ locations => locations.map(_.ip).toList
+ }.toArray
+ }
cacheLocs(rdd.id)
}
- def updateCacheLocs() {
- cacheLocs = cacheTracker.getLocationsSnapshot()
+ private def clearCacheLocs() {
+ cacheLocs.clear()
}
/**
@@ -107,7 +123,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* The priority value passed in will be used if the stage doesn't already exist with
* a lower priority (we assume that priorities always increase across jobs for now).
*/
- def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
+ private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
@@ -122,12 +138,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* as a result stage for the final RDD used directly in an action. The stage will also be given
* the provided priority.
*/
- def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
- // Kind of ugly: need to register RDDs with the cache and map output tracker here
- // since we can't do it in the RDD constructor because # of splits is unknown
- logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
- cacheTracker.registerRDD(rdd.id, rdd.splits.size)
+ private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
if (shuffleDep != None) {
+ // Kind of ugly: need to register RDDs with the cache and map output tracker here
+ // since we can't do it in the RDD constructor because # of splits is unknown
+ logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
}
val id = nextStageId.getAndIncrement()
@@ -140,7 +155,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Get or create the list of parent stages for a given RDD. The stages will be assigned the
* provided priority if they haven't already been created with a lower priority.
*/
- def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
+ private def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
val parents = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(r: RDD[_]) {
@@ -148,8 +163,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visited += r
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown
- logInfo("Registering parent RDD " + r.id + " (" + r.origin + ")")
- cacheTracker.registerRDD(r.id, r.splits.size)
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
@@ -164,25 +177,22 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
parents.toList
}
- def getMissingParentStages(stage: Stage): List[Stage] = {
+ private def getMissingParentStages(stage: Stage): List[Stage] = {
val missing = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
- val locs = getCacheLocs(rdd)
- for (p <- 0 until rdd.splits.size) {
- if (locs(p) == Nil) {
- for (dep <- rdd.dependencies) {
- dep match {
- case shufDep: ShuffleDependency[_,_] =>
- val mapStage = getShuffleMapStage(shufDep, stage.priority)
- if (!mapStage.isAvailable) {
- missing += mapStage
- }
- case narrowDep: NarrowDependency[_] =>
- visit(narrowDep.rdd)
- }
+ if (getCacheLocs(rdd).contains(Nil)) {
+ for (dep <- rdd.dependencies) {
+ dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ val mapStage = getShuffleMapStage(shufDep, stage.priority)
+ if (!mapStage.isAvailable) {
+ missing += mapStage
+ }
+ case narrowDep: NarrowDependency[_] =>
+ visit(narrowDep.rdd)
}
}
}
@@ -192,23 +202,45 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
missing.toList
}
+ /**
+ * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
+ * JobWaiter whose getResult() method will return the result of the job when it is complete.
+ *
+ * The job is assumed to have at least one partition; zero partition jobs should be handled
+ * without a JobSubmitted event.
+ */
+ private[scheduler] def prepareJob[T, U: ClassManifest](
+ finalRdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ callSite: String,
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit)
+ : (JobSubmitted, JobWaiter[U]) =
+ {
+ assert(partitions.size > 0)
+ val waiter = new JobWaiter(partitions.size, resultHandler)
+ val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
+ val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)
+ return (toSubmit, waiter)
+ }
+
def runJob[T, U: ClassManifest](
finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
- allowLocal: Boolean)
- : Array[U] =
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit)
{
if (partitions.size == 0) {
- return new Array[U](0)
+ return
}
- val waiter = new JobWaiter(partitions.size)
- val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter))
- waiter.getResult() match {
- case JobSucceeded(results: Seq[_]) =>
- return results.asInstanceOf[Seq[U]].toArray
+ val (toSubmit, waiter) = prepareJob(
+ finalRdd, func, partitions, callSite, allowLocal, resultHandler)
+ eventQueue.put(toSubmit)
+ waiter.awaitResult() match {
+ case JobSucceeded => {}
case JobFailed(exception: Exception) =>
logInfo("Failed to run " + callSite)
throw exception
@@ -227,90 +259,117 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.splits.size).toArray
eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener))
- return listener.getResult() // Will throw an exception if the job fails
+ return listener.awaitResult() // Will throw an exception if the job fails
+ }
+
+ /**
+ * Process one event retrieved from the event queue.
+ * Returns true if we should stop the event loop.
+ */
+ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
+ event match {
+ case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
+ val runId = nextRunId.getAndIncrement()
+ val finalStage = newStage(finalRDD, None, runId)
+ val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
+ clearCacheLocs()
+ logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
+ " output partitions (allowLocal=" + allowLocal + ")")
+ logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
+ logInfo("Parents of final stage: " + finalStage.parents)
+ logInfo("Missing parents: " + getMissingParentStages(finalStage))
+ if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
+ // Compute very short actions like first() or take() with no parent stages locally.
+ runLocally(job)
+ } else {
+ activeJobs += job
+ resultStageToJob(finalStage) = job
+ submitStage(finalStage)
+ }
+
+ case ExecutorLost(execId) =>
+ handleExecutorLost(execId)
+
+ case completion: CompletionEvent =>
+ handleTaskCompletion(completion)
+
+ case TaskSetFailed(taskSet, reason) =>
+ abortStage(idToStage(taskSet.stageId), reason)
+
+ case StopDAGScheduler =>
+ // Cancel any active jobs
+ for (job <- activeJobs) {
+ val error = new SparkException("Job cancelled because SparkContext was shut down")
+ job.listener.jobFailed(error)
+ }
+ return true
+ }
+ return false
}
/**
+ * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since
+ * the last fetch failure.
+ */
+ private[scheduler] def resubmitFailedStages() {
+ logInfo("Resubmitting failed stages")
+ clearCacheLocs()
+ val failed2 = failed.toArray
+ failed.clear()
+ for (stage <- failed2.sortBy(_.priority)) {
+ submitStage(stage)
+ }
+ }
+
+ /**
+ * Check for waiting or failed stages which are now eligible for resubmission.
+ * Ordinarily run on every iteration of the event loop.
+ */
+ private[scheduler] def submitWaitingStages() {
+ // TODO: We might want to run this less often, when we are sure that something has become
+ // runnable that wasn't before.
+ logTrace("Checking for newly runnable parent stages")
+ logTrace("running: " + running)
+ logTrace("waiting: " + waiting)
+ logTrace("failed: " + failed)
+ val waiting2 = waiting.toArray
+ waiting.clear()
+ for (stage <- waiting2.sortBy(_.priority)) {
+ submitStage(stage)
+ }
+ }
+
+
+ /**
* The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
* events and responds by launching tasks. This runs in a dedicated thread and receives events
* via the eventQueue.
*/
- def run() {
+ private def run() {
SparkEnv.set(env)
while (true) {
val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
- val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
if (event != null) {
logDebug("Got event of type " + event.getClass.getName)
}
- event match {
- case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
- val runId = nextRunId.getAndIncrement()
- val finalStage = newStage(finalRDD, None, runId)
- val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
- updateCacheLocs()
- logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
- " output partitions")
- logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
- logInfo("Parents of final stage: " + finalStage.parents)
- logInfo("Missing parents: " + getMissingParentStages(finalStage))
- if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
- // Compute very short actions like first() or take() with no parent stages locally.
- runLocally(job)
- } else {
- activeJobs += job
- resultStageToJob(finalStage) = job
- submitStage(finalStage)
- }
-
- case HostLost(host) =>
- handleHostLost(host)
-
- case completion: CompletionEvent =>
- handleTaskCompletion(completion)
-
- case TaskSetFailed(taskSet, reason) =>
- abortStage(idToStage(taskSet.stageId), reason)
-
- case StopDAGScheduler =>
- // Cancel any active jobs
- for (job <- activeJobs) {
- val error = new SparkException("Job cancelled because SparkContext was shut down")
- job.listener.jobFailed(error)
- }
+ if (event != null) {
+ if (processEvent(event)) {
return
-
- case null =>
- // queue.poll() timed out, ignore it
+ }
}
+ val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
// Periodically resubmit failed stages if some map output fetches have failed and we have
// waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
// tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
// the same time, so we want to make sure we've identified all the reduce tasks that depend
// on the failed node.
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
- logInfo("Resubmitting failed stages")
- updateCacheLocs()
- val failed2 = failed.toArray
- failed.clear()
- for (stage <- failed2.sortBy(_.priority)) {
- submitStage(stage)
- }
+ resubmitFailedStages()
} else {
- // TODO: We might want to run this less often, when we are sure that something has become
- // runnable that wasn't before.
- logDebug("Checking for newly runnable parent stages")
- logDebug("running: " + running)
- logDebug("waiting: " + waiting)
- logDebug("failed: " + failed)
- val waiting2 = waiting.toArray
- waiting.clear()
- for (stage <- waiting2.sortBy(_.priority)) {
- submitStage(stage)
- }
+ submitWaitingStages()
}
}
}
@@ -320,7 +379,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* We run the operation in a separate thread just in case it takes a bunch of time, so that we
* don't block the DAGScheduler event loop or other concurrent jobs.
*/
- def runLocally(job: ActiveJob) {
+ private def runLocally(job: ActiveJob) {
logInfo("Computing the requested partition locally")
new Thread("Local computation of job " + job.runId) {
override def run() {
@@ -329,9 +388,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val rdd = job.finalStage.rdd
val split = rdd.splits(job.partitions(0))
val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
- val result = job.func(taskContext, rdd.iterator(split, taskContext))
- taskContext.executeOnCompleteCallbacks()
- job.listener.taskSucceeded(0, result)
+ try {
+ val result = job.func(taskContext, rdd.iterator(split, taskContext))
+ job.listener.taskSucceeded(0, result)
+ } finally {
+ taskContext.executeOnCompleteCallbacks()
+ }
} catch {
case e: Exception =>
job.listener.jobFailed(e)
@@ -340,13 +402,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}.start()
}
- def submitStage(stage: Stage) {
+ /** Submits stage, but first recursively submits any missing parents. */
+ private def submitStage(stage: Stage) {
logDebug("submitStage(" + stage + ")")
if (!waiting(stage) && !running(stage) && !failed(stage)) {
val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing)
if (missing == Nil) {
- logInfo("Submitting " + stage + " (" + stage.origin + "), which has no missing parents")
+ logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
submitMissingTasks(stage)
running += stage
} else {
@@ -358,7 +421,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
- def submitMissingTasks(stage: Stage) {
+ /** Called when stage's parents are available and we can now do its task. */
+ private def submitMissingTasks(stage: Stage) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
@@ -379,11 +443,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
if (tasks.size > 0) {
- logInfo("Submitting " + tasks.size + " missing tasks from " + stage)
+ logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
myPending ++= tasks
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority))
+ if (!stage.submissionTime.isDefined) {
+ stage.submissionTime = Some(System.currentTimeMillis())
+ }
} else {
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@@ -395,9 +462,18 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
*/
- def handleTaskCompletion(event: CompletionEvent) {
+ private def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
val stage = idToStage(task.stageId)
+
+ def markStageAsFinished(stage: Stage) = {
+ val serviceTime = stage.submissionTime match {
+ case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
+ case _ => "Unkown"
+ }
+ logInfo("%s (%s) finished in %s s".format(stage, stage.origin, serviceTime))
+ running -= stage
+ }
event.reason match {
case Success =>
logInfo("Completed " + task)
@@ -412,13 +488,13 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
if (!job.finished(rt.outputId)) {
job.finished(rt.outputId) = true
job.numFinished += 1
- job.listener.taskSucceeded(rt.outputId, event.result)
// If the whole job has finished, remove it
if (job.numFinished == job.numPartitions) {
activeJobs -= job
resultStageToJob -= stage
- running -= stage
+ markStageAsFinished(stage)
}
+ job.listener.taskSucceeded(rt.outputId, event.result)
}
case None =>
logInfo("Ignoring result from " + rt + " because its job has finished")
@@ -427,23 +503,32 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
case smt: ShuffleMapTask =>
val stage = idToStage(smt.stageId)
val status = event.result.asInstanceOf[MapStatus]
- val host = status.address.ip
- logInfo("ShuffleMapTask finished with host " + host)
- if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos
+ val execId = status.location.executorId
+ logDebug("ShuffleMapTask finished on " + execId)
+ if (failedGeneration.contains(execId) && smt.generation <= failedGeneration(execId)) {
+ logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
+ } else {
stage.addOutputLoc(smt.partition, status)
}
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
- logInfo(stage + " (" + stage.origin + ") finished; looking for newly runnable stages")
- running -= stage
+ markStageAsFinished(stage)
+ logInfo("looking for newly runnable stages")
logInfo("running: " + running)
logInfo("waiting: " + waiting)
logInfo("failed: " + failed)
if (stage.shuffleDep != None) {
+ // We supply true to increment the generation number here in case this is a
+ // recomputation of the map outputs. In that case, some nodes may have cached
+ // locations with holes (from when we detected the error) and will need the
+ // generation incremented to refetch them.
+ // TODO: Only increment the generation number if this is not the first time
+ // we registered these map outputs.
mapOutputTracker.registerMapOutputs(
stage.shuffleDep.get.shuffleId,
- stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray)
+ stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray,
+ true)
}
- updateCacheLocs()
+ clearCacheLocs()
if (stage.outputLocs.count(_ == Nil) != 0) {
// Some tasks had failed; let's resubmit this stage
// TODO: Lower-level scheduler should also deal with this
@@ -462,7 +547,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
waiting --= newlyRunnable
running ++= newlyRunnable
for (stage <- newlyRunnable.sortBy(_.id)) {
- logInfo("Submitting " + stage + " (" + stage.origin + "), which is now runnable")
+ logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable")
submitMissingTasks(stage)
}
}
@@ -493,9 +578,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// Remember that a fetch failed now; this is used to resubmit the broken
// stages later, after a small wait (to give other tasks the chance to fail)
lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock
- // TODO: mark the host as failed only if there were lots of fetch failures on it
+ // TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
- handleHostLost(bmAddress.ip)
+ handleExecutorLost(bmAddress.executorId, Some(task.generation))
}
case other =>
@@ -505,22 +590,31 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
/**
- * Responds to a host being lost. This is called inside the event loop so it assumes that it can
- * modify the scheduler's internal state. Use hostLost() to post a host lost event from outside.
+ * Responds to an executor being lost. This is called inside the event loop, so it assumes it can
+ * modify the scheduler's internal state. Use executorLost() to post a loss event from outside.
+ *
+ * Optionally the generation during which the failure was caught can be passed to avoid allowing
+ * stray fetch failures from possibly retriggering the detection of a node as lost.
*/
- def handleHostLost(host: String) {
- if (!deadHosts.contains(host)) {
- logInfo("Host lost: " + host)
- deadHosts += host
- env.blockManager.master.notifyADeadHost(host)
+ private def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) {
+ val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration)
+ if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) {
+ failedGeneration(execId) = currentGeneration
+ logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration))
+ blockManagerMaster.removeExecutor(execId)
// TODO: This will be really slow if we keep accumulating shuffle map stages
for ((shuffleId, stage) <- shuffleToMapStage) {
- stage.removeOutputsOnHost(host)
+ stage.removeOutputsOnExecutor(execId)
val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
mapOutputTracker.registerMapOutputs(shuffleId, locs, true)
}
- cacheTracker.cacheLost(host)
- updateCacheLocs()
+ if (shuffleToMapStage.isEmpty) {
+ mapOutputTracker.incrementGeneration()
+ }
+ clearCacheLocs()
+ } else {
+ logDebug("Additional executor lost message for " + execId +
+ "(generation " + currentGeneration + ")")
}
}
@@ -528,7 +622,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
- def abortStage(failedStage: Stage, reason: String) {
+ private def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
@@ -544,7 +638,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
/**
* Return true if one of stage's ancestors is target.
*/
- def stageDependsOn(stage: Stage, target: Stage): Boolean = {
+ private def stageDependsOn(stage: Stage, target: Stage): Boolean = {
if (stage == target) {
return true
}
@@ -571,7 +665,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visitedRdds.contains(target.rdd)
}
- def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
+ private def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
// If the partition is cached, return the cache locations
val cached = getCacheLocs(rdd)(partition)
if (cached != Nil) {
@@ -597,7 +691,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
return Nil
}
- def cleanup(cleanupTime: Long) {
+ private def cleanup(cleanupTime: Long) {
var sizeBefore = idToStage.size
idToStage.clearOldValues(cleanupTime)
logInfo("idToStage " + sizeBefore + " --> " + idToStage.size)
diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
index 3422a21d9d..b34fa78c07 100644
--- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
@@ -28,7 +28,7 @@ private[spark] case class CompletionEvent(
accumUpdates: Map[Long, Any])
extends DAGSchedulerEvent
-private[spark] case class HostLost(host: String) extends DAGSchedulerEvent
+private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala
index c4a74e526f..654131ee84 100644
--- a/core/src/main/scala/spark/scheduler/JobResult.scala
+++ b/core/src/main/scala/spark/scheduler/JobResult.scala
@@ -5,5 +5,5 @@ package spark.scheduler
*/
private[spark] sealed trait JobResult
-private[spark] case class JobSucceeded(results: Seq[_]) extends JobResult
+private[spark] case object JobSucceeded extends JobResult
private[spark] case class JobFailed(exception: Exception) extends JobResult
diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala
index b3d4feebe5..3cc6a86345 100644
--- a/core/src/main/scala/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/spark/scheduler/JobWaiter.scala
@@ -3,10 +3,12 @@ package spark.scheduler
import scala.collection.mutable.ArrayBuffer
/**
- * An object that waits for a DAGScheduler job to complete.
+ * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
+ * results to the given handler function.
*/
-private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
- private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null)
+private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit)
+ extends JobListener {
+
private var finishedTasks = 0
private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
@@ -17,11 +19,11 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
if (jobFinished) {
throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
}
- taskResults(index) = result
+ resultHandler(index, result.asInstanceOf[T])
finishedTasks += 1
if (finishedTasks == totalTasks) {
jobFinished = true
- jobResult = JobSucceeded(taskResults)
+ jobResult = JobSucceeded
this.notifyAll()
}
}
@@ -38,7 +40,7 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
}
}
- def getResult(): JobResult = synchronized {
+ def awaitResult(): JobResult = synchronized {
while (!jobFinished) {
this.wait()
}
diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala
index 4532d9497f..203abb917b 100644
--- a/core/src/main/scala/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/spark/scheduler/MapStatus.scala
@@ -8,19 +8,19 @@ import java.io.{ObjectOutput, ObjectInput, Externalizable}
* task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
* The map output sizes are compressed using MapOutputTracker.compressSize.
*/
-private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes: Array[Byte])
+private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte])
extends Externalizable {
def this() = this(null, null) // For deserialization only
def writeExternal(out: ObjectOutput) {
- address.writeExternal(out)
+ location.writeExternal(out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
}
def readExternal(in: ObjectInput) {
- address = new BlockManagerId(in)
+ location = BlockManagerId(in)
compressedSizes = new Array[Byte](in.readInt())
in.readFully(compressedSizes)
}
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 19f5328eee..bed9f1864f 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -32,7 +32,7 @@ private[spark] object ShuffleMapTask {
return old
} else {
val out = new ByteArrayOutputStream
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objOut = ser.serializeStream(new GZIPOutputStream(out))
objOut.writeObject(rdd)
objOut.writeObject(dep)
@@ -48,7 +48,7 @@ private[spark] object ShuffleMapTask {
synchronized {
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
@@ -81,7 +81,7 @@ private[spark] class ShuffleMapTask(
with Externalizable
with Logging {
- def this() = this(0, null, null, 0, null)
+ protected def this() = this(0, null, null, 0, null)
var split = if (rdd == null) {
null
@@ -117,34 +117,33 @@ private[spark] class ShuffleMapTask(
override def run(attemptId: Long): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
- val partitioner = dep.partitioner
val taskContext = new TaskContext(stageId, partition, attemptId)
+ try {
+ // Partition the map output.
+ val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
+ for (elem <- rdd.iterator(split, taskContext)) {
+ val pair = elem.asInstanceOf[(Any, Any)]
+ val bucketId = dep.partitioner.getPartition(pair._1)
+ buckets(bucketId) += pair
+ }
- // Partition the map output.
- val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
- for (elem <- rdd.iterator(split, taskContext)) {
- val pair = elem.asInstanceOf[(Any, Any)]
- val bucketId = partitioner.getPartition(pair._1)
- buckets(bucketId) += pair
- }
- val bucketIterators = buckets.map(_.iterator)
+ val compressedSizes = new Array[Byte](numOutputSplits)
- val compressedSizes = new Array[Byte](numOutputSplits)
+ val blockManager = SparkEnv.get.blockManager
+ for (i <- 0 until numOutputSplits) {
+ val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
+ // Get a Scala iterator from Java map
+ val iter: Iterator[(Any, Any)] = buckets(i).iterator
+ val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
+ compressedSizes(i) = MapOutputTracker.compressSize(size)
+ }
- val blockManager = SparkEnv.get.blockManager
- for (i <- 0 until numOutputSplits) {
- val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
- // Get a Scala iterator from Java map
- val iter: Iterator[(Any, Any)] = bucketIterators(i)
- val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
- compressedSizes(i) = MapOutputTracker.compressSize(size)
+ return new MapStatus(blockManager.blockManagerId, compressedSizes)
+ } finally {
+ // Execute the callbacks on task completion.
+ taskContext.executeOnCompleteCallbacks()
}
-
- // Execute the callbacks on task completion.
- taskContext.executeOnCompleteCallbacks()
-
- return new MapStatus(blockManager.blockManagerId, compressedSizes)
}
override def preferredLocations: Seq[String] = locs
diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala
index 4846b66729..374114d870 100644
--- a/core/src/main/scala/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/spark/scheduler/Stage.scala
@@ -32,6 +32,9 @@ private[spark] class Stage(
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
var numAvailableOutputs = 0
+ /** When first task was submitted to scheduler. */
+ var submissionTime: Option[Long] = None
+
private var nextAttemptId = 0
def isAvailable: Boolean = {
@@ -51,18 +54,18 @@ private[spark] class Stage(
def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) {
val prevList = outputLocs(partition)
- val newList = prevList.filterNot(_.address == bmAddress)
+ val newList = prevList.filterNot(_.location == bmAddress)
outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) {
numAvailableOutputs -= 1
}
}
- def removeOutputsOnHost(host: String) {
+ def removeOutputsOnExecutor(execId: String) {
var becameUnavailable = false
for (partition <- 0 until numPartitions) {
val prevList = outputLocs(partition)
- val newList = prevList.filterNot(_.address.ip == host)
+ val newList = prevList.filterNot(_.location.executorId == execId)
outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) {
becameUnavailable = true
@@ -70,7 +73,8 @@ private[spark] class Stage(
}
}
if (becameUnavailable) {
- logInfo("%s is now unavailable on %s (%d/%d, %s)".format(this, host, numAvailableOutputs, numPartitions, isAvailable))
+ logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format(
+ this, execId, numAvailableOutputs, numPartitions, isAvailable))
}
}
@@ -82,7 +86,7 @@ private[spark] class Stage(
def origin: String = rdd.origin
- override def toString = "Stage " + id // + ": [RDD = " + rdd.id + ", isShuffle = " + isShuffleMap + "]"
+ override def toString = "Stage " + id
override def hashCode(): Int = id
}
diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
index fa4de15d0d..9fcef86e46 100644
--- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
+++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
@@ -12,7 +12,7 @@ private[spark] trait TaskSchedulerListener {
def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit
// A node was lost from the cluster.
- def hostLost(host: String): Unit
+ def executorLost(execId: String): Unit
// The TaskScheduler wants to abort an entire task set.
def taskSetFailed(taskSet: TaskSet, reason: String): Unit
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index 20f6e65020..1e4fbdb874 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -27,19 +27,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String]
- val taskIdToSlaveId = new HashMap[Long, String]
+ val taskIdToExecutorId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
// Incrementing Mesos task IDs
val nextTaskId = new AtomicLong(0)
- // Which hosts in the cluster are alive (contains hostnames)
- val hostsAlive = new HashSet[String]
+ // Which executor IDs we have executors on
+ val activeExecutorIds = new HashSet[String]
- // Which slave IDs we have executors on
- val slaveIdsWithExecutors = new HashSet[String]
+ // The set of executors we have on each host; this is used to compute hostsAlive, which
+ // in turn is used to decide when we can attain data locality on a given host
+ val executorsByHost = new HashMap[String, HashSet[String]]
- val slaveIdToHost = new HashMap[String, String]
+ val executorIdToHost = new HashMap[String, String]
// JAR server, if any JARs were added by the user to the SparkContext
var jarServer: HttpServer = null
@@ -85,7 +86,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
- def submitTasks(taskSet: TaskSet) {
+ override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
@@ -102,7 +103,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
activeTaskSets -= manager.taskSet.id
activeTaskSetsQueue -= manager
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
- taskIdToSlaveId --= taskSetTaskIds(manager.taskSet.id)
+ taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
taskSetTaskIds.remove(manager.taskSet.id)
}
}
@@ -117,8 +118,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
SparkEnv.set(sc.env)
// Mark each slave as alive and remember its hostname
for (o <- offers) {
- slaveIdToHost(o.slaveId) = o.hostname
- hostsAlive += o.hostname
+ executorIdToHost(o.executorId) = o.hostname
}
// Build a list of tasks to assign to each slave
val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
@@ -128,16 +128,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
do {
launchedTask = false
for (i <- 0 until offers.size) {
- val sid = offers(i).slaveId
+ val execId = offers(i).executorId
val host = offers(i).hostname
- manager.slaveOffer(sid, host, availableCpus(i)) match {
+ manager.slaveOffer(execId, host, availableCpus(i)) match {
case Some(task) =>
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetId(tid) = manager.taskSet.id
taskSetTaskIds(manager.taskSet.id) += tid
- taskIdToSlaveId(tid) = sid
- slaveIdsWithExecutors += sid
+ taskIdToExecutorId(tid) = execId
+ activeExecutorIds += execId
+ if (!executorsByHost.contains(host)) {
+ executorsByHost(host) = new HashSet()
+ }
+ executorsByHost(host) += execId
availableCpus(i) -= 1
launchedTask = true
@@ -152,25 +156,21 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
var taskSetToUpdate: Option[TaskSetManager] = None
- var failedHost: Option[String] = None
+ var failedExecutor: Option[String] = None
var taskFailed = false
synchronized {
try {
- if (state == TaskState.LOST && taskIdToSlaveId.contains(tid)) {
- // We lost the executor on this slave, so remember that it's gone
- val slaveId = taskIdToSlaveId(tid)
- val host = slaveIdToHost(slaveId)
- if (hostsAlive.contains(host)) {
- slaveIdsWithExecutors -= slaveId
- hostsAlive -= host
- activeTaskSetsQueue.foreach(_.hostLost(host))
- failedHost = Some(host)
+ if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
+ // We lost this entire executor, so remember that it's gone
+ val execId = taskIdToExecutorId(tid)
+ if (activeExecutorIds.contains(execId)) {
+ removeExecutor(execId)
+ failedExecutor = Some(execId)
}
}
taskIdToTaskSetId.get(tid) match {
case Some(taskSetId) =>
if (activeTaskSets.contains(taskSetId)) {
- //activeTaskSets(taskSetId).statusUpdate(status)
taskSetToUpdate = Some(activeTaskSets(taskSetId))
}
if (TaskState.isFinished(state)) {
@@ -178,7 +178,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (taskSetTaskIds.contains(taskSetId)) {
taskSetTaskIds(taskSetId) -= tid
}
- taskIdToSlaveId.remove(tid)
+ taskIdToExecutorId.remove(tid)
}
if (state == TaskState.FAILED) {
taskFailed = true
@@ -190,12 +190,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
case e: Exception => logError("Exception in statusUpdate", e)
}
}
- // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock
+ // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock
if (taskSetToUpdate != None) {
taskSetToUpdate.get.statusUpdate(tid, state, serializedData)
}
- if (failedHost != None) {
- listener.hostLost(failedHost.get)
+ if (failedExecutor != None) {
+ listener.executorLost(failedExecutor.get)
backend.reviveOffers()
}
if (taskFailed) {
@@ -249,27 +249,42 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
- def slaveLost(slaveId: String, reason: ExecutorLossReason) {
- var failedHost: Option[String] = None
+ def executorLost(executorId: String, reason: ExecutorLossReason) {
+ var failedExecutor: Option[String] = None
synchronized {
- val host = slaveIdToHost(slaveId)
- if (hostsAlive.contains(host)) {
- logError("Lost an executor on " + host + ": " + reason)
- slaveIdsWithExecutors -= slaveId
- hostsAlive -= host
- activeTaskSetsQueue.foreach(_.hostLost(host))
- failedHost = Some(host)
+ if (activeExecutorIds.contains(executorId)) {
+ val host = executorIdToHost(executorId)
+ logError("Lost executor %s on %s: %s".format(executorId, host, reason))
+ removeExecutor(executorId)
+ failedExecutor = Some(executorId)
} else {
- // We may get multiple slaveLost() calls with different loss reasons. For example, one
- // may be triggered by a dropped connection from the slave while another may be a report
- // of executor termination from Mesos. We produce log messages for both so we eventually
- // report the termination reason.
- logError("Lost an executor on " + host + " (already removed): " + reason)
+ // We may get multiple executorLost() calls with different loss reasons. For example, one
+ // may be triggered by a dropped connection from the slave while another may be a report
+ // of executor termination from Mesos. We produce log messages for both so we eventually
+ // report the termination reason.
+ logError("Lost an executor " + executorId + " (already removed): " + reason)
}
}
- if (failedHost != None) {
- listener.hostLost(failedHost.get)
+ // Call listener.executorLost without holding the lock on this to prevent deadlock
+ if (failedExecutor != None) {
+ listener.executorLost(failedExecutor.get)
backend.reviveOffers()
}
}
+
+ /** Get a list of hosts that currently have executors */
+ def hostsAlive: scala.collection.Set[String] = executorsByHost.keySet
+
+ /** Remove an executor from all our data structures and mark it as lost */
+ private def removeExecutor(executorId: String) {
+ activeExecutorIds -= executorId
+ val host = executorIdToHost(executorId)
+ val execs = executorsByHost.getOrElse(host, new HashSet)
+ execs -= executorId
+ if (execs.isEmpty) {
+ executorsByHost -= host
+ }
+ executorIdToHost -= executorId
+ activeTaskSetsQueue.foreach(_.executorLost(executorId, host))
+ }
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
index ddcd64d7c6..9ac875de3a 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
@@ -1,5 +1,7 @@
package spark.scheduler.cluster
+import spark.Utils
+
/**
* A backend interface for cluster scheduling systems that allows plugging in different ones under
* ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
@@ -11,5 +13,15 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int
+ // Memory used by each executor (in megabytes)
+ protected val executorMemory = {
+ // TODO: Might need to add some extra memory for the non-heap parts of the JVM
+ Option(System.getProperty("spark.executor.memory"))
+ .orElse(Option(System.getenv("SPARK_MEM")))
+ .map(Utils.memoryStringToMb)
+ .getOrElse(512)
+ }
+
+
// TODO: Probably want to add a killTask too
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala b/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala
deleted file mode 100644
index 96ebaa4601..0000000000
--- a/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala
+++ /dev/null
@@ -1,4 +0,0 @@
-package spark.scheduler.cluster
-
-private[spark]
-class SlaveResources(val slaveId: String, val hostname: String, val coresFree: Int) {}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index e2301347e5..59ff8bcb90 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -19,34 +19,25 @@ private[spark] class SparkDeploySchedulerBackend(
var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _
val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt
- val executorIdToSlaveId = new HashMap[String, String]
-
- // Memory used by each executor (in megabytes)
- val executorMemory = {
- if (System.getenv("SPARK_MEM") != null) {
- Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
- // TODO: Might need to add some extra memory for the non-heap parts of the JVM
- } else {
- 512
- }
- }
override def start() {
super.start()
- val masterUrl = "akka://spark@%s:%s/user/%s".format(
- System.getProperty("spark.master.host"), System.getProperty("spark.master.port"),
+ // The endpoint for executors to talk to us
+ val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
StandaloneSchedulerBackend.ACTOR_NAME)
- val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}")
+ val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
- val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command)
+ val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone"))
+ val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome)
client = new Client(sc.env.actorSystem, master, jobDesc, this)
client.start()
}
override def stop() {
- stopping = true;
+ stopping = true
super.stop()
client.stop()
if (shutdownCallback != null) {
@@ -54,35 +45,28 @@ private[spark] class SparkDeploySchedulerBackend(
}
}
- def connected(jobId: String) {
+ override def connected(jobId: String) {
logInfo("Connected to Spark cluster with job ID " + jobId)
}
- def disconnected() {
+ override def disconnected() {
if (!stopping) {
logError("Disconnected from Spark cluster!")
scheduler.error("Disconnected from Spark cluster")
}
}
- def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) {
- executorIdToSlaveId += id -> workerId
+ override def executorAdded(executorId: String, workerId: String, host: String, cores: Int, memory: Int) {
logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format(
- id, host, cores, Utils.memoryMegabytesToString(memory)))
+ executorId, host, cores, Utils.memoryMegabytesToString(memory)))
}
- def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {
+ override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) {
val reason: ExecutorLossReason = exitStatus match {
case Some(code) => ExecutorExited(code)
case None => SlaveLost(message)
}
- logInfo("Executor %s removed: %s".format(id, message))
- executorIdToSlaveId.get(id) match {
- case Some(slaveId) =>
- executorIdToSlaveId.remove(id)
- scheduler.slaveLost(slaveId, reason)
- case None =>
- logInfo("No slave ID known for executor %s".format(id))
- }
+ logInfo("Executor %s removed: %s".format(executorId, message))
+ scheduler.executorLost(executorId, reason)
}
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
index 1386cd9d44..da7dcf4b6b 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
@@ -6,32 +6,34 @@ import spark.util.SerializableBuffer
private[spark] sealed trait StandaloneClusterMessage extends Serializable
-// Master to slaves
+// Driver to executors
private[spark]
case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage
private[spark]
-case class RegisteredSlave(sparkProperties: Seq[(String, String)]) extends StandaloneClusterMessage
+case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
+ extends StandaloneClusterMessage
private[spark]
-case class RegisterSlaveFailed(message: String) extends StandaloneClusterMessage
+case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage
-// Slaves to master
+// Executors to driver
private[spark]
-case class RegisterSlave(slaveId: String, host: String, cores: Int) extends StandaloneClusterMessage
+case class RegisterExecutor(executorId: String, host: String, cores: Int)
+ extends StandaloneClusterMessage
private[spark]
-case class StatusUpdate(slaveId: String, taskId: Long, state: TaskState, data: SerializableBuffer)
+case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer)
extends StandaloneClusterMessage
private[spark]
object StatusUpdate {
/** Alternate factory method that takes a ByteBuffer directly for the data field */
- def apply(slaveId: String, taskId: Long, state: TaskState, data: ByteBuffer): StatusUpdate = {
- StatusUpdate(slaveId, taskId, state, new SerializableBuffer(data))
+ def apply(executorId: String, taskId: Long, state: TaskState, data: ByteBuffer): StatusUpdate = {
+ StatusUpdate(executorId, taskId, state, new SerializableBuffer(data))
}
}
-// Internal messages in master
+// Internal messages in driver
private[spark] case object ReviveOffers extends StandaloneClusterMessage
-private[spark] case object StopMaster extends StandaloneClusterMessage
+private[spark] case object StopDriver extends StandaloneClusterMessage
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
index eeaae23dc8..082022be1c 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -23,13 +23,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
var totalCoreCount = new AtomicInteger(0)
- class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor {
- val slaveActor = new HashMap[String, ActorRef]
- val slaveAddress = new HashMap[String, Address]
- val slaveHost = new HashMap[String, String]
+ class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor {
+ val executorActor = new HashMap[String, ActorRef]
+ val executorAddress = new HashMap[String, Address]
+ val executorHost = new HashMap[String, String]
val freeCores = new HashMap[String, Int]
- val actorToSlaveId = new HashMap[ActorRef, String]
- val addressToSlaveId = new HashMap[Address, String]
+ val actorToExecutorId = new HashMap[ActorRef, String]
+ val addressToExecutorId = new HashMap[Address, String]
override def preStart() {
// Listen for remote client disconnection events, since they don't go through Akka's watch()
@@ -37,86 +37,86 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
def receive = {
- case RegisterSlave(slaveId, host, cores) =>
- if (slaveActor.contains(slaveId)) {
- sender ! RegisterSlaveFailed("Duplicate slave ID: " + slaveId)
+ case RegisterExecutor(executorId, host, cores) =>
+ if (executorActor.contains(executorId)) {
+ sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId)
} else {
- logInfo("Registered slave: " + sender + " with ID " + slaveId)
- sender ! RegisteredSlave(sparkProperties)
+ logInfo("Registered executor: " + sender + " with ID " + executorId)
+ sender ! RegisteredExecutor(sparkProperties)
context.watch(sender)
- slaveActor(slaveId) = sender
- slaveHost(slaveId) = host
- freeCores(slaveId) = cores
- slaveAddress(slaveId) = sender.path.address
- actorToSlaveId(sender) = slaveId
- addressToSlaveId(sender.path.address) = slaveId
+ executorActor(executorId) = sender
+ executorHost(executorId) = host
+ freeCores(executorId) = cores
+ executorAddress(executorId) = sender.path.address
+ actorToExecutorId(sender) = executorId
+ addressToExecutorId(sender.path.address) = executorId
totalCoreCount.addAndGet(cores)
makeOffers()
}
- case StatusUpdate(slaveId, taskId, state, data) =>
+ case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
- freeCores(slaveId) += 1
- makeOffers(slaveId)
+ freeCores(executorId) += 1
+ makeOffers(executorId)
}
case ReviveOffers =>
makeOffers()
- case StopMaster =>
+ case StopDriver =>
sender ! true
context.stop(self)
case Terminated(actor) =>
- actorToSlaveId.get(actor).foreach(removeSlave(_, "Akka actor terminated"))
+ actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated"))
case RemoteClientDisconnected(transport, address) =>
- addressToSlaveId.get(address).foreach(removeSlave(_, "remote Akka client disconnected"))
+ addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected"))
case RemoteClientShutdown(transport, address) =>
- addressToSlaveId.get(address).foreach(removeSlave(_, "remote Akka client shutdown"))
+ addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown"))
}
- // Make fake resource offers on all slaves
+ // Make fake resource offers on all executors
def makeOffers() {
launchTasks(scheduler.resourceOffers(
- slaveHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
+ executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
}
- // Make fake resource offers on just one slave
- def makeOffers(slaveId: String) {
+ // Make fake resource offers on just one executor
+ def makeOffers(executorId: String) {
launchTasks(scheduler.resourceOffers(
- Seq(new WorkerOffer(slaveId, slaveHost(slaveId), freeCores(slaveId)))))
+ Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId)))))
}
// Launch tasks returned by a set of resource offers
def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
for (task <- tasks.flatten) {
- freeCores(task.slaveId) -= 1
- slaveActor(task.slaveId) ! LaunchTask(task)
+ freeCores(task.executorId) -= 1
+ executorActor(task.executorId) ! LaunchTask(task)
}
}
// Remove a disconnected slave from the cluster
- def removeSlave(slaveId: String, reason: String) {
- logInfo("Slave " + slaveId + " disconnected, so removing it")
- val numCores = freeCores(slaveId)
- actorToSlaveId -= slaveActor(slaveId)
- addressToSlaveId -= slaveAddress(slaveId)
- slaveActor -= slaveId
- slaveHost -= slaveId
- freeCores -= slaveId
- slaveHost -= slaveId
+ def removeExecutor(executorId: String, reason: String) {
+ logInfo("Slave " + executorId + " disconnected, so removing it")
+ val numCores = freeCores(executorId)
+ actorToExecutorId -= executorActor(executorId)
+ addressToExecutorId -= executorAddress(executorId)
+ executorActor -= executorId
+ executorHost -= executorId
+ freeCores -= executorId
+ executorHost -= executorId
totalCoreCount.addAndGet(-numCores)
- scheduler.slaveLost(slaveId, SlaveLost(reason))
+ scheduler.executorLost(executorId, SlaveLost(reason))
}
}
- var masterActor: ActorRef = null
+ var driverActor: ActorRef = null
val taskIdsOnSlave = new HashMap[String, HashSet[String]]
- def start() {
+ override def start() {
val properties = new ArrayBuffer[(String, String)]
val iterator = System.getProperties.entrySet.iterator
while (iterator.hasNext) {
@@ -126,15 +126,15 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
properties += ((key, value))
}
}
- masterActor = actorSystem.actorOf(
- Props(new MasterActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
+ driverActor = actorSystem.actorOf(
+ Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
}
- def stop() {
+ override def stop() {
try {
- if (masterActor != null) {
+ if (driverActor != null) {
val timeout = 5.seconds
- val future = masterActor.ask(StopMaster)(timeout)
+ val future = driverActor.ask(StopDriver)(timeout)
Await.result(future, timeout)
}
} catch {
@@ -143,11 +143,11 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
}
- def reviveOffers() {
- masterActor ! ReviveOffers
+ override def reviveOffers() {
+ driverActor ! ReviveOffers
}
- def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2)
+ override def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2)
}
private[spark] object StandaloneSchedulerBackend {
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala
index aa097fd3a2..b41e951be9 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala
@@ -5,7 +5,7 @@ import spark.util.SerializableBuffer
private[spark] class TaskDescription(
val taskId: Long,
- val slaveId: String,
+ val executorId: String,
val name: String,
_serializedTask: ByteBuffer)
extends Serializable {
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
index ca84503780..0f975ce1eb 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
@@ -4,7 +4,12 @@ package spark.scheduler.cluster
* Information about a running task attempt inside a TaskSet.
*/
private[spark]
-class TaskInfo(val taskId: Long, val index: Int, val launchTime: Long, val host: String) {
+class TaskInfo(
+ val taskId: Long,
+ val index: Int,
+ val launchTime: Long,
+ val executorId: String,
+ val host: String) {
var finishTime: Long = 0
var failed = false
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index a089b71644..3dabdd76b1 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -17,10 +17,7 @@ import java.nio.ByteBuffer
/**
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
*/
-private[spark] class TaskSetManager(
- sched: ClusterScheduler,
- val taskSet: TaskSet)
- extends Logging {
+private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Logging {
// Maximum time to wait to run a task in a preferred location (in ms)
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
@@ -100,7 +97,7 @@ private[spark] class TaskSetManager(
}
// Add a task to all the pending-task lists that it should be on.
- def addPendingTask(index: Int) {
+ private def addPendingTask(index: Int) {
val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
if (locations.size == 0) {
pendingTasksWithNoPrefs += index
@@ -115,7 +112,7 @@ private[spark] class TaskSetManager(
// Return the pending tasks list for a given host, or an empty list if
// there is no map entry for that host
- def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
+ private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
pendingTasksForHost.getOrElse(host, ArrayBuffer())
}
@@ -123,7 +120,7 @@ private[spark] class TaskSetManager(
// Return None if the list is empty.
// This method also cleans up any tasks in the list that have already
// been launched, since we want that to happen lazily.
- def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
+ private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
@@ -137,11 +134,12 @@ private[spark] class TaskSetManager(
// Return a speculative task for a given host if any are available. The task should not have an
// attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
// task must have a preference for this host (or no preferred locations at all).
- def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
+ private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
+ val hostsAlive = sched.hostsAlive
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
val localTask = speculatableTasks.find {
index =>
- val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
+ val locations = tasks(index).preferredLocations.toSet & hostsAlive
val attemptLocs = taskAttempts(index).map(_.host)
(locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host)
}
@@ -161,7 +159,7 @@ private[spark] class TaskSetManager(
// Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well.
- def findTask(host: String, localOnly: Boolean): Option[Int] = {
+ private def findTask(host: String, localOnly: Boolean): Option[Int] = {
val localTask = findTaskFromList(getPendingTasksForHost(host))
if (localTask != None) {
return localTask
@@ -183,13 +181,13 @@ private[spark] class TaskSetManager(
// Does a host count as a preferred location for a task? This is true if
// either the task has preferred locations and this host is one, or it has
// no preferred locations (in which we still count the launch as preferred).
- def isPreferredLocation(task: Task[_], host: String): Boolean = {
+ private def isPreferredLocation(task: Task[_], host: String): Boolean = {
val locs = task.preferredLocations
return (locs.contains(host) || locs.isEmpty)
}
// Respond to an offer of a single slave from the scheduler by finding a task
- def slaveOffer(slaveId: String, host: String, availableCpus: Double): Option[TaskDescription] = {
+ def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = {
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
val time = System.currentTimeMillis
val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
@@ -206,11 +204,11 @@ private[spark] class TaskSetManager(
} else {
"non-preferred, not one of " + task.preferredLocations.mkString(", ")
}
- logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
- taskSet.id, index, taskId, slaveId, host, prefStr))
+ logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
+ taskSet.id, index, taskId, execId, host, prefStr))
// Do various bookkeeping
copiesRunning(index) += 1
- val info = new TaskInfo(taskId, index, time, host)
+ val info = new TaskInfo(taskId, index, time, execId, host)
taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index)
if (preferred) {
@@ -224,7 +222,7 @@ private[spark] class TaskSetManager(
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index)
- return Some(new TaskDescription(taskId, slaveId, taskName, serializedTask))
+ return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
}
case _ =>
}
@@ -334,7 +332,7 @@ private[spark] class TaskSetManager(
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format(
taskSet.id, index, MAX_TASK_FAILURES))
- abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES))
+ abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
}
}
} else {
@@ -356,19 +354,22 @@ private[spark] class TaskSetManager(
sched.taskSetFinished(this)
}
- def hostLost(hostname: String) {
- logInfo("Re-queueing tasks for " + hostname + " from TaskSet " + taskSet.id)
- // If some task has preferred locations only on hostname, put it in the no-prefs list
- // to avoid the wait from delay scheduling
- for (index <- getPendingTasksForHost(hostname)) {
- val newLocs = tasks(index).preferredLocations.toSet & sched.hostsAlive
- if (newLocs.isEmpty) {
- pendingTasksWithNoPrefs += index
+ def executorLost(execId: String, hostname: String) {
+ logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
+ val newHostsAlive = sched.hostsAlive
+ // If some task has preferred locations only on hostname, and there are no more executors there,
+ // put it in the no-prefs list to avoid the wait from delay scheduling
+ if (!newHostsAlive.contains(hostname)) {
+ for (index <- getPendingTasksForHost(hostname)) {
+ val newLocs = tasks(index).preferredLocations.toSet & newHostsAlive
+ if (newLocs.isEmpty) {
+ pendingTasksWithNoPrefs += index
+ }
}
}
- // Re-enqueue any tasks that ran on the failed host if this is a shuffle map stage
+ // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
if (tasks(0).isInstanceOf[ShuffleMapTask]) {
- for ((tid, info) <- taskInfos if info.host == hostname) {
+ for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
if (finished(index)) {
finished(index) = false
@@ -382,7 +383,7 @@ private[spark] class TaskSetManager(
}
}
// Also re-enqueue any tasks that were running on the node
- for ((tid, info) <- taskInfos if info.running && info.host == hostname) {
+ for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
taskLost(tid, TaskState.KILLED, null)
}
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
index 6b919d68b2..3c3afcbb14 100644
--- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
@@ -1,8 +1,8 @@
package spark.scheduler.cluster
/**
- * Represents free resources available on a worker node.
+ * Represents free resources available on an executor.
*/
private[spark]
-class WorkerOffer(val slaveId: String, val hostname: String, val cores: Int) {
+class WorkerOffer(val executorId: String, val hostname: String, val cores: Int) {
}
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index dff550036d..482d1cc853 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -20,7 +20,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
with Logging {
var attemptId = new AtomicInteger(0)
- var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
+ var threadPool = Utils.newDaemonFixedThreadPool(threads)
val env = SparkEnv.get
var listener: TaskSchedulerListener = null
@@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
}
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
- logInfo("Running task " + idInJob)
+ logInfo("Running " + task)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
try {
@@ -80,7 +80,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
val resultToReturn = ser.deserialize[Any](ser.serialize(result))
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
- logInfo("Finished task " + idInJob)
+ logInfo("Finished " + task)
// If the threadpool has not already been shutdown, notify DAGScheduler
if (!Thread.currentThread().isInterrupted)
@@ -116,16 +116,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
- val url = new File(".", localName).toURI.toURL
+ val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
if (!classLoader.getURLs.contains(url)) {
logInfo("Adding " + url + " to class loader")
classLoader.addURL(url)
diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
index c45c7df69c..b481ec0a72 100644
--- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
@@ -35,16 +35,6 @@ private[spark] class CoarseMesosSchedulerBackend(
val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures
- // Memory used by each executor (in megabytes)
- val executorMemory = {
- if (System.getenv("SPARK_MEM") != null) {
- Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
- // TODO: Might need to add some extra memory for the non-heap parts of the JVM
- } else {
- 512
- }
- }
-
// Lock used to wait for scheduler to be registered
var isRegistered = false
val registeredLock = new Object()
@@ -64,13 +54,9 @@ private[spark] class CoarseMesosSchedulerBackend(
val taskIdToSlaveId = new HashMap[Int, String]
val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed
- val sparkHome = sc.getSparkHome() match {
- case Some(path) =>
- path
- case None =>
- throw new SparkException("Spark home is not set; set it through the spark.home system " +
- "property, the SPARK_HOME environment variable or the SparkContext constructor")
- }
+ val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException(
+ "Spark home is not set; set it through the spark.home system " +
+ "property, the SPARK_HOME environment variable or the SparkContext constructor"))
val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt
@@ -108,11 +94,11 @@ private[spark] class CoarseMesosSchedulerBackend(
def createCommand(offer: Offer, numCores: Int): CommandInfo = {
val runScript = new File(sparkHome, "run").getCanonicalPath
- val masterUrl = "akka://spark@%s:%s/user/%s".format(
- System.getProperty("spark.master.host"), System.getProperty("spark.master.port"),
+ val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
StandaloneSchedulerBackend.ACTOR_NAME)
val command = "\"%s\" spark.executor.StandaloneExecutorBackend %s %s %s %d".format(
- runScript, masterUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)
+ runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)
val environment = Environment.newBuilder()
sc.executorEnvs.foreach { case (key, value) =>
environment.addVariables(Environment.Variable.newBuilder()
@@ -184,7 +170,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/** Helper function to pull out a resource from a Mesos Resources protobuf */
- def getResource(res: JList[Resource], name: String): Double = {
+ private def getResource(res: JList[Resource], name: String): Double = {
for (r <- res if r.getName == name) {
return r.getScalar.getValue
}
@@ -193,7 +179,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/** Build a Mesos resource protobuf object */
- def createResource(resourceName: String, quantity: Double): Protos.Resource = {
+ private def createResource(resourceName: String, quantity: Double): Protos.Resource = {
Resource.newBuilder()
.setName(resourceName)
.setType(Value.Type.SCALAR)
@@ -202,7 +188,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/** Check whether a Mesos task state represents a finished task */
- def isFinished(state: MesosTaskState) = {
+ private def isFinished(state: MesosTaskState) = {
state == MesosTaskState.TASK_FINISHED ||
state == MesosTaskState.TASK_FAILED ||
state == MesosTaskState.TASK_KILLED ||
diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
index 8c7a1dfbc0..300766d0f5 100644
--- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
@@ -29,16 +29,6 @@ private[spark] class MesosSchedulerBackend(
with MScheduler
with Logging {
- // Memory used by each executor (in megabytes)
- val EXECUTOR_MEMORY = {
- if (System.getenv("SPARK_MEM") != null) {
- Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
- // TODO: Might need to add some extra memory for the non-heap parts of the JVM
- } else {
- 512
- }
- }
-
// Lock used to wait for scheduler to be registered
var isRegistered = false
val registeredLock = new Object()
@@ -51,7 +41,7 @@ private[spark] class MesosSchedulerBackend(
val taskIdToSlaveId = new HashMap[Long, String]
// An ExecutorInfo for our tasks
- var executorInfo: ExecutorInfo = null
+ var execArgs: Array[Byte] = null
override def start() {
synchronized {
@@ -70,19 +60,14 @@ private[spark] class MesosSchedulerBackend(
}
}.start()
- executorInfo = createExecutorInfo()
waitForRegister()
}
}
- def createExecutorInfo(): ExecutorInfo = {
- val sparkHome = sc.getSparkHome() match {
- case Some(path) =>
- path
- case None =>
- throw new SparkException("Spark home is not set; set it through the spark.home system " +
- "property, the SPARK_HOME environment variable or the SparkContext constructor")
- }
+ def createExecutorInfo(execId: String): ExecutorInfo = {
+ val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException(
+ "Spark home is not set; set it through the spark.home system " +
+ "property, the SPARK_HOME environment variable or the SparkContext constructor"))
val execScript = new File(sparkHome, "spark-executor").getCanonicalPath
val environment = Environment.newBuilder()
sc.executorEnvs.foreach { case (key, value) =>
@@ -94,14 +79,14 @@ private[spark] class MesosSchedulerBackend(
val memory = Resource.newBuilder()
.setName("mem")
.setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(EXECUTOR_MEMORY).build())
+ .setScalar(Value.Scalar.newBuilder().setValue(executorMemory).build())
.build()
val command = CommandInfo.newBuilder()
.setValue(execScript)
.setEnvironment(environment)
.build()
ExecutorInfo.newBuilder()
- .setExecutorId(ExecutorID.newBuilder().setValue("default").build())
+ .setExecutorId(ExecutorID.newBuilder().setValue(execId).build())
.setCommand(command)
.setData(ByteString.copyFrom(createExecArg()))
.addResources(memory)
@@ -113,17 +98,20 @@ private[spark] class MesosSchedulerBackend(
* containing all the spark.* system properties in the form of (String, String) pairs.
*/
private def createExecArg(): Array[Byte] = {
- val props = new HashMap[String, String]
- val iterator = System.getProperties.entrySet.iterator
- while (iterator.hasNext) {
- val entry = iterator.next
- val (key, value) = (entry.getKey.toString, entry.getValue.toString)
- if (key.startsWith("spark.")) {
- props(key) = value
+ if (execArgs == null) {
+ val props = new HashMap[String, String]
+ val iterator = System.getProperties.entrySet.iterator
+ while (iterator.hasNext) {
+ val entry = iterator.next
+ val (key, value) = (entry.getKey.toString, entry.getValue.toString)
+ if (key.startsWith("spark.")) {
+ props(key) = value
+ }
}
+ // Serialize the map as an array of (String, String) pairs
+ execArgs = Utils.serialize(props.toArray)
}
- // Serialize the map as an array of (String, String) pairs
- return Utils.serialize(props.toArray)
+ return execArgs
}
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
@@ -163,7 +151,7 @@ private[spark] class MesosSchedulerBackend(
def enoughMemory(o: Offer) = {
val mem = getResource(o.getResourcesList, "mem")
val slaveId = o.getSlaveId.getValue
- mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId)
+ mem >= executorMemory || slaveIdsWithExecutors.contains(slaveId)
}
for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) {
@@ -220,7 +208,7 @@ private[spark] class MesosSchedulerBackend(
return MesosTaskInfo.newBuilder()
.setTaskId(taskId)
.setSlaveId(SlaveID.newBuilder().setValue(slaveId).build())
- .setExecutor(executorInfo)
+ .setExecutor(createExecutorInfo(slaveId))
.setName(task.name)
.addResources(cpuResource)
.setData(ByteString.copyFrom(task.serializedTask))
@@ -272,7 +260,7 @@ private[spark] class MesosSchedulerBackend(
synchronized {
slaveIdsWithExecutors -= slaveId.getValue
}
- scheduler.slaveLost(slaveId.getValue, reason)
+ scheduler.executorLost(slaveId.getValue, reason)
}
override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) {
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 7a8ac10cdd..9893e9625d 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -16,7 +16,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-import spark.{CacheTracker, Logging, SizeEstimator, SparkEnv, SparkException, Utils}
+import spark.{Logging, SizeEstimator, SparkEnv, SparkException, Utils}
import spark.network._
import spark.serializer.Serializer
import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap}
@@ -30,6 +30,7 @@ extends Exception(message)
private[spark]
class BlockManager(
+ executorId: String,
actorSystem: ActorSystem,
val master: BlockManagerMaster,
val serializer: Serializer,
@@ -68,11 +69,8 @@ class BlockManager(
val connectionManager = new ConnectionManager(0)
implicit val futureExecContext = connectionManager.futureExecContext
- val connectionManagerId = connectionManager.id
- val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port)
-
- // TODO: This will be removed after cacheTracker is removed from the code base.
- var cacheTracker: CacheTracker = null
+ val blockManagerId = BlockManagerId(
+ executorId, connectionManager.id.host, connectionManager.id.port)
// Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
// for receiving shuffle outputs)
@@ -93,7 +91,10 @@ class BlockManager(
val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
- @volatile private var shuttingDown = false
+ // Pending reregistration action being executed asynchronously or null if none
+ // is pending. Accesses should synchronize on asyncReregisterLock.
+ var asyncReregisterTask: Future[Unit] = null
+ val asyncReregisterLock = new Object
private def heartBeat() {
if (!master.sendHeartBeat(blockManagerId)) {
@@ -109,8 +110,9 @@ class BlockManager(
/**
* Construct a BlockManager with a memory limit set based on system properties.
*/
- def this(actorSystem: ActorSystem, master: BlockManagerMaster, serializer: Serializer) = {
- this(actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties)
+ def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster,
+ serializer: Serializer) = {
+ this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties)
}
/**
@@ -150,6 +152,8 @@ class BlockManager(
/**
* Reregister with the master and report all blocks to it. This will be called by the heart beat
* thread if our heartbeat to the block amnager indicates that we were not registered.
+ *
+ * Note that this method must be called without any BlockInfo locks held.
*/
def reregister() {
// TODO: We might need to rate limit reregistering.
@@ -159,6 +163,32 @@ class BlockManager(
}
/**
+ * Reregister with the master sometime soon.
+ */
+ def asyncReregister() {
+ asyncReregisterLock.synchronized {
+ if (asyncReregisterTask == null) {
+ asyncReregisterTask = Future[Unit] {
+ reregister()
+ asyncReregisterLock.synchronized {
+ asyncReregisterTask = null
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * For testing. Wait for any pending asynchronous reregistration; otherwise, do nothing.
+ */
+ def waitForAsyncReregister() {
+ val task = asyncReregisterTask
+ if (task != null) {
+ Await.ready(task, Duration.Inf)
+ }
+ }
+
+ /**
* Get storage level of local block. If no info exists for the block, then returns null.
*/
def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
@@ -173,7 +203,7 @@ class BlockManager(
if (needReregister) {
logInfo("Got told to reregister updating block " + blockId)
// Reregistering will report our new block for free.
- reregister()
+ asyncReregister()
}
logDebug("Told master about block " + blockId)
}
@@ -191,7 +221,7 @@ class BlockManager(
case level =>
val inMem = level.useMemory && memoryStore.contains(blockId)
val onDisk = level.useDisk && diskStore.contains(blockId)
- val storageLevel = new StorageLevel(onDisk, inMem, level.deserialized, level.replication)
+ val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication)
val memSize = if (inMem) memoryStore.getSize(blockId) else 0L
val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L
(storageLevel, memSize, diskSize, info.tellMaster)
@@ -213,7 +243,7 @@ class BlockManager(
val startTimeMs = System.currentTimeMillis
var managers = master.getLocations(blockId)
val locations = managers.map(_.ip)
- logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs))
+ logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs))
return locations
}
@@ -223,7 +253,7 @@ class BlockManager(
def getLocations(blockIds: Array[String]): Array[Seq[String]] = {
val startTimeMs = System.currentTimeMillis
val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray
- logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
+ logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
return locations
}
@@ -615,7 +645,7 @@ class BlockManager(
var size = 0L
myInfo.synchronized {
- logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
if (level.useMemory) {
@@ -647,8 +677,10 @@ class BlockManager(
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
+
// Replicate block if required
if (level.replication > 1) {
+ val remoteStartTime = System.currentTimeMillis
// Serialize the block if not already done
if (bytesAfterPut == null) {
if (valuesAfterPut == null) {
@@ -658,16 +690,10 @@ class BlockManager(
bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
}
replicate(blockId, bytesAfterPut, level)
+ logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime))
}
-
BlockManager.dispose(bytesAfterPut)
- // TODO: This code will be removed when CacheTracker is gone.
- if (blockId.startsWith("rdd")) {
- notifyCacheTracker(blockId)
- }
- logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs))
-
return size
}
@@ -733,11 +759,6 @@ class BlockManager(
}
}
- // TODO: This code will be removed when CacheTracker is gone.
- if (blockId.startsWith("rdd")) {
- notifyCacheTracker(blockId)
- }
-
// If replication had started, then wait for it to finish
if (level.replication > 1) {
if (replicationFuture == null) {
@@ -760,8 +781,7 @@ class BlockManager(
*/
var cachedPeers: Seq[BlockManagerId] = null
private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
- val tLevel: StorageLevel =
- new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
+ val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
if (cachedPeers == null) {
cachedPeers = master.getPeers(blockManagerId, level.replication - 1)
}
@@ -780,16 +800,6 @@ class BlockManager(
}
}
- // TODO: This code will be removed when CacheTracker is gone.
- private def notifyCacheTracker(key: String) {
- if (cacheTracker != null) {
- val rddInfo = key.split("_")
- val rddId: Int = rddInfo(1).toInt
- val partition: Int = rddInfo(2).toInt
- cacheTracker.notifyFromBlockManager(spark.AddedToCache(rddId, partition, host))
- }
- }
-
/**
* Read a block consisting of a single object.
*/
@@ -940,6 +950,7 @@ class BlockManager(
blockInfo.clear()
memoryStore.clear()
diskStore.clear()
+ metadataCleaner.cancel()
logInfo("BlockManager stopped")
}
}
@@ -968,7 +979,7 @@ object BlockManager extends Logging {
*/
def dispose(buffer: ByteBuffer) {
if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) {
- logDebug("Unmapping " + buffer)
+ logTrace("Unmapping " + buffer)
if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) {
buffer.asInstanceOf[DirectBuffer].cleaner().clean()
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala
index 488679f049..f2f1e77d41 100644
--- a/core/src/main/scala/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerId.scala
@@ -3,38 +3,67 @@ package spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.ConcurrentHashMap
+/**
+ * This class represent an unique identifier for a BlockManager.
+ * The first 2 constructors of this class is made private to ensure that
+ * BlockManagerId objects can be created only using the factory method in
+ * [[spark.storage.BlockManager$]]. This allows de-duplication of ID objects.
+ * Also, constructor parameters are private to ensure that parameters cannot
+ * be modified from outside this class.
+ */
+private[spark] class BlockManagerId private (
+ private var executorId_ : String,
+ private var ip_ : String,
+ private var port_ : Int
+ ) extends Externalizable {
-private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
- def this() = this(null, 0) // For deserialization only
+ private def this() = this(null, null, 0) // For deserialization only
- def this(in: ObjectInput) = this(in.readUTF(), in.readInt())
+ def executorId: String = executorId_
+
+ def ip: String = ip_
+
+ def port: Int = port_
override def writeExternal(out: ObjectOutput) {
- out.writeUTF(ip)
- out.writeInt(port)
+ out.writeUTF(executorId_)
+ out.writeUTF(ip_)
+ out.writeInt(port_)
}
override def readExternal(in: ObjectInput) {
- ip = in.readUTF()
- port = in.readInt()
+ executorId_ = in.readUTF()
+ ip_ = in.readUTF()
+ port_ = in.readInt()
}
@throws(classOf[IOException])
private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
- override def toString = "BlockManagerId(" + ip + ", " + port + ")"
+ override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, ip, port)
- override def hashCode = ip.hashCode * 41 + port
+ override def hashCode: Int = (executorId.hashCode * 41 + ip.hashCode) * 41 + port
override def equals(that: Any) = that match {
- case id: BlockManagerId => port == id.port && ip == id.ip
- case _ => false
+ case id: BlockManagerId =>
+ executorId == id.executorId && port == id.port && ip == id.ip
+ case _ =>
+ false
}
}
private[spark] object BlockManagerId {
+ def apply(execId: String, ip: String, port: Int) =
+ getCachedBlockManagerId(new BlockManagerId(execId, ip, port))
+
+ def apply(in: ObjectInput) = {
+ val obj = new BlockManagerId()
+ obj.readExternal(in)
+ getCachedBlockManagerId(obj)
+ }
+
val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()
def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
index a3d8671834..7389bee150 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -1,6 +1,10 @@
package spark.storage
-import scala.collection.mutable.ArrayBuffer
+import java.io._
+import java.util.{HashMap => JHashMap}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.util.Random
import akka.actor.{Actor, ActorRef, ActorSystem, Props}
@@ -11,52 +15,49 @@ import akka.util.duration._
import spark.{Logging, SparkException, Utils}
-
private[spark] class BlockManagerMaster(
val actorSystem: ActorSystem,
- isMaster: Boolean,
+ isDriver: Boolean,
isLocal: Boolean,
- masterIp: String,
- masterPort: Int)
+ driverIp: String,
+ driverPort: Int)
extends Logging {
- val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
+ val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
- val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager"
- val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager"
- val DEFAULT_MANAGER_IP: String = Utils.localHostName()
+ val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager"
val timeout = 10.seconds
- var masterActor: ActorRef = {
- if (isMaster) {
- val masterActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)),
- name = MASTER_AKKA_ACTOR_NAME)
+ var driverActor: ActorRef = {
+ if (isDriver) {
+ val driverActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)),
+ name = DRIVER_AKKA_ACTOR_NAME)
logInfo("Registered BlockManagerMaster Actor")
- masterActor
+ driverActor
} else {
- val url = "akka://spark@%s:%s/user/%s".format(masterIp, masterPort, MASTER_AKKA_ACTOR_NAME)
+ val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, DRIVER_AKKA_ACTOR_NAME)
logInfo("Connecting to BlockManagerMaster: " + url)
actorSystem.actorFor(url)
}
}
- /** Remove a dead host from the master actor. This is only called on the master side. */
- def notifyADeadHost(host: String) {
- tell(RemoveHost(host))
- logInfo("Removed " + host + " successfully in notifyADeadHost")
+ /** Remove a dead executor from the driver actor. This is only called on the driver side. */
+ def removeExecutor(execId: String) {
+ tell(RemoveExecutor(execId))
+ logInfo("Removed " + execId + " successfully in removeExecutor")
}
/**
- * Send the master actor a heart beat from the slave. Returns true if everything works out,
- * false if the master does not know about the given block manager, which means the block
+ * Send the driver actor a heart beat from the slave. Returns true if everything works out,
+ * false if the driver does not know about the given block manager, which means the block
* manager should re-register.
*/
def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = {
- askMasterWithRetry[Boolean](HeartBeat(blockManagerId))
+ askDriverWithReply[Boolean](HeartBeat(blockManagerId))
}
- /** Register the BlockManager's id with the master. */
+ /** Register the BlockManager's id with the driver. */
def registerBlockManager(
blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
logInfo("Trying to register BlockManager")
@@ -70,25 +71,25 @@ private[spark] class BlockManagerMaster(
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long): Boolean = {
- val res = askMasterWithRetry[Boolean](
+ val res = askDriverWithReply[Boolean](
UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize))
logInfo("Updated info of block " + blockId)
res
}
- /** Get locations of the blockId from the master */
+ /** Get locations of the blockId from the driver */
def getLocations(blockId: String): Seq[BlockManagerId] = {
- askMasterWithRetry[Seq[BlockManagerId]](GetLocations(blockId))
+ askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId))
}
- /** Get locations of multiple blockIds from the master */
+ /** Get locations of multiple blockIds from the driver */
def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
- askMasterWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
+ askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
}
- /** Get ids of other nodes in the cluster from the master */
+ /** Get ids of other nodes in the cluster from the driver */
def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = {
- val result = askMasterWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
+ val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
if (result.length != numPeers) {
throw new SparkException(
"Error getting peers, only got " + result.size + " instead of " + numPeers)
@@ -98,10 +99,10 @@ private[spark] class BlockManagerMaster(
/**
* Remove a block from the slaves that have it. This can only be used to remove
- * blocks that the master knows about.
+ * blocks that the driver knows about.
*/
def removeBlock(blockId: String) {
- askMasterWithRetry(RemoveBlock(blockId))
+ askDriverWithReply(RemoveBlock(blockId))
}
/**
@@ -111,41 +112,45 @@ private[spark] class BlockManagerMaster(
* amount of remaining memory.
*/
def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = {
- askMasterWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
+ askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
+ }
+
+ def getStorageStatus: Array[StorageStatus] = {
+ askDriverWithReply[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray
}
- /** Stop the master actor, called only on the Spark master node */
+ /** Stop the driver actor, called only on the Spark driver node */
def stop() {
- if (masterActor != null) {
+ if (driverActor != null) {
tell(StopBlockManagerMaster)
- masterActor = null
+ driverActor = null
logInfo("BlockManagerMaster stopped")
}
}
/** Send a one-way message to the master actor, to which we expect it to reply with true. */
private def tell(message: Any) {
- if (!askMasterWithRetry[Boolean](message)) {
+ if (!askDriverWithReply[Boolean](message)) {
throw new SparkException("BlockManagerMasterActor returned false, expected true.")
}
}
/**
- * Send a message to the master actor and get its result within a default timeout, or
+ * Send a message to the driver actor and get its result within a default timeout, or
* throw a SparkException if this fails.
*/
- private def askMasterWithRetry[T](message: Any): T = {
+ private def askDriverWithReply[T](message: Any): T = {
// TODO: Consider removing multiple attempts
- if (masterActor == null) {
- throw new SparkException("Error sending message to BlockManager as masterActor is null " +
+ if (driverActor == null) {
+ throw new SparkException("Error sending message to BlockManager as driverActor is null " +
"[message = " + message + "]")
}
var attempts = 0
var lastException: Exception = null
- while (attempts < AKKA_RETRY_ATTEMPS) {
+ while (attempts < AKKA_RETRY_ATTEMPTS) {
attempts += 1
try {
- val future = masterActor.ask(message)(timeout)
+ val future = driverActor.ask(message)(timeout)
val result = Await.result(future, timeout)
if (result == null) {
throw new Exception("BlockManagerMaster returned null")
diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
index f4d026da33..2830bc6297 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
@@ -23,9 +23,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
private val blockManagerInfo =
new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo]
- // Mapping from host name to block manager id. We allow multiple block managers
- // on the same host name (ip).
- private val blockManagerIdByHost = new HashMap[String, ArrayBuffer[BlockManagerId]]
+ // Mapping from executor ID to block manager ID.
+ private val blockManagerIdByExecutor = new HashMap[String, BlockManagerId]
// Mapping from block id to the set of block managers that have the block.
private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]]
@@ -68,11 +67,14 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
case GetMemoryStatus =>
getMemoryStatus
+ case GetStorageStatus =>
+ getStorageStatus
+
case RemoveBlock(blockId) =>
removeBlock(blockId)
- case RemoveHost(host) =>
- removeHost(host)
+ case RemoveExecutor(execId) =>
+ removeExecutor(execId)
sender ! true
case StopBlockManagerMaster =>
@@ -96,16 +98,12 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
def removeBlockManager(blockManagerId: BlockManagerId) {
val info = blockManagerInfo(blockManagerId)
- // Remove the block manager from blockManagerIdByHost. If the list of block
- // managers belonging to the IP is empty, remove the entry from the hash map.
- blockManagerIdByHost.get(blockManagerId.ip).foreach { managers: ArrayBuffer[BlockManagerId] =>
- managers -= blockManagerId
- if (managers.size == 0) blockManagerIdByHost.remove(blockManagerId.ip)
- }
+ // Remove the block manager from blockManagerIdByExecutor.
+ blockManagerIdByExecutor -= blockManagerId.executorId
// Remove it from blockManagerInfo and remove all the blocks.
blockManagerInfo.remove(blockManagerId)
- var iterator = info.blocks.keySet.iterator
+ val iterator = info.blocks.keySet.iterator
while (iterator.hasNext) {
val blockId = iterator.next
val locations = blockLocations.get(blockId)._2
@@ -117,7 +115,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
def expireDeadHosts() {
- logDebug("Checking for hosts with no recent heart beats in BlockManagerMaster.")
+ logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.")
val now = System.currentTimeMillis()
val minSeenTime = now - slaveTimeout
val toRemove = new HashSet[BlockManagerId]
@@ -130,17 +128,15 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
toRemove.foreach(removeBlockManager)
}
- def removeHost(host: String) {
- logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.")
- logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq)
- blockManagerIdByHost.get(host).foreach(_.foreach(removeBlockManager))
- logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq)
+ def removeExecutor(execId: String) {
+ logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.")
+ blockManagerIdByExecutor.get(execId).foreach(removeBlockManager)
sender ! true
}
def heartBeat(blockManagerId: BlockManagerId) {
if (!blockManagerInfo.contains(blockManagerId)) {
- if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
+ if (blockManagerId.executorId == "<driver>" && !isLocal) {
sender ! true
} else {
sender ! false
@@ -177,24 +173,28 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
sender ! res
}
- private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " " + blockManagerId + " "
+ private def getStorageStatus() {
+ val res = blockManagerInfo.map { case(blockManagerId, info) =>
+ import collection.JavaConverters._
+ StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap)
+ }
+ sender ! res
+ }
- if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
- logInfo("Got Register Msg from master node, don't register it")
- } else {
- blockManagerIdByHost.get(blockManagerId.ip) match {
- case Some(managers) =>
- // A block manager of the same host name already exists.
- logInfo("Got another registration for host " + blockManagerId)
- managers += blockManagerId
+ private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
+ if (id.executorId == "<driver>" && !isLocal) {
+ // Got a register message from the master node; don't register it
+ } else if (!blockManagerInfo.contains(id)) {
+ blockManagerIdByExecutor.get(id.executorId) match {
+ case Some(manager) =>
+ // A block manager of the same host name already exists
+ logError("Got two different block manager registrations on " + id.executorId)
+ System.exit(1)
case None =>
- blockManagerIdByHost += (blockManagerId.ip -> ArrayBuffer(blockManagerId))
+ blockManagerIdByExecutor(id.executorId) = id
}
-
- blockManagerInfo += (blockManagerId -> new BlockManagerMasterActor.BlockManagerInfo(
- blockManagerId, System.currentTimeMillis(), maxMemSize, slaveActor))
+ blockManagerInfo(id) = new BlockManagerMasterActor.BlockManagerInfo(
+ id, System.currentTimeMillis(), maxMemSize, slaveActor)
}
sender ! true
}
@@ -206,11 +206,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
memSize: Long,
diskSize: Long) {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " " + blockManagerId + " " + blockId + " "
-
if (!blockManagerInfo.contains(blockManagerId)) {
- if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
+ if (blockManagerId.executorId == "<driver>" && !isLocal) {
// We intentionally do not register the master (except in local mode),
// so we should not indicate failure.
sender ! true
@@ -342,8 +339,8 @@ object BlockManagerMasterActor {
_lastSeenMs = System.currentTimeMillis()
}
- def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long)
- : Unit = synchronized {
+ def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long,
+ diskSize: Long) {
updateLastSeenMs()
diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala
index d73a9b790f..1494f90103 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala
@@ -54,11 +54,9 @@ class UpdateBlockInfo(
}
override def readExternal(in: ObjectInput) {
- blockManagerId = new BlockManagerId()
- blockManagerId.readExternal(in)
+ blockManagerId = BlockManagerId(in)
blockId = in.readUTF()
- storageLevel = new StorageLevel()
- storageLevel.readExternal(in)
+ storageLevel = StorageLevel(in)
memSize = in.readInt()
diskSize = in.readInt()
}
@@ -90,7 +88,7 @@ private[spark]
case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster
private[spark]
-case class RemoveHost(host: String) extends ToBlockManagerMaster
+case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
private[spark]
case object StopBlockManagerMaster extends ToBlockManagerMaster
@@ -100,3 +98,6 @@ case object GetMemoryStatus extends ToBlockManagerMaster
private[spark]
case object ExpireDeadHosts extends ToBlockManagerMaster
+
+private[spark]
+case object GetStorageStatus extends ToBlockManagerMaster
diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala
new file mode 100644
index 0000000000..9e6721ec17
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala
@@ -0,0 +1,76 @@
+package spark.storage
+
+import akka.actor.{ActorRef, ActorSystem}
+import akka.util.Timeout
+import akka.util.duration._
+import cc.spray.typeconversion.TwirlSupport._
+import cc.spray.Directives
+import spark.{Logging, SparkContext}
+import spark.util.AkkaUtils
+import spark.Utils
+
+
+/**
+ * Web UI server for the BlockManager inside each SparkContext.
+ */
+private[spark]
+class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, sc: SparkContext)
+ extends Directives with Logging {
+
+ val STATIC_RESOURCE_DIR = "spark/deploy/static"
+
+ implicit val timeout = Timeout(10 seconds)
+
+ /** Start a HTTP server to run the Web interface */
+ def start() {
+ try {
+ val port = if (System.getProperty("spark.ui.port") != null) {
+ System.getProperty("spark.ui.port").toInt
+ } else {
+ // TODO: Unfortunately, it's not possible to pass port 0 to spray and figure out which
+ // random port it bound to, so we have to try to find a local one by creating a socket.
+ Utils.findFreePort()
+ }
+ AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, handler, "BlockManagerHTTPServer")
+ logInfo("Started BlockManager web UI at http://%s:%d".format(Utils.localHostName(), port))
+ } catch {
+ case e: Exception =>
+ logError("Failed to create BlockManager WebUI", e)
+ System.exit(1)
+ }
+ }
+
+ val handler = {
+ get {
+ path("") {
+ completeWith {
+ // Request the current storage status from the Master
+ val storageStatusList = sc.getExecutorStorageStatus
+ // Calculate macro-level statistics
+ val maxMem = storageStatusList.map(_.maxMem).reduce(_+_)
+ val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_)
+ val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize))
+ .reduceOption(_+_).getOrElse(0L)
+ val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc)
+ spark.storage.html.index.
+ render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList)
+ }
+ } ~
+ path("rdd") {
+ parameter("id") { id =>
+ completeWith {
+ val prefix = "rdd_" + id.toString
+ val storageStatusList = sc.getExecutorStorageStatus
+ val filteredStorageStatusList = StorageUtils.
+ filterStorageStatusByPrefix(storageStatusList, prefix)
+ val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
+ spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList)
+ }
+ }
+ } ~
+ pathPrefix("static") {
+ getFromResourceDirectory(STATIC_RESOURCE_DIR)
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala
index 3f234df654..30d7500e01 100644
--- a/core/src/main/scala/spark/storage/BlockMessage.scala
+++ b/core/src/main/scala/spark/storage/BlockMessage.scala
@@ -64,7 +64,7 @@ private[spark] class BlockMessage() {
val booleanInt = buffer.getInt()
val replication = buffer.getInt()
- level = new StorageLevel(booleanInt, replication)
+ level = StorageLevel(booleanInt, replication)
val dataLength = buffer.getInt()
data = ByteBuffer.allocate(dataLength)
diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala
index e3544e5aae..3b5a77ab22 100644
--- a/core/src/main/scala/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/spark/storage/StorageLevel.scala
@@ -7,25 +7,30 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
* whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory
* in a serialized format, and whether to replicate the RDD partitions on multiple nodes.
* The [[spark.storage.StorageLevel$]] singleton object contains some static constants for
- * commonly useful storage levels.
+ * commonly useful storage levels. To create your own storage level object, use the factor method
+ * of the singleton object (`StorageLevel(...)`).
*/
-class StorageLevel(
- var useDisk: Boolean,
- var useMemory: Boolean,
- var deserialized: Boolean,
- var replication: Int = 1)
+class StorageLevel private(
+ private var useDisk_ : Boolean,
+ private var useMemory_ : Boolean,
+ private var deserialized_ : Boolean,
+ private var replication_ : Int = 1)
extends Externalizable {
// TODO: Also add fields for caching priority, dataset ID, and flushing.
-
- assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes")
-
- def this(flags: Int, replication: Int) {
+ private def this(flags: Int, replication: Int) {
this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication)
}
def this() = this(false, true, false) // For deserialization
+ def useDisk = useDisk_
+ def useMemory = useMemory_
+ def deserialized = deserialized_
+ def replication = replication_
+
+ assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes")
+
override def clone(): StorageLevel = new StorageLevel(
this.useDisk, this.useMemory, this.deserialized, this.replication)
@@ -43,13 +48,13 @@ class StorageLevel(
def toInt: Int = {
var ret = 0
- if (useDisk) {
+ if (useDisk_) {
ret |= 4
}
- if (useMemory) {
+ if (useMemory_) {
ret |= 2
}
- if (deserialized) {
+ if (deserialized_) {
ret |= 1
}
return ret
@@ -57,15 +62,15 @@ class StorageLevel(
override def writeExternal(out: ObjectOutput) {
out.writeByte(toInt)
- out.writeByte(replication)
+ out.writeByte(replication_)
}
override def readExternal(in: ObjectInput) {
val flags = in.readByte()
- useDisk = (flags & 4) != 0
- useMemory = (flags & 2) != 0
- deserialized = (flags & 1) != 0
- replication = in.readByte()
+ useDisk_ = (flags & 4) != 0
+ useMemory_ = (flags & 2) != 0
+ deserialized_ = (flags & 1) != 0
+ replication_ = in.readByte()
}
@throws(classOf[IOException])
@@ -75,6 +80,14 @@ class StorageLevel(
"StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
override def hashCode(): Int = toInt * 41 + replication
+ def description : String = {
+ var result = ""
+ result += (if (useDisk) "Disk " else "")
+ result += (if (useMemory) "Memory " else "")
+ result += (if (deserialized) "Deserialized " else "Serialized")
+ result += "%sx Replicated".format(replication)
+ result
+ }
}
@@ -91,6 +104,21 @@ object StorageLevel {
val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false)
val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2)
+ /** Create a new StorageLevel object */
+ def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, replication: Int = 1) =
+ getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication))
+
+ /** Create a new StorageLevel object from its integer representation */
+ def apply(flags: Int, replication: Int) =
+ getCachedStorageLevel(new StorageLevel(flags, replication))
+
+ /** Read StorageLevel object from ObjectInput stream */
+ def apply(in: ObjectInput) = {
+ val obj = new StorageLevel()
+ obj.readExternal(in)
+ getCachedStorageLevel(obj)
+ }
+
private[spark]
val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]()
diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala
new file mode 100644
index 0000000000..5f72b67b2c
--- /dev/null
+++ b/core/src/main/scala/spark/storage/StorageUtils.scala
@@ -0,0 +1,82 @@
+package spark.storage
+
+import spark.{Utils, SparkContext}
+import BlockManagerMasterActor.BlockStatus
+
+private[spark]
+case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
+ blocks: Map[String, BlockStatus]) {
+
+ def memUsed(blockPrefix: String = "") = {
+ blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize).
+ reduceOption(_+_).getOrElse(0l)
+ }
+
+ def diskUsed(blockPrefix: String = "") = {
+ blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize).
+ reduceOption(_+_).getOrElse(0l)
+ }
+
+ def memRemaining : Long = maxMem - memUsed()
+
+}
+
+case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
+ numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) {
+ override def toString = {
+ import Utils.memoryBytesToString
+ "RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id,
+ storageLevel.toString, numCachedPartitions, numPartitions, memoryBytesToString(memSize), memoryBytesToString(diskSize))
+ }
+}
+
+/* Helper methods for storage-related objects */
+private[spark]
+object StorageUtils {
+
+ /* Given the current storage status of the BlockManager, returns information for each RDD */
+ def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus],
+ sc: SparkContext) : Array[RDDInfo] = {
+ rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc)
+ }
+
+ /* Given a list of BlockStatus objets, returns information for each RDD */
+ def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
+ sc: SparkContext) : Array[RDDInfo] = {
+
+ // Group by rddId, ignore the partition name
+ val groupedRddBlocks = infos.groupBy { case(k, v) =>
+ k.substring(0,k.lastIndexOf('_'))
+ }.mapValues(_.values.toArray)
+
+ // For each RDD, generate an RDDInfo object
+ groupedRddBlocks.map { case(rddKey, rddBlocks) =>
+
+ // Add up memory and disk sizes
+ val memSize = rddBlocks.map(_.memSize).reduce(_ + _)
+ val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _)
+
+ // Find the id of the RDD, e.g. rdd_1 => 1
+ val rddId = rddKey.split("_").last.toInt
+ // Get the friendly name for the rdd, if available.
+ val rdd = sc.persistentRdds(rddId)
+ val rddName = Option(rdd.name).getOrElse(rddKey)
+ val rddStorageLevel = rdd.getStorageLevel
+
+ RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.splits.size, memSize, diskSize)
+ }.toArray
+ }
+
+ /* Removes all BlockStatus object that are not part of a block prefix */
+ def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus],
+ prefix: String) : Array[StorageStatus] = {
+
+ storageStatusList.map { status =>
+ val newBlocks = status.blocks.filterKeys(_.startsWith(prefix))
+ //val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _)
+ StorageStatus(status.blockManagerId, status.maxMem, newBlocks)
+ }
+
+ }
+
+}
diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala
index 689f07b969..a70d1c8e78 100644
--- a/core/src/main/scala/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/spark/storage/ThreadingTest.scala
@@ -75,10 +75,11 @@ private[spark] object ThreadingTest {
System.setProperty("spark.kryoserializer.buffer.mb", "1")
val actorSystem = ActorSystem("test")
val serializer = new KryoSerializer
- val masterIp: String = System.getProperty("spark.master.host", "localhost")
- val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt
- val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, masterIp, masterPort)
- val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer, 1024 * 1024)
+ val driverIp: String = System.getProperty("spark.driver.host", "localhost")
+ val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
+ val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, driverIp, driverPort)
+ val blockManager = new BlockManager(
+ "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024)
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
producers.foreach(_.start)
diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala
index e67cb0336d..30aec5a663 100644
--- a/core/src/main/scala/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/spark/util/AkkaUtils.scala
@@ -1,6 +1,6 @@
package spark.util
-import akka.actor.{Props, ActorSystemImpl, ActorSystem}
+import akka.actor.{ActorRef, Props, ActorSystemImpl, ActorSystem}
import com.typesafe.config.ConfigFactory
import akka.util.duration._
import akka.pattern.ask
@@ -18,9 +18,13 @@ import java.util.concurrent.TimeoutException
* Various utility classes for working with Akka.
*/
private[spark] object AkkaUtils {
+
/**
* Creates an ActorSystem ready for remoting, with various Spark features. Returns both the
* ActorSystem itself and its port (which is hard to get from Akka).
+ *
+ * Note: the `name` parameter is important, as even if a client sends a message to right
+ * host + port, if the system name is incorrect, Akka will drop the message.
*/
def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = {
val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
@@ -30,8 +34,10 @@ private[spark] object AkkaUtils {
val akkaConf = ConfigFactory.parseString("""
akka.daemonic = on
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
+ akka.stdout-loglevel = "ERROR"
akka.actor.provider = "akka.remote.RemoteActorRefProvider"
akka.remote.transport = "akka.remote.netty.NettyRemoteTransport"
+ akka.remote.log-remote-lifecycle-events = on
akka.remote.netty.hostname = "%s"
akka.remote.netty.port = %d
akka.remote.netty.connection-timeout = %ds
@@ -40,7 +46,7 @@ private[spark] object AkkaUtils {
akka.actor.default-dispatcher.throughput = %d
""".format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize))
- val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader)
+ val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader)
// Figure out the port number we bound to, in case port was passed as 0. This is a bit of a
// hack because Akka doesn't let you figure out the port through the public API yet.
@@ -51,21 +57,22 @@ private[spark] object AkkaUtils {
/**
* Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to
- * handle requests. Throws a SparkException if this fails.
+ * handle requests. Returns the bound port or throws a SparkException on failure.
*/
- def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route) {
+ def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route,
+ name: String = "HttpServer"): ActorRef = {
val ioWorker = new IoWorker(actorSystem).start()
val httpService = actorSystem.actorOf(Props(new HttpService(route)))
val rootService = actorSystem.actorOf(Props(new SprayCanRootService(httpService)))
val server = actorSystem.actorOf(
- Props(new HttpServer(ioWorker, SingletonHandler(rootService))), name = "HttpServer")
+ Props(new HttpServer(ioWorker, SingletonHandler(rootService))), name = name)
actorSystem.registerOnTermination { ioWorker.stop() }
val timeout = 3.seconds
val future = server.ask(HttpServer.Bind(ip, port))(timeout)
try {
Await.result(future, timeout) match {
case bound: HttpServer.Bound =>
- return
+ return server
case other: Any =>
throw new SparkException("Failed to bind web UI to port " + port + ": " + other)
}
diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala
index 139e21d09e..a342d378ff 100644
--- a/core/src/main/scala/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/spark/util/MetadataCleaner.scala
@@ -5,29 +5,29 @@ import java.util.{TimerTask, Timer}
import spark.Logging
+/**
+ * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries)
+ */
class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging {
+ private val delaySeconds = MetadataCleaner.getDelaySeconds
+ private val periodSeconds = math.max(10, delaySeconds / 10)
+ private val timer = new Timer(name + " cleanup timer", true)
- val delaySeconds = MetadataCleaner.getDelaySeconds
- val periodSeconds = math.max(10, delaySeconds / 10)
- val timer = new Timer(name + " cleanup timer", true)
-
- val task = new TimerTask {
- def run() {
+ private val task = new TimerTask {
+ override def run() {
try {
- if (delaySeconds > 0) {
- cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000))
- logInfo("Ran metadata cleaner for " + name)
- }
+ cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000))
+ logInfo("Ran metadata cleaner for " + name)
} catch {
case e: Exception => logError("Error running cleanup task for " + name, e)
}
}
}
- if (periodSeconds > 0) {
- logInfo(
- "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and "
- + "period of " + periodSeconds + " secs")
+ if (delaySeconds > 0) {
+ logDebug(
+ "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " +
+ "and period of " + periodSeconds + " secs")
timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000)
}
@@ -38,7 +38,7 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging
object MetadataCleaner {
- def getDelaySeconds = (System.getProperty("spark.cleaner.delay", "-100").toDouble * 60).toInt
- def setDelaySeconds(delay: Long) { System.setProperty("spark.cleaner.delay", delay.toString) }
+ def getDelaySeconds = System.getProperty("spark.cleaner.delay", "-1").toInt
+ def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.delay", delay.toString) }
}
diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala
index bb7c5c01c8..188f8910da 100644
--- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala
@@ -63,9 +63,9 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging {
override def empty: Map[A, B] = new TimeStampedHashMap[A, B]()
- override def size(): Int = internalMap.size()
+ override def size: Int = internalMap.size
- override def foreach[U](f: ((A, B)) => U): Unit = {
+ override def foreach[U](f: ((A, B)) => U) {
val iterator = internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
diff --git a/core/src/main/twirl/spark/deploy/common/layout.scala.html b/core/src/main/twirl/spark/common/layout.scala.html
index b9192060aa..b9192060aa 100644
--- a/core/src/main/twirl/spark/deploy/common/layout.scala.html
+++ b/core/src/main/twirl/spark/common/layout.scala.html
diff --git a/core/src/main/twirl/spark/deploy/master/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html
index 18c32e5a1f..285645c389 100644
--- a/core/src/main/twirl/spark/deploy/master/index.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/index.scala.html
@@ -2,7 +2,7 @@
@import spark.deploy.master._
@import spark.Utils
-@spark.deploy.common.html.layout(title = "Spark Master on " + state.uri) {
+@spark.common.html.layout(title = "Spark Master on " + state.uri) {
<!-- Cluster Details -->
<div class="row">
diff --git a/core/src/main/twirl/spark/deploy/master/job_details.scala.html b/core/src/main/twirl/spark/deploy/master/job_details.scala.html
index dcf41c28f2..d02a51b214 100644
--- a/core/src/main/twirl/spark/deploy/master/job_details.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/job_details.scala.html
@@ -1,6 +1,6 @@
@(job: spark.deploy.master.JobInfo)
-@spark.deploy.common.html.layout(title = "Job Details") {
+@spark.common.html.layout(title = "Job Details") {
<!-- Job Details -->
<div class="row">
diff --git a/core/src/main/twirl/spark/deploy/worker/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html
index b247307dab..1d703dae58 100644
--- a/core/src/main/twirl/spark/deploy/worker/index.scala.html
+++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html
@@ -1,8 +1,7 @@
@(worker: spark.deploy.WorkerState)
-
@import spark.Utils
-@spark.deploy.common.html.layout(title = "Spark Worker on " + worker.uri) {
+@spark.common.html.layout(title = "Spark Worker on " + worker.uri) {
<!-- Worker Details -->
<div class="row">
diff --git a/core/src/main/twirl/spark/storage/index.scala.html b/core/src/main/twirl/spark/storage/index.scala.html
new file mode 100644
index 0000000000..2b337f6133
--- /dev/null
+++ b/core/src/main/twirl/spark/storage/index.scala.html
@@ -0,0 +1,40 @@
+@(maxMem: Long, remainingMem: Long, diskSpaceUsed: Long, rdds: Array[spark.storage.RDDInfo], storageStatusList: Array[spark.storage.StorageStatus])
+@import spark.Utils
+
+@spark.common.html.layout(title = "Storage Dashboard") {
+
+ <!-- High-Level Information -->
+ <div class="row">
+ <div class="span12">
+ <ul class="unstyled">
+ <li><strong>Memory:</strong>
+ @{Utils.memoryBytesToString(maxMem - remainingMem)} Used
+ (@{Utils.memoryBytesToString(remainingMem)} Available) </li>
+ <li><strong>Disk:</strong> @{Utils.memoryBytesToString(diskSpaceUsed)} Used </li>
+ </ul>
+ </div>
+ </div>
+
+ <hr/>
+
+ <!-- RDD Summary -->
+ <div class="row">
+ <div class="span12">
+ <h3> RDD Summary </h3>
+ <br/>
+ @rdd_table(rdds)
+ </div>
+ </div>
+
+ <hr/>
+
+ <!-- Worker Summary -->
+ <div class="row">
+ <div class="span12">
+ <h3> Worker Summary </h3>
+ <br/>
+ @worker_table(storageStatusList)
+ </div>
+ </div>
+
+} \ No newline at end of file
diff --git a/core/src/main/twirl/spark/storage/rdd.scala.html b/core/src/main/twirl/spark/storage/rdd.scala.html
new file mode 100644
index 0000000000..d85addeb17
--- /dev/null
+++ b/core/src/main/twirl/spark/storage/rdd.scala.html
@@ -0,0 +1,81 @@
+@(rddInfo: spark.storage.RDDInfo, storageStatusList: Array[spark.storage.StorageStatus])
+@import spark.Utils
+
+@spark.common.html.layout(title = "RDD Info ") {
+
+ <!-- High-Level Information -->
+ <div class="row">
+ <div class="span12">
+ <ul class="unstyled">
+ <li>
+ <strong>Storage Level:</strong>
+ @(rddInfo.storageLevel.description)
+ <li>
+ <strong>Cached Partitions:</strong>
+ @(rddInfo.numCachedPartitions)
+ </li>
+ <li>
+ <strong>Total Partitions:</strong>
+ @(rddInfo.numPartitions)
+ </li>
+ <li>
+ <strong>Memory Size:</strong>
+ @{Utils.memoryBytesToString(rddInfo.memSize)}
+ </li>
+ <li>
+ <strong>Disk Size:</strong>
+ @{Utils.memoryBytesToString(rddInfo.diskSize)}
+ </li>
+ </ul>
+ </div>
+ </div>
+
+ <hr/>
+
+ <!-- RDD Summary -->
+ <div class="row">
+ <div class="span12">
+ <h3> RDD Summary </h3>
+ <br/>
+
+
+ <!-- Block Table Summary -->
+ <table class="table table-bordered table-striped table-condensed sortable">
+ <thead>
+ <tr>
+ <th>Block Name</th>
+ <th>Storage Level</th>
+ <th>Size in Memory</th>
+ <th>Size on Disk</th>
+ </tr>
+ </thead>
+ <tbody>
+ @storageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1).map { case (k,v) =>
+ <tr>
+ <td>@k</td>
+ <td>
+ @(v.storageLevel.description)
+ </td>
+ <td>@{Utils.memoryBytesToString(v.memSize)}</td>
+ <td>@{Utils.memoryBytesToString(v.diskSize)}</td>
+ </tr>
+ }
+ </tbody>
+ </table>
+
+
+ </div>
+ </div>
+
+ <hr/>
+
+ <!-- Worker Table -->
+ <div class="row">
+ <div class="span12">
+ <h3> Worker Summary </h3>
+ <br/>
+ @worker_table(storageStatusList, "rdd_" + rddInfo.id )
+ </div>
+ </div>
+
+} \ No newline at end of file
diff --git a/core/src/main/twirl/spark/storage/rdd_table.scala.html b/core/src/main/twirl/spark/storage/rdd_table.scala.html
new file mode 100644
index 0000000000..a51e64aed0
--- /dev/null
+++ b/core/src/main/twirl/spark/storage/rdd_table.scala.html
@@ -0,0 +1,32 @@
+@(rdds: Array[spark.storage.RDDInfo])
+@import spark.Utils
+
+<table class="table table-bordered table-striped table-condensed sortable">
+ <thead>
+ <tr>
+ <th>RDD Name</th>
+ <th>Storage Level</th>
+ <th>Cached Partitions</th>
+ <th>Fraction Partitions Cached</th>
+ <th>Size in Memory</th>
+ <th>Size on Disk</th>
+ </tr>
+ </thead>
+ <tbody>
+ @for(rdd <- rdds) {
+ <tr>
+ <td>
+ <a href="rdd?id=@(rdd.id)">
+ @rdd.name
+ </a>
+ </td>
+ <td>@(rdd.storageLevel.description)
+ </td>
+ <td>@rdd.numCachedPartitions</td>
+ <td>@(rdd.numCachedPartitions / rdd.numPartitions.toDouble)</td>
+ <td>@{Utils.memoryBytesToString(rdd.memSize)}</td>
+ <td>@{Utils.memoryBytesToString(rdd.diskSize)}</td>
+ </tr>
+ }
+ </tbody>
+</table> \ No newline at end of file
diff --git a/core/src/main/twirl/spark/storage/worker_table.scala.html b/core/src/main/twirl/spark/storage/worker_table.scala.html
new file mode 100644
index 0000000000..d54b8de4cc
--- /dev/null
+++ b/core/src/main/twirl/spark/storage/worker_table.scala.html
@@ -0,0 +1,24 @@
+@(workersStatusList: Array[spark.storage.StorageStatus], prefix: String = "")
+@import spark.Utils
+
+<table class="table table-bordered table-striped table-condensed sortable">
+ <thead>
+ <tr>
+ <th>Host</th>
+ <th>Memory Usage</th>
+ <th>Disk Usage</th>
+ </tr>
+ </thead>
+ <tbody>
+ @for(status <- workersStatusList) {
+ <tr>
+ <td>@(status.blockManagerId.ip + ":" + status.blockManagerId.port)</td>
+ <td>
+ @(Utils.memoryBytesToString(status.memUsed(prefix)))
+ (@(Utils.memoryBytesToString(status.memRemaining)) Total Available)
+ </td>
+ <td>@(Utils.memoryBytesToString(status.diskUsed(prefix)))</td>
+ </tr>
+ }
+ </tbody>
+</table> \ No newline at end of file
diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala
index d8be99dde7..ac8ae7d308 100644
--- a/core/src/test/scala/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/spark/AccumulatorSuite.scala
@@ -1,6 +1,5 @@
package spark
-import org.scalatest.BeforeAndAfter
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import collection.mutable
@@ -9,18 +8,7 @@ import scala.math.exp
import scala.math.signum
import spark.SparkContext._
-class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
-
- var sc: SparkContext = null
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
test ("basic accumulation"){
sc = new SparkContext("local", "test")
@@ -29,6 +17,12 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
val d = sc.parallelize(1 to 20)
d.foreach{x => acc += x}
acc.value should be (210)
+
+
+ val longAcc = sc.accumulator(0l)
+ val maxInt = Integer.MAX_VALUE.toLong
+ d.foreach{x => longAcc += maxInt + x}
+ longAcc.value should be (210l + maxInt * 20)
}
test ("value not assignable from tasks") {
@@ -53,10 +47,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
for (i <- 1 to maxI) {
v should contain(i)
}
- sc.stop()
- sc = null
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ resetSparkContext()
}
}
@@ -86,10 +77,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
x => acc.value += x
}
} should produce [SparkException]
- sc.stop()
- sc = null
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ resetSparkContext()
}
}
@@ -115,10 +103,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
bufferAcc.value should contain(i)
mapAcc.value should contain (i -> i.toString)
}
- sc.stop()
- sc = null
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ resetSparkContext()
}
}
@@ -134,8 +119,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
x => acc.localValue ++= x
}
acc.value should be ( (0 to maxI).toSet)
- sc.stop()
- sc = null
+ resetSparkContext()
}
}
diff --git a/core/src/test/scala/spark/BroadcastSuite.scala b/core/src/test/scala/spark/BroadcastSuite.scala
index 2d3302f0aa..362a31fb0d 100644
--- a/core/src/test/scala/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/spark/BroadcastSuite.scala
@@ -1,20 +1,8 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
-class BroadcastSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class BroadcastSuite extends FunSuite with LocalSparkContext {
test("basic broadcast") {
sc = new SparkContext("local", "test")
diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala
deleted file mode 100644
index 467605981b..0000000000
--- a/core/src/test/scala/spark/CacheTrackerSuite.scala
+++ /dev/null
@@ -1,131 +0,0 @@
-package spark
-
-import org.scalatest.FunSuite
-
-import scala.collection.mutable.HashMap
-
-import akka.actor._
-import akka.dispatch._
-import akka.pattern.ask
-import akka.remote._
-import akka.util.Duration
-import akka.util.Timeout
-import akka.util.duration._
-
-class CacheTrackerSuite extends FunSuite {
- // Send a message to an actor and wait for a reply, in a blocking manner
- private def ask(actor: ActorRef, message: Any): Any = {
- try {
- val timeout = 10.seconds
- val future = actor.ask(message)(timeout)
- return Await.result(future, timeout)
- } catch {
- case e: Exception =>
- throw new SparkException("Error communicating with actor", e)
- }
- }
-
- test("CacheTrackerActor slave initialization & cache status") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 0L)))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("RegisterRDD") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 3)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(getCacheLocations(tracker) === Map(1 -> List(Nil, Nil, Nil), 2 -> List(Nil)))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("AddedToCache") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 2)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true)
- assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true)
- assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L)))
-
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("DroppedFromCache") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 2)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true)
- assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true)
- assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L)))
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
-
- assert(ask(tracker, DroppedFromCache(1, 1, "host001", 2L << 11)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 68608L)))
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- /**
- * Helper function to get cacheLocations from CacheTracker
- */
- def getCacheLocations(tracker: ActorRef): HashMap[Int, List[List[String]]] = {
- val answer = ask(tracker, GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
- answer.map { case (i, arr) => (i, arr.toList) }
- }
-}
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
index 51573254ca..0b74607fb8 100644
--- a/core/src/test/scala/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -1,34 +1,27 @@
package spark
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.FunSuite
import java.io.File
import spark.rdd._
import spark.SparkContext._
import storage.StorageLevel
-class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
+class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
initLogging()
- var sc: SparkContext = _
var checkpointDir: File = _
val partitioner = new HashPartitioner(2)
- before {
+ override def beforeEach() {
+ super.beforeEach()
checkpointDir = File.createTempFile("temp", "")
checkpointDir.delete()
-
sc = new SparkContext("local", "test")
sc.setCheckpointDir(checkpointDir.toString)
}
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
-
+ override def afterEach() {
+ super.afterEach()
if (checkpointDir != null) {
checkpointDir.delete()
}
@@ -106,7 +99,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
// the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
// Note that this test is very specific to the current implementation of CartesianRDD.
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
- ones.checkpoint // checkpoint that MappedRDD
+ ones.checkpoint() // checkpoint that MappedRDD
val cartesian = new CartesianRDD(sc, ones, ones)
val splitBeforeCheckpoint =
serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
@@ -132,7 +125,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
// the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
// Note that this test is very specific to the current implementation of CoalescedRDDSplits
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
- ones.checkpoint // checkpoint that MappedRDD
+ ones.checkpoint() // checkpoint that MappedRDD
val coalesced = new CoalescedRDD(ones, 2)
val splitBeforeCheckpoint =
serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
@@ -167,7 +160,6 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
// so only the RDD will reduce in serialized size, not the splits.
testParentCheckpointing(
rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
-
}
/**
@@ -183,7 +175,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
testRDDSplitSize: Boolean = false
) {
// Generate the final RDD using given RDD operation
- val baseRDD = generateLongLineageRDD
+ val baseRDD = generateLongLineageRDD()
val operatedRDD = op(baseRDD)
val parentRDD = operatedRDD.dependencies.headOption.orNull
val rddType = operatedRDD.getClass.getSimpleName
@@ -252,12 +244,16 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
testRDDSplitSize: Boolean
) {
// Generate the final RDD using given RDD operation
- val baseRDD = generateLongLineageRDD
+ val baseRDD = generateLongLineageRDD()
val operatedRDD = op(baseRDD)
val parentRDD = operatedRDD.dependencies.head.rdd
val rddType = operatedRDD.getClass.getSimpleName
val parentRDDType = parentRDD.getClass.getSimpleName
+ // Get the splits and dependencies of the parent in case they're lazily computed
+ parentRDD.dependencies
+ parentRDD.splits
+
// Find serialized sizes before and after the checkpoint
val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one
@@ -274,7 +270,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
if (testRDDSize) {
assert(
rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
- "Size of " + rddType + " did not reduce after parent checkpointing parent " + parentRDDType +
+ "Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType +
"[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
)
}
@@ -325,10 +321,12 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
}
/**
- * Get serialized sizes of the RDD and its splits
+ * Get serialized sizes of the RDD and its splits, in order to test whether the size shrinks
+ * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint.
*/
def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
- (Utils.serialize(rdd).size, Utils.serialize(rdd.splits).size)
+ (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length,
+ Utils.serialize(rdd.splits).length)
}
/**
diff --git a/core/src/test/scala/spark/ClosureCleanerSuite.scala b/core/src/test/scala/spark/ClosureCleanerSuite.scala
index dfa2de80e6..b2d0dd4627 100644
--- a/core/src/test/scala/spark/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/spark/ClosureCleanerSuite.scala
@@ -3,6 +3,7 @@ package spark
import java.io.NotSerializableException
import org.scalatest.FunSuite
+import spark.LocalSparkContext._
import SparkContext._
class ClosureCleanerSuite extends FunSuite {
@@ -43,13 +44,10 @@ object TestObject {
def run(): Int = {
var nonSer = new NonSerializable
var x = 5
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- val answer = nums.map(_ + x).reduce(_ + _)
- sc.stop()
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- return answer
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + x).reduce(_ + _)
+ }
}
}
@@ -60,11 +58,10 @@ class TestClass extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- val answer = nums.map(_ + getX).reduce(_ + _)
- sc.stop()
- return answer
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + getX).reduce(_ + _)
+ }
}
}
@@ -73,11 +70,10 @@ class TestClassWithoutDefaultConstructor(x: Int) extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- val answer = nums.map(_ + getX).reduce(_ + _)
- sc.stop()
- return answer
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + getX).reduce(_ + _)
+ }
}
}
@@ -89,11 +85,10 @@ class TestClassWithoutFieldAccess {
def run(): Int = {
var nonSer2 = new NonSerializable
var x = 5
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- val answer = nums.map(_ + x).reduce(_ + _)
- sc.stop()
- return answer
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + x).reduce(_ + _)
+ }
}
}
@@ -102,16 +97,16 @@ object TestObjectWithNesting {
def run(): Int = {
var nonSer = new NonSerializable
var answer = 0
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- var y = 1
- for (i <- 1 to 4) {
- var nonSer2 = new NonSerializable
- var x = i
- answer += nums.map(_ + x + y).reduce(_ + _)
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ var y = 1
+ for (i <- 1 to 4) {
+ var nonSer2 = new NonSerializable
+ var x = i
+ answer += nums.map(_ + x + y).reduce(_ + _)
+ }
+ answer
}
- sc.stop()
- return answer
}
}
@@ -121,14 +116,14 @@ class TestClassWithNesting(val y: Int) extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
var answer = 0
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- for (i <- 1 to 4) {
- var nonSer2 = new NonSerializable
- var x = i
- answer += nums.map(_ + x + getY).reduce(_ + _)
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ for (i <- 1 to 4) {
+ var nonSer2 = new NonSerializable
+ var x = i
+ answer += nums.map(_ + x + getY).reduce(_ + _)
+ }
+ answer
}
- sc.stop()
- return answer
}
}
diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala
index cacc2796b6..0e2585daa4 100644
--- a/core/src/test/scala/spark/DistributedSuite.scala
+++ b/core/src/test/scala/spark/DistributedSuite.scala
@@ -15,41 +15,28 @@ import scala.collection.mutable.ArrayBuffer
import SparkContext._
import storage.StorageLevel
-class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
+class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext {
val clusterUrl = "local-cluster[2,1,512]"
- @transient var sc: SparkContext = _
-
after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
System.clearProperty("spark.reducer.maxMbInFlight")
System.clearProperty("spark.storage.memoryFraction")
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
}
test("local-cluster format") {
sc = new SparkContext("local-cluster[2,1,512]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
- sc.stop()
- System.clearProperty("spark.master.port")
+ resetSparkContext()
sc = new SparkContext("local-cluster[2 , 1 , 512]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
- sc.stop()
- System.clearProperty("spark.master.port")
+ resetSparkContext()
sc = new SparkContext("local-cluster[2, 1, 512]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
- sc.stop()
- System.clearProperty("spark.master.port")
+ resetSparkContext()
sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
- sc.stop()
- System.clearProperty("spark.master.port")
- sc = null
+ resetSparkContext()
}
test("simple groupByKey") {
@@ -188,4 +175,73 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
val values = sc.parallelize(1 to 2, 2).map(x => System.getenv("TEST_VAR")).collect()
assert(values.toSeq === Seq("TEST_VALUE", "TEST_VALUE"))
}
+
+ test("recover from node failures") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ val data = sc.parallelize(Seq(true, true), 2)
+ assert(data.count === 2) // force executors to start
+ val masterId = SparkEnv.get.blockManager.blockManagerId
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ assert(data.map(failOnMarkedIdentity).collect.size === 2)
+ }
+
+ test("recover from repeated node failures during shuffle-map") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, false), 2)
+ assert(data.count === 2)
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2)
+ }
+ }
+
+ test("recover from repeated node failures during shuffle-reduce") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, true), 2)
+ assert(data.count === 2)
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ // This relies on mergeCombiners being used to perform the actual reduce for this
+ // test to actually be testing what it claims.
+ val grouped = data.map(x => x -> x).combineByKey(
+ x => x,
+ (x: Boolean, y: Boolean) => x,
+ (x: Boolean, y: Boolean) => failOnMarkedIdentity(x)
+ )
+ assert(grouped.collect.size === 1)
+ }
+ }
+}
+
+object DistributedSuite {
+ // Indicates whether this JVM is marked for failure.
+ var mark = false
+
+ // Set by test to remember if we are in the driver program so we can assert
+ // that we are not.
+ var amMaster = false
+
+ // Act like an identity function, but if the argument is true, set mark to true.
+ def markNodeIfIdentity(item: Boolean): Boolean = {
+ if (item) {
+ assert(!amMaster)
+ mark = true
+ }
+ item
+ }
+
+ // Act like an identity function, but if mark was set to true previously, fail,
+ // crashing the entire JVM.
+ def failOnMarkedIdentity(item: Boolean): Boolean = {
+ if (mark) {
+ System.exit(42)
+ }
+ item
+ }
}
diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala
new file mode 100644
index 0000000000..5e84b3a66a
--- /dev/null
+++ b/core/src/test/scala/spark/DriverSuite.scala
@@ -0,0 +1,33 @@
+package spark
+
+import java.io.File
+
+import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.prop.TableDrivenPropertyChecks._
+import org.scalatest.time.SpanSugar._
+
+class DriverSuite extends FunSuite with Timeouts {
+ test("driver should exit after finishing") {
+ assert(System.getenv("SPARK_HOME") != null)
+ // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing"
+ val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]"))
+ forAll(masters) { (master: String) =>
+ failAfter(30 seconds) {
+ Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master),
+ new File(System.getenv("SPARK_HOME")))
+ }
+ }
+ }
+}
+
+/**
+ * Program that creates a Spark driver but doesn't call SparkContext.stop() or
+ * Sys.exit() after finishing.
+ */
+object DriverWithoutCleanup {
+ def main(args: Array[String]) {
+ val sc = new SparkContext(args(0), "DriverWithoutCleanup")
+ sc.parallelize(1 to 100, 4).count()
+ }
+}
diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala
index a3454f25f6..8c1445a465 100644
--- a/core/src/test/scala/spark/FailureSuite.scala
+++ b/core/src/test/scala/spark/FailureSuite.scala
@@ -1,7 +1,6 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import org.scalatest.prop.Checkers
import scala.collection.mutable.ArrayBuffer
@@ -23,18 +22,7 @@ object FailureSuiteState {
}
}
-class FailureSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class FailureSuite extends FunSuite with LocalSparkContext {
// Run a 3-task map job in which task 1 deterministically fails once, and check
// whether the job completes successfully and we ran 4 tasks in total.
diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala
index b4283d9604..f1a35bced3 100644
--- a/core/src/test/scala/spark/FileServerSuite.scala
+++ b/core/src/test/scala/spark/FileServerSuite.scala
@@ -2,17 +2,16 @@ package spark
import com.google.common.io.Files
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import java.io.{File, PrintWriter, FileReader, BufferedReader}
import SparkContext._
-class FileServerSuite extends FunSuite with BeforeAndAfter {
+class FileServerSuite extends FunSuite with LocalSparkContext {
- @transient var sc: SparkContext = _
- @transient var tmpFile : File = _
- @transient var testJarFile : File = _
+ @transient var tmpFile: File = _
+ @transient var testJarFile: File = _
- before {
+ override def beforeEach() {
+ super.beforeEach()
// Create a sample text file
val tmpdir = new File(Files.createTempDir(), "test")
tmpdir.mkdir()
@@ -22,17 +21,12 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
pw.close()
}
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
+ override def afterEach() {
+ super.afterEach()
// Clean up downloaded file
if (tmpFile.exists) {
tmpFile.delete()
}
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
}
test("Distributing files locally") {
@@ -40,7 +34,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
@@ -54,7 +49,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile((new File(tmpFile.toString)).toURL.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
@@ -83,7 +79,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala
index 554bea53a9..91b48c7456 100644
--- a/core/src/test/scala/spark/FileSuite.scala
+++ b/core/src/test/scala/spark/FileSuite.scala
@@ -6,24 +6,12 @@ import scala.io.Source
import com.google.common.io.Files
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import org.apache.hadoop.io._
import SparkContext._
-class FileSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
-
+class FileSuite extends FunSuite with LocalSparkContext {
+
test("text files") {
sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index 01351de4ae..934e4c2f67 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -46,7 +46,7 @@ public class JavaAPISuite implements Serializable {
sc.stop();
sc = null;
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port");
+ System.clearProperty("spark.driver.port");
}
static class ReverseIntComparator implements Comparator<Integer>, Serializable {
@@ -356,6 +356,34 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void mapsFromPairsToPairs() {
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> pairRDD = sc.parallelizePairs(pairs);
+
+ // Regression test for SPARK-668:
+ JavaPairRDD<String, Integer> swapped = pairRDD.flatMap(
+ new PairFlatMapFunction<Tuple2<Integer, String>, String, Integer>() {
+ @Override
+ public Iterable<Tuple2<String, Integer>> call(Tuple2<Integer, String> item) throws Exception {
+ return Collections.singletonList(item.swap());
+ }
+ });
+ swapped.collect();
+
+ // There was never a bug here, but it's worth testing:
+ pairRDD.map(new PairFunction<Tuple2<Integer, String>, String, Integer>() {
+ @Override
+ public Tuple2<String, Integer> call(Tuple2<Integer, String> item) throws Exception {
+ return item.swap();
+ }
+ }).collect();
+ }
+
+ @Test
public void mapPartitions() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
JavaRDD<Integer> partitionSums = rdd.mapPartitions(
diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala
new file mode 100644
index 0000000000..ff00dd05dd
--- /dev/null
+++ b/core/src/test/scala/spark/LocalSparkContext.scala
@@ -0,0 +1,41 @@
+package spark
+
+import org.scalatest.Suite
+import org.scalatest.BeforeAndAfterEach
+
+/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */
+trait LocalSparkContext extends BeforeAndAfterEach { self: Suite =>
+
+ @transient var sc: SparkContext = _
+
+ override def afterEach() {
+ resetSparkContext()
+ super.afterEach()
+ }
+
+ def resetSparkContext() = {
+ if (sc != null) {
+ LocalSparkContext.stop(sc)
+ sc = null
+ }
+ }
+
+}
+
+object LocalSparkContext {
+ def stop(sc: SparkContext) {
+ sc.stop()
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.driver.port")
+ }
+
+ /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */
+ def withSpark[T](sc: SparkContext)(f: SparkContext => T) = {
+ try {
+ f(sc)
+ } finally {
+ stop(sc)
+ }
+ }
+
+} \ No newline at end of file
diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index d3dd3a8fa4..dd19442dcb 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -1,17 +1,13 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import akka.actor._
import spark.scheduler.MapStatus
import spark.storage.BlockManagerId
import spark.util.AkkaUtils
-class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter {
- after {
- System.clearProperty("spark.master.port")
- }
+class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("compressSize") {
assert(MapOutputTracker.compressSize(0L) === 0)
@@ -47,13 +43,13 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter {
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
Array(compressedSize1000, compressedSize10000)))
- tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000)))
val statuses = tracker.getServerStatuses(10, 0)
- assert(statuses.toSeq === Seq((new BlockManagerId("hostA", 1000), size1000),
- (new BlockManagerId("hostB", 1000), size10000)))
+ assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000),
+ (BlockManagerId("b", "hostB", 1000), size10000)))
tracker.stop()
}
@@ -65,48 +61,51 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter {
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
- tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
// As if we had two simulatenous fetch failures
- tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
- tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
- // The remaining reduce task might try to grab the output dispite the shuffle failure;
+ // The remaining reduce task might try to grab the output despite the shuffle failure;
// this should cause it to fail, and the scheduler will ignore the failure due to the
// stage already being aborted.
intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) }
}
test("remote fetch") {
- System.clearProperty("spark.master.host")
- val (actorSystem, boundPort) =
- AkkaUtils.createActorSystem("test", "localhost", 0)
- System.setProperty("spark.master.port", boundPort.toString)
- val masterTracker = new MapOutputTracker(actorSystem, true)
- val slaveTracker = new MapOutputTracker(actorSystem, false)
- masterTracker.registerShuffle(10, 1)
- masterTracker.incrementGeneration()
- slaveTracker.updateGeneration(masterTracker.getGeneration)
- intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+ try {
+ System.clearProperty("spark.driver.host") // In case some previous test had set it
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0)
+ System.setProperty("spark.driver.port", boundPort.toString)
+ val masterTracker = new MapOutputTracker(actorSystem, true)
+ val slaveTracker = new MapOutputTracker(actorSystem, false)
+ masterTracker.registerShuffle(10, 1)
+ masterTracker.incrementGeneration()
+ slaveTracker.updateGeneration(masterTracker.getGeneration)
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
- val compressedSize1000 = MapOutputTracker.compressSize(1000L)
- val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
- masterTracker.registerMapOutput(10, 0, new MapStatus(
- new BlockManagerId("hostA", 1000), Array(compressedSize1000)))
- masterTracker.incrementGeneration()
- slaveTracker.updateGeneration(masterTracker.getGeneration)
- assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((new BlockManagerId("hostA", 1000), size1000)))
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ masterTracker.registerMapOutput(10, 0, new MapStatus(
+ BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
+ masterTracker.incrementGeneration()
+ slaveTracker.updateGeneration(masterTracker.getGeneration)
+ assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+ Seq((BlockManagerId("a", "hostA", 1000), size1000)))
- masterTracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
- masterTracker.incrementGeneration()
- slaveTracker.updateGeneration(masterTracker.getGeneration)
- intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+ masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
+ masterTracker.incrementGeneration()
+ slaveTracker.updateGeneration(masterTracker.getGeneration)
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
- // failure should be cached
- intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+ // failure should be cached
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+ } finally {
+ System.clearProperty("spark.driver.port")
+ }
}
}
diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala
index eb3c8f238f..af1107cd19 100644
--- a/core/src/test/scala/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/spark/PartitioningSuite.scala
@@ -1,25 +1,12 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import scala.collection.mutable.ArrayBuffer
import SparkContext._
-class PartitioningSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if(sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
-
+class PartitioningSuite extends FunSuite with LocalSparkContext {
test("HashPartitioner equality") {
val p2 = new HashPartitioner(2)
diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala
index 9b84b29227..a6344edf8f 100644
--- a/core/src/test/scala/spark/PipedRDDSuite.scala
+++ b/core/src/test/scala/spark/PipedRDDSuite.scala
@@ -1,21 +1,9 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import SparkContext._
-class PipedRDDSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class PipedRDDSuite extends FunSuite with LocalSparkContext {
test("basic pipe") {
sc = new SparkContext("local", "test")
@@ -51,5 +39,3 @@ class PipedRDDSuite extends FunSuite with BeforeAndAfter {
}
}
-
-
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index db217f8482..fe7deb10d6 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -2,32 +2,20 @@ package spark
import scala.collection.mutable.HashMap
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
+import spark.SparkContext._
+import spark.rdd.{CoalescedRDD, PartitionPruningRDD}
-import spark.rdd.CoalescedRDD
-import SparkContext._
-
-class RDDSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class RDDSuite extends FunSuite with LocalSparkContext {
test("basic operations") {
sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
- assert(dups.distinct.count === 4)
- assert(dups.distinct().collect === dups.distinct.collect)
- assert(dups.distinct(2).collect === dups.distinct.collect)
+ assert(dups.distinct().count() === 4)
+ assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses?
+ assert(dups.distinct.collect === dups.distinct().collect)
+ assert(dups.distinct(2).collect === dups.distinct().collect)
assert(nums.reduce(_ + _) === 10)
assert(nums.fold(0)(_ + _) === 10)
assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4"))
@@ -44,6 +32,10 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
case(split, iter) => Iterator((split, iter.reduceLeft(_ + _)))
}
assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7)))
+
+ intercept[UnsupportedOperationException] {
+ nums.filter(_ > 5).reduce(_ + _)
+ }
}
test("SparkContext.union") {
@@ -104,7 +96,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
}
test("caching with failures") {
- sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val onlySplit = new Split { override def index: Int = 0 }
var shouldFail = true
val rdd = new RDD[Int](sc, Nil) {
@@ -136,8 +128,10 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10)))
// Check that the narrow dependency is also specified correctly
- assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === List(0, 1, 2, 3, 4))
- assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === List(5, 6, 7, 8, 9))
+ assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList ===
+ List(0, 1, 2, 3, 4))
+ assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList ===
+ List(5, 6, 7, 8, 9))
val coalesced2 = new CoalescedRDD(data, 3)
assert(coalesced2.collect().toList === (1 to 10).toList)
@@ -168,4 +162,15 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
nums.zip(sc.parallelize(1 to 4, 1)).collect()
}
}
+
+ test("partition pruning") {
+ sc = new SparkContext("local", "test")
+ val data = sc.parallelize(1 to 10, 10)
+ // Note that split number starts from 0, so > 8 means only 10th partition left.
+ val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8)
+ assert(prunedRdd.splits.size === 1)
+ val prunedData = prunedRdd.collect()
+ assert(prunedData.size === 1)
+ assert(prunedData(0) === 10)
+ }
}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index bebb8ebe86..3493b9511f 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -3,7 +3,6 @@ package spark
import scala.collection.mutable.ArrayBuffer
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
@@ -15,18 +14,7 @@ import com.google.common.io.Files
import spark.rdd.ShuffledRDD
import spark.SparkContext._
-class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
test("groupByKey") {
sc = new SparkContext("local", "test")
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala
index 1ad11ff4c3..edb8c839fc 100644
--- a/core/src/test/scala/spark/SortingSuite.scala
+++ b/core/src/test/scala/spark/SortingSuite.scala
@@ -5,18 +5,7 @@ import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.ShouldMatchers
import SparkContext._
-class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with Logging {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers with Logging {
test("sortByKey") {
sc = new SparkContext("local", "test")
diff --git a/core/src/test/scala/spark/ThreadingSuite.scala b/core/src/test/scala/spark/ThreadingSuite.scala
index e9b1837d89..ff315b6693 100644
--- a/core/src/test/scala/spark/ThreadingSuite.scala
+++ b/core/src/test/scala/spark/ThreadingSuite.scala
@@ -22,19 +22,7 @@ object ThreadingSuiteState {
}
}
-class ThreadingSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if(sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
-
+class ThreadingSuite extends FunSuite with LocalSparkContext {
test("accessing SparkContext form a different thread") {
sc = new SparkContext("local", "test")
diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
new file mode 100644
index 0000000000..83663ac702
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -0,0 +1,663 @@
+package spark.scheduler
+
+import scala.collection.mutable.{Map, HashMap}
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import org.scalatest.concurrent.TimeLimitedTests
+import org.scalatest.mock.EasyMockSugar
+import org.scalatest.time.{Span, Seconds}
+
+import org.easymock.EasyMock._
+import org.easymock.Capture
+import org.easymock.EasyMock
+import org.easymock.{IAnswer, IArgumentMatcher}
+
+import akka.actor.ActorSystem
+
+import spark.storage.BlockManager
+import spark.storage.BlockManagerId
+import spark.storage.BlockManagerMaster
+import spark.{Dependency, ShuffleDependency, OneToOneDependency}
+import spark.FetchFailedException
+import spark.MapOutputTracker
+import spark.RDD
+import spark.SparkContext
+import spark.SparkException
+import spark.Split
+import spark.TaskContext
+import spark.TaskEndReason
+
+import spark.{FetchFailed, Success}
+
+/**
+ * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
+ * rather than spawning an event loop thread as happens in the real code. They use EasyMock
+ * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are
+ * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead
+ * host notifications are sent). In addition, tests may check for side effects on a non-mocked
+ * MapOutputTracker instance.
+ *
+ * Tests primarily consist of running DAGScheduler#processEvent and
+ * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet)
+ * and capturing the resulting TaskSets from the mock TaskScheduler.
+ */
+class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar with TimeLimitedTests {
+
+ // impose a time limit on this test in case we don't let the job finish, in which case
+ // JobWaiter#getResult will hang.
+ override val timeLimit = Span(5, Seconds)
+
+ val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite")
+ var scheduler: DAGScheduler = null
+ val taskScheduler = mock[TaskScheduler]
+ val blockManagerMaster = mock[BlockManagerMaster]
+ var mapOutputTracker: MapOutputTracker = null
+ var schedulerThread: Thread = null
+ var schedulerException: Throwable = null
+
+ /**
+ * Set of EasyMock argument matchers that match a TaskSet for a given RDD.
+ * We cache these so we do not create duplicate matchers for the same RDD.
+ * This allows us to easily setup a sequence of expectations for task sets for
+ * that RDD.
+ */
+ val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher]
+
+ /**
+ * Set of cache locations to return from our mock BlockManagerMaster.
+ * Keys are (rdd ID, partition ID). Anything not present will return an empty
+ * list of cache locations silently.
+ */
+ val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
+
+ /**
+ * JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which
+ * will only submit one job) from needing to explicitly track it.
+ */
+ var lastJobWaiter: JobWaiter[Int] = null
+
+ /**
+ * Array into which we are accumulating the results from the last job asynchronously.
+ */
+ var lastJobResult: Array[Int] = null
+
+ /**
+ * Tell EasyMockSugar what mock objects we want to be configured by expecting {...}
+ * and whenExecuting {...} */
+ implicit val mocks = MockObjects(taskScheduler, blockManagerMaster)
+
+ /**
+ * Utility function to reset mocks and set expectations on them. EasyMock wants mock objects
+ * to be reset after each time their expectations are set, and we tend to check mock object
+ * calls over a single call to DAGScheduler.
+ *
+ * We also set a default expectation here that blockManagerMaster.getLocations can be called
+ * and will return values from cacheLocations.
+ */
+ def resetExpecting(f: => Unit) {
+ reset(taskScheduler)
+ reset(blockManagerMaster)
+ expecting {
+ expectGetLocations()
+ f
+ }
+ }
+
+ before {
+ taskSetMatchers.clear()
+ cacheLocations.clear()
+ val actorSystem = ActorSystem("test")
+ mapOutputTracker = new MapOutputTracker(actorSystem, true)
+ resetExpecting {
+ taskScheduler.setListener(anyObject())
+ }
+ whenExecuting {
+ scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null)
+ }
+ }
+
+ after {
+ assert(scheduler.processEvent(StopDAGScheduler))
+ resetExpecting {
+ taskScheduler.stop()
+ }
+ whenExecuting {
+ scheduler.stop()
+ }
+ sc.stop()
+ System.clearProperty("spark.master.port")
+ }
+
+ def makeBlockManagerId(host: String): BlockManagerId =
+ BlockManagerId("exec-" + host, host, 12345)
+
+ /**
+ * Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
+ * This is a pair RDD type so it can always be used in ShuffleDependencies.
+ */
+ type MyRDD = RDD[(Int, Int)]
+
+ /**
+ * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and
+ * preferredLocations (if any) that are passed to them. They are deliberately not executable
+ * so we can test that DAGScheduler does not try to execute RDDs locally.
+ */
+ def makeRdd(
+ numSplits: Int,
+ dependencies: List[Dependency[_]],
+ locations: Seq[Seq[String]] = Nil
+ ): MyRDD = {
+ val maxSplit = numSplits - 1
+ return new MyRDD(sc, dependencies) {
+ override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
+ throw new RuntimeException("should not be reached")
+ override def getSplits() = (0 to maxSplit).map(i => new Split {
+ override def index = i
+ }).toArray
+ override def getPreferredLocations(split: Split): Seq[String] =
+ if (locations.isDefinedAt(split.index))
+ locations(split.index)
+ else
+ Nil
+ override def toString: String = "DAGSchedulerSuiteRDD " + id
+ }
+ }
+
+ /**
+ * EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task
+ * is from a particular RDD.
+ */
+ def taskSetForRdd(rdd: MyRDD): TaskSet = {
+ val matcher = taskSetMatchers.getOrElseUpdate(rdd,
+ new IArgumentMatcher {
+ override def matches(actual: Any): Boolean = {
+ val taskSet = actual.asInstanceOf[TaskSet]
+ taskSet.tasks(0) match {
+ case rt: ResultTask[_, _] => rt.rdd.id == rdd.id
+ case smt: ShuffleMapTask => smt.rdd.id == rdd.id
+ case _ => false
+ }
+ }
+ override def appendTo(buf: StringBuffer) {
+ buf.append("taskSetForRdd(" + rdd + ")")
+ }
+ })
+ EasyMock.reportMatcher(matcher)
+ return null
+ }
+
+ /**
+ * Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from
+ * cacheLocations.
+ */
+ def expectGetLocations(): Unit = {
+ EasyMock.expect(blockManagerMaster.getLocations(anyObject().asInstanceOf[Array[String]])).
+ andAnswer(new IAnswer[Seq[Seq[BlockManagerId]]] {
+ override def answer(): Seq[Seq[BlockManagerId]] = {
+ val blocks = getCurrentArguments()(0).asInstanceOf[Array[String]]
+ return blocks.map { name =>
+ val pieces = name.split("_")
+ if (pieces(0) == "rdd") {
+ val key = pieces(1).toInt -> pieces(2).toInt
+ if (cacheLocations.contains(key)) {
+ cacheLocations(key)
+ } else {
+ Seq[BlockManagerId]()
+ }
+ } else {
+ Seq[BlockManagerId]()
+ }
+ }.toSeq
+ }
+ }).anyTimes()
+ }
+
+ /**
+ * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
+ * the scheduler not to exit.
+ *
+ * After processing the event, submit waiting stages as is done on most iterations of the
+ * DAGScheduler event loop.
+ */
+ def runEvent(event: DAGSchedulerEvent) {
+ assert(!scheduler.processEvent(event))
+ scheduler.submitWaitingStages()
+ }
+
+ /**
+ * Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be
+ * called from a resetExpecting { ... } block.
+ *
+ * Returns a easymock Capture that will contain the task set after the stage is submitted.
+ * Most tests should use interceptStage() instead of this directly.
+ */
+ def expectStage(rdd: MyRDD): Capture[TaskSet] = {
+ val taskSetCapture = new Capture[TaskSet]
+ taskScheduler.submitTasks(and(capture(taskSetCapture), taskSetForRdd(rdd)))
+ return taskSetCapture
+ }
+
+ /**
+ * Expect the supplied code snippet to submit a stage for the specified RDD.
+ * Return the resulting TaskSet. First marks all the tasks are belonging to the
+ * current MapOutputTracker generation.
+ */
+ def interceptStage(rdd: MyRDD)(f: => Unit): TaskSet = {
+ var capture: Capture[TaskSet] = null
+ resetExpecting {
+ capture = expectStage(rdd)
+ }
+ whenExecuting {
+ f
+ }
+ val taskSet = capture.getValue
+ for (task <- taskSet.tasks) {
+ task.generation = mapOutputTracker.getGeneration
+ }
+ return taskSet
+ }
+
+ /**
+ * Send the given CompletionEvent messages for the tasks in the TaskSet.
+ */
+ def respondToTaskSet(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
+ assert(taskSet.tasks.size >= results.size)
+ for ((result, i) <- results.zipWithIndex) {
+ if (i < taskSet.tasks.size) {
+ runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()))
+ }
+ }
+ }
+
+ /**
+ * Assert that the supplied TaskSet has exactly the given preferredLocations.
+ */
+ def expectTaskSetLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
+ assert(locations.size === taskSet.tasks.size)
+ for ((expectLocs, taskLocs) <-
+ taskSet.tasks.map(_.preferredLocations).zip(locations)) {
+ assert(expectLocs === taskLocs)
+ }
+ }
+
+ /**
+ * When we submit dummy Jobs, this is the compute function we supply. Except in a local test
+ * below, we do not expect this function to ever be executed; instead, we will return results
+ * directly through CompletionEvents.
+ */
+ def jobComputeFunc(context: TaskContext, it: Iterator[(Int, Int)]): Int =
+ it.next._1.asInstanceOf[Int]
+
+
+ /**
+ * Start a job to compute the given RDD. Returns the JobWaiter that will
+ * collect the result of the job via callbacks from DAGScheduler.
+ */
+ def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): (JobWaiter[Int], Array[Int]) = {
+ val resultArray = new Array[Int](rdd.splits.size)
+ val (toSubmit, waiter) = scheduler.prepareJob[(Int, Int), Int](
+ rdd,
+ jobComputeFunc,
+ (0 to (rdd.splits.size - 1)),
+ "test-site",
+ allowLocal,
+ (i: Int, value: Int) => resultArray(i) = value
+ )
+ lastJobWaiter = waiter
+ lastJobResult = resultArray
+ runEvent(toSubmit)
+ return (waiter, resultArray)
+ }
+
+ /**
+ * Assert that a job we started has failed.
+ */
+ def expectJobException(waiter: JobWaiter[Int] = lastJobWaiter) {
+ waiter.awaitResult() match {
+ case JobSucceeded => fail()
+ case JobFailed(_) => return
+ }
+ }
+
+ /**
+ * Assert that a job we started has succeeded and has the given result.
+ */
+ def expectJobResult(expected: Array[Int], waiter: JobWaiter[Int] = lastJobWaiter,
+ result: Array[Int] = lastJobResult) {
+ waiter.awaitResult match {
+ case JobSucceeded =>
+ assert(expected === result)
+ case JobFailed(_) =>
+ fail()
+ }
+ }
+
+ def makeMapStatus(host: String, reduces: Int): MapStatus =
+ new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
+
+ test("zero split job") {
+ val rdd = makeRdd(0, Nil)
+ var numResults = 0
+ def accumulateResult(partition: Int, value: Int) {
+ numResults += 1
+ }
+ scheduler.runJob(rdd, jobComputeFunc, Seq(), "test-site", false, accumulateResult)
+ assert(numResults === 0)
+ }
+
+ test("run trivial job") {
+ val rdd = makeRdd(1, Nil)
+ val taskSet = interceptStage(rdd) { submitRdd(rdd) }
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("local job") {
+ val rdd = new MyRDD(sc, Nil) {
+ override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
+ Array(42 -> 0).iterator
+ override def getSplits() = Array( new Split { override def index = 0 } )
+ override def getPreferredLocations(split: Split) = Nil
+ override def toString = "DAGSchedulerSuite Local RDD"
+ }
+ submitRdd(rdd, true)
+ expectJobResult(Array(42))
+ }
+
+ test("run trivial job w/ dependency") {
+ val baseRdd = makeRdd(1, Nil)
+ val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+ val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cache location preferences w/ dependency") {
+ val baseRdd = makeRdd(1, Nil)
+ val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+ cacheLocations(baseRdd.id -> 0) =
+ Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
+ val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
+ expectTaskSetLocations(taskSet, List(Seq("hostA", "hostB")))
+ respondToTaskSet(taskSet, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("trivial job failure") {
+ val rdd = makeRdd(1, Nil)
+ val taskSet = interceptStage(rdd) { submitRdd(rdd) }
+ runEvent(TaskSetFailed(taskSet, "test failure"))
+ expectJobException()
+ }
+
+ test("run trivial shuffle") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(1, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val secondStage = interceptStage(reduceRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ respondToTaskSet(secondStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("run trivial shuffle with fetch failure") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(2, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val secondStage = interceptStage(reduceRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(secondStage, List(
+ (Success, 42),
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)
+ ))
+ }
+ val thirdStage = interceptStage(shuffleMapRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ val fourthStage = interceptStage(reduceRdd) {
+ respondToTaskSet(thirdStage, List( (Success, makeMapStatus("hostA", 1)) ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ respondToTaskSet(fourthStage, List( (Success, 43) ))
+ expectJobResult(Array(42, 43))
+ }
+
+ test("ignore late map task completions") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(2, List(shuffleDep))
+
+ val taskSet = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ val oldGeneration = mapOutputTracker.getGeneration
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ runEvent(ExecutorLost("exec-hostA"))
+ }
+ val newGeneration = mapOutputTracker.getGeneration
+ assert(newGeneration > oldGeneration)
+ val noAccum = Map[Long, Any]()
+ // We rely on the event queue being ordered and increasing the generation number by 1
+ // should be ignored for being too old
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+ // should work because it's a non-failed host
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum))
+ // should be ignored for being too old
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+ taskSet.tasks(1).generation = newGeneration
+ val secondStage = interceptStage(reduceRdd) {
+ runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
+ respondToTaskSet(secondStage, List( (Success, 42), (Success, 43) ))
+ expectJobResult(Array(42, 43))
+ }
+
+ test("run trivial shuffle with out-of-band failure and retry") {
+ val shuffleMapRdd = makeRdd(2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = makeRdd(1, List(shuffleDep))
+
+ val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ runEvent(ExecutorLost("exec-hostA"))
+ }
+ // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
+ // rather than marking it is as failed and waiting.
+ val secondStage = interceptStage(shuffleMapRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ val thirdStage = interceptStage(reduceRdd) {
+ respondToTaskSet(secondStage, List(
+ (Success, makeMapStatus("hostC", 1))
+ ))
+ }
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+ respondToTaskSet(thirdStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("recursive shuffle failures") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ val secondStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val thirdStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostC", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(thirdStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ val recomputeOne = interceptStage(shuffleOneRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ val recomputeTwo = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(recomputeOne, List(
+ (Success, makeMapStatus("hostA", 2))
+ ))
+ }
+ val finalStage = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwo, List(
+ (Success, makeMapStatus("hostA", 1))
+ ))
+ }
+ respondToTaskSet(finalStage, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cached post-shuffle") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+ cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+ val secondShuffleStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstShuffleStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val reduceStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondShuffleStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(reduceStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
+ val recomputeTwo = interceptStage(shuffleTwoRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ expectTaskSetLocations(recomputeTwo, Seq(Seq("hostD")))
+ val finalRetry = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwo, List(
+ (Success, makeMapStatus("hostD", 1))
+ ))
+ }
+ respondToTaskSet(finalRetry, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+
+ test("cached post-shuffle but fails") {
+ val shuffleOneRdd = makeRdd(2, Nil)
+ val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+ val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+ val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+ val finalRdd = makeRdd(1, List(shuffleDepTwo))
+
+ val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+ cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+ cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+ val secondShuffleStage = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(firstShuffleStage, List(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))
+ ))
+ }
+ val reduceStage = interceptStage(finalRdd) {
+ respondToTaskSet(secondShuffleStage, List(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))
+ ))
+ }
+ resetExpecting {
+ blockManagerMaster.removeExecutor("exec-hostA")
+ }
+ whenExecuting {
+ respondToTaskSet(reduceStage, List(
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
+ ))
+ }
+ val recomputeTwoCached = interceptStage(shuffleTwoRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ expectTaskSetLocations(recomputeTwoCached, Seq(Seq("hostD")))
+ intercept[FetchFailedException]{
+ mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0)
+ }
+
+ // Simulate the shuffle input data failing to be cached.
+ cacheLocations.remove(shuffleTwoRdd.id -> 0)
+ respondToTaskSet(recomputeTwoCached, List(
+ (FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null)
+ ))
+
+ // After the fetch failure, DAGScheduler should recheck the cache and decide to resubmit
+ // everything.
+ val recomputeOne = interceptStage(shuffleOneRdd) {
+ scheduler.resubmitFailedStages()
+ }
+ // We use hostA here to make sure DAGScheduler doesn't think it's still dead.
+ val recomputeTwoUncached = interceptStage(shuffleTwoRdd) {
+ respondToTaskSet(recomputeOne, List( (Success, makeMapStatus("hostA", 1)) ))
+ }
+ expectTaskSetLocations(recomputeTwoUncached, Seq(Seq[String]()))
+ val finalRetry = interceptStage(finalRdd) {
+ respondToTaskSet(recomputeTwoUncached, List( (Success, makeMapStatus("hostA", 1)) ))
+
+ }
+ respondToTaskSet(finalRetry, List( (Success, 42) ))
+ expectJobResult(Array(42))
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
index ba6f8b588f..a5db7103f5 100644
--- a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
@@ -6,19 +6,9 @@ import spark.TaskContext
import spark.RDD
import spark.SparkContext
import spark.Split
+import spark.LocalSparkContext
-class TaskContextSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
test("Calls executeOnCompleteCallbacks after failure") {
var completed = false
diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
index 8f86e3170e..2d177bbf67 100644
--- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
@@ -69,33 +69,41 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("StorageLevel object caching") {
- val level1 = new StorageLevel(false, false, false, 3)
- val level2 = new StorageLevel(false, false, false, 3)
+ val level1 = StorageLevel(false, false, false, 3)
+ val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1
+ val level3 = StorageLevel(false, false, false, 2) // this should return a different object
+ assert(level2 === level1, "level2 is not same as level1")
+ assert(level2.eq(level1), "level2 is not the same object as level1")
+ assert(level3 != level1, "level3 is same as level1")
val bytes1 = spark.Utils.serialize(level1)
val level1_ = spark.Utils.deserialize[StorageLevel](bytes1)
val bytes2 = spark.Utils.serialize(level2)
val level2_ = spark.Utils.deserialize[StorageLevel](bytes2)
assert(level1_ === level1, "Deserialized level1 not same as original level1")
- assert(level2_ === level2, "Deserialized level2 not same as original level1")
- assert(level1_ === level2_, "Deserialized level1 not same as deserialized level2")
- assert(level2_.eq(level1_), "Deserialized level2 not the same object as deserialized level1")
+ assert(level1_.eq(level1), "Deserialized level1 not the same object as original level2")
+ assert(level2_ === level2, "Deserialized level2 not same as original level2")
+ assert(level2_.eq(level1), "Deserialized level2 not the same object as original level1")
}
test("BlockManagerId object caching") {
- val id1 = new StorageLevel(false, false, false, 3)
- val id2 = new StorageLevel(false, false, false, 3)
+ val id1 = BlockManagerId("e1", "XXX", 1)
+ val id2 = BlockManagerId("e1", "XXX", 1) // this should return the same object as id1
+ val id3 = BlockManagerId("e1", "XXX", 2) // this should return a different object
+ assert(id2 === id1, "id2 is not same as id1")
+ assert(id2.eq(id1), "id2 is not the same object as id1")
+ assert(id3 != id1, "id3 is same as id1")
val bytes1 = spark.Utils.serialize(id1)
- val id1_ = spark.Utils.deserialize[StorageLevel](bytes1)
+ val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1)
val bytes2 = spark.Utils.serialize(id2)
- val id2_ = spark.Utils.deserialize[StorageLevel](bytes2)
- assert(id1_ === id1, "Deserialized id1 not same as original id1")
- assert(id2_ === id2, "Deserialized id2 not same as original id1")
- assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2")
- assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1")
+ val id2_ = spark.Utils.deserialize[BlockManagerId](bytes2)
+ assert(id1_ === id1, "Deserialized id1 is not same as original id1")
+ assert(id1_.eq(id1), "Deserialized id1 is not the same object as original id1")
+ assert(id2_ === id2, "Deserialized id2 is not same as original id2")
+ assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1")
}
test("master + 1 manager interaction") {
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -125,8 +133,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("master + 2 managers interaction") {
- store = new BlockManager(actorSystem, master, serializer, 2000)
- store2 = new BlockManager(actorSystem, master, new KryoSerializer, 2000)
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000)
+ store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer, 2000)
val peers = master.getPeers(store.blockManagerId, 1)
assert(peers.size === 1, "master did not return the other manager as a peer")
@@ -141,7 +149,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("removing block") {
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -190,7 +198,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("reregistration on heart beat") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
@@ -198,7 +206,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
assert(store.getSingle("a1") != None, "a1 was not in store")
assert(master.getLocations("a1").size > 0, "master was not told about a1")
- master.notifyADeadHost(store.blockManagerId.ip)
+ master.removeExecutor(store.blockManagerId.executorId)
assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
store invokePrivate heartBeat()
@@ -206,25 +214,63 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("reregistration on block update") {
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
-
assert(master.getLocations("a1").size > 0, "master was not told about a1")
- master.notifyADeadHost(store.blockManagerId.ip)
+ master.removeExecutor(store.blockManagerId.executorId)
assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY)
+ store.waitForAsyncReregister()
assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master")
assert(master.getLocations("a2").size > 0, "master was not told about a2")
}
+ test("reregistration doesn't dead lock") {
+ val heartBeat = PrivateMethod[Unit]('heartBeat)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
+ val a1 = new Array[Byte](400)
+ val a2 = List(new Array[Byte](400))
+
+ // try many times to trigger any deadlocks
+ for (i <- 1 to 100) {
+ master.removeExecutor(store.blockManagerId.executorId)
+ val t1 = new Thread {
+ override def run() {
+ store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, true)
+ }
+ }
+ val t2 = new Thread {
+ override def run() {
+ store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
+ }
+ }
+ val t3 = new Thread {
+ override def run() {
+ store invokePrivate heartBeat()
+ }
+ }
+
+ t1.start()
+ t2.start()
+ t3.start()
+ t1.join()
+ t2.join()
+ t3.join()
+
+ store.dropFromMemory("a1", null)
+ store.dropFromMemory("a2", null)
+ store.waitForAsyncReregister()
+ }
+ }
+
test("in-memory LRU storage") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -243,7 +289,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU storage with serialization") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -262,14 +308,14 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of same RDD") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY)
- // Even though we accessed rdd_0_3 last, it should not have replaced partitiosn 1 and 2
+ // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2
// from the same RDD
assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store")
@@ -281,7 +327,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of multiple RDDs") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
@@ -304,7 +350,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("on-disk storage") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -317,7 +363,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -332,7 +378,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with getLocalBytes") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -347,7 +393,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -362,7 +408,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization and getLocalBytes") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -377,7 +423,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -402,7 +448,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU with streams") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@@ -426,7 +472,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels and streams") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@@ -472,7 +518,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("overly large block") {
- store = new BlockManager(actorSystem, master, serializer, 500)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 500)
store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.getSingle("a1") === None, "a1 was in store")
store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK)
@@ -483,49 +529,49 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("block compression") {
try {
System.setProperty("spark.shuffle.compress", "true")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000)
store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.shuffle.compress", "false")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec2", actorSystem, master, serializer, 2000)
store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed")
store.stop()
store = null
System.setProperty("spark.broadcast.compress", "true")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec3", actorSystem, master, serializer, 2000)
store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.broadcast.compress", "false")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec4", actorSystem, master, serializer, 2000)
store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed")
store.stop()
store = null
System.setProperty("spark.rdd.compress", "true")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec5", actorSystem, master, serializer, 2000)
store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.rdd.compress", "false")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec6", actorSystem, master, serializer, 2000)
store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed")
store.stop()
store = null
// Check that any other block types are also kept uncompressed
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec7", actorSystem, master, serializer, 2000)
store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed")
store.stop()
diff --git a/docs/configuration.md b/docs/configuration.md
index 87cb4a6797..a7054b4321 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -198,25 +198,41 @@ Apart from these, the following properties are also available, and may be useful
</td>
</tr>
<tr>
+ <td>spark.akka.frameSize</td>
+ <td>10</td>
+ <td>
+ Maximum message size to allow in "control plane" communication (for serialized tasks and task
+ results), in MB. Increase this if your tasks need to send back large results to the driver
+ (e.g. using <code>collect()</code> on a large dataset).
+ </td>
+</tr>
+<tr>
<td>spark.akka.threads</td>
<td>4</td>
<td>
Number of actor threads to use for communication. Can be useful to increase on large clusters
- when the master has a lot of CPU cores.
+ when the driver has a lot of CPU cores.
+ </td>
+</tr>
+<tr>
+ <td>spark.akka.timeout</td>
+ <td>20</td>
+ <td>
+ Communication timeout between Spark nodes.
</td>
</tr>
<tr>
- <td>spark.master.host</td>
+ <td>spark.driver.host</td>
<td>(local hostname)</td>
<td>
- Hostname or IP address for the master to listen on.
+ Hostname or IP address for the driver to listen on.
</td>
</tr>
<tr>
- <td>spark.master.port</td>
+ <td>spark.driver.port</td>
<td>(random)</td>
<td>
- Port for the master to listen on.
+ Port for the driver to listen on.
</td>
</tr>
<tr>
diff --git a/docs/java-programming-guide.md b/docs/java-programming-guide.md
index 188ca4995e..37a906ea1c 100644
--- a/docs/java-programming-guide.md
+++ b/docs/java-programming-guide.md
@@ -75,7 +75,8 @@ class has a single abstract method, `call()`, that must be implemented.
## Storage Levels
RDD [storage level](scala-programming-guide.html#rdd-persistence) constants, such as `MEMORY_AND_DISK`, are
-declared in the [spark.api.java.StorageLevels](api/core/index.html#spark.api.java.StorageLevels) class.
+declared in the [spark.api.java.StorageLevels](api/core/index.html#spark.api.java.StorageLevels) class. To
+define your own storage level, you can use StorageLevels.create(...).
# Other Features
diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md
index a840b9b34b..4e84d23edf 100644
--- a/docs/python-programming-guide.md
+++ b/docs/python-programming-guide.md
@@ -67,13 +67,20 @@ The script automatically adds the `pyspark` package to the `PYTHONPATH`.
# Interactive Use
-The `pyspark` script launches a Python interpreter that is configured to run PySpark jobs.
-When run without any input files, `pyspark` launches a shell that can be used explore data interactively, which is a simple way to learn the API:
+The `pyspark` script launches a Python interpreter that is configured to run PySpark jobs. To use `pyspark` interactively, first build Spark, then launch it directly from the command line without any options:
+
+{% highlight bash %}
+$ sbt/sbt package
+$ ./pyspark
+{% endhighlight %}
+
+The Python shell can be used explore data interactively and is a simple way to learn the API:
{% highlight python %}
>>> words = sc.textFile("/usr/share/dict/words")
>>> words.filter(lambda w: w.startswith("spar")).take(5)
[u'spar', u'sparable', u'sparada', u'sparadrap', u'sparagrass']
+>>> help(pyspark) # Show all pyspark functions
{% endhighlight %}
By default, the `pyspark` shell creates SparkContext that runs jobs locally.
diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md
index 7350eca837..301b330a79 100644
--- a/docs/scala-programming-guide.md
+++ b/docs/scala-programming-guide.md
@@ -301,7 +301,8 @@ We recommend going through the following process to select one:
* Use the replicated storage levels if you want fast fault recovery (e.g. if using Spark to serve requests from a web
application). *All* the storage levels provide full fault tolerance by recomputing lost data, but the replicated ones
let you continue running tasks on the RDD without waiting to recompute a lost partition.
-
+
+If you want to define your own storage level (say, with replication factor of 3 instead of 2), then use the function factor method `apply()` of the [`StorageLevel`](api/core/index.html#spark.storage.StorageLevel$) singleton object.
# Shared Variables
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index e0ba7c35cb..bf296221b8 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -51,11 +51,11 @@ Finally, the following configuration options can be passed to the master and wor
</tr>
<tr>
<td><code>-c CORES</code>, <code>--cores CORES</code></td>
- <td>Number of CPU cores to use (default: all available); only on worker</td>
+ <td>Total CPU cores to allow Spark jobs to use on the machine (default: all available); only on worker</td>
</tr>
<tr>
<td><code>-m MEM</code>, <code>--memory MEM</code></td>
- <td>Amount of memory to use, in a format like 1000M or 2G (default: your machine's total RAM minus 1 GB); only on worker</td>
+ <td>Total amount of memory to allow Spark jobs to use on the machine, in a format like 1000M or 2G (default: your machine's total RAM minus 1 GB); only on worker</td>
</tr>
<tr>
<td><code>-d DIR</code>, <code>--work-dir DIR</code></td>
@@ -66,9 +66,20 @@ Finally, the following configuration options can be passed to the master and wor
# Cluster Launch Scripts
-To launch a Spark standalone cluster with the deploy scripts, you need to set up two files, `conf/spark-env.sh` and `conf/slaves`. The `conf/spark-env.sh` file lets you specify global settings for the master and slave instances, such as memory, or port numbers to bind to, while `conf/slaves` is a list of slave nodes. The system requires that all the slave machines have the same configuration files, so *copy these files to each machine*.
+To launch a Spark standalone cluster with the deploy scripts, you need to create a file called `conf/slaves` in your Spark directory, which should contain the hostnames of all the machines where you would like to start Spark workers, one per line. The master machine must be able to access each of the slave machines via password-less `ssh` (using a private key). For testing, you can just put `localhost` in this file.
-In `conf/spark-env.sh`, you can set the following parameters, in addition to the [standard Spark configuration settings](configuration.html):
+Once you've set up this fine, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`:
+
+- `bin/start-master.sh` - Starts a master instance on the machine the script is executed on.
+- `bin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file.
+- `bin/start-all.sh` - Starts both a master and a number of slaves as described above.
+- `bin/stop-master.sh` - Stops the master that was started via the `bin/start-master.sh` script.
+- `bin/stop-slaves.sh` - Stops the slave instances that were started via `bin/start-slaves.sh`.
+- `bin/stop-all.sh` - Stops both the master and the slaves as described above.
+
+Note that these scripts must be executed on the machine you want to run the Spark master on, not your local machine.
+
+You can optionally configure the cluster further by setting environment variables in `conf/spark-env.sh`. Create this file by starting with the `conf/spark-env.sh.template`, and _copy it to all your worker machines_ for the settings to take effect. The following settings are available:
<table class="table">
<tr><th style="width:21%">Environment Variable</th><th>Meaning</th></tr>
@@ -89,35 +100,23 @@ In `conf/spark-env.sh`, you can set the following parameters, in addition to the
<td>Start the Spark worker on a specific port (default: random)</td>
</tr>
<tr>
+ <td><code>SPARK_WORKER_DIR</code></td>
+ <td>Directory to run jobs in, which will include both logs and scratch space (default: SPARK_HOME/work)</td>
+ </tr>
+ <tr>
<td><code>SPARK_WORKER_CORES</code></td>
- <td>Number of cores to use (default: all available cores)</td>
+ <td>Total number of cores to allow Spark jobs to use on the machine (default: all available cores)</td>
</tr>
<tr>
<td><code>SPARK_WORKER_MEMORY</code></td>
- <td>How much memory to use, e.g. 1000M, 2G (default: total memory minus 1 GB)</td>
+ <td>Total amount of memory to allow Spark jobs to use on the machine, e.g. 1000M, 2G (default: total memory minus 1 GB); note that each job's <i>individual</i> memory is configured using <code>SPARK_MEM</code></td>
</tr>
<tr>
<td><code>SPARK_WORKER_WEBUI_PORT</code></td>
<td>Port for the worker web UI (default: 8081)</td>
</tr>
- <tr>
- <td><code>SPARK_WORKER_DIR</code></td>
- <td>Directory to run jobs in, which will include both logs and scratch space (default: SPARK_HOME/work)</td>
- </tr>
</table>
-In `conf/slaves`, include a list of all machines where you would like to start a Spark worker, one per line. The master machine must be able to access each of the slave machines via password-less `ssh` (using a private key). For testing purposes, you can have a single `localhost` entry in the slaves file.
-
-Once you've set up these configuration files, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`:
-
-- `bin/start-master.sh` - Starts a master instance on the machine the script is executed on.
-- `bin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file.
-- `bin/start-all.sh` - Starts both a master and a number of slaves as described above.
-- `bin/stop-master.sh` - Stops the master that was started via the `bin/start-master.sh` script.
-- `bin/stop-slaves.sh` - Stops the slave instances that were started via `bin/start-slaves.sh`.
-- `bin/stop-all.sh` - Stops both the master and the slaves as described above.
-
-Note that the scripts must be executed on the machine you want to run the Spark master on, not your local machine.
# Connecting a Job to the Cluster
diff --git a/examples/pom.xml b/examples/pom.xml
index 3355deb6b7..f43af670c6 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -19,6 +19,11 @@
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.twitter4j</groupId>
+ <artifactId>twitter4j-stream</artifactId>
+ <version>3.0.3</version>
+ </dependency>
<dependency>
<groupId>org.scalatest</groupId>
@@ -45,11 +50,6 @@
<profiles>
<profile>
<id>hadoop1</id>
- <activation>
- <property>
- <name>!hadoopVersion</name>
- </property>
- </activation>
<dependencies>
<dependency>
<groupId>org.spark-project</groupId>
@@ -58,6 +58,12 @@
<classifier>hadoop1</classifier>
</dependency>
<dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-streaming</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop1</classifier>
+ </dependency>
+ <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
@@ -77,12 +83,6 @@
</profile>
<profile>
<id>hadoop2</id>
- <activation>
- <property>
- <name>hadoopVersion</name>
- <value>2</value>
- </property>
- </activation>
<dependencies>
<dependency>
<groupId>org.spark-project</groupId>
@@ -91,6 +91,12 @@
<classifier>hadoop2</classifier>
</dependency>
<dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-streaming</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2</classifier>
+ </dependency>
+ <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
diff --git a/pom.xml b/pom.xml
index 751189a9d8..7e06cae052 100644
--- a/pom.xml
+++ b/pom.xml
@@ -41,6 +41,7 @@
<module>core</module>
<module>bagel</module>
<module>examples</module>
+ <module>streaming</module>
<module>repl</module>
<module>repl-bin</module>
</modules>
@@ -104,6 +105,17 @@
<enabled>false</enabled>
</snapshots>
</repository>
+ <repository>
+ <id>twitter4j-repo</id>
+ <name>Twitter4J Repository</name>
+ <url>http://twitter4j.org/maven2/</url>
+ <releases>
+ <enabled>true</enabled>
+ </releases>
+ <snapshots>
+ <enabled>false</enabled>
+ </snapshots>
+ </repository>
</repositories>
<pluginRepositories>
<pluginRepository>
@@ -262,6 +274,12 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.easymock</groupId>
+ <artifactId>easymock</artifactId>
+ <version>3.1</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.version}</artifactId>
<version>1.9</version>
@@ -487,11 +505,6 @@
<profiles>
<profile>
<id>hadoop1</id>
- <activation>
- <property>
- <name>!hadoopVersion</name>
- </property>
- </activation>
<properties>
<hadoop.major.version>1</hadoop.major.version>
@@ -509,12 +522,6 @@
<profile>
<id>hadoop2</id>
- <activation>
- <property>
- <name>hadoopVersion</name>
- <value>2</value>
- </property>
- </activation>
<properties>
<hadoop.major.version>2</hadoop.major.version>
</properties>
@@ -530,6 +537,17 @@
<artifactId>hadoop-client</artifactId>
<version>2.0.0-mr1-cdh${cdh.version}</version>
</dependency>
+ <!-- Specify Avro version because Kafka also has it as a dependency -->
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro</artifactId>
+ <version>1.7.1.cloudera.2</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro-ipc</artifactId>
+ <version>1.7.1.cloudera.2</version>
+ </dependency>
</dependencies>
</dependencyManagement>
</profile>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 3dbb993f9c..af8b5ba017 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -21,7 +21,7 @@ object SparkBuild extends Build {
lazy val core = Project("core", file("core"), settings = coreSettings)
- lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core)
+ lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core) dependsOn (streaming)
lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core) dependsOn (streaming)
@@ -93,7 +93,7 @@ object SparkBuild extends Build {
"org.scalatest" %% "scalatest" % "1.8" % "test",
"org.scalacheck" %% "scalacheck" % "1.9" % "test",
"com.novocode" % "junit-interface" % "0.8" % "test",
- "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile"
+ "org.easymock" % "easymock" % "3.1" % "test"
),
parallelExecution := false,
/* Workaround for issue #206 (fixed after SBT 0.11.0) */
@@ -136,8 +136,6 @@ object SparkBuild extends Build {
"com.typesafe.akka" % "akka-slf4j" % "2.0.3",
"it.unimi.dsi" % "fastutil" % "6.4.4",
"colt" % "colt" % "1.2.0",
- "org.twitter4j" % "twitter4j-core" % "3.0.2",
- "org.twitter4j" % "twitter4j-stream" % "3.0.2",
"cc.spray" % "spray-can" % "1.0-M2.1",
"cc.spray" % "spray-server" % "1.0-M2.1",
"cc.spray" %% "spray-json" % "1.1.1",
@@ -156,7 +154,10 @@ object SparkBuild extends Build {
)
def examplesSettings = sharedSettings ++ Seq(
- name := "spark-examples"
+ name := "spark-examples",
+ libraryDependencies ++= Seq(
+ "org.twitter4j" % "twitter4j-stream" % "3.0.3"
+ )
)
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")
@@ -164,7 +165,9 @@ object SparkBuild extends Build {
def streamingSettings = sharedSettings ++ Seq(
name := "spark-streaming",
libraryDependencies ++= Seq(
- "com.github.sgroschupf" % "zkclient" % "0.1")
+ "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile",
+ "com.github.sgroschupf" % "zkclient" % "0.1"
+ )
) ++ assemblySettings ++ extraAssemblySettings
def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq(
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 91ac984ba2..45102cd9fe 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -16,4 +16,4 @@ target: docs/
private: no
exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
- pyspark.java_gateway pyspark.examples pyspark.shell
+ pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 00666bc0a3..3e8bca62f0 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -11,6 +11,8 @@ Public classes:
A broadcast variable that gets reused across tasks.
- L{Accumulator<pyspark.accumulators.Accumulator>}
An "add-only" shared variable that tasks can only add values to.
+ - L{SparkFiles<pyspark.files.SparkFiles>}
+ Access files shipped with jobs.
"""
import sys
import os
@@ -19,6 +21,7 @@ sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.eg
from pyspark.context import SparkContext
from pyspark.rdd import RDD
+from pyspark.files import SparkFiles
-__all__ = ["SparkContext", "RDD"]
+__all__ = ["SparkContext", "RDD", "SparkFiles"]
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index c00c3a37af..3e9d7d36da 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -11,6 +11,12 @@
>>> a.value
7
+>>> sc.accumulator(1.0).value
+1.0
+
+>>> sc.accumulator(1j).value
+1j
+
>>> rdd = sc.parallelize([1,2,3])
>>> def f(x):
... global a
@@ -19,7 +25,8 @@
>>> a.value
13
->>> class VectorAccumulatorParam(object):
+>>> from pyspark.accumulators import AccumulatorParam
+>>> class VectorAccumulatorParam(AccumulatorParam):
... def zero(self, value):
... return [0.0] * len(value)
... def addInPlace(self, val1, val2):
@@ -84,8 +91,7 @@ class Accumulator(object):
While C{SparkContext} supports accumulators for primitive data types like C{int} and
C{float}, users can also define accumulators for custom types by providing a custom
- C{AccumulatorParam} object with a C{zero} and C{addInPlace} method. Refer to the doctest
- of this module for an example.
+ L{AccumulatorParam} object. Refer to the doctest of this module for an example.
"""
def __init__(self, aid, value, accum_param):
@@ -124,8 +130,31 @@ class Accumulator(object):
def __str__(self):
return str(self._value)
+ def __repr__(self):
+ return "Accumulator<id=%i, value=%s>" % (self.aid, self._value)
-class AddingAccumulatorParam(object):
+
+class AccumulatorParam(object):
+ """
+ Helper object that defines how to accumulate values of a given type.
+ """
+
+ def zero(self, value):
+ """
+ Provide a "zero value" for the type, compatible in dimensions with the
+ provided C{value} (e.g., a zero vector)
+ """
+ raise NotImplementedError
+
+ def addInPlace(self, value1, value2):
+ """
+ Add two values of the accumulator's data type, returning a new value;
+ for efficiency, can also update C{value1} in place and return it.
+ """
+ raise NotImplementedError
+
+
+class AddingAccumulatorParam(AccumulatorParam):
"""
An AccumulatorParam that uses the + operators to add values. Designed for simple types
such as integers, floats, and lists. Requires the zero value for the underlying type
@@ -145,7 +174,7 @@ class AddingAccumulatorParam(object):
# Singleton accumulator params for some standard types
INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0)
-DOUBLE_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
+FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
@@ -167,12 +196,3 @@ def _start_update_server():
thread.daemon = True
thread.start()
return server
-
-
-def _test():
- import doctest
- doctest.testmod()
-
-
-if __name__ == "__main__":
- _test()
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index 93876fa738..def810dd46 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -37,12 +37,3 @@ class Broadcast(object):
def __reduce__(self):
self._pickle_registry.add(self)
return (_from_id, (self.bid, ))
-
-
-def _test():
- import doctest
- doctest.testmod()
-
-
-if __name__ == "__main__":
- _test()
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 1e2f845f9c..657fe6f989 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -1,10 +1,13 @@
import os
-import atexit
+import shutil
+import sys
+from threading import Lock
from tempfile import NamedTemporaryFile
from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
+from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import dump_pickle, write_with_length, batched
from pyspark.rdd import RDD
@@ -19,12 +22,13 @@ class SparkContext(object):
broadcast variables on that cluster.
"""
- gateway = launch_gateway()
- jvm = gateway.jvm
- _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
- _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
- _takePartition = jvm.PythonRDD.takePartition
+ _gateway = None
+ _jvm = None
+ _writeIteratorToPickleFile = None
+ _takePartition = None
_next_accum_id = 0
+ _active_spark_context = None
+ _lock = Lock()
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
@@ -44,6 +48,18 @@ class SparkContext(object):
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
"""
+ with SparkContext._lock:
+ if SparkContext._active_spark_context:
+ raise ValueError("Cannot run multiple SparkContexts at once")
+ else:
+ SparkContext._active_spark_context = self
+ if not SparkContext._gateway:
+ SparkContext._gateway = launch_gateway()
+ SparkContext._jvm = SparkContext._gateway.jvm
+ SparkContext._writeIteratorToPickleFile = \
+ SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
+ SparkContext._takePartition = \
+ SparkContext._jvm.PythonRDD.takePartition
self.master = master
self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J
@@ -51,8 +67,8 @@ class SparkContext(object):
self.batchSize = batchSize # -1 represents a unlimited batch size
# Create the Java SparkContext through Py4J
- empty_string_array = self.gateway.new_array(self.jvm.String, 0)
- self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
+ empty_string_array = self._gateway.new_array(self._jvm.String, 0)
+ self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome,
empty_string_array)
# Create a single Accumulator in Java that we'll send all our updates through;
@@ -60,8 +76,8 @@ class SparkContext(object):
self._accumulatorServer = accumulators._start_update_server()
(host, port) = self._accumulatorServer.server_address
self._javaAccumulator = self._jsc.accumulator(
- self.jvm.java.util.ArrayList(),
- self.jvm.PythonAccumulatorParam(host, port))
+ self._jvm.java.util.ArrayList(),
+ self._jvm.PythonAccumulatorParam(host, port))
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
# Broadcast's __reduce__ method stores Broadcast instances here.
@@ -73,6 +89,13 @@ class SparkContext(object):
# Deploy any code dependencies specified in the constructor
for path in (pyFiles or []):
self.addPyFile(path)
+ SparkFiles._sc = self
+ sys.path.append(SparkFiles.getRootDirectory())
+
+ # Create a temporary directory inside spark.local.dir:
+ local_dir = self._jvm.spark.Utils.getLocalDir()
+ self._temp_dir = \
+ self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath()
@property
def defaultParallelism(self):
@@ -83,17 +106,20 @@ class SparkContext(object):
return self._jsc.sc().defaultParallelism()
def __del__(self):
- if self._jsc:
- self._jsc.stop()
- if self._accumulatorServer:
- self._accumulatorServer.shutdown()
+ self.stop()
def stop(self):
"""
Shut down the SparkContext.
"""
- self._jsc.stop()
- self._jsc = None
+ if self._jsc:
+ self._jsc.stop()
+ self._jsc = None
+ if self._accumulatorServer:
+ self._accumulatorServer.shutdown()
+ self._accumulatorServer = None
+ with SparkContext._lock:
+ SparkContext._active_spark_context = None
def parallelize(self, c, numSlices=None):
"""
@@ -103,14 +129,14 @@ class SparkContext(object):
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
- tempFile = NamedTemporaryFile(delete=False)
- atexit.register(lambda: os.unlink(tempFile.name))
+ tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
if self.batchSize != 1:
c = batched(c, self.batchSize)
for x in c:
write_with_length(dump_pickle(x), tempFile)
tempFile.close()
- jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
+ readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
+ jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self)
def textFile(self, name, minSplits=None):
@@ -123,6 +149,10 @@ class SparkContext(object):
jrdd = self._jsc.textFile(name, minSplits)
return RDD(jrdd, self)
+ def _checkpointFile(self, name):
+ jrdd = self._jsc.checkpointFile(name)
+ return RDD(jrdd, self)
+
def union(self, rdds):
"""
Build the union of a list of RDDs.
@@ -144,16 +174,11 @@ class SparkContext(object):
def accumulator(self, value, accum_param=None):
"""
- Create an C{Accumulator} with the given initial value, using a given
- AccumulatorParam helper object to define how to add values of the data
- type if provided. Default AccumulatorParams are used for integers and
- floating-point numbers if you do not provide one. For other types, the
- AccumulatorParam must implement two methods:
- - C{zero(value)}: provide a "zero value" for the type, compatible in
- dimensions with the provided C{value} (e.g., a zero vector).
- - C{addInPlace(val1, val2)}: add two values of the accumulator's data
- type, returning a new value; for efficiency, can also update C{val1}
- in place and return it.
+ Create an L{Accumulator} with the given initial value, using a given
+ L{AccumulatorParam} helper object to define how to add values of the
+ data type if provided. Default AccumulatorParams are used for integers
+ and floating-point numbers if you do not provide one. For other types,
+ a custom AccumulatorParam can be used.
"""
if accum_param == None:
if isinstance(value, int):
@@ -169,10 +194,26 @@ class SparkContext(object):
def addFile(self, path):
"""
- Add a file to be downloaded into the working directory of this Spark
- job on every node. The C{path} passed can be either a local file,
- a file in HDFS (or other Hadoop-supported filesystems), or an HTTP,
- HTTPS or FTP URI.
+ Add a file to be downloaded with this Spark job on every node.
+ The C{path} passed can be either a local file, a file in HDFS
+ (or other Hadoop-supported filesystems), or an HTTP, HTTPS or
+ FTP URI.
+
+ To access the file in Spark jobs, use
+ L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its
+ download location.
+
+ >>> from pyspark import SparkFiles
+ >>> path = os.path.join(tempdir, "test.txt")
+ >>> with open(path, "w") as testFile:
+ ... testFile.write("100")
+ >>> sc.addFile(path)
+ >>> def func(iterator):
+ ... with open(SparkFiles.get("test.txt")) as testFile:
+ ... fileVal = int(testFile.readline())
+ ... return [x * 100 for x in iterator]
+ >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect()
+ [100, 200, 300, 400]
"""
self._jsc.sc().addFile(path)
@@ -193,5 +234,33 @@ class SparkContext(object):
"""
self.addFile(path)
filename = path.split("/")[-1]
- os.environ["PYTHONPATH"] = \
- "%s:%s" % (filename, os.environ["PYTHONPATH"])
+
+ def setCheckpointDir(self, dirName, useExisting=False):
+ """
+ Set the directory under which RDDs are going to be checkpointed. The
+ directory must be a HDFS path if running on a cluster.
+
+ If the directory does not exist, it will be created. If the directory
+ exists and C{useExisting} is set to true, then the exisiting directory
+ will be used. Otherwise an exception will be thrown to prevent
+ accidental overriding of checkpoint files in the existing directory.
+ """
+ self._jsc.sc().setCheckpointDir(dirName, useExisting)
+
+
+def _test():
+ import atexit
+ import doctest
+ import tempfile
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ globs['tempdir'] = tempfile.mkdtemp()
+ atexit.register(lambda: shutil.rmtree(globs['tempdir']))
+ (failure_count, test_count) = doctest.testmod(globs=globs)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/files.py b/python/pyspark/files.py
new file mode 100644
index 0000000000..001b7a28b6
--- /dev/null
+++ b/python/pyspark/files.py
@@ -0,0 +1,38 @@
+import os
+
+
+class SparkFiles(object):
+ """
+ Resolves paths to files added through
+ L{SparkContext.addFile()<pyspark.context.SparkContext.addFile>}.
+
+ SparkFiles contains only classmethods; users should not create SparkFiles
+ instances.
+ """
+
+ _root_directory = None
+ _is_running_on_worker = False
+ _sc = None
+
+ def __init__(self):
+ raise NotImplementedError("Do not construct SparkFiles objects")
+
+ @classmethod
+ def get(cls, filename):
+ """
+ Get the absolute path of a file added through C{SparkContext.addFile()}.
+ """
+ path = os.path.join(SparkFiles.getRootDirectory(), filename)
+ return os.path.abspath(path)
+
+ @classmethod
+ def getRootDirectory(cls):
+ """
+ Get the root directory that contains files added through
+ C{SparkContext.addFile()}.
+ """
+ if cls._is_running_on_worker:
+ return cls._root_directory
+ else:
+ # This will have to change if we support multiple SparkContexts:
+ return cls._sc._jvm.spark.SparkFiles.getRootDirectory()
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d705f0f9e1..4cda6cf661 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1,4 +1,3 @@
-import atexit
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
@@ -32,7 +31,9 @@ class RDD(object):
def __init__(self, jrdd, ctx):
self._jrdd = jrdd
self.is_cached = False
+ self.is_checkpointed = False
self.ctx = ctx
+ self._partitionFunc = None
@property
def context(self):
@@ -49,6 +50,34 @@ class RDD(object):
self._jrdd.cache()
return self
+ def checkpoint(self):
+ """
+ Mark this RDD for checkpointing. It will be saved to a file inside the
+ checkpoint directory set with L{SparkContext.setCheckpointDir()} and
+ all references to its parent RDDs will be removed. This function must
+ be called before any job has been executed on this RDD. It is strongly
+ recommended that this RDD is persisted in memory, otherwise saving it
+ on a file will require recomputation.
+ """
+ self.is_checkpointed = True
+ self._jrdd.rdd().checkpoint()
+
+ def isCheckpointed(self):
+ """
+ Return whether this RDD has been checkpointed or not
+ """
+ return self._jrdd.rdd().isCheckpointed()
+
+ def getCheckpointFile(self):
+ """
+ Gets the name of the file to which this RDD was checkpointed
+ """
+ checkpointFile = self._jrdd.rdd().getCheckpointFile()
+ if checkpointFile.isDefined():
+ return checkpointFile.get()
+ else:
+ return None
+
# TODO persist(self, storageLevel)
def map(self, f, preservesPartitioning=False):
@@ -234,12 +263,8 @@ class RDD(object):
# Transferring lots of data through Py4J can be slow because
# socket.readline() is inefficient. Instead, we'll dump the data to a
# file and read it back.
- tempFile = NamedTemporaryFile(delete=False)
+ tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close()
- def clean_up_file():
- try: os.unlink(tempFile.name)
- except: pass
- atexit.register(clean_up_file)
self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
@@ -347,6 +372,10 @@ class RDD(object):
items = []
for partition in range(self._jrdd.splits().size()):
iterator = self.ctx._takePartition(self._jrdd.rdd(), partition)
+ # Each item in the iterator is a string, Python object, batch of
+ # Python objects. Regardless, it is sufficient to take `num`
+ # of these objects in order to collect `num` Python objects:
+ iterator = iterator.take(num)
items.extend(self._collect_iterator_through_file(iterator))
if len(items) >= num:
break
@@ -377,7 +406,7 @@ class RDD(object):
return (str(x).encode("utf-8") for x in iterator)
keyed = PipelinedRDD(self, func)
keyed._bypass_serializer = True
- keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path)
+ keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path)
# Pair functions
@@ -497,7 +526,7 @@ class RDD(object):
return python_right_outer_join(self, other, numSplits)
# TODO: add option to control map-side combining
- def partitionBy(self, numSplits, hashFunc=hash):
+ def partitionBy(self, numSplits, partitionFunc=hash):
"""
Return a copy of the RDD partitioned using the specified partitioner.
@@ -514,17 +543,21 @@ class RDD(object):
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
for (k, v) in iterator:
- buckets[hashFunc(k) % numSplits].append((k, v))
+ buckets[partitionFunc(k) % numSplits].append((k, v))
for (split, items) in buckets.iteritems():
yield str(split)
yield dump_pickle(Batch(items))
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
- pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
- jrdd = pairRDD.partitionBy(partitioner)
- jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
- return RDD(jrdd, self.ctx)
+ pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
+ partitioner = self.ctx._jvm.PythonPartitioner(numSplits,
+ id(partitionFunc))
+ jrdd = pairRDD.partitionBy(partitioner).values()
+ rdd = RDD(jrdd, self.ctx)
+ # This is required so that id(partitionFunc) remains unique, even if
+ # partitionFunc is a lambda:
+ rdd._partitionFunc = partitionFunc
+ return rdd
# TODO: add control over map-side aggregation
def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
@@ -662,7 +695,7 @@ class PipelinedRDD(RDD):
20
"""
def __init__(self, prev, func, preservesPartitioning=False):
- if isinstance(prev, PipelinedRDD) and not prev.is_cached:
+ if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
prev_func = prev.func
def pipeline_func(split, iterator):
return func(split, prev_func(split, iterator))
@@ -675,6 +708,7 @@ class PipelinedRDD(RDD):
self.preservesPartitioning = preservesPartitioning
self._prev_jrdd = prev._jrdd
self.is_cached = False
+ self.is_checkpointed = False
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
@@ -695,18 +729,21 @@ class PipelinedRDD(RDD):
pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
- self.ctx.gateway._gateway_client)
+ self.ctx._gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
class_manifest = self._prev_jrdd.classManifest()
env = copy.copy(self.ctx.environment)
env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "")
- env = MapConverter().convert(env, self.ctx.gateway._gateway_client)
- python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
+ env = MapConverter().convert(env, self.ctx._gateway._gateway_client)
+ python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
broadcast_vars, self.ctx._javaAccumulator, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
+ def _is_pipelinable(self):
+ return not (self.is_cached or self.is_checkpointed)
+
def _test():
import doctest
@@ -715,8 +752,10 @@ def _test():
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
- doctest.testmod(globs=globs)
+ (failure_count, test_count) = doctest.testmod(globs=globs)
globs['sc'].stop()
+ if failure_count:
+ exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index f6328c561f..54ff1bf8e7 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -4,6 +4,7 @@ An interactive shell.
This file is designed to be launched as a PYTHONSTARTUP script.
"""
import os
+import pyspark
from pyspark.context import SparkContext
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
new file mode 100644
index 0000000000..6a1962d267
--- /dev/null
+++ b/python/pyspark/tests.py
@@ -0,0 +1,121 @@
+"""
+Unit tests for PySpark; additional tests are implemented as doctests in
+individual modules.
+"""
+import os
+import shutil
+import sys
+from tempfile import NamedTemporaryFile
+import time
+import unittest
+
+from pyspark.context import SparkContext
+from pyspark.files import SparkFiles
+from pyspark.java_gateway import SPARK_HOME
+
+
+class PySparkTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self._old_sys_path = list(sys.path)
+ class_name = self.__class__.__name__
+ self.sc = SparkContext('local[4]', class_name , batchSize=2)
+
+ def tearDown(self):
+ self.sc.stop()
+ sys.path = self._old_sys_path
+ # To avoid Akka rebinding to the same port, since it doesn't unbind
+ # immediately on shutdown
+ self.sc._jvm.System.clearProperty("spark.driver.port")
+
+
+class TestCheckpoint(PySparkTestCase):
+
+ def setUp(self):
+ PySparkTestCase.setUp(self)
+ self.checkpointDir = NamedTemporaryFile(delete=False)
+ os.unlink(self.checkpointDir.name)
+ self.sc.setCheckpointDir(self.checkpointDir.name)
+
+ def tearDown(self):
+ PySparkTestCase.tearDown(self)
+ shutil.rmtree(self.checkpointDir.name)
+
+ def test_basic_checkpointing(self):
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+ flatMappedRDD.checkpoint()
+ result = flatMappedRDD.collect()
+ time.sleep(1) # 1 second
+ self.assertTrue(flatMappedRDD.isCheckpointed())
+ self.assertEqual(flatMappedRDD.collect(), result)
+ self.assertEqual(self.checkpointDir.name,
+ os.path.dirname(flatMappedRDD.getCheckpointFile()))
+
+ def test_checkpoint_and_restore(self):
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: [x])
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+ flatMappedRDD.checkpoint()
+ flatMappedRDD.count() # forces a checkpoint to be computed
+ time.sleep(1) # 1 second
+
+ self.assertIsNotNone(flatMappedRDD.getCheckpointFile())
+ recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+ self.assertEquals([1, 2, 3, 4], recovered.collect())
+
+
+class TestAddFile(PySparkTestCase):
+
+ def test_add_py_file(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this job fails due to `userlibrary` not being on the Python path:
+ def func(x):
+ from userlibrary import UserClass
+ return UserClass().hello()
+ self.assertRaises(Exception,
+ self.sc.parallelize(range(2)).map(func).first)
+ # Add the file, so the job should now succeed:
+ path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
+ self.sc.addPyFile(path)
+ res = self.sc.parallelize(range(2)).map(func).first()
+ self.assertEqual("Hello World!", res)
+
+ def test_add_file_locally(self):
+ path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ self.sc.addFile(path)
+ download_path = SparkFiles.get("hello.txt")
+ self.assertNotEqual(path, download_path)
+ with open(download_path) as test_file:
+ self.assertEquals("Hello World!\n", test_file.readline())
+
+ def test_add_py_file_locally(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this fails due to `userlibrary` not being on the Python path:
+ def func():
+ from userlibrary import UserClass
+ self.assertRaises(ImportError, func)
+ path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
+ self.sc.addFile(path)
+ from userlibrary import UserClass
+ self.assertEqual("Hello World!", UserClass().hello())
+
+
+class TestIO(PySparkTestCase):
+
+ def test_stdout_redirection(self):
+ import subprocess
+ def func(x):
+ subprocess.check_call('ls', shell=True)
+ self.sc.parallelize([1]).foreach(func)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index b2b9288089..812e7a9da5 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1,20 +1,23 @@
"""
Worker that receives input from Piped RDD.
"""
+import os
import sys
+import traceback
from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
+from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, read_with_length, write_int, \
read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
# Redirect stdout to stderr so that users must return values from functions.
-old_stdout = sys.stdout
-sys.stdout = sys.stderr
+old_stdout = os.fdopen(os.dup(1), 'w')
+os.dup2(2, 1)
def load_obj():
@@ -23,6 +26,10 @@ def load_obj():
def main():
split_index = read_int(sys.stdin)
+ spark_files_dir = load_pickle(read_with_length(sys.stdin))
+ SparkFiles._root_directory = spark_files_dir
+ SparkFiles._is_running_on_worker = True
+ sys.path.append(spark_files_dir)
num_broadcast_variables = read_int(sys.stdin)
for _ in range(num_broadcast_variables):
bid = read_long(sys.stdin)
@@ -35,8 +42,13 @@ def main():
else:
dumps = dump_pickle
iterator = read_from_pickle_file(sys.stdin)
- for obj in func(split_index, iterator):
- write_with_length(dumps(obj), old_stdout)
+ try:
+ for obj in func(split_index, iterator):
+ write_with_length(dumps(obj), old_stdout)
+ except Exception as e:
+ write_int(-2, old_stdout)
+ write_with_length(traceback.format_exc(), old_stdout)
+ sys.exit(-1)
# Mark the beginning of the accumulators section of the output
write_int(-1, old_stdout)
for aid, accum in _accumulatorRegistry.items():
diff --git a/python/run-tests b/python/run-tests
index 32470911f9..a3a9ff5dcb 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -8,12 +8,18 @@ FAILED=0
$FWDIR/pyspark pyspark/rdd.py
FAILED=$(($?||$FAILED))
+$FWDIR/pyspark pyspark/context.py
+FAILED=$(($?||$FAILED))
+
$FWDIR/pyspark -m doctest pyspark/broadcast.py
FAILED=$(($?||$FAILED))
$FWDIR/pyspark -m doctest pyspark/accumulators.py
FAILED=$(($?||$FAILED))
+$FWDIR/pyspark -m unittest pyspark.tests
+FAILED=$(($?||$FAILED))
+
if [[ $FAILED != 0 ]]; then
echo -en "\033[31m" # Red
echo "Had test failures; see logs."
diff --git a/python/test_support/hello.txt b/python/test_support/hello.txt
new file mode 100755
index 0000000000..980a0d5f19
--- /dev/null
+++ b/python/test_support/hello.txt
@@ -0,0 +1 @@
+Hello World!
diff --git a/python/test_support/userlibrary.py b/python/test_support/userlibrary.py
new file mode 100755
index 0000000000..5bb6f5009f
--- /dev/null
+++ b/python/test_support/userlibrary.py
@@ -0,0 +1,7 @@
+"""
+Used to test shipping of code depenencies with SparkContext.addPyFile().
+"""
+
+class UserClass(object):
+ def hello(self):
+ return "Hello World!"
diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml
index da91c0f3ab..0667b71cc7 100644
--- a/repl-bin/pom.xml
+++ b/repl-bin/pom.xml
@@ -70,11 +70,6 @@
<profiles>
<profile>
<id>hadoop1</id>
- <activation>
- <property>
- <name>!hadoopVersion</name>
- </property>
- </activation>
<properties>
<classifier>hadoop1</classifier>
</properties>
@@ -115,12 +110,6 @@
</profile>
<profile>
<id>hadoop2</id>
- <activation>
- <property>
- <name>hadoopVersion</name>
- <value>2</value>
- </property>
- </activation>
<properties>
<classifier>hadoop2</classifier>
</properties>
diff --git a/repl/pom.xml b/repl/pom.xml
index 38e883c7f8..4a296fa630 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -72,11 +72,6 @@
<profiles>
<profile>
<id>hadoop1</id>
- <activation>
- <property>
- <name>!hadoopVersion</name>
- </property>
- </activation>
<properties>
<classifier>hadoop1</classifier>
</properties>
@@ -102,6 +97,13 @@
<scope>runtime</scope>
</dependency>
<dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-streaming</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop1</classifier>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
@@ -121,12 +123,6 @@
</profile>
<profile>
<id>hadoop2</id>
- <activation>
- <property>
- <name>hadoopVersion</name>
- <value>2</value>
- </property>
- </activation>
<properties>
<classifier>hadoop2</classifier>
</properties>
@@ -152,6 +148,13 @@
<scope>runtime</scope>
</dependency>
<dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-streaming</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2</classifier>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
@@ -161,6 +164,16 @@
<artifactId>hadoop-client</artifactId>
<scope>provided</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro-ipc</artifactId>
+ <scope>provided</scope>
+ </dependency>
</dependencies>
<build>
<plugins>
diff --git a/repl/src/test/scala/spark/repl/ReplSuite.scala b/repl/src/test/scala/spark/repl/ReplSuite.scala
index db78d06d4f..43559b96d3 100644
--- a/repl/src/test/scala/spark/repl/ReplSuite.scala
+++ b/repl/src/test/scala/spark/repl/ReplSuite.scala
@@ -31,7 +31,7 @@ class ReplSuite extends FunSuite {
if (interp.sparkContext != null)
interp.sparkContext.stop()
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ System.clearProperty("spark.driver.port")
return out.toString
}
diff --git a/run b/run
index 9015fdbff7..37861f1a92 100755
--- a/run
+++ b/run
@@ -92,9 +92,11 @@ if [ -e "$FWDIR/lib_managed" ]; then
CLASSPATH+=":$FWDIR/lib_managed/bundles/*"
fi
CLASSPATH+=":$REPL_DIR/lib/*"
-for jar in `find "$REPL_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
- CLASSPATH+=":$jar"
-done
+if [ -e repl-bin/target ]; then
+ for jar in `find "repl-bin/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
+ CLASSPATH+=":$jar"
+ done
+fi
CLASSPATH+=":$BAGEL_DIR/target/scala-$SCALA_VERSION/classes"
for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do
CLASSPATH+=":$jar"
diff --git a/sbt/sbt b/sbt/sbt
index a3055c13c1..8f426d18e8 100755
--- a/sbt/sbt
+++ b/sbt/sbt
@@ -5,4 +5,4 @@ if [ "$MESOS_HOME" != "" ]; then
fi
export SPARK_HOME=$(cd "$(dirname $0)/.."; pwd)
export SPARK_TESTING=1 # To put test classes on classpath
-java -Xmx1200M -XX:MaxPermSize=200m $EXTRA_ARGS -jar $SPARK_HOME/sbt/sbt-launch-*.jar "$@"
+java -Xmx1200M -XX:MaxPermSize=250m $EXTRA_ARGS -jar $SPARK_HOME/sbt/sbt-launch-*.jar "$@"
diff --git a/streaming/lib/kafka-0.7.2.jar b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar
index 65f79925a4..65f79925a4 100644
--- a/streaming/lib/kafka-0.7.2.jar
+++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar
Binary files differ
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5
new file mode 100644
index 0000000000..29f45f4adb
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5
@@ -0,0 +1 @@
+18876b8bc2e4cef28b6d191aa49d963f \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1
new file mode 100644
index 0000000000..e3bd62bac0
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1
@@ -0,0 +1 @@
+06b27270ffa52250a2c08703b397c99127b72060 \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom
new file mode 100644
index 0000000000..082d35726a
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom
@@ -0,0 +1,9 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd" xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
+ <modelVersion>4.0.0</modelVersion>
+ <groupId>org.apache.kafka</groupId>
+ <artifactId>kafka</artifactId>
+ <version>0.7.2-spark</version>
+ <description>POM was created from install:install-file</description>
+</project>
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5
new file mode 100644
index 0000000000..92c4132b5b
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5
@@ -0,0 +1 @@
+7bc4322266e6032bdf9ef6eebdd8097d \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1
new file mode 100644
index 0000000000..8a1d8a097a
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1
@@ -0,0 +1 @@
+d0f79e8eff0db43ca7bcf7dce2c8cd2972685c9d \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml
new file mode 100644
index 0000000000..720cd51c2f
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml
@@ -0,0 +1,12 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<metadata>
+ <groupId>org.apache.kafka</groupId>
+ <artifactId>kafka</artifactId>
+ <versioning>
+ <release>0.7.2-spark</release>
+ <versions>
+ <version>0.7.2-spark</version>
+ </versions>
+ <lastUpdated>20130121015225</lastUpdated>
+ </versioning>
+</metadata>
diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5 b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5
new file mode 100644
index 0000000000..a4ce5dc9e8
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5
@@ -0,0 +1 @@
+e2b9c7c5f6370dd1d21a0aae5e8dcd77 \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1 b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1
new file mode 100644
index 0000000000..b869eaf2a6
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1
@@ -0,0 +1 @@
+2a4341da936b6c07a09383d17ffb185ac558ee91 \ No newline at end of file
diff --git a/streaming/pom.xml b/streaming/pom.xml
new file mode 100644
index 0000000000..6ee7e59df3
--- /dev/null
+++ b/streaming/pom.xml
@@ -0,0 +1,144 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.spark-project</groupId>
+ <artifactId>parent</artifactId>
+ <version>0.7.0-SNAPSHOT</version>
+ <relativePath>../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-streaming</artifactId>
+ <packaging>jar</packaging>
+ <name>Spark Project Streaming</name>
+ <url>http://spark-project.org/</url>
+
+ <repositories>
+ <!-- A repository in the local filesystem for the Kafka JAR, which we modified for Scala 2.9 -->
+ <repository>
+ <id>lib</id>
+ <url>file://${project.basedir}/lib</url>
+ </repository>
+ </repositories>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-server</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.codehaus.jackson</groupId>
+ <artifactId>jackson-mapper-asl</artifactId>
+ <version>1.9.11</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.kafka</groupId>
+ <artifactId>kafka</artifactId>
+ <version>0.7.2-spark</version> <!-- Comes from our in-project repository -->
+ </dependency>
+ <dependency>
+ <groupId>org.apache.flume</groupId>
+ <artifactId>flume-ng-sdk</artifactId>
+ <version>1.2.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.github.sgroschupf</groupId>
+ <artifactId>zkclient</artifactId>
+ <version>0.1</version>
+ </dependency>
+
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.scalacheck</groupId>
+ <artifactId>scalacheck_${scala.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.novocode</groupId>
+ <artifactId>junit-interface</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-log4j12</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.version}/test-classes</testOutputDirectory>
+ <plugins>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+
+ <profiles>
+ <profile>
+ <id>hadoop1</id>
+ <dependencies>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-core</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop1</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-core</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <configuration>
+ <classifier>hadoop1</classifier>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
+ <id>hadoop2</id>
+ <dependencies>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-core</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-core</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <configuration>
+ <classifier>hadoop2</classifier>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ </profiles>
+</project>
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index 07ecb018ee..0eb6aad187 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -198,10 +198,10 @@ abstract class DStream[T: ClassManifest] (
metadataCleanerDelay < 0 || rememberDuration.milliseconds < metadataCleanerDelay * 1000,
"It seems you are doing some DStream window operation or setting a checkpoint interval " +
"which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " +
- "than " + rememberDuration.milliseconds + " milliseconds. But the Spark's metadata cleanup" +
- "delay is set to " + (metadataCleanerDelay / 60.0) + " minutes, which is not sufficient. Please set " +
- "the Java property 'spark.cleaner.delay' to more than " +
- math.ceil(rememberDuration.milliseconds.toDouble / 60000.0).toInt + " minutes."
+ "than " + rememberDuration.milliseconds / 1000 + " seconds. But Spark's metadata cleanup" +
+ "delay is set to " + metadataCleanerDelay + " seconds, which is not sufficient. Please " +
+ "set the Java property 'spark.cleaner.delay' to more than " +
+ math.ceil(rememberDuration.milliseconds / 1000.0).toInt + " seconds."
)
dependencies.foreach(_.validate())
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
index db0461b985..8cfbec51d2 100644
--- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -406,7 +406,7 @@ object StreamingContext {
// Set the default cleaner delay to an hour if not already set.
// This should be sufficient for even 1 second interval.
if (MetadataCleaner.getDelaySeconds < 0) {
- MetadataCleaner.setDelaySeconds(60)
+ MetadataCleaner.setDelaySeconds(3600)
}
new SparkContext(master, frameworkName)
}
diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
index 70d6bd2b1b..5bbf2b084f 100644
--- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
@@ -34,6 +34,14 @@ class JavaStreamingContext(val ssc: StreamingContext) {
this(new StreamingContext(master, frameworkName, batchDuration))
/**
+ * Creates a StreamingContext.
+ * @param sparkContext The underlying JavaSparkContext to use
+ * @param batchDuration The time interval at which streaming data will be divided into batches
+ */
+ def this(sparkContext: JavaSparkContext, batchDuration: Duration) =
+ this(new StreamingContext(sparkContext.sc, batchDuration))
+
+ /**
* Re-creates a StreamingContext from a checkpoint file.
* @param path Path either to the directory that was specified as the checkpoint directory, or
* to the checkpoint file 'graph' or 'graph.bk'.
diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
index aa6be95f30..8c322dd698 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
@@ -153,8 +153,8 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
/** A helper actor that communicates with the NetworkInputTracker */
private class NetworkReceiverActor extends Actor {
logInfo("Attempting to register with tracker")
- val ip = System.getProperty("spark.master.host", "localhost")
- val port = System.getProperty("spark.master.port", "7077").toInt
+ val ip = System.getProperty("spark.driver.host", "localhost")
+ val port = System.getProperty("spark.driver.port", "7077").toInt
val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port)
val tracker = env.actorSystem.actorFor(url)
val timeout = 5.seconds
diff --git a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala
index 290fab1ce0..04e6b69b7b 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala
@@ -1,6 +1,6 @@
package spark.streaming.dstream
-import spark.{DaemonThread, Logging}
+import spark.Logging
import spark.storage.StorageLevel
import spark.streaming.StreamingContext
@@ -48,7 +48,8 @@ class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel)
val queue = new ArrayBlockingQueue[ByteBuffer](2)
- blockPushingThread = new DaemonThread {
+ blockPushingThread = new Thread {
+ setDaemon(true)
override def run() {
var nextBlockNumber = 0
while (true) {
diff --git a/streaming/src/test/java/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java
index 7a189d85b4..fbe4af4597 100644
--- a/streaming/src/test/java/JavaAPISuite.java
+++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java
@@ -43,7 +43,7 @@ public class JavaAPISuite implements Serializable {
ssc = null;
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port");
+ System.clearProperty("spark.driver.port");
}
/*
@Test
diff --git a/streaming/src/test/java/JavaTestUtils.scala b/streaming/src/test/java/spark/streaming/JavaTestUtils.scala
index 56349837e5..56349837e5 100644
--- a/streaming/src/test/java/JavaTestUtils.scala
+++ b/streaming/src/test/java/spark/streaming/JavaTestUtils.scala
diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
index d98b840b8e..c031949dd1 100644
--- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
@@ -10,7 +10,7 @@ class BasicOperationsSuite extends TestSuiteBase {
after {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ System.clearProperty("spark.driver.port")
}
test("map") {
diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
index 04ccca4c01..7126af62d9 100644
--- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
@@ -21,7 +21,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
FileUtils.deleteDirectory(new File(checkpointDir))
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ System.clearProperty("spark.driver.port")
}
var ssc: StreamingContext = null
diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala
index 7493ac1207..c4cfffbfc1 100644
--- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala
@@ -24,7 +24,7 @@ class FailureSuite extends TestSuiteBase with BeforeAndAfter {
FileUtils.deleteDirectory(new File(checkpointDir))
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ System.clearProperty("spark.driver.port")
}
override def framework = "CheckpointSuite"
diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
index aa08ea1141..c442210004 100644
--- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
@@ -29,7 +29,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
after {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ System.clearProperty("spark.driver.port")
}
diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
index 0c6e928835..cd9608df53 100644
--- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
@@ -13,7 +13,7 @@ class WindowOperationsSuite extends TestSuiteBase {
after {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ System.clearProperty("spark.driver.port")
}
val largerSlideInput = Seq(