aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorStephen Haberman <stephen@exigencecorp.com>2013-02-02 01:57:18 -0600
committerStephen Haberman <stephen@exigencecorp.com>2013-02-02 01:57:18 -0600
commit103c375ba044b4fb1061298d6375587ed30832a4 (patch)
tree45b1a45260b0d266965caed1f4612f0b2eb49246 /core
parentfdec42385a1a8f10f9dd803525cb3c132a25ba53 (diff)
parentae26911ec0d768dcdae8b7d706ca4544e36535e6 (diff)
downloadspark-103c375ba044b4fb1061298d6375587ed30832a4.tar.gz
spark-103c375ba044b4fb1061298d6375587ed30832a4.tar.bz2
spark-103c375ba044b4fb1061298d6375587ed30832a4.zip
Merge branch 'master' into sparkmem
Diffstat (limited to 'core')
-rw-r--r--core/pom.xml11
-rw-r--r--core/src/main/scala/spark/Accumulators.scala3
-rw-r--r--core/src/main/scala/spark/CacheManager.scala65
-rw-r--r--core/src/main/scala/spark/CacheTracker.scala232
-rw-r--r--core/src/main/scala/spark/DaemonThreadFactory.scala18
-rw-r--r--core/src/main/scala/spark/Dependency.scala10
-rw-r--r--core/src/main/scala/spark/HttpFileServer.scala8
-rw-r--r--core/src/main/scala/spark/HttpServer.scala9
-rw-r--r--core/src/main/scala/spark/KryoSerializer.scala5
-rw-r--r--core/src/main/scala/spark/Logging.scala3
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala72
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala49
-rw-r--r--core/src/main/scala/spark/ParallelCollection.scala24
-rw-r--r--core/src/main/scala/spark/RDD.scala275
-rw-r--r--core/src/main/scala/spark/RDDCheckpointData.scala106
-rw-r--r--core/src/main/scala/spark/SparkContext.scala247
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala57
-rw-r--r--core/src/main/scala/spark/SparkFiles.java25
-rw-r--r--core/src/main/scala/spark/TaskContext.scala3
-rw-r--r--core/src/main/scala/spark/Utils.scala90
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala37
-rw-r--r--core/src/main/scala/spark/api/java/JavaSparkContext.scala39
-rw-r--r--core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java20
-rw-r--r--core/src/main/scala/spark/api/java/StorageLevels.java11
-rw-r--r--core/src/main/scala/spark/api/python/PythonPartitioner.scala13
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala112
-rw-r--r--core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala24
-rw-r--r--core/src/main/scala/spark/broadcast/Broadcast.scala6
-rw-r--r--core/src/main/scala/spark/broadcast/BroadcastFactory.scala4
-rw-r--r--core/src/main/scala/spark/broadcast/HttpBroadcast.scala32
-rw-r--r--core/src/main/scala/spark/broadcast/MultiTracker.scala35
-rw-r--r--core/src/main/scala/spark/broadcast/TreeBroadcast.scala52
-rw-r--r--core/src/main/scala/spark/deploy/DeployMessage.scala4
-rw-r--r--core/src/main/scala/spark/deploy/JobDescription.scala3
-rw-r--r--core/src/main/scala/spark/deploy/LocalSparkCluster.scala39
-rw-r--r--core/src/main/scala/spark/deploy/client/ClientListener.scala4
-rw-r--r--core/src/main/scala/spark/deploy/client/TestClient.scala2
-rw-r--r--core/src/main/scala/spark/deploy/master/JobInfo.scala2
-rw-r--r--core/src/main/scala/spark/deploy/master/Master.scala34
-rw-r--r--core/src/main/scala/spark/deploy/master/MasterWebUI.scala28
-rw-r--r--core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala9
-rw-r--r--core/src/main/scala/spark/deploy/worker/Worker.scala18
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala6
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala38
-rw-r--r--core/src/main/scala/spark/executor/MesosExecutorBackend.scala7
-rw-r--r--core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala41
-rw-r--r--core/src/main/scala/spark/network/Connection.scala15
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala8
-rw-r--r--core/src/main/scala/spark/partial/ApproximateActionListener.scala4
-rw-r--r--core/src/main/scala/spark/rdd/BlockRDD.scala20
-rw-r--r--core/src/main/scala/spark/rdd/CartesianRDD.scala43
-rw-r--r--core/src/main/scala/spark/rdd/CheckpointRDD.scala129
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala54
-rw-r--r--core/src/main/scala/spark/rdd/CoalescedRDD.scala47
-rw-r--r--core/src/main/scala/spark/rdd/FilteredRDD.scala14
-rw-r--r--core/src/main/scala/spark/rdd/FlatMappedRDD.scala10
-rw-r--r--core/src/main/scala/spark/rdd/GlommedRDD.scala14
-rw-r--r--core/src/main/scala/spark/rdd/HadoopRDD.scala15
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsRDD.scala14
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala14
-rw-r--r--core/src/main/scala/spark/rdd/MappedRDD.scala19
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala19
-rw-r--r--core/src/main/scala/spark/rdd/PartitionPruningRDD.scala42
-rw-r--r--core/src/main/scala/spark/rdd/PipedRDD.scala18
-rw-r--r--core/src/main/scala/spark/rdd/SampledRDD.scala29
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala24
-rw-r--r--core/src/main/scala/spark/rdd/UnionRDD.scala43
-rw-r--r--core/src/main/scala/spark/rdd/ZippedRDD.scala57
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala210
-rw-r--r--core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/JobResult.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/JobWaiter.scala14
-rw-r--r--core/src/main/scala/spark/scheduler/MapStatus.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/ResultTask.scala102
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala73
-rw-r--r--core/src/main/scala/spark/scheduler/Stage.scala14
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala105
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala34
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala22
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala102
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala7
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala57
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala20
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala22
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala42
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala89
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerId.scala51
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala85
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMasterActor.scala77
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMessages.scala11
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerUI.scala85
-rw-r--r--core/src/main/scala/spark/storage/BlockMessage.scala2
-rw-r--r--core/src/main/scala/spark/storage/MemoryStore.scala1
-rw-r--r--core/src/main/scala/spark/storage/StorageLevel.scala64
-rw-r--r--core/src/main/scala/spark/storage/StorageUtils.scala78
-rw-r--r--core/src/main/scala/spark/storage/ThreadingTest.scala9
-rw-r--r--core/src/main/scala/spark/util/AkkaUtils.scala13
-rw-r--r--core/src/main/scala/spark/util/MetadataCleaner.scala37
-rw-r--r--core/src/main/scala/spark/util/RateLimitedOutputStream.scala62
-rw-r--r--core/src/main/scala/spark/util/TimeStampedHashMap.scala28
-rw-r--r--core/src/main/scala/spark/util/TimeStampedHashSet.scala69
-rw-r--r--core/src/main/twirl/spark/common/layout.scala.html (renamed from core/src/main/twirl/spark/deploy/common/layout.scala.html)0
-rw-r--r--core/src/main/twirl/spark/deploy/master/index.scala.html2
-rw-r--r--core/src/main/twirl/spark/deploy/master/job_details.scala.html2
-rw-r--r--core/src/main/twirl/spark/deploy/worker/index.scala.html3
-rw-r--r--core/src/main/twirl/spark/storage/index.scala.html40
-rw-r--r--core/src/main/twirl/spark/storage/rdd.scala.html77
-rw-r--r--core/src/main/twirl/spark/storage/rdd_table.scala.html30
-rw-r--r--core/src/main/twirl/spark/storage/worker_table.scala.html24
-rw-r--r--core/src/test/resources/log4j.properties4
-rw-r--r--core/src/test/scala/spark/AccumulatorSuite.scala38
-rw-r--r--core/src/test/scala/spark/BroadcastSuite.scala14
-rw-r--r--core/src/test/scala/spark/CacheTrackerSuite.scala131
-rw-r--r--core/src/test/scala/spark/CheckpointSuite.scala355
-rw-r--r--core/src/test/scala/spark/ClosureCleanerSuite.scala71
-rw-r--r--core/src/test/scala/spark/DistributedSuite.scala92
-rw-r--r--core/src/test/scala/spark/DriverSuite.scala32
-rw-r--r--core/src/test/scala/spark/FailureSuite.scala14
-rw-r--r--core/src/test/scala/spark/FileServerSuite.scala29
-rw-r--r--core/src/test/scala/spark/FileSuite.scala16
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java61
-rw-r--r--core/src/test/scala/spark/LocalSparkContext.scala41
-rw-r--r--core/src/test/scala/spark/MapOutputTrackerSuite.scala58
-rw-r--r--core/src/test/scala/spark/PartitioningSuite.scala15
-rw-r--r--core/src/test/scala/spark/PipedRDDSuite.scala16
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala74
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala14
-rw-r--r--core/src/test/scala/spark/SortingSuite.scala13
-rw-r--r--core/src/test/scala/spark/ThreadingSuite.scala14
-rw-r--r--core/src/test/scala/spark/scheduler/TaskContextSuite.scala32
-rw-r--r--core/src/test/scala/spark/storage/BlockManagerSuite.scala132
-rw-r--r--core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala23
136 files changed, 3738 insertions, 1920 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 862d3ec37a..873e8a1d0f 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -163,11 +163,6 @@
<profiles>
<profile>
<id>hadoop1</id>
- <activation>
- <property>
- <name>!hadoopVersion</name>
- </property>
- </activation>
<dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
@@ -220,12 +215,6 @@
</profile>
<profile>
<id>hadoop2</id>
- <activation>
- <property>
- <name>hadoopVersion</name>
- <value>2</value>
- </property>
- </activation>
<dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala
index b644aba5f8..57c6df35be 100644
--- a/core/src/main/scala/spark/Accumulators.scala
+++ b/core/src/main/scala/spark/Accumulators.scala
@@ -25,8 +25,7 @@ class Accumulable[R, T] (
extends Serializable {
val id = Accumulators.newId
- @transient
- private var value_ = initialValue // Current value on master
+ @transient private var value_ = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
var deserialized = false
diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala
new file mode 100644
index 0000000000..711435c333
--- /dev/null
+++ b/core/src/main/scala/spark/CacheManager.scala
@@ -0,0 +1,65 @@
+package spark
+
+import scala.collection.mutable.{ArrayBuffer, HashSet}
+import spark.storage.{BlockManager, StorageLevel}
+
+
+/** Spark class responsible for passing RDDs split contents to the BlockManager and making
+ sure a node doesn't load two copies of an RDD at once.
+ */
+private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
+ private val loading = new HashSet[String]
+
+ /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
+ def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
+ : Iterator[T] = {
+ val key = "rdd_%d_%d".format(rdd.id, split.index)
+ logInfo("Cache key is " + key)
+ blockManager.get(key) match {
+ case Some(cachedValues) =>
+ // Split is in cache, so just return its values
+ logInfo("Found partition in cache!")
+ return cachedValues.asInstanceOf[Iterator[T]]
+
+ case None =>
+ // Mark the split as loading (unless someone else marks it first)
+ loading.synchronized {
+ if (loading.contains(key)) {
+ logInfo("Loading contains " + key + ", waiting...")
+ while (loading.contains(key)) {
+ try {loading.wait()} catch {case _ =>}
+ }
+ logInfo("Loading no longer contains " + key + ", so returning cached result")
+ // See whether someone else has successfully loaded it. The main way this would fail
+ // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
+ // partition but we didn't want to make space for it. However, that case is unlikely
+ // because it's unlikely that two threads would work on the same RDD partition. One
+ // downside of the current code is that threads wait serially if this does happen.
+ blockManager.get(key) match {
+ case Some(values) =>
+ return values.asInstanceOf[Iterator[T]]
+ case None =>
+ logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
+ loading.add(key)
+ }
+ } else {
+ loading.add(key)
+ }
+ }
+ try {
+ // If we got here, we have to load the split
+ val elements = new ArrayBuffer[Any]
+ logInfo("Computing partition " + split)
+ elements ++= rdd.computeOrReadCheckpoint(split, context)
+ // Try to put this block in the blockManager
+ blockManager.put(key, elements, storageLevel, true)
+ return elements.iterator.asInstanceOf[Iterator[T]]
+ } finally {
+ loading.synchronized {
+ loading.remove(key)
+ loading.notifyAll()
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
deleted file mode 100644
index 04c26b2e40..0000000000
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ /dev/null
@@ -1,232 +0,0 @@
-package spark
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-
-import akka.actor._
-import akka.dispatch._
-import akka.pattern.ask
-import akka.remote._
-import akka.util.Duration
-import akka.util.Timeout
-import akka.util.duration._
-
-import spark.storage.BlockManager
-import spark.storage.StorageLevel
-
-private[spark] sealed trait CacheTrackerMessage
-
-private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
- extends CacheTrackerMessage
-private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
- extends CacheTrackerMessage
-private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage
-private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
-private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
-private[spark] case object GetCacheStatus extends CacheTrackerMessage
-private[spark] case object GetCacheLocations extends CacheTrackerMessage
-private[spark] case object StopCacheTracker extends CacheTrackerMessage
-
-private[spark] class CacheTrackerActor extends Actor with Logging {
- // TODO: Should probably store (String, CacheType) tuples
- private val locs = new HashMap[Int, Array[List[String]]]
-
- /**
- * A map from the slave's host name to its cache size.
- */
- private val slaveCapacity = new HashMap[String, Long]
- private val slaveUsage = new HashMap[String, Long]
-
- private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
- private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
- private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
-
- def receive = {
- case SlaveCacheStarted(host: String, size: Long) =>
- slaveCapacity.put(host, size)
- slaveUsage.put(host, 0)
- sender ! true
-
- case RegisterRDD(rddId: Int, numPartitions: Int) =>
- logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
- locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
- sender ! true
-
- case AddedToCache(rddId, partition, host, size) =>
- slaveUsage.put(host, getCacheUsage(host) + size)
- locs(rddId)(partition) = host :: locs(rddId)(partition)
- sender ! true
-
- case DroppedFromCache(rddId, partition, host, size) =>
- slaveUsage.put(host, getCacheUsage(host) - size)
- // Do a sanity check to make sure usage is greater than 0.
- locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
- sender ! true
-
- case MemoryCacheLost(host) =>
- logInfo("Memory cache lost on " + host)
- for ((id, locations) <- locs) {
- for (i <- 0 until locations.length) {
- locations(i) = locations(i).filterNot(_ == host)
- }
- }
- sender ! true
-
- case GetCacheLocations =>
- logInfo("Asked for current cache locations")
- sender ! locs.map{case (rrdId, array) => (rrdId -> array.clone())}
-
- case GetCacheStatus =>
- val status = slaveCapacity.map { case (host, capacity) =>
- (host, capacity, getCacheUsage(host))
- }.toSeq
- sender ! status
-
- case StopCacheTracker =>
- logInfo("Stopping CacheTrackerActor")
- sender ! true
- context.stop(self)
- }
-}
-
-private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
- extends Logging {
-
- // Tracker actor on the master, or remote reference to it on workers
- val ip: String = System.getProperty("spark.master.host", "localhost")
- val port: Int = System.getProperty("spark.master.port", "7077").toInt
- val actorName: String = "CacheTracker"
-
- val timeout = 10.seconds
-
- var trackerActor: ActorRef = if (isMaster) {
- val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName)
- logInfo("Registered CacheTrackerActor actor")
- actor
- } else {
- val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
- actorSystem.actorFor(url)
- }
-
- val registeredRddIds = new HashSet[Int]
-
- // Remembers which splits are currently being loaded (on worker nodes)
- val loading = new HashSet[String]
-
- // Send a message to the trackerActor and get its result within a default timeout, or
- // throw a SparkException if this fails.
- def askTracker(message: Any): Any = {
- try {
- val future = trackerActor.ask(message)(timeout)
- return Await.result(future, timeout)
- } catch {
- case e: Exception =>
- throw new SparkException("Error communicating with CacheTracker", e)
- }
- }
-
- // Send a one-way message to the trackerActor, to which we expect it to reply with true.
- def communicate(message: Any) {
- if (askTracker(message) != true) {
- throw new SparkException("Error reply received from CacheTracker")
- }
- }
-
- // Registers an RDD (on master only)
- def registerRDD(rddId: Int, numPartitions: Int) {
- registeredRddIds.synchronized {
- if (!registeredRddIds.contains(rddId)) {
- logInfo("Registering RDD ID " + rddId + " with cache")
- registeredRddIds += rddId
- communicate(RegisterRDD(rddId, numPartitions))
- }
- }
- }
-
- // For BlockManager.scala only
- def cacheLost(host: String) {
- communicate(MemoryCacheLost(host))
- logInfo("CacheTracker successfully removed entries on " + host)
- }
-
- // Get the usage status of slave caches. Each tuple in the returned sequence
- // is in the form of (host name, capacity, usage).
- def getCacheStatus(): Seq[(String, Long, Long)] = {
- askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]]
- }
-
- // For BlockManager.scala only
- def notifyFromBlockManager(t: AddedToCache) {
- communicate(t)
- }
-
- // Get a snapshot of the currently known locations
- def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
- askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
- }
-
- // Gets or computes an RDD split
- def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
- : Iterator[T] = {
- val key = "rdd_%d_%d".format(rdd.id, split.index)
- logInfo("Cache key is " + key)
- blockManager.get(key) match {
- case Some(cachedValues) =>
- // Split is in cache, so just return its values
- logInfo("Found partition in cache!")
- return cachedValues.asInstanceOf[Iterator[T]]
-
- case None =>
- // Mark the split as loading (unless someone else marks it first)
- loading.synchronized {
- if (loading.contains(key)) {
- logInfo("Loading contains " + key + ", waiting...")
- while (loading.contains(key)) {
- try {loading.wait()} catch {case _ =>}
- }
- logInfo("Loading no longer contains " + key + ", so returning cached result")
- // See whether someone else has successfully loaded it. The main way this would fail
- // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
- // partition but we didn't want to make space for it. However, that case is unlikely
- // because it's unlikely that two threads would work on the same RDD partition. One
- // downside of the current code is that threads wait serially if this does happen.
- blockManager.get(key) match {
- case Some(values) =>
- return values.asInstanceOf[Iterator[T]]
- case None =>
- logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
- loading.add(key)
- }
- } else {
- loading.add(key)
- }
- }
- try {
- // If we got here, we have to load the split
- val elements = new ArrayBuffer[Any]
- logInfo("Computing partition " + split)
- elements ++= rdd.compute(split, context)
- // Try to put this block in the blockManager
- blockManager.put(key, elements, storageLevel, true)
- return elements.iterator.asInstanceOf[Iterator[T]]
- } finally {
- loading.synchronized {
- loading.remove(key)
- loading.notifyAll()
- }
- }
- }
- }
-
- // Called by the Cache to report that an entry has been dropped from it
- def dropEntry(rddId: Int, partition: Int) {
- communicate(DroppedFromCache(rddId, partition, Utils.localHostName()))
- }
-
- def stop() {
- communicate(StopCacheTracker)
- registeredRddIds.clear()
- trackerActor = null
- }
-}
diff --git a/core/src/main/scala/spark/DaemonThreadFactory.scala b/core/src/main/scala/spark/DaemonThreadFactory.scala
deleted file mode 100644
index 56e59adeb7..0000000000
--- a/core/src/main/scala/spark/DaemonThreadFactory.scala
+++ /dev/null
@@ -1,18 +0,0 @@
-package spark
-
-import java.util.concurrent.ThreadFactory
-
-/**
- * A ThreadFactory that creates daemon threads
- */
-private object DaemonThreadFactory extends ThreadFactory {
- override def newThread(r: Runnable): Thread = new DaemonThread(r)
-}
-
-private class DaemonThread(r: Runnable = null) extends Thread {
- override def run() {
- if (r != null) {
- r.run()
- }
- }
-} \ No newline at end of file
diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala
index b85d2732db..5eea907322 100644
--- a/core/src/main/scala/spark/Dependency.scala
+++ b/core/src/main/scala/spark/Dependency.scala
@@ -5,6 +5,7 @@ package spark
*/
abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
+
/**
* Base class for dependencies where each partition of the parent RDD is used by at most one
* partition of the child RDD. Narrow dependencies allow for pipelined execution.
@@ -12,12 +13,13 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
/**
* Get the parent partitions for a child partition.
- * @param outputPartition a partition of the child RDD
+ * @param partitionId a partition of the child RDD
* @return the partitions of the parent RDD that the child partition depends upon
*/
- def getParents(outputPartition: Int): Seq[Int]
+ def getParents(partitionId: Int): Seq[Int]
}
+
/**
* Represents a dependency on the output of a shuffle stage.
* @param shuffleId the shuffle id
@@ -32,6 +34,7 @@ class ShuffleDependency[K, V](
val shuffleId: Int = rdd.context.newShuffleId()
}
+
/**
* Represents a one-to-one dependency between partitions of the parent and child RDDs.
*/
@@ -39,6 +42,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
override def getParents(partitionId: Int) = List(partitionId)
}
+
/**
* Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs.
* @param rdd the parent RDD
@@ -48,7 +52,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
*/
class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int)
extends NarrowDependency[T](rdd) {
-
+
override def getParents(partitionId: Int) = {
if (partitionId >= outStart && partitionId < outStart + length) {
List(partitionId - outStart + inStart)
diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala
index 659d17718f..00901d95e2 100644
--- a/core/src/main/scala/spark/HttpFileServer.scala
+++ b/core/src/main/scala/spark/HttpFileServer.scala
@@ -1,9 +1,7 @@
package spark
-import java.io.{File, PrintWriter}
-import java.net.URL
-import scala.collection.mutable.HashMap
-import org.apache.hadoop.fs.FileUtil
+import java.io.{File}
+import com.google.common.io.Files
private[spark] class HttpFileServer extends Logging {
@@ -40,7 +38,7 @@ private[spark] class HttpFileServer extends Logging {
}
def addFileToDir(file: File, dir: File) : String = {
- Utils.copyFile(file, new File(dir, file.getName))
+ Files.copy(file, new File(dir, file.getName))
return dir + "/" + file.getName
}
diff --git a/core/src/main/scala/spark/HttpServer.scala b/core/src/main/scala/spark/HttpServer.scala
index 0196595ba1..4e0507c080 100644
--- a/core/src/main/scala/spark/HttpServer.scala
+++ b/core/src/main/scala/spark/HttpServer.scala
@@ -4,6 +4,7 @@ import java.io.File
import java.net.InetAddress
import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.server.bio.SocketConnector
import org.eclipse.jetty.server.handler.DefaultHandler
import org.eclipse.jetty.server.handler.HandlerList
import org.eclipse.jetty.server.handler.ResourceHandler
@@ -27,7 +28,13 @@ private[spark] class HttpServer(resourceBase: File) extends Logging {
if (server != null) {
throw new ServerStateException("Server is already started")
} else {
- server = new Server(0)
+ server = new Server()
+ val connector = new SocketConnector
+ connector.setMaxIdleTime(60*1000)
+ connector.setSoLingerTime(-1)
+ connector.setPort(0)
+ server.addConnector(connector)
+
val threadPool = new QueuedThreadPool
threadPool.setDaemon(true)
server.setThreadPool(threadPool)
diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala
index 93d7327324..0bd73e936b 100644
--- a/core/src/main/scala/spark/KryoSerializer.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -206,5 +206,8 @@ class KryoSerializer extends spark.serializer.Serializer with Logging {
kryo
}
- def newInstance(): SerializerInstance = new KryoSerializerInstance(this)
+ def newInstance(): SerializerInstance = {
+ this.kryo.get().setClassLoader(Thread.currentThread().getContextClassLoader)
+ new KryoSerializerInstance(this)
+ }
}
diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala
index 90bae26202..7c1c1bb144 100644
--- a/core/src/main/scala/spark/Logging.scala
+++ b/core/src/main/scala/spark/Logging.scala
@@ -11,8 +11,7 @@ import org.slf4j.LoggerFactory
trait Logging {
// Make the log field transient so that objects with Logging can
// be serialized and used on another machine
- @transient
- private var log_ : Logger = null
+ @transient private var log_ : Logger = null
// Method to get or create the logger for this object
protected def log: Logger = {
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 70eb9f702e..4735207585 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -17,6 +17,7 @@ import akka.util.duration._
import spark.scheduler.MapStatus
import spark.storage.BlockManagerId
+import spark.util.{MetadataCleaner, TimeStampedHashMap}
private[spark] sealed trait MapOutputTrackerMessage
@@ -37,14 +38,11 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
}
}
-private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging {
- val ip: String = System.getProperty("spark.master.host", "localhost")
- val port: Int = System.getProperty("spark.master.port", "7077").toInt
- val actorName: String = "MapOutputTracker"
+private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolean) extends Logging {
val timeout = 10.seconds
- var mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
+ var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
@@ -53,17 +51,22 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheGeneration = generation
- val cachedSerializedStatuses = new HashMap[Int, Array[Byte]]
+ val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
- var trackerActor: ActorRef = if (isMaster) {
+ val actorName: String = "MapOutputTracker"
+ var trackerActor: ActorRef = if (isDriver) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
logInfo("Registered MapOutputTrackerActor actor")
actor
} else {
+ val ip = System.getProperty("spark.driver.host", "localhost")
+ val port = System.getProperty("spark.driver.port", "7077").toInt
val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
actorSystem.actorFor(url)
}
+ val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
+
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
@@ -84,14 +87,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (mapStatuses.get(shuffleId) != null) {
+ if (mapStatuses.get(shuffleId) != None) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
- var array = mapStatuses.get(shuffleId)
+ var array = mapStatuses(shuffleId)
array.synchronized {
array(mapId) = status
}
@@ -108,10 +111,10 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var array = mapStatuses.get(shuffleId)
+ var array = mapStatuses(shuffleId)
if (array != null) {
array.synchronized {
- if (array(mapId) != null && array(mapId).address == bmAddress) {
+ if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
}
}
@@ -126,7 +129,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
- val statuses = mapStatuses.get(shuffleId)
+ val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
fetching.synchronized {
@@ -139,8 +142,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
case e: InterruptedException =>
}
}
- return mapStatuses.get(shuffleId).map(status =>
- (status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId))))
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId))
} else {
fetching += shuffleId
}
@@ -156,27 +158,27 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
fetchedStatuses = deserializeStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
- if (fetchedStatuses.contains(null)) {
- throw new FetchFailedException(null, shuffleId, -1, reduceId,
- new Exception("Missing an output location for shuffle " + shuffleId))
- }
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
- return fetchedStatuses.map(s =>
- (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
} else {
- return statuses.map(s =>
- (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
}
}
+ private def cleanup(cleanupTime: Long) {
+ mapStatuses.clearOldValues(cleanupTime)
+ cachedSerializedStatuses.clearOldValues(cleanupTime)
+ }
+
def stop() {
communicate(StopMapOutputTracker)
mapStatuses.clear()
+ metadataCleaner.cancel()
trackerActor = null
}
@@ -202,7 +204,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
- mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
+ mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
generation = newGen
}
}
@@ -220,7 +222,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
case Some(bytes) =>
return bytes
case None =>
- statuses = mapStatuses.get(shuffleId)
+ statuses = mapStatuses(shuffleId)
generationGotten = generation
}
}
@@ -258,6 +260,28 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
private[spark] object MapOutputTracker {
private val LOG_BASE = 1.1
+ // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
+ // any of the statuses is null (indicating a missing location due to a failed mapper),
+ // throw a FetchFailedException.
+ def convertMapStatuses(
+ shuffleId: Int,
+ reduceId: Int,
+ statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
+ if (statuses == null) {
+ throw new FetchFailedException(null, shuffleId, -1, reduceId,
+ new Exception("Missing all output locations for shuffle " + shuffleId))
+ }
+ statuses.map {
+ status =>
+ if (status == null) {
+ throw new FetchFailedException(null, shuffleId, -1, reduceId,
+ new Exception("Missing an output location for shuffle " + shuffleId))
+ } else {
+ (status.location, decompressSize(status.compressedSizes(reduceId)))
+ }
+ }
+ }
+
/**
* Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
* We do this by encoding the log base 1.1 of the size as an integer, which can support
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index ce48cea903..cc3cca2571 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -199,9 +199,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
/**
- * Merge the values for each key using an associative reduce function. This will also perform
- * the merging locally on each mapper before sending results to a reducer, similarly to a
- * "combiner" in MapReduce.
+ * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
+ * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
+ * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD.
*/
def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = {
this.cogroup(other, partitioner).flatMapValues {
@@ -465,7 +465,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
val res = self.context.runJob(self, process _, Array(index), false)
res(0)
case None =>
- self.filter(_._1 == key).map(_._2).collect
+ self.filter(_._1 == key).map(_._2).collect()
}
}
@@ -493,20 +493,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
path: String,
keyClass: Class[_],
valueClass: Class[_],
- outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) {
- saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, new Configuration)
- }
-
- /**
- * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
- * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
- */
- def saveAsNewAPIHadoopFile(
- path: String,
- keyClass: Class[_],
- valueClass: Class[_],
outputFormatClass: Class[_ <: NewOutputFormat[_, _]],
- conf: Configuration) {
+ conf: Configuration = self.context.hadoopConfiguration) {
val job = new NewAPIHadoopJob(conf)
job.setOutputKeyClass(keyClass)
job.setOutputValueClass(valueClass)
@@ -557,7 +545,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: OutputFormat[_, _]],
- conf: JobConf = new JobConf) {
+ conf: JobConf = new JobConf(self.context.hadoopConfiguration)) {
conf.setOutputKeyClass(keyClass)
conf.setOutputValueClass(valueClass)
// conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug
@@ -602,7 +590,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
var count = 0
while(iter.hasNext) {
- val record = iter.next
+ val record = iter.next()
count += 1
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
}
@@ -661,24 +649,21 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
}
private[spark]
-class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev.context) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override val partitioner = prev.partitioner
- override def compute(split: Split, taskContext: TaskContext) =
- prev.iterator(split, taskContext).map{case (k, v) => (k, f(v))}
+class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev) {
+ override def getSplits = firstParent[(K, V)].splits
+ override val partitioner = firstParent[(K, V)].partitioner
+ override def compute(split: Split, context: TaskContext) =
+ firstParent[(K, V)].iterator(split, context).map{ case (k, v) => (k, f(v)) }
}
private[spark]
class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U])
- extends RDD[(K, U)](prev.context) {
-
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override val partitioner = prev.partitioner
+ extends RDD[(K, U)](prev) {
- override def compute(split: Split, taskContext: TaskContext) = {
- prev.iterator(split, taskContext).flatMap { case (k, v) => f(v).map(x => (k, x)) }
+ override def getSplits = firstParent[(K, V)].splits
+ override val partitioner = firstParent[(K, V)].partitioner
+ override def compute(split: Split, context: TaskContext) = {
+ firstParent[(K, V)].iterator(split, context).flatMap { case (k, v) => f(v).map(x => (k, x)) }
}
}
diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala
index a27f766e31..10adcd53ec 100644
--- a/core/src/main/scala/spark/ParallelCollection.scala
+++ b/core/src/main/scala/spark/ParallelCollection.scala
@@ -2,6 +2,7 @@ package spark
import scala.collection.immutable.NumericRange
import scala.collection.mutable.ArrayBuffer
+import scala.collection.Map
private[spark] class ParallelCollectionSplit[T: ClassManifest](
val rddId: Long,
@@ -22,28 +23,33 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest](
}
private[spark] class ParallelCollection[T: ClassManifest](
- sc: SparkContext,
+ @transient sc: SparkContext,
@transient data: Seq[T],
- numSlices: Int)
- extends RDD[T](sc) {
+ numSlices: Int,
+ locationPrefs: Map[Int,Seq[String]])
+ extends RDD[T](sc, Nil) {
// TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
// cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
// instead.
+ // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.
- @transient
- val splits_ = {
+ @transient var splits_ : Array[Split] = {
val slices = ParallelCollection.slice(data, numSlices).toArray
slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray
}
- override def splits = splits_.asInstanceOf[Array[Split]]
+ override def getSplits = splits_
- override def compute(s: Split, taskContext: TaskContext) =
+ override def compute(s: Split, context: TaskContext) =
s.asInstanceOf[ParallelCollectionSplit[T]].iterator
- override def preferredLocations(s: Split): Seq[String] = Nil
+ override def getPreferredLocations(s: Split): Seq[String] = {
+ locationPrefs.getOrElse(s.index, Nil)
+ }
- override val dependencies: List[Dependency[_]] = Nil
+ override def clearDependencies() {
+ splits_ = null
+ }
}
private object ParallelCollection {
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 3b9ced1946..9d6ea782bd 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -1,12 +1,8 @@
package spark
-import java.io.EOFException
-import java.io.ObjectInputStream
import java.net.URL
-import java.util.Random
-import java.util.Date
+import java.util.{Date, Random}
import java.util.{HashMap => JHashMap}
-import java.util.concurrent.atomic.AtomicLong
import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
@@ -16,13 +12,6 @@ import scala.collection.mutable.HashMap
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
-import org.apache.hadoop.io.Writable
-import org.apache.hadoop.mapred.FileOutputCommitter
-import org.apache.hadoop.mapred.HadoopWriter
-import org.apache.hadoop.mapred.JobConf
-import org.apache.hadoop.mapred.OutputCommitter
-import org.apache.hadoop.mapred.OutputFormat
-import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.TextOutputFormat
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
@@ -31,7 +20,6 @@ import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
import spark.partial.GroupedCountEvaluator
import spark.partial.PartialResult
-import spark.rdd.BlockRDD
import spark.rdd.CartesianRDD
import spark.rdd.FilteredRDD
import spark.rdd.FlatMappedRDD
@@ -73,40 +61,55 @@ import SparkContext._
* [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details
* on RDD internals.
*/
-abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serializable {
+abstract class RDD[T: ClassManifest](
+ @transient private var sc: SparkContext,
+ @transient private var deps: Seq[Dependency[_]]
+ ) extends Serializable with Logging {
- // Methods that must be implemented by subclasses:
+ /** Construct an RDD with just a one-to-one dependency on one parent */
+ def this(@transient oneParent: RDD[_]) =
+ this(oneParent.context , List(new OneToOneDependency(oneParent)))
- /** Set of partitions in this RDD. */
- def splits: Array[Split]
+ // =======================================================================
+ // Methods that should be implemented by subclasses of RDD
+ // =======================================================================
- /** Function for computing a given partition. */
+ /** Implemented by subclasses to compute a given partition. */
def compute(split: Split, context: TaskContext): Iterator[T]
- /** How this RDD depends on any parent RDDs. */
- @transient val dependencies: List[Dependency[_]]
+ /**
+ * Implemented by subclasses to return the set of partitions in this RDD. This method will only
+ * be called once, so it is safe to implement a time-consuming computation in it.
+ */
+ protected def getSplits: Array[Split]
- // Methods available on all RDDs:
+ /**
+ * Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only
+ * be called once, so it is safe to implement a time-consuming computation in it.
+ */
+ protected def getDependencies: Seq[Dependency[_]] = deps
- /** Record user function generating this RDD. */
- private[spark] val origin = Utils.getSparkCallSite
+ /** Optionally overridden by subclasses to specify placement preferences. */
+ protected def getPreferredLocations(split: Split): Seq[String] = Nil
/** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None
- /** Optionally overridden by subclasses to specify placement preferences. */
- def preferredLocations(split: Split): Seq[String] = Nil
-
- /** The [[spark.SparkContext]] that this RDD was created on. */
- def context = sc
-
- private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
+ // =======================================================================
+ // Methods and fields available on all RDDs
+ // =======================================================================
/** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId()
- // Variables relating to persistence
- private var storageLevel: StorageLevel = StorageLevel.NONE
+ /** A friendly name for this RDD */
+ var name: String = null
+
+ /** Assign a name to this RDD */
+ def setName(_name: String) = {
+ name = _name
+ this
+ }
/**
* Set this RDD's storage level to persist its values across operations after the first time
@@ -119,6 +122,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
"Cannot change storage level of an RDD after it was already assigned a level")
}
storageLevel = newLevel
+ // Register the RDD with the SparkContext
+ sc.persistentRdds(id) = this
this
}
@@ -131,22 +136,47 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
- private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
- if (!level.useDisk && level.replication < 2) {
- throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
- }
+ // Our dependencies and splits will be gotten by calling subclass's methods below, and will
+ // be overwritten when we're checkpointed
+ private var dependencies_ : Seq[Dependency[_]] = null
+ @transient private var splits_ : Array[Split] = null
- // This is a hack. Ideally this should re-use the code used by the CacheTracker
- // to generate the key.
- def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index)
+ /** An Option holding our checkpoint RDD, if we are checkpointed */
+ private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD)
- persist(level)
- sc.runJob(this, (iter: Iterator[T]) => {} )
+ /**
+ * Get the list of dependencies of this RDD, taking into account whether the
+ * RDD is checkpointed or not.
+ */
+ final def dependencies: Seq[Dependency[_]] = {
+ checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse {
+ if (dependencies_ == null) {
+ dependencies_ = getDependencies
+ }
+ dependencies_
+ }
+ }
- val p = this.partitioner
+ /**
+ * Get the array of splits of this RDD, taking into account whether the
+ * RDD is checkpointed or not.
+ */
+ final def splits: Array[Split] = {
+ checkpointRDD.map(_.splits).getOrElse {
+ if (splits_ == null) {
+ splits_ = getSplits
+ }
+ splits_
+ }
+ }
- new BlockRDD[T](sc, splits.map(getSplitKey).toArray) {
- override val partitioner = p
+ /**
+ * Get the preferred location of a split, taking into account whether the
+ * RDD is checkpointed or not.
+ */
+ final def preferredLocations(split: Split): Seq[String] = {
+ checkpointRDD.map(_.getPreferredLocations(split)).getOrElse {
+ getPreferredLocations(split)
}
}
@@ -157,7 +187,18 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
*/
final def iterator(split: Split, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
- SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel)
+ SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
+ } else {
+ computeOrReadCheckpoint(split, context)
+ }
+ }
+
+ /**
+ * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
+ */
+ private[spark] def computeOrReadCheckpoint(split: Split, context: TaskContext): Iterator[T] = {
+ if (isCheckpointed) {
+ firstParent[T].iterator(split, context)
} else {
compute(split, context)
}
@@ -344,20 +385,22 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
val reducePartition: Iterator[T] => Option[T] = iter => {
if (iter.hasNext) {
Some(iter.reduceLeft(cleanF))
- }else {
+ } else {
None
}
}
- val options = sc.runJob(this, reducePartition)
- val results = new ArrayBuffer[T]
- for (opt <- options; elem <- opt) {
- results += elem
- }
- if (results.size == 0) {
- throw new UnsupportedOperationException("empty collection")
- } else {
- return results.reduceLeft(cleanF)
+ var jobResult: Option[T] = None
+ val mergeResult = (index: Int, taskResult: Option[T]) => {
+ if (taskResult != None) {
+ jobResult = jobResult match {
+ case Some(value) => Some(f(value, taskResult.get))
+ case None => taskResult
+ }
+ }
}
+ sc.runJob(this, reducePartition, mergeResult)
+ // Get the final result out of our Option, or throw an exception if the RDD was empty
+ jobResult.getOrElse(throw new UnsupportedOperationException("empty collection"))
}
/**
@@ -367,9 +410,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
* modify t2.
*/
def fold(zeroValue: T)(op: (T, T) => T): T = {
+ // Clone the zero value since we will also be serializing it as part of tasks
+ var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanOp = sc.clean(op)
- val results = sc.runJob(this, (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp))
- return results.fold(zeroValue)(cleanOp)
+ val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp)
+ val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult)
+ sc.runJob(this, foldPartition, mergeResult)
+ jobResult
}
/**
@@ -381,11 +428,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
* allocation.
*/
def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
+ // Clone the zero value since we will also be serializing it as part of tasks
+ var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanSeqOp = sc.clean(seqOp)
val cleanCombOp = sc.clean(combOp)
- val results = sc.runJob(this,
- (iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp))
- return results.fold(zeroValue)(cleanCombOp)
+ val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
+ val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
+ sc.runJob(this, aggregatePartition, mergeResult)
+ jobResult
}
/**
@@ -396,7 +446,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
var result = 0L
while (iter.hasNext) {
result += 1L
- iter.next
+ iter.next()
}
result
}).sum
@@ -411,7 +461,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
var result = 0L
while (iter.hasNext) {
result += 1L
- iter.next
+ iter.next()
}
result
}
@@ -528,4 +578,105 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
private[spark] def collectPartitions(): Array[Array[T]] = {
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
}
+
+ /**
+ * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
+ * directory set with SparkContext.setCheckpointDir() and all references to its parent
+ * RDDs will be removed. This function must be called before any job has been
+ * executed on this RDD. It is strongly recommended that this RDD is persisted in
+ * memory, otherwise saving it on a file will require recomputation.
+ */
+ def checkpoint() {
+ if (context.checkpointDir.isEmpty) {
+ throw new Exception("Checkpoint directory has not been set in the SparkContext")
+ } else if (checkpointData.isEmpty) {
+ checkpointData = Some(new RDDCheckpointData(this))
+ checkpointData.get.markForCheckpoint()
+ }
+ }
+
+ /**
+ * Return whether this RDD has been checkpointed or not
+ */
+ def isCheckpointed: Boolean = {
+ checkpointData.map(_.isCheckpointed).getOrElse(false)
+ }
+
+ /**
+ * Gets the name of the file to which this RDD was checkpointed
+ */
+ def getCheckpointFile: Option[String] = {
+ checkpointData.flatMap(_.getCheckpointFile)
+ }
+
+ // =======================================================================
+ // Other internal methods and fields
+ // =======================================================================
+
+ private var storageLevel: StorageLevel = StorageLevel.NONE
+
+ /** Record user function generating this RDD. */
+ private[spark] val origin = Utils.getSparkCallSite
+
+ private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
+
+ private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
+
+ /** Returns the first parent RDD */
+ protected[spark] def firstParent[U: ClassManifest] = {
+ dependencies.head.rdd.asInstanceOf[RDD[U]]
+ }
+
+ /** The [[spark.SparkContext]] that this RDD was created on. */
+ def context = sc
+
+ /**
+ * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler
+ * after a job using this RDD has completed (therefore the RDD has been materialized and
+ * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs.
+ */
+ private[spark] def doCheckpoint() {
+ if (checkpointData.isDefined) {
+ checkpointData.get.doCheckpoint()
+ } else {
+ dependencies.foreach(_.rdd.doCheckpoint())
+ }
+ }
+
+ /**
+ * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`)
+ * created from the checkpoint file, and forget its old dependencies and splits.
+ */
+ private[spark] def markCheckpointed(checkpointRDD: RDD[_]) {
+ clearDependencies()
+ dependencies_ = null
+ splits_ = null
+ deps = null // Forget the constructor argument for dependencies too
+ }
+
+ /**
+ * Clears the dependencies of this RDD. This method must ensure that all references
+ * to the original parent RDDs is removed to enable the parent RDDs to be garbage
+ * collected. Subclasses of RDD may override this method for implementing their own cleaning
+ * logic. See [[spark.rdd.UnionRDD]] for an example.
+ */
+ protected def clearDependencies() {
+ dependencies_ = null
+ }
+
+ /** A description of this RDD and its recursive dependencies for debugging. */
+ def toDebugString(): String = {
+ def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = {
+ Seq(prefix + rdd + " (" + rdd.splits.size + " splits)") ++
+ rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " "))
+ }
+ debugString(this).mkString("\n")
+ }
+
+ override def toString(): String = "%s%s[%d] at %s".format(
+ Option(name).map(_ + " ").getOrElse(""),
+ getClass.getSimpleName,
+ id,
+ origin)
+
}
diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala
new file mode 100644
index 0000000000..a4a4ebaf53
--- /dev/null
+++ b/core/src/main/scala/spark/RDDCheckpointData.scala
@@ -0,0 +1,106 @@
+package spark
+
+import org.apache.hadoop.fs.Path
+import rdd.{CheckpointRDD, CoalescedRDD}
+import scheduler.{ResultTask, ShuffleMapTask}
+
+/**
+ * Enumeration to manage state transitions of an RDD through checkpointing
+ * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ]
+ */
+private[spark] object CheckpointState extends Enumeration {
+ type CheckpointState = Value
+ val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value
+}
+
+/**
+ * This class contains all the information related to RDD checkpointing. Each instance of this class
+ * is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as,
+ * manages the post-checkpoint state by providing the updated splits, iterator and preferred locations
+ * of the checkpointed RDD.
+ */
+private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
+ extends Logging with Serializable {
+
+ import CheckpointState._
+
+ // The checkpoint state of the associated RDD.
+ var cpState = Initialized
+
+ // The file to which the associated RDD has been checkpointed to
+ @transient var cpFile: Option[String] = None
+
+ // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
+ var cpRDD: Option[RDD[T]] = None
+
+ // Mark the RDD for checkpointing
+ def markForCheckpoint() {
+ RDDCheckpointData.synchronized {
+ if (cpState == Initialized) cpState = MarkedForCheckpoint
+ }
+ }
+
+ // Is the RDD already checkpointed
+ def isCheckpointed: Boolean = {
+ RDDCheckpointData.synchronized { cpState == Checkpointed }
+ }
+
+ // Get the file to which this RDD was checkpointed to as an Option
+ def getCheckpointFile: Option[String] = {
+ RDDCheckpointData.synchronized { cpFile }
+ }
+
+ // Do the checkpointing of the RDD. Called after the first job using that RDD is over.
+ def doCheckpoint() {
+ // If it is marked for checkpointing AND checkpointing is not already in progress,
+ // then set it to be in progress, else return
+ RDDCheckpointData.synchronized {
+ if (cpState == MarkedForCheckpoint) {
+ cpState = CheckpointingInProgress
+ } else {
+ return
+ }
+ }
+
+ // Save to file, and reload it as an RDD
+ val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString
+ rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _)
+ val newRDD = new CheckpointRDD[T](rdd.context, path)
+
+ // Change the dependencies and splits of the RDD
+ RDDCheckpointData.synchronized {
+ cpFile = Some(path)
+ cpRDD = Some(newRDD)
+ rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and splits
+ cpState = Checkpointed
+ RDDCheckpointData.clearTaskCaches()
+ logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
+ }
+ }
+
+ // Get preferred location of a split after checkpointing
+ def getPreferredLocations(split: Split): Seq[String] = {
+ RDDCheckpointData.synchronized {
+ cpRDD.get.preferredLocations(split)
+ }
+ }
+
+ def getSplits: Array[Split] = {
+ RDDCheckpointData.synchronized {
+ cpRDD.get.splits
+ }
+ }
+
+ def checkpointRDD: Option[RDD[T]] = {
+ RDDCheckpointData.synchronized {
+ cpRDD
+ }
+ }
+}
+
+private[spark] object RDDCheckpointData {
+ def clearTaskCaches() {
+ ShuffleMapTask.clearCache()
+ ResultTask.clearCache()
+ }
+}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 402355bd52..86ed293bae 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -1,12 +1,15 @@
package spark
import java.io._
+import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger
import java.net.{URI, URLClassLoader}
+import java.lang.ref.WeakReference
import scala.collection.Map
import scala.collection.generic.Growable
import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.JavaConversions._
import akka.actor.Actor
import akka.actor.Actor._
@@ -36,15 +39,13 @@ import spark.broadcast._
import spark.deploy.LocalSparkCluster
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
-import spark.rdd.HadoopRDD
-import spark.rdd.NewHadoopRDD
-import spark.rdd.UnionRDD
-import spark.scheduler.ShuffleMapTask
-import spark.scheduler.DAGScheduler
-import spark.scheduler.TaskScheduler
+import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD}
+import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler}
import spark.scheduler.local.LocalScheduler
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
+import storage.BlockManagerUI
+import util.{MetadataCleaner, TimeStampedHashMap}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -58,54 +59,49 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend
* @param environment Environment variables to set on worker nodes.
*/
class SparkContext(
- master: String,
- jobName: String,
- val sparkHome: String,
- jars: Seq[String],
- environment: Map[String, String])
+ val master: String,
+ val jobName: String,
+ val sparkHome: String = null,
+ val jars: Seq[String] = Nil,
+ environment: Map[String, String] = Map())
extends Logging {
- /**
- * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI
- * @param sparkHome Location where Spark is installed on cluster nodes.
- * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
- * system or HDFS, HTTP, HTTPS, or FTP URLs.
- */
- def this(master: String, jobName: String, sparkHome: String, jars: Seq[String]) =
- this(master, jobName, sparkHome, jars, Map())
-
- /**
- * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI
- */
- def this(master: String, jobName: String) = this(master, jobName, null, Nil, Map())
-
// Ensure logging is initialized before we spawn any threads
initLogging()
- // Set Spark master host and port system properties
- if (System.getProperty("spark.master.host") == null) {
- System.setProperty("spark.master.host", Utils.localIpAddress)
+ // Set Spark driver host and port system properties
+ if (System.getProperty("spark.driver.host") == null) {
+ System.setProperty("spark.driver.host", Utils.localIpAddress)
}
- if (System.getProperty("spark.master.port") == null) {
- System.setProperty("spark.master.port", "0")
+ if (System.getProperty("spark.driver.port") == null) {
+ System.setProperty("spark.driver.port", "0")
}
private val isLocal = (master == "local" || master.startsWith("local["))
// Create the Spark execution environment (cache, map output tracker, etc)
private[spark] val env = SparkEnv.createFromSystemProperties(
- System.getProperty("spark.master.host"),
- System.getProperty("spark.master.port").toInt,
+ "<driver>",
+ System.getProperty("spark.driver.host"),
+ System.getProperty("spark.driver.port").toInt,
true,
isLocal)
SparkEnv.set(env)
+ // Start the BlockManager UI
+ private[spark] val ui = new BlockManagerUI(
+ env.actorSystem, env.blockManager.master.driverActor, this)
+ ui.start()
+
// Used to store a URL for each static file/jar together with the file's local timestamp
private[spark] val addedFiles = HashMap[String, Long]()
private[spark] val addedJars = HashMap[String, Long]()
+ // Keeps track of all persisted RDDs
+ private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]()
+ private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
+
+
// Add each JAR given through the constructor
jars.foreach { addJar(_) }
@@ -131,6 +127,8 @@ class SparkContext(
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """(spark://.*)""".r
+ //Regular expression for connection to Mesos cluster
+ val MESOS_REGEX = """(mesos://.*)""".r
master match {
case "local" =>
@@ -171,6 +169,9 @@ class SparkContext(
scheduler
case _ =>
+ if (MESOS_REGEX.findFirstIn(master).isEmpty) {
+ logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
+ }
MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
@@ -188,11 +189,32 @@ class SparkContext(
private var dagScheduler = new DAGScheduler(taskScheduler)
+ /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
+ val hadoopConfiguration = {
+ val conf = new Configuration()
+ // Explicitly check for S3 environment variables
+ if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
+ conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
+ conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
+ conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ }
+ // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
+ for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) {
+ conf.set(key.substring("spark.hadoop.".length), System.getProperty(key))
+ }
+ val bufferSize = System.getProperty("spark.buffer.size", "65536")
+ conf.set("io.file.buffer.size", bufferSize)
+ conf
+ }
+
+ private[spark] var checkpointDir: Option[String] = None
+
// Methods for creating RDDs
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
- new ParallelCollection[T](this, seq, numSlices)
+ new ParallelCollection[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
/** Distribute a local Scala collection to form an RDD. */
@@ -200,6 +222,14 @@ class SparkContext(
parallelize(seq, numSlices)
}
+ /** Distribute a local Scala collection to form an RDD, with one or more
+ * location preferences (hostnames of Spark nodes) for each object.
+ * Create a new partition for each collection item. */
+ def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = {
+ val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
+ new ParallelCollection[T](this, seq.map(_._1), seq.size, indexToPrefs)
+ }
+
/**
* Read a text file from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI, and return it as an RDD of Strings.
@@ -232,10 +262,8 @@ class SparkContext(
valueClass: Class[V],
minSplits: Int = defaultMinSplits
) : RDD[(K, V)] = {
- val conf = new JobConf()
+ val conf = new JobConf(hadoopConfiguration)
FileInputFormat.setInputPaths(conf, path)
- val bufferSize = System.getProperty("spark.buffer.size", "65536")
- conf.set("io.file.buffer.size", bufferSize)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
}
@@ -276,8 +304,7 @@ class SparkContext(
path,
fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]],
- vm.erasure.asInstanceOf[Class[V]],
- new Configuration)
+ vm.erasure.asInstanceOf[Class[V]])
}
/**
@@ -289,7 +316,7 @@ class SparkContext(
fClass: Class[F],
kClass: Class[K],
vClass: Class[V],
- conf: Configuration): RDD[(K, V)] = {
+ conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
@@ -301,7 +328,7 @@ class SparkContext(
* and extra configuration options to pass to the input format.
*/
def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]](
- conf: Configuration,
+ conf: Configuration = hadoopConfiguration,
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = {
@@ -366,6 +393,13 @@ class SparkContext(
.flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes))
}
+
+ protected[spark] def checkpointFile[T: ClassManifest](
+ path: String
+ ): RDD[T] = {
+ new CheckpointRDD[T](this, path)
+ }
+
/** Build the union of a list of RDDs. */
def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
@@ -377,14 +411,14 @@ class SparkContext(
/**
* Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values
- * to using the `+=` method. Only the master can access the accumulator's `value`.
+ * to using the `+=` method. Only the driver can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
new Accumulator(initialValue, param)
/**
* Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`.
- * Only the master can access the accumuable's `value`.
+ * Only the driver can access the accumuable's `value`.
* @tparam T accumulator type
* @tparam R type that can be added to the accumulator
*/
@@ -409,9 +443,10 @@ class SparkContext(
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
/**
- * Add a file to be downloaded into the working directory of this Spark job on every node.
+ * Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
- * filesystems), or an HTTP, HTTPS or FTP URI.
+ * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
+ * use `SparkFiles.get(path)` to find its download location.
*/
def addFile(path: String) {
val uri = new URI(path)
@@ -424,7 +459,7 @@ class SparkContext(
// Fetch the file locally in case a job is executed locally.
// Jobs that run through LocalScheduler will already fetch the required dependencies,
// but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
- Utils.fetchFile(path, new File("."))
+ Utils.fetchFile(path, new File(SparkFiles.getRootDirectory))
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
@@ -472,17 +507,23 @@ class SparkContext(
/** Shut down the SparkContext. */
def stop() {
- dagScheduler.stop()
- dagScheduler = null
- taskScheduler = null
- // TODO: Cache.stop()?
- env.stop()
- // Clean up locally linked files
- clearFiles()
- clearJars()
- SparkEnv.set(null)
- ShuffleMapTask.clearCache()
- logInfo("Successfully stopped SparkContext")
+ if (dagScheduler != null) {
+ metadataCleaner.cancel()
+ dagScheduler.stop()
+ dagScheduler = null
+ taskScheduler = null
+ // TODO: Cache.stop()?
+ env.stop()
+ // Clean up locally linked files
+ clearFiles()
+ clearJars()
+ SparkEnv.set(null)
+ ShuffleMapTask.clearCache()
+ ResultTask.clearCache()
+ logInfo("Successfully stopped SparkContext")
+ } else {
+ logInfo("SparkContext already stopped")
+ }
}
/**
@@ -503,26 +544,43 @@ class SparkContext(
}
/**
- * Run a function on a given set of partitions in an RDD and return the results. This is the main
- * entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies
- * whether the scheduler can run the computation on the master rather than shipping it out to the
- * cluster, for short actions like first().
+ * Run a function on a given set of partitions in an RDD and pass the results to the given
+ * handler function. This is the main entry point for all actions in Spark. The allowLocal
+ * flag specifies whether the scheduler can run the computation on the driver rather than
+ * shipping it out to the cluster, for short actions like first().
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
- allowLocal: Boolean
- ): Array[U] = {
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit) {
val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal)
+ val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
+ rdd.doCheckpoint()
result
}
/**
+ * Run a function on a given set of partitions in an RDD and return the results as an array. The
+ * allowLocal flag specifies whether the scheduler can run the computation on the driver rather
+ * than shipping it out to the cluster, for short actions like first().
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ allowLocal: Boolean
+ ): Array[U] = {
+ val results = new Array[U](partitions.size)
+ runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res)
+ results
+ }
+
+ /**
* Run a job on a given set of partitions of an RDD, but take a function of type
* `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
*/
@@ -550,6 +608,29 @@ class SparkContext(
}
/**
+ * Run a job on all partitions in an RDD and pass the results to a handler function.
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ processPartition: (TaskContext, Iterator[T]) => U,
+ resultHandler: (Int, U) => Unit)
+ {
+ runJob[T, U](rdd, processPartition, 0 until rdd.splits.size, false, resultHandler)
+ }
+
+ /**
+ * Run a job on all partitions in an RDD and pass the results to a handler function.
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ processPartition: Iterator[T] => U,
+ resultHandler: (Int, U) => Unit)
+ {
+ val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter)
+ runJob[T, U](rdd, processFunc, 0 until rdd.splits.size, false, resultHandler)
+ }
+
+ /**
* Run a job that can return approximate results.
*/
def runApproximateJob[T, U, R](
@@ -575,6 +656,26 @@ class SparkContext(
return f
}
+ /**
+ * Set the directory under which RDDs are going to be checkpointed. The directory must
+ * be a HDFS path if running on a cluster. If the directory does not exist, it will
+ * be created. If the directory exists and useExisting is set to true, then the
+ * exisiting directory will be used. Otherwise an exception will be thrown to
+ * prevent accidental overriding of checkpoint files in the existing directory.
+ */
+ def setCheckpointDir(dir: String, useExisting: Boolean = false) {
+ val path = new Path(dir)
+ val fs = path.getFileSystem(new Configuration())
+ if (!useExisting) {
+ if (fs.exists(path)) {
+ throw new Exception("Checkpoint directory '" + path + "' already exists.")
+ } else {
+ fs.mkdirs(path)
+ }
+ }
+ checkpointDir = Some(dir)
+ }
+
/** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */
def defaultParallelism: Int = taskScheduler.defaultParallelism
@@ -589,6 +690,11 @@ class SparkContext(
/** Register a new RDD, returning its RDD ID */
private[spark] def newRddId(): Int = nextRddId.getAndIncrement()
+
+ /** Called by MetadataCleaner to clean up the persistentRdds map periodically */
+ private[spark] def cleanup(cleanupTime: Long) {
+ persistentRdds.clearOldValues(cleanupTime)
+ }
}
/**
@@ -596,6 +702,7 @@ class SparkContext(
* various Spark features.
*/
object SparkContext {
+
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
@@ -606,6 +713,16 @@ object SparkContext {
def zero(initialValue: Int) = 0
}
+ implicit object LongAccumulatorParam extends AccumulatorParam[Long] {
+ def addInPlace(t1: Long, t2: Long) = t1 + t2
+ def zero(initialValue: Long) = 0l
+ }
+
+ implicit object FloatAccumulatorParam extends AccumulatorParam[Float] {
+ def addInPlace(t1: Float, t2: Float) = t1 + t2
+ def zero(initialValue: Float) = 0f
+ }
+
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 41441720a7..d2193ae72b 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -19,27 +19,23 @@ import spark.util.AkkaUtils
* SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
*/
class SparkEnv (
+ val executorId: String,
val actorSystem: ActorSystem,
val serializer: Serializer,
val closureSerializer: Serializer,
- val cacheTracker: CacheTracker,
+ val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
- val httpFileServer: HttpFileServer
+ val httpFileServer: HttpFileServer,
+ val sparkFilesDir: String
) {
- /** No-parameter constructor for unit tests. */
- def this() = {
- this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
- }
-
def stop() {
httpFileServer.stop()
mapOutputTracker.stop()
- cacheTracker.stop()
shuffleFetcher.stop()
broadcastManager.stop()
blockManager.stop()
@@ -63,17 +59,18 @@ object SparkEnv extends Logging {
}
def createFromSystemProperties(
+ executorId: String,
hostname: String,
port: Int,
- isMaster: Boolean,
- isLocal: Boolean
- ) : SparkEnv = {
+ isDriver: Boolean,
+ isLocal: Boolean): SparkEnv = {
+
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port)
- // Bit of a hack: If this is the master and our port was 0 (meaning bind to any free port),
- // figure out which port number Akka actually bound to and set spark.master.port to it.
- if (isMaster && port == 0) {
- System.setProperty("spark.master.port", boundPort.toString)
+ // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port),
+ // figure out which port number Akka actually bound to and set spark.driver.port to it.
+ if (isDriver && port == 0) {
+ System.setProperty("spark.driver.port", boundPort.toString)
}
val classLoader = Thread.currentThread.getContextClassLoader
@@ -87,23 +84,22 @@ object SparkEnv extends Logging {
val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
- val masterIp: String = System.getProperty("spark.master.host", "localhost")
- val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt
+ val driverIp: String = System.getProperty("spark.driver.host", "localhost")
+ val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
val blockManagerMaster = new BlockManagerMaster(
- actorSystem, isMaster, isLocal, masterIp, masterPort)
- val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer)
+ actorSystem, isDriver, isLocal, driverIp, driverPort)
+ val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer)
val connectionManager = blockManager.connectionManager
- val broadcastManager = new BroadcastManager(isMaster)
+ val broadcastManager = new BroadcastManager(isDriver)
val closureSerializer = instantiateClass[Serializer](
"spark.closure.serializer", "spark.JavaSerializer")
- val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager)
- blockManager.cacheTracker = cacheTracker
+ val cacheManager = new CacheManager(blockManager)
- val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster)
+ val mapOutputTracker = new MapOutputTracker(actorSystem, isDriver)
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
@@ -112,6 +108,15 @@ object SparkEnv extends Logging {
httpFileServer.initialize()
System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
+ // Set the sparkFiles directory, used when downloading dependencies. In local mode,
+ // this is a temporary directory; in distributed mode, this is the executor's current working
+ // directory.
+ val sparkFilesDir: String = if (isDriver) {
+ Utils.createTempDir().getAbsolutePath
+ } else {
+ "."
+ }
+
// Warn about deprecated spark.cache.class property
if (System.getProperty("spark.cache.class") != null) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
@@ -119,15 +124,17 @@ object SparkEnv extends Logging {
}
new SparkEnv(
+ executorId,
actorSystem,
serializer,
closureSerializer,
- cacheTracker,
+ cacheManager,
mapOutputTracker,
shuffleFetcher,
broadcastManager,
blockManager,
connectionManager,
- httpFileServer)
+ httpFileServer,
+ sparkFilesDir)
}
}
diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java
new file mode 100644
index 0000000000..566aec622c
--- /dev/null
+++ b/core/src/main/scala/spark/SparkFiles.java
@@ -0,0 +1,25 @@
+package spark;
+
+import java.io.File;
+
+/**
+ * Resolves paths to files added through `SparkContext.addFile()`.
+ */
+public class SparkFiles {
+
+ private SparkFiles() {}
+
+ /**
+ * Get the absolute path of a file added through `SparkContext.addFile()`.
+ */
+ public static String get(String filename) {
+ return new File(getRootDirectory(), filename).getAbsolutePath();
+ }
+
+ /**
+ * Get the root directory that contains files added through `SparkContext.addFile()`.
+ */
+ public static String getRootDirectory() {
+ return SparkEnv.get().sparkFilesDir();
+ }
+}
diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala
index d2746b26b3..eab85f85a2 100644
--- a/core/src/main/scala/spark/TaskContext.scala
+++ b/core/src/main/scala/spark/TaskContext.scala
@@ -5,8 +5,7 @@ import scala.collection.mutable.ArrayBuffer
class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable {
- @transient
- val onCompleteCallbacks = new ArrayBuffer[() => Unit]
+ @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
// Add a callback function to be executed on task completion. An example use
// is for HadoopRDD to register a callback to close the input stream.
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 0e7007459d..28d643abca 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -1,7 +1,7 @@
package spark
import java.io._
-import java.net.{NetworkInterface, InetAddress, URL, URI}
+import java.net._
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.conf.Configuration
@@ -10,6 +10,9 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import scala.io.Source
import com.google.common.io.Files
+import com.google.common.util.concurrent.ThreadFactoryBuilder
+import scala.Some
+import spark.serializer.SerializerInstance
/**
* Various utility methods used by Spark.
@@ -111,20 +114,6 @@ private object Utils extends Logging {
}
}
- /** Copy a file on the local file system */
- def copyFile(source: File, dest: File) {
- val in = new FileInputStream(source)
- val out = new FileOutputStream(dest)
- copyStream(in, out, true)
- }
-
- /** Download a file from a given URL to the local filesystem */
- def downloadFile(url: URL, localPath: String) {
- val in = url.openStream()
- val out = new FileOutputStream(localPath)
- Utils.copyStream(in, out, true)
- }
-
/**
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
@@ -134,7 +123,7 @@ private object Utils extends Logging {
*/
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
- val tempDir = System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))
+ val tempDir = getLocalDir
val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
val targetFile = new File(targetDir, filename)
val uri = new URI(url)
@@ -201,7 +190,16 @@ private object Utils extends Logging {
Utils.execute(Seq("tar", "-xf", filename), targetDir)
}
// Make the file executable - That's necessary for scripts
- FileUtil.chmod(filename, "a+x")
+ FileUtil.chmod(targetFile.getAbsolutePath, "a+x")
+ }
+
+ /**
+ * Get a temporary directory using Spark's spark.local.dir property, if set. This will always
+ * return a single directory, even though the spark.local.dir property might be a list of
+ * multiple paths.
+ */
+ def getLocalDir: String = {
+ System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0)
}
/**
@@ -242,7 +240,8 @@ private object Utils extends Logging {
// Address resolves to something like 127.0.1.1, which happens on Debian; try to find
// a better address using the local network interfaces
for (ni <- NetworkInterface.getNetworkInterfaces) {
- for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && !addr.isLoopbackAddress) {
+ for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress &&
+ !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) {
// We've found an address that looks reasonable!
logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" +
" a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress +
@@ -277,48 +276,28 @@ private object Utils extends Logging {
customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
}
- /**
- * Returns a standard ThreadFactory except all threads are daemons.
- */
- private def newDaemonThreadFactory: ThreadFactory = {
- new ThreadFactory {
- def newThread(r: Runnable): Thread = {
- var t = Executors.defaultThreadFactory.newThread (r)
- t.setDaemon (true)
- return t
- }
- }
- }
+ private[spark] val daemonThreadFactory: ThreadFactory =
+ new ThreadFactoryBuilder().setDaemon(true).build()
/**
* Wrapper over newCachedThreadPool.
*/
- def newDaemonCachedThreadPool(): ThreadPoolExecutor = {
- var threadPool = Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory (newDaemonThreadFactory)
-
- return threadPool
- }
+ def newDaemonCachedThreadPool(): ThreadPoolExecutor =
+ Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
/**
* Return the string to tell how long has passed in seconds. The passing parameter should be in
* millisecond.
*/
def getUsedTimeMs(startTimeMs: Long): String = {
- return " " + (System.currentTimeMillis - startTimeMs) + " ms "
+ return " " + (System.currentTimeMillis - startTimeMs) + " ms"
}
/**
* Wrapper over newFixedThreadPool.
*/
- def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
- var threadPool = Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory(newDaemonThreadFactory)
-
- return threadPool
- }
+ def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor =
+ Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
/**
* Delete a file or directory and its contents recursively.
@@ -454,4 +433,25 @@ private object Utils extends Logging {
}
"%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
}
+
+ /**
+ * Try to find a free port to bind to on the local host. This should ideally never be needed,
+ * except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray)
+ * don't let users bind to port 0 and then figure out which free port they actually bound to.
+ * We work around this by binding a ServerSocket and immediately unbinding it. This is *not*
+ * necessarily guaranteed to work, but it's the best we can do.
+ */
+ def findFreePort(): Int = {
+ val socket = new ServerSocket(0)
+ val portBound = socket.getLocalPort
+ socket.close()
+ portBound
+ }
+
+ /**
+ * Clone an object using a Spark serializer.
+ */
+ def clone[T](value: T, serializer: SerializerInstance): T = {
+ serializer.deserialize[T](serializer.serialize(value))
+ }
}
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index d15f6dd02f..60025b459c 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -9,9 +9,10 @@ import spark.api.java.JavaPairRDD._
import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
import spark.partial.{PartialResult, BoundedDouble}
import spark.storage.StorageLevel
+import com.google.common.base.Optional
-trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
+trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround[T] {
def wrapRDD(rdd: RDD[T]): This
implicit val classManifest: ClassManifest[T]
@@ -81,10 +82,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Return a new RDD by first applying a function to all elements of this
- * RDD, and then flattening the results.
+ * Part of the workaround for SPARK-668; called in PairFlatMapWorkaround.java.
*/
- def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
+ private[spark] def doFlatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
@@ -306,4 +306,33 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
JavaPairRDD.fromRDD(rdd.keyBy(f))
}
+
+ /**
+ * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
+ * directory set with SparkContext.setCheckpointDir() and all references to its parent
+ * RDDs will be removed. This function must be called before any job has been
+ * executed on this RDD. It is strongly recommended that this RDD is persisted in
+ * memory, otherwise saving it on a file will require recomputation.
+ */
+ def checkpoint() = rdd.checkpoint()
+
+ /**
+ * Return whether this RDD has been checkpointed or not
+ */
+ def isCheckpointed: Boolean = rdd.isCheckpointed
+
+ /**
+ * Gets the name of the file to which this RDD was checkpointed
+ */
+ def getCheckpointFile(): Optional[String] = {
+ rdd.getCheckpointFile match {
+ case Some(file) => Optional.of(file)
+ case _ => Optional.absent()
+ }
+ }
+
+ /** A description of this RDD and its recursive dependencies for debugging. */
+ def toDebugString(): String = {
+ rdd.toDebugString()
+ }
}
diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
index 88ab2846be..50b8970cd8 100644
--- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
@@ -323,9 +323,10 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
def getSparkHome(): Option[String] = sc.getSparkHome()
/**
- * Add a file to be downloaded into the working directory of this Spark job on every node.
+ * Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
- * filesystems), or an HTTP, HTTPS or FTP URI.
+ * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
+ * use `SparkFiles.get(path)` to find its download location.
*/
def addFile(path: String) {
sc.addFile(path)
@@ -355,6 +356,40 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
def clearFiles() {
sc.clearFiles()
}
+
+ /**
+ * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse.
+ */
+ def hadoopConfiguration(): Configuration = {
+ sc.hadoopConfiguration
+ }
+
+ /**
+ * Set the directory under which RDDs are going to be checkpointed. The directory must
+ * be a HDFS path if running on a cluster. If the directory does not exist, it will
+ * be created. If the directory exists and useExisting is set to true, then the
+ * exisiting directory will be used. Otherwise an exception will be thrown to
+ * prevent accidental overriding of checkpoint files in the existing directory.
+ */
+ def setCheckpointDir(dir: String, useExisting: Boolean) {
+ sc.setCheckpointDir(dir, useExisting)
+ }
+
+ /**
+ * Set the directory under which RDDs are going to be checkpointed. The directory must
+ * be a HDFS path if running on a cluster. If the directory does not exist, it will
+ * be created. If the directory exists, an exception will be thrown to prevent accidental
+ * overriding of checkpoint files.
+ */
+ def setCheckpointDir(dir: String) {
+ sc.setCheckpointDir(dir)
+ }
+
+ protected def checkpointFile[T](path: String): JavaRDD[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ new JavaRDD(sc.checkpointFile(path))
+ }
}
object JavaSparkContext {
diff --git a/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java
new file mode 100644
index 0000000000..68b6fd6622
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java
@@ -0,0 +1,20 @@
+package spark.api.java;
+
+import spark.api.java.JavaPairRDD;
+import spark.api.java.JavaRDDLike;
+import spark.api.java.function.PairFlatMapFunction;
+
+import java.io.Serializable;
+
+/**
+ * Workaround for SPARK-668.
+ */
+class PairFlatMapWorkaround<T> implements Serializable {
+ /**
+ * Return a new RDD by first applying a function to all elements of this
+ * RDD, and then flattening the results.
+ */
+ public <K, V> JavaPairRDD<K, V> flatMap(PairFlatMapFunction<T, K, V> f) {
+ return ((JavaRDDLike <T, ?>) this).doFlatMap(f);
+ }
+}
diff --git a/core/src/main/scala/spark/api/java/StorageLevels.java b/core/src/main/scala/spark/api/java/StorageLevels.java
index 722af3c06c..5e5845ac3a 100644
--- a/core/src/main/scala/spark/api/java/StorageLevels.java
+++ b/core/src/main/scala/spark/api/java/StorageLevels.java
@@ -17,4 +17,15 @@ public class StorageLevels {
public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2);
public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1);
public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2);
+
+ /**
+ * Create a new StorageLevel object.
+ * @param useDisk saved to disk, if true
+ * @param useMemory saved to memory, if true
+ * @param deserialized saved as deserialized objects, if true
+ * @param replication replication factor
+ */
+ public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) {
+ return StorageLevel.apply(useDisk, useMemory, deserialized, replication);
+ }
}
diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
index 648d9402b0..519e310323 100644
--- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala
+++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
@@ -6,8 +6,17 @@ import java.util.Arrays
/**
* A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
+ *
+ * Stores the unique id() of the Python-side partitioning function so that it is incorporated into
+ * equality comparisons. Correctness requires that the id is a unique identifier for the
+ * lifetime of the job (i.e. that it is not re-used as the id of a different partitioning
+ * function). This can be ensured by using the Python id() function and maintaining a reference
+ * to the Python partitioning function so that its id() is not reused.
*/
-private[spark] class PythonPartitioner(override val numPartitions: Int) extends Partitioner {
+private[spark] class PythonPartitioner(
+ override val numPartitions: Int,
+ val pyPartitionFunctionId: Long)
+ extends Partitioner {
override def getPartition(key: Any): Int = {
if (key == null) {
@@ -32,7 +41,7 @@ private[spark] class PythonPartitioner(override val numPartitions: Int) extends
override def equals(other: Any): Boolean = other match {
case h: PythonPartitioner =>
- h.numPartitions == numPartitions
+ h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId
case _ =>
false
}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index f431ef28d3..39758e94f4 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -1,7 +1,8 @@
package spark.api.python
import java.io._
-import java.util.{List => JList}
+import java.net._
+import java.util.{List => JList, ArrayList => JArrayList, Collections}
import scala.collection.JavaConversions._
import scala.io.Source
@@ -10,29 +11,28 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import spark.broadcast.Broadcast
import spark._
import spark.rdd.PipedRDD
-import java.util
private[spark] class PythonRDD[T: ClassManifest](
- parent: RDD[T],
- command: Seq[String],
- envVars: java.util.Map[String, String],
- preservePartitoning: Boolean,
- pythonExec: String,
- broadcastVars: java.util.List[Broadcast[Array[Byte]]])
- extends RDD[Array[Byte]](parent.context) {
+ parent: RDD[T],
+ command: Seq[String],
+ envVars: java.util.Map[String, String],
+ preservePartitoning: Boolean,
+ pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
+ accumulator: Accumulator[JList[Array[Byte]]])
+ extends RDD[Array[Byte]](parent) {
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
- preservePartitoning: Boolean, pythonExec: String,
- broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
+ preservePartitoning: Boolean, pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
+ accumulator: Accumulator[JList[Array[Byte]]]) =
this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
- broadcastVars)
+ broadcastVars, accumulator)
- override def splits = parent.splits
-
- override val dependencies = List(new OneToOneDependency(parent))
+ override def getSplits = parent.splits
override val partitioner = if (preservePartitoning) parent.partitioner else None
@@ -67,6 +67,8 @@ private[spark] class PythonRDD[T: ClassManifest](
val dOut = new DataOutputStream(proc.getOutputStream)
// Split index
dOut.writeInt(split.index)
+ // sparkFilesDir
+ PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut)
// Broadcast variables
dOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
@@ -93,18 +95,36 @@ private[spark] class PythonRDD[T: ClassManifest](
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(proc.getInputStream)
return new Iterator[Array[Byte]] {
- def next() = {
+ def next(): Array[Byte] = {
val obj = _nextObj
_nextObj = read()
obj
}
- private def read() = {
+ private def read(): Array[Byte] = {
try {
- val length = stream.readInt()
- val obj = new Array[Byte](length)
- stream.readFully(obj)
- obj
+ stream.readInt() match {
+ case length if length > 0 =>
+ val obj = new Array[Byte](length)
+ stream.readFully(obj)
+ obj
+ case -2 =>
+ // Signals that an exception has been thrown in python
+ val exLength = stream.readInt()
+ val obj = new Array[Byte](exLength)
+ stream.readFully(obj)
+ throw new PythonException(new String(obj))
+ case -1 =>
+ // We've finished the data section of the output, but we can still read some
+ // accumulator updates; let's do that, breaking when we get EOFException
+ while (true) {
+ val len2 = stream.readInt()
+ val update = new Array[Byte](len2)
+ stream.readFully(update)
+ accumulator += Collections.singletonList(update)
+ }
+ new Array[Byte](0)
+ }
} catch {
case eof: EOFException => {
val exitStatus = proc.waitFor()
@@ -126,14 +146,16 @@ private[spark] class PythonRDD[T: ClassManifest](
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
}
+/** Thrown for exceptions in user Python code. */
+private class PythonException(msg: String) extends Exception(msg)
+
/**
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
* This is used by PySpark's shuffle operations.
*/
private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
- RDD[(Array[Byte], Array[Byte])](prev.context) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
+ RDD[(Array[Byte], Array[Byte])](prev) {
+ override def getSplits = prev.splits
override def compute(split: Split, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (a, b)
@@ -238,11 +260,43 @@ private object Pickle {
val APPENDS: Byte = 'e'
}
-private class ExtractValue extends spark.api.java.function.Function[(Array[Byte],
- Array[Byte]), Array[Byte]] {
- override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2
-}
-
private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
}
+
+/**
+ * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
+ * collects a list of pickled strings that we pass to Python through a socket.
+ */
+class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
+ extends AccumulatorParam[JList[Array[Byte]]] {
+
+ override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
+
+ override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
+ : JList[Array[Byte]] = {
+ if (serverHost == null) {
+ // This happens on the worker node, where we just want to remember all the updates
+ val1.addAll(val2)
+ val1
+ } else {
+ // This happens on the master, where we pass the updates to Python through a socket
+ val socket = new Socket(serverHost, serverPort)
+ val in = socket.getInputStream
+ val out = new DataOutputStream(socket.getOutputStream)
+ out.writeInt(val2.size)
+ for (array <- val2) {
+ out.writeInt(array.length)
+ out.write(array)
+ }
+ out.flush()
+ // Wait for a byte from the Python side as an acknowledgement
+ val byteRead = in.read()
+ if (byteRead == -1) {
+ throw new SparkException("EOF reached before Python server acknowledged")
+ }
+ socket.close()
+ null
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
index 386f505f2a..adcb2d2415 100644
--- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
@@ -31,7 +31,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
@transient var totalBlocks = -1
@transient var hasBlocks = new AtomicInteger(0)
- // Used ONLY by Master to track how many unique blocks have been sent out
+ // Used ONLY by driver to track how many unique blocks have been sent out
@transient var sentBlocks = new AtomicInteger(0)
@transient var listenPortLock = new Object
@@ -42,7 +42,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
@transient var serveMR: ServeMultipleRequests = null
- // Used only in Master
+ // Used only in driver
@transient var guideMR: GuideMultipleRequests = null
// Used only in Workers
@@ -99,14 +99,14 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
}
// Must always come AFTER listenPort is created
- val masterSource =
+ val driverSource =
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
hasBlocksBitVector.synchronized {
- masterSource.hasBlocksBitVector = hasBlocksBitVector
+ driverSource.hasBlocksBitVector = hasBlocksBitVector
}
// In the beginning, this is the only known source to Guide
- listOfSources += masterSource
+ listOfSources += driverSource
// Register with the Tracker
MultiTracker.registerBroadcast(id,
@@ -122,7 +122,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
case None =>
logInfo("Started reading broadcast variable " + id)
- // Initializing everything because Master will only send null/0 values
+ // Initializing everything because driver will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables()
@@ -151,7 +151,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
}
}
- // Initialize variables in the worker node. Master sends everything as 0/null
+ // Initialize variables in the worker node. Driver sends everything as 0/null
private def initializeWorkerVariables() {
arrayOfBlocks = null
hasBlocksBitVector = null
@@ -248,7 +248,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
// Receive source information from Guide
var suitableSources =
oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]]
- logDebug("Received suitableSources from Master " + suitableSources)
+ logDebug("Received suitableSources from Driver " + suitableSources)
addToListOfSources(suitableSources)
@@ -532,7 +532,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
oosSource.writeObject(blockToAskFor)
oosSource.flush()
- // CHANGED: Master might send some other block than the one
+ // CHANGED: Driver might send some other block than the one
// requested to ensure fast spreading of all blocks.
val recvStartTime = System.currentTimeMillis
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
@@ -982,9 +982,9 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
// Receive which block to send
var blockToSend = ois.readObject.asInstanceOf[Int]
- // If it is master AND at least one copy of each block has not been
+ // If it is driver AND at least one copy of each block has not been
// sent out already, MODIFY blockToSend
- if (MultiTracker.isMaster && sentBlocks.get < totalBlocks) {
+ if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) {
blockToSend = sentBlocks.getAndIncrement
}
@@ -1031,7 +1031,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
private[spark] class BitTorrentBroadcastFactory
extends BroadcastFactory {
- def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) }
+ def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new BitTorrentBroadcast[T](value_, isLocal, id)
diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala
index 2ffe7f741d..415bde5d67 100644
--- a/core/src/main/scala/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/spark/broadcast/Broadcast.scala
@@ -15,7 +15,7 @@ abstract class Broadcast[T](private[spark] val id: Long) extends Serializable {
}
private[spark]
-class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable {
+class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable {
private var initialized = false
private var broadcastFactory: BroadcastFactory = null
@@ -33,7 +33,7 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Initialize appropriate BroadcastFactory and BroadcastObject
- broadcastFactory.initialize(isMaster)
+ broadcastFactory.initialize(isDriver)
initialized = true
}
@@ -49,5 +49,5 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl
def newBroadcast[T](value_ : T, isLocal: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
- def isMaster = isMaster_
+ def isDriver = _isDriver
}
diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
index ab6d302827..5c6184c3c7 100644
--- a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
@@ -7,7 +7,7 @@ package spark.broadcast
* entire Spark job.
*/
private[spark] trait BroadcastFactory {
- def initialize(isMaster: Boolean): Unit
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long): Broadcast[T]
+ def initialize(isDriver: Boolean): Unit
+ def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
}
diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
index 7eb4ddb74f..7e30b8f7d2 100644
--- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
@@ -11,6 +11,7 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark._
import spark.storage.StorageLevel
+import util.{MetadataCleaner, TimeStampedHashSet}
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
@@ -47,7 +48,7 @@ extends Broadcast[T](id) with Logging with Serializable {
}
private[spark] class HttpBroadcastFactory extends BroadcastFactory {
- def initialize(isMaster: Boolean) { HttpBroadcast.initialize(isMaster) }
+ def initialize(isDriver: Boolean) { HttpBroadcast.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
@@ -64,12 +65,16 @@ private object HttpBroadcast extends Logging {
private var serverUri: String = null
private var server: HttpServer = null
- def initialize(isMaster: Boolean) {
+ private val files = new TimeStampedHashSet[String]
+ private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup)
+
+
+ def initialize(isDriver: Boolean) {
synchronized {
if (!initialized) {
bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
compress = System.getProperty("spark.broadcast.compress", "true").toBoolean
- if (isMaster) {
+ if (isDriver) {
createServer()
}
serverUri = System.getProperty("spark.httpBroadcast.uri")
@@ -85,11 +90,12 @@ private object HttpBroadcast extends Logging {
server = null
}
initialized = false
+ cleaner.cancel()
}
}
private def createServer() {
- broadcastDir = Utils.createTempDir()
+ broadcastDir = Utils.createTempDir(Utils.getLocalDir)
server = new HttpServer(broadcastDir)
server.start()
serverUri = server.uri
@@ -108,6 +114,7 @@ private object HttpBroadcast extends Logging {
val serOut = ser.serializeStream(out)
serOut.writeObject(value)
serOut.close()
+ files += file.getAbsolutePath
}
def read[T](id: Long): T = {
@@ -123,4 +130,21 @@ private object HttpBroadcast extends Logging {
serIn.close()
obj
}
+
+ def cleanup(cleanupTime: Long) {
+ val iterator = files.internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ val entry = iterator.next()
+ val (file, time) = (entry.getKey, entry.getValue)
+ if (time < cleanupTime) {
+ try {
+ iterator.remove()
+ new File(file.toString).delete()
+ logInfo("Deleted broadcast file '" + file + "'")
+ } catch {
+ case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e)
+ }
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/spark/broadcast/MultiTracker.scala b/core/src/main/scala/spark/broadcast/MultiTracker.scala
index 5e76dedb94..3fd77af73f 100644
--- a/core/src/main/scala/spark/broadcast/MultiTracker.scala
+++ b/core/src/main/scala/spark/broadcast/MultiTracker.scala
@@ -23,25 +23,24 @@ extends Logging {
var ranGen = new Random
private var initialized = false
- private var isMaster_ = false
+ private var _isDriver = false
private var stopBroadcast = false
private var trackMV: TrackMultipleValues = null
- def initialize(isMaster__ : Boolean) {
+ def initialize(__isDriver: Boolean) {
synchronized {
if (!initialized) {
+ _isDriver = __isDriver
- isMaster_ = isMaster__
-
- if (isMaster) {
+ if (isDriver) {
trackMV = new TrackMultipleValues
trackMV.setDaemon(true)
trackMV.start()
- // Set masterHostAddress to the master's IP address for the slaves to read
- System.setProperty("spark.MultiTracker.MasterHostAddress", Utils.localIpAddress)
+ // Set DriverHostAddress to the driver's IP address for the slaves to read
+ System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress)
}
initialized = true
@@ -54,10 +53,10 @@ extends Logging {
}
// Load common parameters
- private var MasterHostAddress_ = System.getProperty(
- "spark.MultiTracker.MasterHostAddress", "")
- private var MasterTrackerPort_ = System.getProperty(
- "spark.broadcast.masterTrackerPort", "11111").toInt
+ private var DriverHostAddress_ = System.getProperty(
+ "spark.MultiTracker.DriverHostAddress", "")
+ private var DriverTrackerPort_ = System.getProperty(
+ "spark.broadcast.driverTrackerPort", "11111").toInt
private var BlockSize_ = System.getProperty(
"spark.broadcast.blockSize", "4096").toInt * 1024
private var MaxRetryCount_ = System.getProperty(
@@ -91,11 +90,11 @@ extends Logging {
private var EndGameFraction_ = System.getProperty(
"spark.broadcast.endGameFraction", "0.95").toDouble
- def isMaster = isMaster_
+ def isDriver = _isDriver
// Common config params
- def MasterHostAddress = MasterHostAddress_
- def MasterTrackerPort = MasterTrackerPort_
+ def DriverHostAddress = DriverHostAddress_
+ def DriverTrackerPort = DriverTrackerPort_
def BlockSize = BlockSize_
def MaxRetryCount = MaxRetryCount_
@@ -123,7 +122,7 @@ extends Logging {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
- serverSocket = new ServerSocket(MasterTrackerPort)
+ serverSocket = new ServerSocket(DriverTrackerPort)
logInfo("TrackMultipleValues started at " + serverSocket)
try {
@@ -235,7 +234,7 @@ extends Logging {
try {
// Connect to the tracker to find out GuideInfo
clientSocketToTracker =
- new Socket(MultiTracker.MasterHostAddress, MultiTracker.MasterTrackerPort)
+ new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort)
oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush()
@@ -276,7 +275,7 @@ extends Logging {
}
def registerBroadcast(id: Long, gInfo: SourceInfo) {
- val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
+ val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream)
@@ -303,7 +302,7 @@ extends Logging {
}
def unregisterBroadcast(id: Long) {
- val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
+ val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream)
diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
index f573512835..c55c476117 100644
--- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
@@ -98,7 +98,7 @@ extends Broadcast[T](id) with Logging with Serializable {
case None =>
logInfo("Started reading broadcast variable " + id)
- // Initializing everything because Master will only send null/0 values
+ // Initializing everything because Driver will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables()
@@ -157,55 +157,55 @@ extends Broadcast[T](id) with Logging with Serializable {
listenPortLock.synchronized { listenPortLock.wait() }
}
- var clientSocketToMaster: Socket = null
- var oosMaster: ObjectOutputStream = null
- var oisMaster: ObjectInputStream = null
+ var clientSocketToDriver: Socket = null
+ var oosDriver: ObjectOutputStream = null
+ var oisDriver: ObjectInputStream = null
// Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures
var retriesLeft = MultiTracker.MaxRetryCount
do {
- // Connect to Master and send this worker's Information
- clientSocketToMaster = new Socket(MultiTracker.MasterHostAddress, gInfo.listenPort)
- oosMaster = new ObjectOutputStream(clientSocketToMaster.getOutputStream)
- oosMaster.flush()
- oisMaster = new ObjectInputStream(clientSocketToMaster.getInputStream)
+ // Connect to Driver and send this worker's Information
+ clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort)
+ oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream)
+ oosDriver.flush()
+ oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream)
- logDebug("Connected to Master's guiding object")
+ logDebug("Connected to Driver's guiding object")
// Send local source information
- oosMaster.writeObject(SourceInfo(hostAddress, listenPort))
- oosMaster.flush()
+ oosDriver.writeObject(SourceInfo(hostAddress, listenPort))
+ oosDriver.flush()
- // Receive source information from Master
- var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
+ // Receive source information from Driver
+ var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
totalBytes = sourceInfo.totalBytes
- logDebug("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
+ logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission(sourceInfo)
val time = (System.nanoTime - start) / 1e9
- // Updating some statistics in sourceInfo. Master will be using them later
+ // Updating some statistics in sourceInfo. Driver will be using them later
if (!receptionSucceeded) {
sourceInfo.receptionFailed = true
}
- // Send back statistics to the Master
- oosMaster.writeObject(sourceInfo)
+ // Send back statistics to the Driver
+ oosDriver.writeObject(sourceInfo)
- if (oisMaster != null) {
- oisMaster.close()
+ if (oisDriver != null) {
+ oisDriver.close()
}
- if (oosMaster != null) {
- oosMaster.close()
+ if (oosDriver != null) {
+ oosDriver.close()
}
- if (clientSocketToMaster != null) {
- clientSocketToMaster.close()
+ if (clientSocketToDriver != null) {
+ clientSocketToDriver.close()
}
retriesLeft -= 1
@@ -552,7 +552,7 @@ extends Broadcast[T](id) with Logging with Serializable {
}
private def sendObject() {
- // Wait till receiving the SourceInfo from Master
+ // Wait till receiving the SourceInfo from Driver
while (totalBlocks == -1) {
totalBlocksLock.synchronized { totalBlocksLock.wait() }
}
@@ -576,7 +576,7 @@ extends Broadcast[T](id) with Logging with Serializable {
private[spark] class TreeBroadcastFactory
extends BroadcastFactory {
- def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) }
+ def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TreeBroadcast[T](value_, isLocal, id)
diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala
index 457122745b..35f40c6e91 100644
--- a/core/src/main/scala/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/spark/deploy/DeployMessage.scala
@@ -4,7 +4,6 @@ import spark.deploy.ExecutorState.ExecutorState
import spark.deploy.master.{WorkerInfo, JobInfo}
import spark.deploy.worker.ExecutorRunner
import scala.collection.immutable.List
-import scala.collection.mutable.HashMap
private[spark] sealed trait DeployMessage extends Serializable
@@ -42,7 +41,8 @@ private[spark] case class LaunchExecutor(
execId: Int,
jobDesc: JobDescription,
cores: Int,
- memory: Int)
+ memory: Int,
+ sparkHome: String)
extends DeployMessage
diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/JobDescription.scala
index 20879c5f11..7160fc05fc 100644
--- a/core/src/main/scala/spark/deploy/JobDescription.scala
+++ b/core/src/main/scala/spark/deploy/JobDescription.scala
@@ -4,7 +4,8 @@ private[spark] class JobDescription(
val name: String,
val cores: Int,
val memoryPerSlave: Int,
- val command: Command)
+ val command: Command,
+ val sparkHome: String)
extends Serializable {
val user = System.getProperty("user.name", "<unknown>")
diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
index 4211d80596..2836574ecb 100644
--- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
@@ -9,8 +9,14 @@ import spark.{Logging, Utils}
import scala.collection.mutable.ArrayBuffer
+/**
+ * Testing class that creates a Spark standalone process in-cluster (that is, running the
+ * spark.deploy.master.Master and spark.deploy.worker.Workers in the same JVMs). Executors launched
+ * by the Workers still run in separate JVMs. This can be used to test distributed operation and
+ * fault recovery without spinning up a lot of processes.
+ */
private[spark]
-class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging {
+class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
val localIpAddress = Utils.localIpAddress
@@ -19,33 +25,28 @@ class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int)
var masterPort : Int = _
var masterUrl : String = _
- val slaveActorSystems = ArrayBuffer[ActorSystem]()
- val slaveActors = ArrayBuffer[ActorRef]()
+ val workerActorSystems = ArrayBuffer[ActorSystem]()
+ val workerActors = ArrayBuffer[ActorRef]()
def start() : String = {
- logInfo("Starting a local Spark cluster with " + numSlaves + " slaves.")
+ logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0)
masterActorSystem = actorSystem
masterUrl = "spark://" + localIpAddress + ":" + masterPort
- val actor = masterActorSystem.actorOf(
+ masterActor = masterActorSystem.actorOf(
Props(new Master(localIpAddress, masterPort, 0)), name = "Master")
- masterActor = actor
/* Start the Slaves */
- for (slaveNum <- 1 to numSlaves) {
- /* We can pretend to test distributed stuff by giving the slaves distinct hostnames.
- All of 127/8 should be a loopback, we use 127.100.*.* in hopes that it is
- sufficiently distinctive. */
- val slaveIpAddress = "127.100.0." + (slaveNum % 256)
+ for (workerNum <- 1 to numWorkers) {
val (actorSystem, boundPort) =
- AkkaUtils.createActorSystem("sparkWorker" + slaveNum, slaveIpAddress, 0)
- slaveActorSystems += actorSystem
+ AkkaUtils.createActorSystem("sparkWorker" + workerNum, localIpAddress, 0)
+ workerActorSystems += actorSystem
val actor = actorSystem.actorOf(
- Props(new Worker(slaveIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)),
- name = "Worker")
- slaveActors += actor
+ Props(new Worker(localIpAddress, boundPort, 0, coresPerWorker, memoryPerWorker, masterUrl)),
+ name = "Worker")
+ workerActors += actor
}
return masterUrl
@@ -53,9 +54,9 @@ class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int)
def stop() {
logInfo("Shutting down local Spark cluster.")
- // Stop the slaves before the master so they don't get upset that it disconnected
- slaveActorSystems.foreach(_.shutdown())
- slaveActorSystems.foreach(_.awaitTermination())
+ // Stop the workers before the master so they don't get upset that it disconnected
+ workerActorSystems.foreach(_.shutdown())
+ workerActorSystems.foreach(_.awaitTermination())
masterActorSystem.shutdown()
masterActorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala
index da6abcc9c2..7035f4b394 100644
--- a/core/src/main/scala/spark/deploy/client/ClientListener.scala
+++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala
@@ -12,7 +12,7 @@ private[spark] trait ClientListener {
def disconnected(): Unit
- def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int): Unit
+ def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit
- def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit
+ def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
}
diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala
index 57a7e123b7..8764c400e2 100644
--- a/core/src/main/scala/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/spark/deploy/client/TestClient.scala
@@ -25,7 +25,7 @@ private[spark] object TestClient {
val url = args(0)
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0)
val desc = new JobDescription(
- "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()))
+ "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home")
val listener = new TestListener
val client = new Client(actorSystem, url, desc, listener)
client.start()
diff --git a/core/src/main/scala/spark/deploy/master/JobInfo.scala b/core/src/main/scala/spark/deploy/master/JobInfo.scala
index 130b031a2a..a274b21c34 100644
--- a/core/src/main/scala/spark/deploy/master/JobInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/JobInfo.scala
@@ -10,7 +10,7 @@ private[spark] class JobInfo(
val id: String,
val desc: JobDescription,
val submitDate: Date,
- val actor: ActorRef)
+ val driver: ActorRef)
{
var state = JobState.WAITING
var executors = new mutable.HashMap[Int, ExecutorInfo]
diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala
index 6ecebe626a..c618e87cdd 100644
--- a/core/src/main/scala/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/spark/deploy/master/Master.scala
@@ -88,7 +88,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
execOption match {
case Some(exec) => {
exec.state = state
- exec.job.actor ! ExecutorUpdated(execId, state, message, exitStatus)
+ exec.job.driver ! ExecutorUpdated(execId, state, message, exitStatus)
if (ExecutorState.isFinished(state)) {
val jobInfo = idToJob(jobId)
// Remove this executor from the worker and job
@@ -97,14 +97,12 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
exec.worker.removeExecutor(exec)
// Only retry certain number of times so we don't go into an infinite loop.
- if (jobInfo.incrementRetryCount <= JobState.MAX_NUM_RETRY) {
+ if (jobInfo.incrementRetryCount < JobState.MAX_NUM_RETRY) {
schedule()
} else {
- val e = new SparkException("Job %s wth ID %s failed %d times.".format(
+ logError("Job %s with ID %s failed %d times, removing it".format(
jobInfo.desc.name, jobInfo.id, jobInfo.retryCount))
- logError(e.getMessage, e)
- throw e
- //System.exit(1)
+ removeJob(jobInfo)
}
}
}
@@ -173,7 +171,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
for (pos <- 0 until numUsable) {
if (assigned(pos) > 0) {
val exec = job.addExecutor(usableWorkers(pos), assigned(pos))
- launchExecutor(usableWorkers(pos), exec)
+ launchExecutor(usableWorkers(pos), exec, job.desc.sparkHome)
job.state = JobState.RUNNING
}
}
@@ -186,7 +184,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
val coresToUse = math.min(worker.coresFree, job.coresLeft)
if (coresToUse > 0) {
val exec = job.addExecutor(worker, coresToUse)
- launchExecutor(worker, exec)
+ launchExecutor(worker, exec, job.desc.sparkHome)
job.state = JobState.RUNNING
}
}
@@ -195,11 +193,11 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
}
- def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) {
+ def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
- worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory)
- exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
+ worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome)
+ exec.job.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
}
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
@@ -221,19 +219,19 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
actorToWorker -= worker.actor
addressToWorker -= worker.actor.path.address
for (exec <- worker.executors.values) {
- exec.job.actor ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None)
+ exec.job.driver ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None)
exec.job.executors -= exec.id
}
}
- def addJob(desc: JobDescription, actor: ActorRef): JobInfo = {
+ def addJob(desc: JobDescription, driver: ActorRef): JobInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
- val job = new JobInfo(now, newJobId(date), desc, date, actor)
+ val job = new JobInfo(now, newJobId(date), desc, date, driver)
jobs += job
idToJob(job.id) = job
- actorToJob(sender) = job
- addressToJob(sender.path.address) = job
+ actorToJob(driver) = job
+ addressToJob(driver.path.address) = job
return job
}
@@ -242,8 +240,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
logInfo("Removing job " + job.id)
jobs -= job
idToJob -= job.id
- actorToJob -= job.actor
- addressToWorker -= job.actor.path.address
+ actorToJob -= job.driver
+ addressToWorker -= job.driver.path.address
completedJobs += job // Remember it in our history
waitingJobs -= job
for (exec <- job.executors.values) {
diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
index 458ee2d665..529f72e9da 100644
--- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
+++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
@@ -14,12 +14,15 @@ import cc.spray.typeconversion.SprayJsonSupport._
import spark.deploy._
import spark.deploy.JsonProtocol._
+/**
+ * Web UI server for the standalone master.
+ */
private[spark]
class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives {
val RESOURCE_DIR = "spark/deploy/master/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
- implicit val timeout = Timeout(1 seconds)
+ implicit val timeout = Timeout(10 seconds)
val handler = {
get {
@@ -42,13 +45,9 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) =>
val future = master ? RequestMasterState
val jobInfo = for (masterState <- future.mapTo[MasterState]) yield {
- masterState.activeJobs.find(_.id == jobId) match {
- case Some(job) => job
- case _ => masterState.completedJobs.find(_.id == jobId) match {
- case Some(job) => job
- case _ => null
- }
- }
+ masterState.activeJobs.find(_.id == jobId).getOrElse({
+ masterState.completedJobs.find(_.id == jobId).getOrElse(null)
+ })
}
respondWithMediaType(MediaTypes.`application/json`) { ctx =>
ctx.complete(jobInfo.mapTo[JobInfo])
@@ -58,14 +57,10 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
val future = master ? RequestMasterState
future.map { state =>
val masterState = state.asInstanceOf[MasterState]
-
- masterState.activeJobs.find(_.id == jobId) match {
- case Some(job) => spark.deploy.master.html.job_details.render(job)
- case _ => masterState.completedJobs.find(_.id == jobId) match {
- case Some(job) => spark.deploy.master.html.job_details.render(job)
- case _ => null
- }
- }
+ val job = masterState.activeJobs.find(_.id == jobId).getOrElse({
+ masterState.completedJobs.find(_.id == jobId).getOrElse(null)
+ })
+ spark.deploy.master.html.job_details.render(job)
}
}
}
@@ -76,5 +71,4 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
getFromResourceDirectory(RESOURCE_DIR)
}
}
-
}
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
index e910416235..4ef637090c 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
@@ -65,9 +65,9 @@ private[spark] class ExecutorRunner(
}
}
- /** Replace variables such as {{SLAVEID}} and {{CORES}} in a command argument passed to us */
+ /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */
def substituteVariables(argument: String): String = argument match {
- case "{{SLAVEID}}" => workerId
+ case "{{EXECUTOR_ID}}" => execId.toString
case "{{HOSTNAME}}" => hostname
case "{{CORES}}" => cores.toString
case other => other
@@ -106,11 +106,6 @@ private[spark] class ExecutorRunner(
throw new IOException("Failed to create directory " + executorDir)
}
- // Download the files it depends on into it (disabled for now)
- //for (url <- jobDesc.fileUrls) {
- // fetchFile(url, executorDir)
- //}
-
// Launch the process
val command = buildCommandSeq()
val builder = new ProcessBuilder(command: _*).directory(executorDir)
diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala
index 7c9e588ea2..8b41620d98 100644
--- a/core/src/main/scala/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/spark/deploy/worker/Worker.scala
@@ -119,10 +119,10 @@ private[spark] class Worker(
logError("Worker registration failed: " + message)
System.exit(1)
- case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_) =>
+ case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) =>
logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name))
val manager = new ExecutorRunner(
- jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, sparkHome, workDir)
+ jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
executors(jobId + "/" + execId) = manager
manager.start()
coresUsed += cores_
@@ -134,7 +134,9 @@ private[spark] class Worker(
val fullId = jobId + "/" + execId
if (ExecutorState.isFinished(state)) {
val executor = executors(fullId)
- logInfo("Executor " + fullId + " finished with state " + state)
+ logInfo("Executor " + fullId + " finished with state " + state +
+ message.map(" message " + _).getOrElse("") +
+ exitStatus.map(" exitStatus " + _).getOrElse(""))
finishedExecutors(fullId) = executor
executors -= fullId
coresUsed -= executor.cores
@@ -143,9 +145,13 @@ private[spark] class Worker(
case KillExecutor(jobId, execId) =>
val fullId = jobId + "/" + execId
- val executor = executors(fullId)
- logInfo("Asked to kill executor " + fullId)
- executor.kill()
+ executors.get(fullId) match {
+ case Some(executor) =>
+ logInfo("Asked to kill executor " + fullId)
+ executor.kill()
+ case None =>
+ logInfo("Asked to kill unknown executor " + fullId)
+ }
case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
masterDisconnected()
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
index f9489d99fc..ef81f072a3 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
@@ -13,12 +13,15 @@ import cc.spray.typeconversion.SprayJsonSupport._
import spark.deploy.{WorkerState, RequestWorkerState}
import spark.deploy.JsonProtocol._
+/**
+ * Web UI server for the standalone worker.
+ */
private[spark]
class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives {
val RESOURCE_DIR = "spark/deploy/worker/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
- implicit val timeout = Timeout(1 seconds)
+ implicit val timeout = Timeout(10 seconds)
val handler = {
get {
@@ -50,5 +53,4 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct
getFromResourceDirectory(RESOURCE_DIR)
}
}
-
}
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index 2552958d27..bd21ba719a 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -30,7 +30,7 @@ private[spark] class Executor extends Logging {
initLogging()
- def initialize(slaveHostname: String, properties: Seq[(String, String)]) {
+ def initialize(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) {
// Make sure the local hostname we report matches the cluster scheduler's name for this host
Utils.setCustomHostname(slaveHostname)
@@ -64,7 +64,7 @@ private[spark] class Executor extends Logging {
)
// Initialize Spark environment (using system properties read above)
- env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false)
+ env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
SparkEnv.set(env)
// Start worker thread pool
@@ -159,22 +159,24 @@ private[spark] class Executor extends Logging {
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
- // Fetch missing dependencies
- for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
- currentFiles(name) = timestamp
- }
- for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
- currentJars(name) = timestamp
- // Add it to our class loader
- val localName = name.split("/").last
- val url = new File(".", localName).toURI.toURL
- if (!urlClassLoader.getURLs.contains(url)) {
- logInfo("Adding " + url + " to class loader")
- urlClassLoader.addURL(url)
+ synchronized {
+ // Fetch missing dependencies
+ for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ currentFiles(name) = timestamp
+ }
+ for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ currentJars(name) = timestamp
+ // Add it to our class loader
+ val localName = name.split("/").last
+ val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
+ if (!urlClassLoader.getURLs.contains(url)) {
+ logInfo("Adding " + url + " to class loader")
+ urlClassLoader.addURL(url)
+ }
}
}
}
diff --git a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala
index eeab3959c6..818d6d1dda 100644
--- a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala
@@ -29,9 +29,14 @@ private[spark] class MesosExecutorBackend(executor: Executor)
executorInfo: ExecutorInfo,
frameworkInfo: FrameworkInfo,
slaveInfo: SlaveInfo) {
+ logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
this.driver = driver
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
- executor.initialize(slaveInfo.getHostname, properties)
+ executor.initialize(
+ executorInfo.getExecutorId.getValue,
+ slaveInfo.getHostname,
+ properties
+ )
}
override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
index 915f71ba9f..e45288ff53 100644
--- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
@@ -8,47 +8,44 @@ import akka.actor.{ActorRef, Actor, Props}
import java.util.concurrent.{TimeUnit, ThreadPoolExecutor, SynchronousQueue}
import akka.remote.RemoteClientLifeCycleEvent
import spark.scheduler.cluster._
-import spark.scheduler.cluster.RegisteredSlave
+import spark.scheduler.cluster.RegisteredExecutor
import spark.scheduler.cluster.LaunchTask
-import spark.scheduler.cluster.RegisterSlaveFailed
-import spark.scheduler.cluster.RegisterSlave
+import spark.scheduler.cluster.RegisterExecutorFailed
+import spark.scheduler.cluster.RegisterExecutor
private[spark] class StandaloneExecutorBackend(
executor: Executor,
- masterUrl: String,
- slaveId: String,
+ driverUrl: String,
+ executorId: String,
hostname: String,
cores: Int)
extends Actor
with ExecutorBackend
with Logging {
- val threadPool = new ThreadPoolExecutor(
- 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
-
- var master: ActorRef = null
+ var driver: ActorRef = null
override def preStart() {
try {
- logInfo("Connecting to master: " + masterUrl)
- master = context.actorFor(masterUrl)
- master ! RegisterSlave(slaveId, hostname, cores)
+ logInfo("Connecting to driver: " + driverUrl)
+ driver = context.actorFor(driverUrl)
+ driver ! RegisterExecutor(executorId, hostname, cores)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
+ context.watch(driver) // Doesn't work with remote actors, but useful for testing
} catch {
case e: Exception =>
- logError("Failed to connect to master", e)
+ logError("Failed to connect to driver", e)
System.exit(1)
}
}
override def receive = {
- case RegisteredSlave(sparkProperties) =>
- logInfo("Successfully registered with master")
- executor.initialize(hostname, sparkProperties)
+ case RegisteredExecutor(sparkProperties) =>
+ logInfo("Successfully registered with driver")
+ executor.initialize(executorId, hostname, sparkProperties)
- case RegisterSlaveFailed(message) =>
+ case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
System.exit(1)
@@ -58,24 +55,24 @@ private[spark] class StandaloneExecutorBackend(
}
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
- master ! StatusUpdate(slaveId, taskId, state, data)
+ driver ! StatusUpdate(executorId, taskId, state, data)
}
}
private[spark] object StandaloneExecutorBackend {
- def run(masterUrl: String, slaveId: String, hostname: String, cores: Int) {
+ def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0)
val actor = actorSystem.actorOf(
- Props(new StandaloneExecutorBackend(new Executor, masterUrl, slaveId, hostname, cores)),
+ Props(new StandaloneExecutorBackend(new Executor, driverUrl, executorId, hostname, cores)),
name = "Executor")
actorSystem.awaitTermination()
}
def main(args: Array[String]) {
if (args.length != 4) {
- System.err.println("Usage: StandaloneExecutorBackend <master> <slaveId> <hostname> <cores>")
+ System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores>")
System.exit(1)
}
run(args(0), args(1), args(2), args(3).toInt)
diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
index c193bf7c8d..cd5b7d57f3 100644
--- a/core/src/main/scala/spark/network/Connection.scala
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -12,7 +12,14 @@ import java.net._
private[spark]
-abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging {
+abstract class Connection(val channel: SocketChannel, val selector: Selector,
+ val remoteConnectionManagerId: ConnectionManagerId) extends Logging {
+ def this(channel_ : SocketChannel, selector_ : Selector) = {
+ this(channel_, selector_,
+ ConnectionManagerId.fromSocketAddress(
+ channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
+ ))
+ }
channel.configureBlocking(false)
channel.socket.setTcpNoDelay(true)
@@ -25,7 +32,6 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
val remoteAddress = getRemoteAddress()
- val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
def key() = channel.keyFor(selector)
@@ -103,8 +109,9 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
}
-private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector)
-extends Connection(SocketChannel.open, selector_) {
+private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
+ remoteId_ : ConnectionManagerId)
+extends Connection(SocketChannel.open, selector_, remoteId_) {
class Outbox(fair: Int = 0) {
val messages = new Queue[Message]()
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index 36c01ad629..c7f226044d 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -52,9 +52,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
val sendMessageRequests = new Queue[(Message, SendingConnection)]
- implicit val futureExecContext = ExecutionContext.fromExecutor(
- Executors.newCachedThreadPool(DaemonThreadFactory))
-
+ implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
+
var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
serverChannel.configureBlocking(false)
@@ -300,7 +299,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
- val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, new SendingConnection(inetSocketAddress, selector))
+ val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId,
+ new SendingConnection(inetSocketAddress, selector, connectionManagerId))
newConnection
}
val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
index 42f46e06ed..24b4909380 100644
--- a/core/src/main/scala/spark/partial/ApproximateActionListener.scala
+++ b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
@@ -32,7 +32,7 @@ private[spark] class ApproximateActionListener[T, U, R](
if (finishedTasks == totalTasks) {
// If we had already returned a PartialResult, set its final value
resultObject.foreach(r => r.setFinalValue(evaluator.currentResult()))
- // Notify any waiting thread that may have called getResult
+ // Notify any waiting thread that may have called awaitResult
this.notifyAll()
}
}
@@ -49,7 +49,7 @@ private[spark] class ApproximateActionListener[T, U, R](
* Waits for up to timeout milliseconds since the listener was created and then returns a
* PartialResult with the result so far. This may be complete if the whole job is done.
*/
- def getResult(): PartialResult[R] = synchronized {
+ def awaitResult(): PartialResult[R] = synchronized {
val finishTime = startTime + timeout
while (true) {
val time = System.currentTimeMillis()
diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index f98528a183..2c022f88e0 100644
--- a/core/src/main/scala/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -1,9 +1,7 @@
package spark.rdd
import scala.collection.mutable.HashMap
-
-import spark.{Dependency, RDD, SparkContext, SparkEnv, Split, TaskContext}
-
+import spark.{RDD, SparkContext, SparkEnv, Split, TaskContext}
private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
val index = idx
@@ -11,22 +9,20 @@ private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split
private[spark]
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
- extends RDD[T](sc) {
+ extends RDD[T](sc, Nil) {
- @transient
- val splits_ = (0 until blockIds.size).map(i => {
+ @transient var splits_ : Array[Split] = (0 until blockIds.size).map(i => {
new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split]
}).toArray
- @transient
- lazy val locations_ = {
+ @transient lazy val locations_ = {
val blockManager = SparkEnv.get.blockManager
/*val locations = blockIds.map(id => blockManager.getLocations(id))*/
val locations = blockManager.getLocations(blockIds)
HashMap(blockIds.zip(locations):_*)
}
- override def splits = splits_
+ override def getSplits = splits_
override def compute(split: Split, context: TaskContext): Iterator[T] = {
val blockManager = SparkEnv.get.blockManager
@@ -38,9 +34,11 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
}
}
- override def preferredLocations(split: Split) =
+ override def getPreferredLocations(split: Split) =
locations_(split.asInstanceOf[BlockRDDSplit].blockId)
- override val dependencies: List[Dependency[_]] = Nil
+ override def clearDependencies() {
+ splits_ = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
index 4a7e5f3d06..0f9ca06531 100644
--- a/core/src/main/scala/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -1,37 +1,51 @@
package spark.rdd
-import spark.{NarrowDependency, RDD, SparkContext, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+import spark._
private[spark]
-class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable {
+class CartesianSplit(
+ idx: Int,
+ @transient rdd1: RDD[_],
+ @transient rdd2: RDD[_],
+ s1Index: Int,
+ s2Index: Int
+ ) extends Split {
+ var s1 = rdd1.splits(s1Index)
+ var s2 = rdd2.splits(s2Index)
override val index: Int = idx
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ s1 = rdd1.splits(s1Index)
+ s2 = rdd2.splits(s2Index)
+ oos.defaultWriteObject()
+ }
}
private[spark]
class CartesianRDD[T: ClassManifest, U:ClassManifest](
sc: SparkContext,
- rdd1: RDD[T],
- rdd2: RDD[U])
- extends RDD[Pair[T, U]](sc)
+ var rdd1 : RDD[T],
+ var rdd2 : RDD[U])
+ extends RDD[Pair[T, U]](sc, Nil)
with Serializable {
val numSplitsInRdd2 = rdd2.splits.size
- @transient
- val splits_ = {
+ override def getSplits: Array[Split] = {
// create the cross product split
val array = new Array[Split](rdd1.splits.size * rdd2.splits.size)
for (s1 <- rdd1.splits; s2 <- rdd2.splits) {
val idx = s1.index * numSplitsInRdd2 + s2.index
- array(idx) = new CartesianSplit(idx, s1, s2)
+ array(idx) = new CartesianSplit(idx, rdd1, rdd2, s1.index, s2.index)
}
array
}
- override def splits = splits_
-
- override def preferredLocations(split: Split) = {
+ override def getPreferredLocations(split: Split) = {
val currSplit = split.asInstanceOf[CartesianSplit]
rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
}
@@ -42,7 +56,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
}
- override val dependencies = List(
+ override def getDependencies: Seq[Dependency[_]] = List(
new NarrowDependency(rdd1) {
def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2)
},
@@ -50,4 +64,9 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
def getParents(id: Int): Seq[Int] = List(id % numSplitsInRdd2)
}
)
+
+ override def clearDependencies() {
+ rdd1 = null
+ rdd2 = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
new file mode 100644
index 0000000000..96b593ba7c
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
@@ -0,0 +1,129 @@
+package spark.rdd
+
+import spark._
+import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.io.{NullWritable, BytesWritable}
+import org.apache.hadoop.util.ReflectionUtils
+import org.apache.hadoop.fs.Path
+import java.io.{File, IOException, EOFException}
+import java.text.NumberFormat
+
+private[spark] class CheckpointRDDSplit(val index: Int) extends Split {}
+
+/**
+ * This RDD represents a RDD checkpoint file (similar to HadoopRDD).
+ */
+private[spark]
+class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String)
+ extends RDD[T](sc, Nil) {
+
+ @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
+
+ @transient val splits_ : Array[Split] = {
+ val dirContents = fs.listStatus(new Path(checkpointPath))
+ val splitFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
+ val numSplits = splitFiles.size
+ if (!splitFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
+ !splitFiles(numSplits-1).endsWith(CheckpointRDD.splitIdToFile(numSplits-1))) {
+ throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
+ }
+ Array.tabulate(numSplits)(i => new CheckpointRDDSplit(i))
+ }
+
+ checkpointData = Some(new RDDCheckpointData[T](this))
+ checkpointData.get.cpFile = Some(checkpointPath)
+
+ override def getSplits = splits_
+
+ override def getPreferredLocations(split: Split): Seq[String] = {
+ val status = fs.getFileStatus(new Path(checkpointPath))
+ val locations = fs.getFileBlockLocations(status, 0, status.getLen)
+ locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
+ }
+
+ override def compute(split: Split, context: TaskContext): Iterator[T] = {
+ val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))
+ CheckpointRDD.readFromFile(file, context)
+ }
+
+ override def checkpoint() {
+ // Do nothing. CheckpointRDD should not be checkpointed.
+ }
+}
+
+private[spark] object CheckpointRDD extends Logging {
+
+ def splitIdToFile(splitId: Int): String = {
+ "part-%05d".format(splitId)
+ }
+
+ def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
+ val outputDir = new Path(path)
+ val fs = outputDir.getFileSystem(new Configuration())
+
+ val finalOutputName = splitIdToFile(ctx.splitId)
+ val finalOutputPath = new Path(outputDir, finalOutputName)
+ val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
+
+ if (fs.exists(tempOutputPath)) {
+ throw new IOException("Checkpoint failed: temporary path " +
+ tempOutputPath + " already exists")
+ }
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+
+ val fileOutputStream = if (blockSize < 0) {
+ fs.create(tempOutputPath, false, bufferSize)
+ } else {
+ // This is mainly for testing purpose
+ fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
+ }
+ val serializer = SparkEnv.get.serializer.newInstance()
+ val serializeStream = serializer.serializeStream(fileOutputStream)
+ serializeStream.writeAll(iterator)
+ serializeStream.close()
+
+ if (!fs.rename(tempOutputPath, finalOutputPath)) {
+ if (!fs.exists(finalOutputPath)) {
+ fs.delete(tempOutputPath, false)
+ throw new IOException("Checkpoint failed: failed to save output of task: "
+ + ctx.attemptId + " and final output path does not exist")
+ } else {
+ // Some other copy of this task must've finished before us and renamed it
+ logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
+ fs.delete(tempOutputPath, false)
+ }
+ }
+ }
+
+ def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
+ val fs = path.getFileSystem(new Configuration())
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ val fileInputStream = fs.open(path, bufferSize)
+ val serializer = SparkEnv.get.serializer.newInstance()
+ val deserializeStream = serializer.deserializeStream(fileInputStream)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback(() => deserializeStream.close())
+
+ deserializeStream.asIterator.asInstanceOf[Iterator[T]]
+ }
+
+ // Test whether CheckpointRDD generate expected number of splits despite
+ // each split file having multiple blocks. This needs to be run on a
+ // cluster (mesos or standalone) using HDFS.
+ def main(args: Array[String]) {
+ import spark._
+
+ val Array(cluster, hdfsPath) = args
+ val sc = new SparkContext(cluster, "CheckpointRDD Test")
+ val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
+ val path = new Path(hdfsPath, "temp")
+ val fs = path.getFileSystem(new Configuration())
+ sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
+ val cpRDD = new CheckpointRDD[Int](sc, path.toString)
+ assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same")
+ assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same")
+ fs.delete(path)
+ }
+}
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index ce5f171911..8fafd27bb6 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -1,5 +1,6 @@
package spark.rdd
+import java.io.{ObjectOutputStream, IOException}
import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer
@@ -9,7 +10,21 @@ import spark.{Dependency, OneToOneDependency, ShuffleDependency}
private[spark] sealed trait CoGroupSplitDep extends Serializable
-private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
+
+private[spark] case class NarrowCoGroupSplitDep(
+ rdd: RDD[_],
+ splitIndex: Int,
+ var split: Split
+ ) extends CoGroupSplitDep {
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ split = rdd.splits(splitIndex)
+ oos.defaultWriteObject()
+ }
+}
+
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
private[spark]
@@ -25,30 +40,29 @@ private[spark] class CoGroupAggregator
{ (b1, b2) => b1 ++ b2 })
with Serializable
-class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
- extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging {
+class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
+ extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging {
val aggr = new CoGroupAggregator
- @transient
- override val dependencies = {
+ @transient var deps_ = {
val deps = new ArrayBuffer[Dependency[_]]
for ((rdd, index) <- rdds.zipWithIndex) {
- val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
- if (mapSideCombinedRDD.partitioner == Some(part)) {
- logInfo("Adding one-to-one dependency with " + mapSideCombinedRDD)
- deps += new OneToOneDependency(mapSideCombinedRDD)
+ if (rdd.partitioner == Some(part)) {
+ logInfo("Adding one-to-one dependency with " + rdd)
+ deps += new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
+ val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
deps += new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
}
}
deps.toList
}
- @transient
- val splits_ : Array[Split] = {
- val firstRdd = rdds.head
+ override def getDependencies = deps_
+
+ @transient var splits_ : Array[Split] = {
val array = new Array[Split](part.numPartitions)
for (i <- 0 until array.size) {
array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) =>
@@ -56,19 +70,17 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
case s: ShuffleDependency[_, _] =>
new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep
case _ =>
- new NarrowCoGroupSplitDep(r, r.splits(i)): CoGroupSplitDep
+ new NarrowCoGroupSplitDep(r, i, r.splits(i)): CoGroupSplitDep
}
}.toList)
}
array
}
- override def splits = splits_
-
+ override def getSplits = splits_
+
override val partitioner = Some(part)
- override def preferredLocations(s: Split) = Nil
-
override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit]
val numRdds = split.deps.size
@@ -84,7 +96,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
}
}
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
- case NarrowCoGroupSplitDep(rdd, itsSplit) => {
+ case NarrowCoGroupSplitDep(rdd, itsSplitIndex, itsSplit) => {
// Read them from the parent
for ((k, v) <- rdd.iterator(itsSplit, context)) {
getSeq(k.asInstanceOf[K])(depNum) += v
@@ -103,4 +115,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
}
JavaConversions.mapAsScalaMap(map).iterator
}
+
+ override def clearDependencies() {
+ deps_ = null
+ splits_ = null
+ rdds = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
index 1affe0e0ef..4c57434b65 100644
--- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -1,9 +1,22 @@
package spark.rdd
-import spark.{NarrowDependency, RDD, Split, TaskContext}
-
-
-private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split
+import spark.{Dependency, OneToOneDependency, NarrowDependency, RDD, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+
+private[spark] case class CoalescedRDDSplit(
+ index: Int,
+ @transient rdd: RDD[_],
+ parentsIndices: Array[Int]
+ ) extends Split {
+ var parents: Seq[Split] = parentsIndices.map(rdd.splits(_))
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ parents = parentsIndices.map(rdd.splits(_))
+ oos.defaultWriteObject()
+ }
+}
/**
* Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of
@@ -13,34 +26,38 @@ private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) exten
* This transformation is useful when an RDD with many partitions gets filtered into a smaller one,
* or to avoid having a large number of small tasks when processing a directory with many files.
*/
-class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int)
- extends RDD[T](prev.context) {
+class CoalescedRDD[T: ClassManifest](
+ @transient var prev: RDD[T],
+ maxPartitions: Int)
+ extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
- @transient val splits_ : Array[Split] = {
+ override def getSplits: Array[Split] = {
val prevSplits = prev.splits
if (prevSplits.length < maxPartitions) {
- prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) }
+ prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) }
} else {
(0 until maxPartitions).map { i =>
val rangeStart = (i * prevSplits.length) / maxPartitions
val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions
- new CoalescedRDDSplit(i, prevSplits.slice(rangeStart, rangeEnd))
+ new CoalescedRDDSplit(i, prev, (rangeStart until rangeEnd).toArray)
}.toArray
}
}
- override def splits = splits_
-
override def compute(split: Split, context: TaskContext): Iterator[T] = {
- split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap {
- parentSplit => prev.iterator(parentSplit, context)
+ split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit =>
+ firstParent[T].iterator(parentSplit, context)
}
}
- val dependencies = List(
+ override def getDependencies: Seq[Dependency[_]] = List(
new NarrowDependency(prev) {
def getParents(id: Int): Seq[Int] =
- splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index)
+ splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices
}
)
+
+ override def clearDependencies() {
+ prev = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala
index d46549b8b6..6dbe235bd9 100644
--- a/core/src/main/scala/spark/rdd/FilteredRDD.scala
+++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala
@@ -2,11 +2,15 @@ package spark.rdd
import spark.{OneToOneDependency, RDD, Split, TaskContext}
+private[spark] class FilteredRDD[T: ClassManifest](
+ prev: RDD[T],
+ f: T => Boolean)
+ extends RDD[T](prev) {
+
+ override def getSplits = firstParent[T].splits
-private[spark]
-class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
override val partitioner = prev.partitioner // Since filter cannot change a partition's keys
- override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).filter(f)
+
+ override def compute(split: Split, context: TaskContext) =
+ firstParent[T].iterator(split, context).filter(f)
}
diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
index 785662b2da..1b604c66e2 100644
--- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
@@ -1,16 +1,16 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{RDD, Split, TaskContext}
+
private[spark]
class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: T => TraversableOnce[U])
- extends RDD[U](prev.context) {
+ extends RDD[U](prev) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
+ override def getSplits = firstParent[T].splits
override def compute(split: Split, context: TaskContext) =
- prev.iterator(split, context).flatMap(f)
+ firstParent[T].iterator(split, context).flatMap(f)
}
diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala
index fac8ffb4cb..051bffed19 100644
--- a/core/src/main/scala/spark/rdd/GlommedRDD.scala
+++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala
@@ -1,12 +1,12 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{RDD, Split, TaskContext}
+private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T])
+ extends RDD[Array[T]](prev) {
+
+ override def getSplits = firstParent[T].splits
-private[spark]
-class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split, context: TaskContext) =
- Array(prev.iterator(split, context).toArray).iterator
-} \ No newline at end of file
+ Array(firstParent[T].iterator(split, context).toArray).iterator
+}
diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index ab163f569b..f547f53812 100644
--- a/core/src/main/scala/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -22,9 +22,8 @@ import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskCo
* A Spark split class that wraps around a Hadoop InputSplit.
*/
private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
- extends Split
- with Serializable {
-
+ extends Split {
+
val inputSplit = new SerializableWritable[InputSplit](s)
override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt
@@ -43,7 +42,7 @@ class HadoopRDD[K, V](
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int)
- extends RDD[(K, V)](sc) {
+ extends RDD[(K, V)](sc, Nil) {
// A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it
val confBroadcast = sc.broadcast(new SerializableWritable(conf))
@@ -64,7 +63,7 @@ class HadoopRDD[K, V](
.asInstanceOf[InputFormat[K, V]]
}
- override def splits = splits_
+ override def getSplits = splits_
override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[HadoopSplit]
@@ -110,11 +109,13 @@ class HadoopRDD[K, V](
}
}
- override def preferredLocations(split: Split) = {
+ override def getPreferredLocations(split: Split) = {
// TODO: Filtering out "localhost" in case of file:// URLs
val hadoopSplit = split.asInstanceOf[HadoopSplit]
hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
}
- override val dependencies: List[Dependency[_]] = Nil
+ override def checkpoint() {
+ // Do nothing. Hadoop RDD should not be checkpointed.
+ }
}
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
index c764505345..073f7d7d2a 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
@@ -1,6 +1,6 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{RDD, Split, TaskContext}
private[spark]
@@ -8,11 +8,13 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: Iterator[T] => Iterator[U],
preservesPartitioning: Boolean = false)
- extends RDD[U](prev.context) {
+ extends RDD[U](prev) {
- override val partitioner = if (preservesPartitioning) prev.partitioner else None
+ override val partitioner =
+ if (preservesPartitioning) firstParent[T].partitioner else None
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split, context: TaskContext) = f(prev.iterator(split, context))
+ override def getSplits = firstParent[T].splits
+
+ override def compute(split: Split, context: TaskContext) =
+ f(firstParent[T].iterator(split, context))
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
index 3d9888bd34..2ddc3d01b6 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
@@ -1,6 +1,7 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{RDD, Split, TaskContext}
+
/**
* A variant of the MapPartitionsRDD that passes the split index into the
@@ -11,12 +12,13 @@ private[spark]
class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: (Int, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean)
- extends RDD[U](prev.context) {
+ preservesPartitioning: Boolean
+ ) extends RDD[U](prev) {
+
+ override def getSplits = firstParent[T].splits
override val partitioner = if (preservesPartitioning) prev.partitioner else None
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
+
override def compute(split: Split, context: TaskContext) =
- f(split.index, prev.iterator(split, context))
+ f(split.index, firstParent[T].iterator(split, context))
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala
index 70fa8f4497..5466c9c657 100644
--- a/core/src/main/scala/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/MappedRDD.scala
@@ -1,14 +1,13 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{RDD, Split, TaskContext}
private[spark]
-class MappedRDD[U: ClassManifest, T: ClassManifest](
- prev: RDD[T],
- f: T => U)
- extends RDD[U](prev.context) {
-
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).map(f)
-} \ No newline at end of file
+class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U)
+ extends RDD[U](prev) {
+
+ override def getSplits = firstParent[T].splits
+
+ override def compute(split: Split, context: TaskContext) =
+ firstParent[T].iterator(split, context).map(f)
+}
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index 197ed5ea17..c3b155fcbd 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -20,11 +20,12 @@ class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit
}
class NewHadoopRDD[K, V](
- sc: SparkContext,
+ sc : SparkContext,
inputFormatClass: Class[_ <: InputFormat[K, V]],
- keyClass: Class[K], valueClass: Class[V],
+ keyClass: Class[K],
+ valueClass: Class[V],
@transient conf: Configuration)
- extends RDD[(K, V)](sc)
+ extends RDD[(K, V)](sc, Nil)
with HadoopMapReduceUtil {
// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
@@ -36,11 +37,9 @@ class NewHadoopRDD[K, V](
formatter.format(new Date())
}
- @transient
- private val jobId = new JobID(jobtrackerId, id)
+ @transient private val jobId = new JobID(jobtrackerId, id)
- @transient
- private val splits_ : Array[Split] = {
+ @transient private val splits_ : Array[Split] = {
val inputFormat = inputFormatClass.newInstance
val jobContext = newJobContext(conf, jobId)
val rawSplits = inputFormat.getSplits(jobContext).toArray
@@ -51,7 +50,7 @@ class NewHadoopRDD[K, V](
result
}
- override def splits = splits_
+ override def getSplits = splits_
override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopSplit]
@@ -86,10 +85,8 @@ class NewHadoopRDD[K, V](
}
}
- override def preferredLocations(split: Split) = {
+ override def getPreferredLocations(split: Split) = {
val theSplit = split.asInstanceOf[NewHadoopSplit]
theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
}
-
- override val dependencies: List[Dependency[_]] = Nil
}
diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
new file mode 100644
index 0000000000..a50ce75171
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
@@ -0,0 +1,42 @@
+package spark.rdd
+
+import spark.{NarrowDependency, RDD, SparkEnv, Split, TaskContext}
+
+
+class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split {
+ override val index = idx
+}
+
+
+/**
+ * Represents a dependency between the PartitionPruningRDD and its parent. In this
+ * case, the child RDD contains a subset of partitions of the parents'.
+ */
+class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean)
+ extends NarrowDependency[T](rdd) {
+
+ @transient
+ val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index))
+ .zipWithIndex.map { case(split, idx) => new PartitionPruningRDDSplit(idx, split) : Split }
+
+ override def getParents(partitionId: Int) = List(partitions(partitionId).index)
+}
+
+
+/**
+ * A RDD used to prune RDD partitions/splits so we can avoid launching tasks on
+ * all partitions. An example use case: If we know the RDD is partitioned by range,
+ * and the execution DAG has a filter on the key, we can avoid launching tasks
+ * on partitions that don't have the range covering the key.
+ */
+class PartitionPruningRDD[T: ClassManifest](
+ @transient prev: RDD[T],
+ @transient partitionFilterFunc: Int => Boolean)
+ extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) {
+
+ override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(
+ split.asInstanceOf[PartitionPruningRDDSplit].parentSplit, context)
+
+ override protected def getSplits =
+ getDependencies.head.asInstanceOf[PruneDependency[T]].partitions
+}
diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index 336e193217..6631f83510 100644
--- a/core/src/main/scala/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -8,7 +8,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
-import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext}
+import spark.{RDD, SparkEnv, Split, TaskContext}
/**
@@ -16,18 +16,18 @@ import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext}
* (printing them one per line) and returns the output as a collection of strings.
*/
class PipedRDD[T: ClassManifest](
- parent: RDD[T], command: Seq[String], envVars: Map[String, String])
- extends RDD[String](parent.context) {
+ prev: RDD[T],
+ command: Seq[String],
+ envVars: Map[String, String])
+ extends RDD[String](prev) {
- def this(parent: RDD[T], command: Seq[String]) = this(parent, command, Map())
+ def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map())
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
- def this(parent: RDD[T], command: String) = this(parent, PipedRDD.tokenize(command))
+ def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command))
- override def splits = parent.splits
-
- override val dependencies = List(new OneToOneDependency(parent))
+ override def getSplits = firstParent[T].splits
override def compute(split: Split, context: TaskContext): Iterator[String] = {
val pb = new ProcessBuilder(command)
@@ -52,7 +52,7 @@ class PipedRDD[T: ClassManifest](
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
- for (elem <- parent.iterator(split, context)) {
+ for (elem <- firstParent[T].iterator(split, context)) {
out.println(elem)
}
out.close()
diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala
index 6e4797aabb..e24ad23b21 100644
--- a/core/src/main/scala/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/spark/rdd/SampledRDD.scala
@@ -1,11 +1,11 @@
package spark.rdd
import java.util.Random
+
import cern.jet.random.Poisson
import cern.jet.random.engine.DRand
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
-
+import spark.{RDD, Split, TaskContext}
private[spark]
class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable {
@@ -14,23 +14,20 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali
class SampledRDD[T: ClassManifest](
prev: RDD[T],
- withReplacement: Boolean,
+ withReplacement: Boolean,
frac: Double,
seed: Int)
- extends RDD[T](prev.context) {
+ extends RDD[T](prev) {
- @transient
- val splits_ = {
+ @transient var splits_ : Array[Split] = {
val rg = new Random(seed)
- prev.splits.map(x => new SampledRDDSplit(x, rg.nextInt))
+ firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt))
}
- override def splits = splits_.asInstanceOf[Array[Split]]
-
- override val dependencies = List(new OneToOneDependency(prev))
+ override def getSplits = splits_
- override def preferredLocations(split: Split) =
- prev.preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)
+ override def getPreferredLocations(split: Split) =
+ firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)
override def compute(splitIn: Split, context: TaskContext) = {
val split = splitIn.asInstanceOf[SampledRDDSplit]
@@ -38,7 +35,7 @@ class SampledRDD[T: ClassManifest](
// For large datasets, the expected number of occurrences of each element in a sample with
// replacement is Poisson(frac). We use that to get a count for each element.
val poisson = new Poisson(frac, new DRand(split.seed))
- prev.iterator(split.prev, context).flatMap { element =>
+ firstParent[T].iterator(split.prev, context).flatMap { element =>
val count = poisson.nextInt()
if (count == 0) {
Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
@@ -48,7 +45,11 @@ class SampledRDD[T: ClassManifest](
}
} else { // Sampling without replacement
val rand = new Random(split.seed)
- prev.iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
+ firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
}
}
+
+ override def clearDependencies() {
+ splits_ = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index f832633646..d396478673 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -1,7 +1,7 @@
package spark.rdd
-import spark.{OneToOneDependency, Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext}
-
+import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext}
+import spark.SparkContext._
private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
override val index = idx
@@ -10,28 +10,22 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
/**
* The resulting RDD from a shuffle (e.g. repartitioning of data).
- * @param parent the parent RDD.
+ * @param prev the parent RDD.
* @param part the partitioner used to partition the RDD
* @tparam K the key class.
* @tparam V the value class.
*/
class ShuffledRDD[K, V](
- @transient parent: RDD[(K, V)],
- part: Partitioner) extends RDD[(K, V)](parent.context) {
+ prev: RDD[(K, V)],
+ part: Partitioner)
+ extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) {
override val partitioner = Some(part)
- @transient
- val splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
-
- override def splits = splits_
-
- override def preferredLocations(split: Split) = Nil
-
- val dep = new ShuffleDependency(parent, part)
- override val dependencies = List(dep)
+ override def getSplits = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = {
- SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index)
+ val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
+ SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index)
}
}
diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala
index a08473f7be..26a2d511f2 100644
--- a/core/src/main/scala/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/spark/rdd/UnionRDD.scala
@@ -1,55 +1,60 @@
package spark.rdd
import scala.collection.mutable.ArrayBuffer
-
import spark.{Dependency, RangeDependency, RDD, SparkContext, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int)
+ extends Split {
-private[spark] class UnionSplit[T: ClassManifest](
- idx: Int,
- rdd: RDD[T],
- split: Split)
- extends Split
- with Serializable {
+ var split: Split = rdd.splits(splitIndex)
def iterator(context: TaskContext) = rdd.iterator(split, context)
+
def preferredLocations() = rdd.preferredLocations(split)
+
override val index: Int = idx
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ split = rdd.splits(splitIndex)
+ oos.defaultWriteObject()
+ }
}
class UnionRDD[T: ClassManifest](
sc: SparkContext,
- @transient rdds: Seq[RDD[T]])
- extends RDD[T](sc)
- with Serializable {
+ @transient var rdds: Seq[RDD[T]])
+ extends RDD[T](sc, Nil) { // Nil since we implement getDependencies
- @transient
- val splits_ : Array[Split] = {
+ override def getSplits: Array[Split] = {
val array = new Array[Split](rdds.map(_.splits.size).sum)
var pos = 0
for (rdd <- rdds; split <- rdd.splits) {
- array(pos) = new UnionSplit(pos, rdd, split)
+ array(pos) = new UnionSplit(pos, rdd, split.index)
pos += 1
}
array
}
- override def splits = splits_
-
- @transient
- override val dependencies = {
+ override def getDependencies: Seq[Dependency[_]] = {
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for (rdd <- rdds) {
deps += new RangeDependency(rdd, 0, pos, rdd.splits.size)
pos += rdd.splits.size
}
- deps.toList
+ deps
}
override def compute(s: Split, context: TaskContext): Iterator[T] =
s.asInstanceOf[UnionSplit[T]].iterator(context)
- override def preferredLocations(s: Split): Seq[String] =
+ override def getPreferredLocations(s: Split): Seq[String] =
s.asInstanceOf[UnionSplit[T]].preferredLocations()
+
+ override def clearDependencies() {
+ rdds = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
index 92d667ff1e..e5df6d8c72 100644
--- a/core/src/main/scala/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -1,53 +1,60 @@
package spark.rdd
import spark.{OneToOneDependency, RDD, SparkContext, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
idx: Int,
- rdd1: RDD[T],
- rdd2: RDD[U],
- split1: Split,
- split2: Split)
- extends Split
- with Serializable {
+ @transient rdd1: RDD[T],
+ @transient rdd2: RDD[U]
+ ) extends Split {
- def iterator(context: TaskContext): Iterator[(T, U)] =
- rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
+ var split1 = rdd1.splits(idx)
+ var split2 = rdd1.splits(idx)
+ override val index: Int = idx
- def preferredLocations(): Seq[String] =
- rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
+ def splits = (split1, split2)
- override val index: Int = idx
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ split1 = rdd1.splits(idx)
+ split2 = rdd2.splits(idx)
+ oos.defaultWriteObject()
+ }
}
class ZippedRDD[T: ClassManifest, U: ClassManifest](
sc: SparkContext,
- @transient rdd1: RDD[T],
- @transient rdd2: RDD[U])
- extends RDD[(T, U)](sc)
+ var rdd1: RDD[T],
+ var rdd2: RDD[U])
+ extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)))
with Serializable {
- @transient
- val splits_ : Array[Split] = {
+ override def getSplits: Array[Split] = {
if (rdd1.splits.size != rdd2.splits.size) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
}
val array = new Array[Split](rdd1.splits.size)
for (i <- 0 until rdd1.splits.size) {
- array(i) = new ZippedSplit(i, rdd1, rdd2, rdd1.splits(i), rdd2.splits(i))
+ array(i) = new ZippedSplit(i, rdd1, rdd2)
}
array
}
- override def splits = splits_
-
- @transient
- override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))
+ override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = {
+ val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
+ rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
+ }
- override def compute(s: Split, context: TaskContext): Iterator[(T, U)] =
- s.asInstanceOf[ZippedSplit[T, U]].iterator(context)
+ override def getPreferredLocations(s: Split): Seq[String] = {
+ val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
+ rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
+ }
- override def preferredLocations(s: Split): Seq[String] =
- s.asInstanceOf[ZippedSplit[T, U]].preferredLocations()
+ override def clearDependencies() {
+ rdd1 = null
+ rdd2 = null
+ }
}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 29757b1178..908a22b2df 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -14,6 +14,7 @@ import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
import spark.storage.BlockManagerMaster
import spark.storage.BlockManagerId
+import util.{MetadataCleaner, TimeStampedHashMap}
/**
* A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
@@ -34,12 +35,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
eventQueue.put(CompletionEvent(task, reason, result, accumUpdates))
}
- // Called by TaskScheduler when a host fails.
- override def hostLost(host: String) {
- eventQueue.put(HostLost(host))
+ // Called by TaskScheduler when an executor fails.
+ override def executorLost(execId: String) {
+ eventQueue.put(ExecutorLost(execId))
}
- // Called by TaskScheduler to cancel an entier TaskSet due to repeated failures.
+ // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
override def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
}
@@ -53,36 +54,41 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// resubmit failed stages
val POLL_TIMEOUT = 10L
- private val lock = new Object // Used for access to the entire DAGScheduler
-
private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
val nextRunId = new AtomicInteger(0)
val nextStageId = new AtomicInteger(0)
- val idToStage = new HashMap[Int, Stage]
+ val idToStage = new TimeStampedHashMap[Int, Stage]
- val shuffleToMapStage = new HashMap[Int, Stage]
+ val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
var cacheLocs = new HashMap[Int, Array[List[String]]]
val env = SparkEnv.get
- val cacheTracker = env.cacheTracker
val mapOutputTracker = env.mapOutputTracker
+ val blockManagerMaster = env.blockManager.master
- val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back;
- // that's not going to be a realistic assumption in general
+ // For tracking failed nodes, we use the MapOutputTracker's generation number, which is
+ // sent with every task. When we detect a node failing, we note the current generation number
+ // and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask
+ // results.
+ // TODO: Garbage collect information about failure generations when we know there are no more
+ // stray messages to detect.
+ val failedGeneration = new HashMap[String, Long]
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
val running = new HashSet[Stage] // Stages we are running right now
val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
- val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage
+ val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage
var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits
val activeJobs = new HashSet[ActiveJob]
val resultStageToJob = new HashMap[Stage, ActiveJob]
+ val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
+
// Start a thread to run the DAGScheduler event loop
new Thread("DAGScheduler") {
setDaemon(true)
@@ -91,12 +97,18 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}.start()
- def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
+ private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
+ if (!cacheLocs.contains(rdd.id)) {
+ val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
+ cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
+ locations => locations.map(_.ip).toList
+ }.toArray
+ }
cacheLocs(rdd.id)
}
- def updateCacheLocs() {
- cacheLocs = cacheTracker.getLocationsSnapshot()
+ private def clearCacheLocs() {
+ cacheLocs.clear()
}
/**
@@ -104,7 +116,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* The priority value passed in will be used if the stage doesn't already exist with
* a lower priority (we assume that priorities always increase across jobs for now).
*/
- def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
+ private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
@@ -119,12 +131,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* as a result stage for the final RDD used directly in an action. The stage will also be given
* the provided priority.
*/
- def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
- // Kind of ugly: need to register RDDs with the cache and map output tracker here
- // since we can't do it in the RDD constructor because # of splits is unknown
- logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
- cacheTracker.registerRDD(rdd.id, rdd.splits.size)
+ private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
if (shuffleDep != None) {
+ // Kind of ugly: need to register RDDs with the cache and map output tracker here
+ // since we can't do it in the RDD constructor because # of splits is unknown
+ logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
}
val id = nextStageId.getAndIncrement()
@@ -137,7 +148,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Get or create the list of parent stages for a given RDD. The stages will be assigned the
* provided priority if they haven't already been created with a lower priority.
*/
- def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
+ private def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
val parents = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(r: RDD[_]) {
@@ -145,8 +156,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visited += r
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown
- logInfo("Registering parent RDD " + r.id + " (" + r.origin + ")")
- cacheTracker.registerRDD(r.id, r.splits.size)
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
@@ -161,7 +170,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
parents.toList
}
- def getMissingParentStages(stage: Stage): List[Stage] = {
+ private def getMissingParentStages(stage: Stage): List[Stage] = {
val missing = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(rdd: RDD[_]) {
@@ -194,18 +203,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
- allowLocal: Boolean)
- : Array[U] =
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit)
{
if (partitions.size == 0) {
- return new Array[U](0)
+ return
}
- val waiter = new JobWaiter(partitions.size)
+ val waiter = new JobWaiter(partitions.size, resultHandler)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter))
- waiter.getResult() match {
- case JobSucceeded(results: Seq[_]) =>
- return results.asInstanceOf[Seq[U]].toArray
+ waiter.awaitResult() match {
+ case JobSucceeded => {}
case JobFailed(exception: Exception) =>
logInfo("Failed to run " + callSite)
throw exception
@@ -224,7 +232,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.splits.size).toArray
eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener))
- return listener.getResult() // Will throw an exception if the job fails
+ return listener.awaitResult() // Will throw an exception if the job fails
}
/**
@@ -232,7 +240,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* events and responds by launching tasks. This runs in a dedicated thread and receives events
* via the eventQueue.
*/
- def run() {
+ private def run() {
SparkEnv.set(env)
while (true) {
@@ -247,7 +255,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val runId = nextRunId.getAndIncrement()
val finalStage = newStage(finalRDD, None, runId)
val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
- updateCacheLocs()
+ clearCacheLocs()
logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
" output partitions")
logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
@@ -262,8 +270,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
submitStage(finalStage)
}
- case HostLost(host) =>
- handleHostLost(host)
+ case ExecutorLost(execId) =>
+ handleExecutorLost(execId)
case completion: CompletionEvent =>
handleTaskCompletion(completion)
@@ -290,7 +298,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// on the failed node.
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
logInfo("Resubmitting failed stages")
- updateCacheLocs()
+ clearCacheLocs()
val failed2 = failed.toArray
failed.clear()
for (stage <- failed2.sortBy(_.priority)) {
@@ -299,10 +307,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
} else {
// TODO: We might want to run this less often, when we are sure that something has become
// runnable that wasn't before.
- logDebug("Checking for newly runnable parent stages")
- logDebug("running: " + running)
- logDebug("waiting: " + waiting)
- logDebug("failed: " + failed)
+ logTrace("Checking for newly runnable parent stages")
+ logTrace("running: " + running)
+ logTrace("waiting: " + waiting)
+ logTrace("failed: " + failed)
val waiting2 = waiting.toArray
waiting.clear()
for (stage <- waiting2.sortBy(_.priority)) {
@@ -317,7 +325,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* We run the operation in a separate thread just in case it takes a bunch of time, so that we
* don't block the DAGScheduler event loop or other concurrent jobs.
*/
- def runLocally(job: ActiveJob) {
+ private def runLocally(job: ActiveJob) {
logInfo("Computing the requested partition locally")
new Thread("Local computation of job " + job.runId) {
override def run() {
@@ -326,9 +334,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val rdd = job.finalStage.rdd
val split = rdd.splits(job.partitions(0))
val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
- val result = job.func(taskContext, rdd.iterator(split, taskContext))
- taskContext.executeOnCompleteCallbacks()
- job.listener.taskSucceeded(0, result)
+ try {
+ val result = job.func(taskContext, rdd.iterator(split, taskContext))
+ job.listener.taskSucceeded(0, result)
+ } finally {
+ taskContext.executeOnCompleteCallbacks()
+ }
} catch {
case e: Exception =>
job.listener.jobFailed(e)
@@ -337,13 +348,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}.start()
}
- def submitStage(stage: Stage) {
+ /** Submits stage, but first recursively submits any missing parents. */
+ private def submitStage(stage: Stage) {
logDebug("submitStage(" + stage + ")")
if (!waiting(stage) && !running(stage) && !failed(stage)) {
val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing)
if (missing == Nil) {
- logInfo("Submitting " + stage + " (" + stage.origin + "), which has no missing parents")
+ logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
submitMissingTasks(stage)
running += stage
} else {
@@ -355,7 +367,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
- def submitMissingTasks(stage: Stage) {
+ /** Called when stage's parents are available and we can now do its task. */
+ private def submitMissingTasks(stage: Stage) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
@@ -376,11 +389,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
if (tasks.size > 0) {
- logInfo("Submitting " + tasks.size + " missing tasks from " + stage)
+ logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
myPending ++= tasks
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority))
+ if (!stage.submissionTime.isDefined) {
+ stage.submissionTime = Some(System.currentTimeMillis())
+ }
} else {
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@@ -392,9 +408,18 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
*/
- def handleTaskCompletion(event: CompletionEvent) {
+ private def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
val stage = idToStage(task.stageId)
+
+ def markStageAsFinished(stage: Stage) = {
+ val serviceTime = stage.submissionTime match {
+ case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
+ case _ => "Unkown"
+ }
+ logInfo("%s (%s) finished in %s s".format(stage, stage.origin, serviceTime))
+ running -= stage
+ }
event.reason match {
case Success =>
logInfo("Completed " + task)
@@ -409,13 +434,13 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
if (!job.finished(rt.outputId)) {
job.finished(rt.outputId) = true
job.numFinished += 1
- job.listener.taskSucceeded(rt.outputId, event.result)
// If the whole job has finished, remove it
if (job.numFinished == job.numPartitions) {
activeJobs -= job
resultStageToJob -= stage
- running -= stage
+ markStageAsFinished(stage)
}
+ job.listener.taskSucceeded(rt.outputId, event.result)
}
case None =>
logInfo("Ignoring result from " + rt + " because its job has finished")
@@ -424,23 +449,32 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
case smt: ShuffleMapTask =>
val stage = idToStage(smt.stageId)
val status = event.result.asInstanceOf[MapStatus]
- val host = status.address.ip
- logInfo("ShuffleMapTask finished with host " + host)
- if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos
+ val execId = status.location.executorId
+ logDebug("ShuffleMapTask finished on " + execId)
+ if (failedGeneration.contains(execId) && smt.generation <= failedGeneration(execId)) {
+ logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
+ } else {
stage.addOutputLoc(smt.partition, status)
}
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
- logInfo(stage + " (" + stage.origin + ") finished; looking for newly runnable stages")
- running -= stage
+ markStageAsFinished(stage)
+ logInfo("looking for newly runnable stages")
logInfo("running: " + running)
logInfo("waiting: " + waiting)
logInfo("failed: " + failed)
if (stage.shuffleDep != None) {
+ // We supply true to increment the generation number here in case this is a
+ // recomputation of the map outputs. In that case, some nodes may have cached
+ // locations with holes (from when we detected the error) and will need the
+ // generation incremented to refetch them.
+ // TODO: Only increment the generation number if this is not the first time
+ // we registered these map outputs.
mapOutputTracker.registerMapOutputs(
stage.shuffleDep.get.shuffleId,
- stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray)
+ stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray,
+ true)
}
- updateCacheLocs()
+ clearCacheLocs()
if (stage.outputLocs.count(_ == Nil) != 0) {
// Some tasks had failed; let's resubmit this stage
// TODO: Lower-level scheduler should also deal with this
@@ -459,7 +493,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
waiting --= newlyRunnable
running ++= newlyRunnable
for (stage <- newlyRunnable.sortBy(_.id)) {
- logInfo("Submitting " + stage + " (" + stage.origin + "), which is now runnable")
+ logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable")
submitMissingTasks(stage)
}
}
@@ -490,9 +524,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// Remember that a fetch failed now; this is used to resubmit the broken
// stages later, after a small wait (to give other tasks the chance to fail)
lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock
- // TODO: mark the host as failed only if there were lots of fetch failures on it
+ // TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
- handleHostLost(bmAddress.ip)
+ handleExecutorLost(bmAddress.executorId, Some(task.generation))
}
case other =>
@@ -502,22 +536,31 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
/**
- * Responds to a host being lost. This is called inside the event loop so it assumes that it can
- * modify the scheduler's internal state. Use hostLost() to post a host lost event from outside.
+ * Responds to an executor being lost. This is called inside the event loop, so it assumes it can
+ * modify the scheduler's internal state. Use executorLost() to post a loss event from outside.
+ *
+ * Optionally the generation during which the failure was caught can be passed to avoid allowing
+ * stray fetch failures from possibly retriggering the detection of a node as lost.
*/
- def handleHostLost(host: String) {
- if (!deadHosts.contains(host)) {
- logInfo("Host lost: " + host)
- deadHosts += host
- env.blockManager.master.notifyADeadHost(host)
+ private def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) {
+ val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration)
+ if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) {
+ failedGeneration(execId) = currentGeneration
+ logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration))
+ env.blockManager.master.removeExecutor(execId)
// TODO: This will be really slow if we keep accumulating shuffle map stages
for ((shuffleId, stage) <- shuffleToMapStage) {
- stage.removeOutputsOnHost(host)
+ stage.removeOutputsOnExecutor(execId)
val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
mapOutputTracker.registerMapOutputs(shuffleId, locs, true)
}
- cacheTracker.cacheLost(host)
- updateCacheLocs()
+ if (shuffleToMapStage.isEmpty) {
+ mapOutputTracker.incrementGeneration()
+ }
+ clearCacheLocs()
+ } else {
+ logDebug("Additional executor lost message for " + execId +
+ "(generation " + currentGeneration + ")")
}
}
@@ -525,7 +568,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
- def abortStage(failedStage: Stage, reason: String) {
+ private def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
@@ -541,7 +584,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
/**
* Return true if one of stage's ancestors is target.
*/
- def stageDependsOn(stage: Stage, target: Stage): Boolean = {
+ private def stageDependsOn(stage: Stage, target: Stage): Boolean = {
if (stage == target) {
return true
}
@@ -568,7 +611,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visitedRdds.contains(target.rdd)
}
- def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
+ private def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
// If the partition is cached, return the cache locations
val cached = getCacheLocs(rdd)(partition)
if (cached != Nil) {
@@ -594,8 +637,23 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
return Nil
}
+ private def cleanup(cleanupTime: Long) {
+ var sizeBefore = idToStage.size
+ idToStage.clearOldValues(cleanupTime)
+ logInfo("idToStage " + sizeBefore + " --> " + idToStage.size)
+
+ sizeBefore = shuffleToMapStage.size
+ shuffleToMapStage.clearOldValues(cleanupTime)
+ logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size)
+
+ sizeBefore = pendingTasks.size
+ pendingTasks.clearOldValues(cleanupTime)
+ logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size)
+ }
+
def stop() {
eventQueue.put(StopDAGScheduler)
+ metadataCleaner.cancel()
taskSched.stop()
}
}
diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
index 3422a21d9d..b34fa78c07 100644
--- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
@@ -28,7 +28,7 @@ private[spark] case class CompletionEvent(
accumUpdates: Map[Long, Any])
extends DAGSchedulerEvent
-private[spark] case class HostLost(host: String) extends DAGSchedulerEvent
+private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala
index c4a74e526f..654131ee84 100644
--- a/core/src/main/scala/spark/scheduler/JobResult.scala
+++ b/core/src/main/scala/spark/scheduler/JobResult.scala
@@ -5,5 +5,5 @@ package spark.scheduler
*/
private[spark] sealed trait JobResult
-private[spark] case class JobSucceeded(results: Seq[_]) extends JobResult
+private[spark] case object JobSucceeded extends JobResult
private[spark] case class JobFailed(exception: Exception) extends JobResult
diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala
index b3d4feebe5..3cc6a86345 100644
--- a/core/src/main/scala/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/spark/scheduler/JobWaiter.scala
@@ -3,10 +3,12 @@ package spark.scheduler
import scala.collection.mutable.ArrayBuffer
/**
- * An object that waits for a DAGScheduler job to complete.
+ * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
+ * results to the given handler function.
*/
-private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
- private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null)
+private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit)
+ extends JobListener {
+
private var finishedTasks = 0
private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
@@ -17,11 +19,11 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
if (jobFinished) {
throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
}
- taskResults(index) = result
+ resultHandler(index, result.asInstanceOf[T])
finishedTasks += 1
if (finishedTasks == totalTasks) {
jobFinished = true
- jobResult = JobSucceeded(taskResults)
+ jobResult = JobSucceeded
this.notifyAll()
}
}
@@ -38,7 +40,7 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
}
}
- def getResult(): JobResult = synchronized {
+ def awaitResult(): JobResult = synchronized {
while (!jobFinished) {
this.wait()
}
diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala
index 4532d9497f..203abb917b 100644
--- a/core/src/main/scala/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/spark/scheduler/MapStatus.scala
@@ -8,19 +8,19 @@ import java.io.{ObjectOutput, ObjectInput, Externalizable}
* task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
* The map output sizes are compressed using MapOutputTracker.compressSize.
*/
-private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes: Array[Byte])
+private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte])
extends Externalizable {
def this() = this(null, null) // For deserialization only
def writeExternal(out: ObjectOutput) {
- address.writeExternal(out)
+ location.writeExternal(out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
}
def readExternal(in: ObjectInput) {
- address = new BlockManagerId(in)
+ location = BlockManagerId(in)
compressedSizes = new Array[Byte](in.readInt())
in.readFully(compressedSizes)
}
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index e492279b4e..8cd4c661eb 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -1,26 +1,112 @@
package spark.scheduler
import spark._
+import java.io._
+import util.{MetadataCleaner, TimeStampedHashMap}
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+private[spark] object ResultTask {
+
+ // A simple map between the stage id to the serialized byte array of a task.
+ // Served as a cache for task serialization because serialization can be
+ // expensive on the master node if it needs to launch thousands of tasks.
+ val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
+
+ val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues)
+
+ def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = {
+ synchronized {
+ val old = serializedInfoCache.get(stageId).orNull
+ if (old != null) {
+ return old
+ } else {
+ val out = new ByteArrayOutputStream
+ val ser = SparkEnv.get.closureSerializer.newInstance
+ val objOut = ser.serializeStream(new GZIPOutputStream(out))
+ objOut.writeObject(rdd)
+ objOut.writeObject(func)
+ objOut.close()
+ val bytes = out.toByteArray
+ serializedInfoCache.put(stageId, bytes)
+ return bytes
+ }
+ }
+ }
+
+ def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
+ synchronized {
+ val loader = Thread.currentThread.getContextClassLoader
+ val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+ val ser = SparkEnv.get.closureSerializer.newInstance
+ val objIn = ser.deserializeStream(in)
+ val rdd = objIn.readObject().asInstanceOf[RDD[_]]
+ val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
+ return (rdd, func)
+ }
+ }
+
+ def clearCache() {
+ synchronized {
+ serializedInfoCache.clear()
+ }
+ }
+}
+
private[spark] class ResultTask[T, U](
stageId: Int,
- rdd: RDD[T],
- func: (TaskContext, Iterator[T]) => U,
- val partition: Int,
+ var rdd: RDD[T],
+ var func: (TaskContext, Iterator[T]) => U,
+ var partition: Int,
@transient locs: Seq[String],
val outputId: Int)
- extends Task[U](stageId) {
+ extends Task[U](stageId) with Externalizable {
- val split = rdd.splits(partition)
+ def this() = this(0, null, null, 0, null, 0)
+
+ var split = if (rdd == null) {
+ null
+ } else {
+ rdd.splits(partition)
+ }
override def run(attemptId: Long): U = {
val context = new TaskContext(stageId, partition, attemptId)
- val result = func(context, rdd.iterator(split, context))
- context.executeOnCompleteCallbacks()
- result
+ try {
+ func(context, rdd.iterator(split, context))
+ } finally {
+ context.executeOnCompleteCallbacks()
+ }
}
override def preferredLocations: Seq[String] = locs
override def toString = "ResultTask(" + stageId + ", " + partition + ")"
+
+ override def writeExternal(out: ObjectOutput) {
+ RDDCheckpointData.synchronized {
+ split = rdd.splits(partition)
+ out.writeInt(stageId)
+ val bytes = ResultTask.serializeInfo(
+ stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ out.writeInt(partition)
+ out.writeInt(outputId)
+ out.writeObject(split)
+ }
+ }
+
+ override def readExternal(in: ObjectInput) {
+ val stageId = in.readInt()
+ val numBytes = in.readInt()
+ val bytes = new Array[Byte](numBytes)
+ in.readFully(bytes)
+ val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
+ rdd = rdd_.asInstanceOf[RDD[T]]
+ func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
+ partition = in.readInt()
+ val outputId = in.readInt()
+ split = in.readObject().asInstanceOf[Split]
+ }
}
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index bd1911fce2..bed9f1864f 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -14,22 +14,25 @@ import com.ning.compress.lzf.LZFOutputStream
import spark._
import spark.storage._
+import util.{TimeStampedHashMap, MetadataCleaner}
private[spark] object ShuffleMapTask {
// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
- val serializedInfoCache = new JHashMap[Int, Array[Byte]]
+ val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
+
+ val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues)
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
synchronized {
- val old = serializedInfoCache.get(stageId)
+ val old = serializedInfoCache.get(stageId).orNull
if (old != null) {
return old
} else {
val out = new ByteArrayOutputStream
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objOut = ser.serializeStream(new GZIPOutputStream(out))
objOut.writeObject(rdd)
objOut.writeObject(dep)
@@ -45,7 +48,7 @@ private[spark] object ShuffleMapTask {
synchronized {
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
@@ -78,7 +81,7 @@ private[spark] class ShuffleMapTask(
with Externalizable
with Logging {
- def this() = this(0, null, null, 0, null)
+ protected def this() = this(0, null, null, 0, null)
var split = if (rdd == null) {
null
@@ -87,13 +90,16 @@ private[spark] class ShuffleMapTask(
}
override def writeExternal(out: ObjectOutput) {
- out.writeInt(stageId)
- val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
- out.writeInt(bytes.length)
- out.write(bytes)
- out.writeInt(partition)
- out.writeLong(generation)
- out.writeObject(split)
+ RDDCheckpointData.synchronized {
+ split = rdd.splits(partition)
+ out.writeInt(stageId)
+ val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ out.writeInt(partition)
+ out.writeLong(generation)
+ out.writeObject(split)
+ }
}
override def readExternal(in: ObjectInput) {
@@ -111,34 +117,33 @@ private[spark] class ShuffleMapTask(
override def run(attemptId: Long): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
- val partitioner = dep.partitioner
val taskContext = new TaskContext(stageId, partition, attemptId)
+ try {
+ // Partition the map output.
+ val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
+ for (elem <- rdd.iterator(split, taskContext)) {
+ val pair = elem.asInstanceOf[(Any, Any)]
+ val bucketId = dep.partitioner.getPartition(pair._1)
+ buckets(bucketId) += pair
+ }
- // Partition the map output.
- val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
- for (elem <- rdd.iterator(split, taskContext)) {
- val pair = elem.asInstanceOf[(Any, Any)]
- val bucketId = partitioner.getPartition(pair._1)
- buckets(bucketId) += pair
- }
- val bucketIterators = buckets.map(_.iterator)
+ val compressedSizes = new Array[Byte](numOutputSplits)
- val compressedSizes = new Array[Byte](numOutputSplits)
+ val blockManager = SparkEnv.get.blockManager
+ for (i <- 0 until numOutputSplits) {
+ val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
+ // Get a Scala iterator from Java map
+ val iter: Iterator[(Any, Any)] = buckets(i).iterator
+ val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
+ compressedSizes(i) = MapOutputTracker.compressSize(size)
+ }
- val blockManager = SparkEnv.get.blockManager
- for (i <- 0 until numOutputSplits) {
- val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
- // Get a Scala iterator from Java map
- val iter: Iterator[(Any, Any)] = bucketIterators(i)
- val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
- compressedSizes(i) = MapOutputTracker.compressSize(size)
+ return new MapStatus(blockManager.blockManagerId, compressedSizes)
+ } finally {
+ // Execute the callbacks on task completion.
+ taskContext.executeOnCompleteCallbacks()
}
-
- // Execute the callbacks on task completion.
- taskContext.executeOnCompleteCallbacks()
-
- return new MapStatus(blockManager.blockManagerId, compressedSizes)
}
override def preferredLocations: Seq[String] = locs
diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala
index 4846b66729..374114d870 100644
--- a/core/src/main/scala/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/spark/scheduler/Stage.scala
@@ -32,6 +32,9 @@ private[spark] class Stage(
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
var numAvailableOutputs = 0
+ /** When first task was submitted to scheduler. */
+ var submissionTime: Option[Long] = None
+
private var nextAttemptId = 0
def isAvailable: Boolean = {
@@ -51,18 +54,18 @@ private[spark] class Stage(
def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) {
val prevList = outputLocs(partition)
- val newList = prevList.filterNot(_.address == bmAddress)
+ val newList = prevList.filterNot(_.location == bmAddress)
outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) {
numAvailableOutputs -= 1
}
}
- def removeOutputsOnHost(host: String) {
+ def removeOutputsOnExecutor(execId: String) {
var becameUnavailable = false
for (partition <- 0 until numPartitions) {
val prevList = outputLocs(partition)
- val newList = prevList.filterNot(_.address.ip == host)
+ val newList = prevList.filterNot(_.location.executorId == execId)
outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) {
becameUnavailable = true
@@ -70,7 +73,8 @@ private[spark] class Stage(
}
}
if (becameUnavailable) {
- logInfo("%s is now unavailable on %s (%d/%d, %s)".format(this, host, numAvailableOutputs, numPartitions, isAvailable))
+ logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format(
+ this, execId, numAvailableOutputs, numPartitions, isAvailable))
}
}
@@ -82,7 +86,7 @@ private[spark] class Stage(
def origin: String = rdd.origin
- override def toString = "Stage " + id // + ": [RDD = " + rdd.id + ", isShuffle = " + isShuffleMap + "]"
+ override def toString = "Stage " + id
override def hashCode(): Int = id
}
diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
index fa4de15d0d..9fcef86e46 100644
--- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
+++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
@@ -12,7 +12,7 @@ private[spark] trait TaskSchedulerListener {
def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit
// A node was lost from the cluster.
- def hostLost(host: String): Unit
+ def executorLost(execId: String): Unit
// The TaskScheduler wants to abort an entire task set.
def taskSetFailed(taskSet: TaskSet, reason: String): Unit
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index 20f6e65020..1e4fbdb874 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -27,19 +27,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String]
- val taskIdToSlaveId = new HashMap[Long, String]
+ val taskIdToExecutorId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
// Incrementing Mesos task IDs
val nextTaskId = new AtomicLong(0)
- // Which hosts in the cluster are alive (contains hostnames)
- val hostsAlive = new HashSet[String]
+ // Which executor IDs we have executors on
+ val activeExecutorIds = new HashSet[String]
- // Which slave IDs we have executors on
- val slaveIdsWithExecutors = new HashSet[String]
+ // The set of executors we have on each host; this is used to compute hostsAlive, which
+ // in turn is used to decide when we can attain data locality on a given host
+ val executorsByHost = new HashMap[String, HashSet[String]]
- val slaveIdToHost = new HashMap[String, String]
+ val executorIdToHost = new HashMap[String, String]
// JAR server, if any JARs were added by the user to the SparkContext
var jarServer: HttpServer = null
@@ -85,7 +86,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
- def submitTasks(taskSet: TaskSet) {
+ override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
@@ -102,7 +103,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
activeTaskSets -= manager.taskSet.id
activeTaskSetsQueue -= manager
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
- taskIdToSlaveId --= taskSetTaskIds(manager.taskSet.id)
+ taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
taskSetTaskIds.remove(manager.taskSet.id)
}
}
@@ -117,8 +118,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
SparkEnv.set(sc.env)
// Mark each slave as alive and remember its hostname
for (o <- offers) {
- slaveIdToHost(o.slaveId) = o.hostname
- hostsAlive += o.hostname
+ executorIdToHost(o.executorId) = o.hostname
}
// Build a list of tasks to assign to each slave
val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
@@ -128,16 +128,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
do {
launchedTask = false
for (i <- 0 until offers.size) {
- val sid = offers(i).slaveId
+ val execId = offers(i).executorId
val host = offers(i).hostname
- manager.slaveOffer(sid, host, availableCpus(i)) match {
+ manager.slaveOffer(execId, host, availableCpus(i)) match {
case Some(task) =>
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetId(tid) = manager.taskSet.id
taskSetTaskIds(manager.taskSet.id) += tid
- taskIdToSlaveId(tid) = sid
- slaveIdsWithExecutors += sid
+ taskIdToExecutorId(tid) = execId
+ activeExecutorIds += execId
+ if (!executorsByHost.contains(host)) {
+ executorsByHost(host) = new HashSet()
+ }
+ executorsByHost(host) += execId
availableCpus(i) -= 1
launchedTask = true
@@ -152,25 +156,21 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
var taskSetToUpdate: Option[TaskSetManager] = None
- var failedHost: Option[String] = None
+ var failedExecutor: Option[String] = None
var taskFailed = false
synchronized {
try {
- if (state == TaskState.LOST && taskIdToSlaveId.contains(tid)) {
- // We lost the executor on this slave, so remember that it's gone
- val slaveId = taskIdToSlaveId(tid)
- val host = slaveIdToHost(slaveId)
- if (hostsAlive.contains(host)) {
- slaveIdsWithExecutors -= slaveId
- hostsAlive -= host
- activeTaskSetsQueue.foreach(_.hostLost(host))
- failedHost = Some(host)
+ if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
+ // We lost this entire executor, so remember that it's gone
+ val execId = taskIdToExecutorId(tid)
+ if (activeExecutorIds.contains(execId)) {
+ removeExecutor(execId)
+ failedExecutor = Some(execId)
}
}
taskIdToTaskSetId.get(tid) match {
case Some(taskSetId) =>
if (activeTaskSets.contains(taskSetId)) {
- //activeTaskSets(taskSetId).statusUpdate(status)
taskSetToUpdate = Some(activeTaskSets(taskSetId))
}
if (TaskState.isFinished(state)) {
@@ -178,7 +178,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (taskSetTaskIds.contains(taskSetId)) {
taskSetTaskIds(taskSetId) -= tid
}
- taskIdToSlaveId.remove(tid)
+ taskIdToExecutorId.remove(tid)
}
if (state == TaskState.FAILED) {
taskFailed = true
@@ -190,12 +190,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
case e: Exception => logError("Exception in statusUpdate", e)
}
}
- // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock
+ // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock
if (taskSetToUpdate != None) {
taskSetToUpdate.get.statusUpdate(tid, state, serializedData)
}
- if (failedHost != None) {
- listener.hostLost(failedHost.get)
+ if (failedExecutor != None) {
+ listener.executorLost(failedExecutor.get)
backend.reviveOffers()
}
if (taskFailed) {
@@ -249,27 +249,42 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
- def slaveLost(slaveId: String, reason: ExecutorLossReason) {
- var failedHost: Option[String] = None
+ def executorLost(executorId: String, reason: ExecutorLossReason) {
+ var failedExecutor: Option[String] = None
synchronized {
- val host = slaveIdToHost(slaveId)
- if (hostsAlive.contains(host)) {
- logError("Lost an executor on " + host + ": " + reason)
- slaveIdsWithExecutors -= slaveId
- hostsAlive -= host
- activeTaskSetsQueue.foreach(_.hostLost(host))
- failedHost = Some(host)
+ if (activeExecutorIds.contains(executorId)) {
+ val host = executorIdToHost(executorId)
+ logError("Lost executor %s on %s: %s".format(executorId, host, reason))
+ removeExecutor(executorId)
+ failedExecutor = Some(executorId)
} else {
- // We may get multiple slaveLost() calls with different loss reasons. For example, one
- // may be triggered by a dropped connection from the slave while another may be a report
- // of executor termination from Mesos. We produce log messages for both so we eventually
- // report the termination reason.
- logError("Lost an executor on " + host + " (already removed): " + reason)
+ // We may get multiple executorLost() calls with different loss reasons. For example, one
+ // may be triggered by a dropped connection from the slave while another may be a report
+ // of executor termination from Mesos. We produce log messages for both so we eventually
+ // report the termination reason.
+ logError("Lost an executor " + executorId + " (already removed): " + reason)
}
}
- if (failedHost != None) {
- listener.hostLost(failedHost.get)
+ // Call listener.executorLost without holding the lock on this to prevent deadlock
+ if (failedExecutor != None) {
+ listener.executorLost(failedExecutor.get)
backend.reviveOffers()
}
}
+
+ /** Get a list of hosts that currently have executors */
+ def hostsAlive: scala.collection.Set[String] = executorsByHost.keySet
+
+ /** Remove an executor from all our data structures and mark it as lost */
+ private def removeExecutor(executorId: String) {
+ activeExecutorIds -= executorId
+ val host = executorIdToHost(executorId)
+ val execs = executorsByHost.getOrElse(host, new HashSet)
+ execs -= executorId
+ if (execs.isEmpty) {
+ executorsByHost -= host
+ }
+ executorIdToHost -= executorId
+ activeTaskSetsQueue.foreach(_.executorLost(executorId, host))
+ }
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala b/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala
deleted file mode 100644
index 96ebaa4601..0000000000
--- a/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala
+++ /dev/null
@@ -1,4 +0,0 @@
-package spark.scheduler.cluster
-
-private[spark]
-class SlaveResources(val slaveId: String, val hostname: String, val coresFree: Int) {}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index f2fb244b24..2f7099c5b9 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -19,7 +19,6 @@ private[spark] class SparkDeploySchedulerBackend(
var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _
val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt
- val executorIdToSlaveId = new HashMap[String, String]
// Memory used by each executor (in megabytes)
val executorMemory = {
@@ -33,19 +32,21 @@ private[spark] class SparkDeploySchedulerBackend(
override def start() {
super.start()
- val masterUrl = "akka://spark@%s:%s/user/%s".format(
- System.getProperty("spark.master.host"), System.getProperty("spark.master.port"),
+ // The endpoint for executors to talk to us
+ val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
StandaloneSchedulerBackend.ACTOR_NAME)
- val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}")
+ val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
- val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command)
+ val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone"))
+ val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome)
client = new Client(sc.env.actorSystem, master, jobDesc, this)
client.start()
}
override def stop() {
- stopping = true;
+ stopping = true
super.stop()
client.stop()
if (shutdownCallback != null) {
@@ -53,35 +54,28 @@ private[spark] class SparkDeploySchedulerBackend(
}
}
- def connected(jobId: String) {
+ override def connected(jobId: String) {
logInfo("Connected to Spark cluster with job ID " + jobId)
}
- def disconnected() {
+ override def disconnected() {
if (!stopping) {
logError("Disconnected from Spark cluster!")
scheduler.error("Disconnected from Spark cluster")
}
}
- def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) {
- executorIdToSlaveId += id -> workerId
+ override def executorAdded(executorId: String, workerId: String, host: String, cores: Int, memory: Int) {
logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format(
- id, host, cores, Utils.memoryMegabytesToString(memory)))
+ executorId, host, cores, Utils.memoryMegabytesToString(memory)))
}
- def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {
+ override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) {
val reason: ExecutorLossReason = exitStatus match {
case Some(code) => ExecutorExited(code)
case None => SlaveLost(message)
}
- logInfo("Executor %s removed: %s".format(id, message))
- executorIdToSlaveId.get(id) match {
- case Some(slaveId) =>
- executorIdToSlaveId.remove(id)
- scheduler.slaveLost(slaveId, reason)
- case None =>
- logInfo("No slave ID known for executor %s".format(id))
- }
+ logInfo("Executor %s removed: %s".format(executorId, message))
+ scheduler.executorLost(executorId, reason)
}
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
index 1386cd9d44..da7dcf4b6b 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
@@ -6,32 +6,34 @@ import spark.util.SerializableBuffer
private[spark] sealed trait StandaloneClusterMessage extends Serializable
-// Master to slaves
+// Driver to executors
private[spark]
case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage
private[spark]
-case class RegisteredSlave(sparkProperties: Seq[(String, String)]) extends StandaloneClusterMessage
+case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
+ extends StandaloneClusterMessage
private[spark]
-case class RegisterSlaveFailed(message: String) extends StandaloneClusterMessage
+case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage
-// Slaves to master
+// Executors to driver
private[spark]
-case class RegisterSlave(slaveId: String, host: String, cores: Int) extends StandaloneClusterMessage
+case class RegisterExecutor(executorId: String, host: String, cores: Int)
+ extends StandaloneClusterMessage
private[spark]
-case class StatusUpdate(slaveId: String, taskId: Long, state: TaskState, data: SerializableBuffer)
+case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer)
extends StandaloneClusterMessage
private[spark]
object StatusUpdate {
/** Alternate factory method that takes a ByteBuffer directly for the data field */
- def apply(slaveId: String, taskId: Long, state: TaskState, data: ByteBuffer): StatusUpdate = {
- StatusUpdate(slaveId, taskId, state, new SerializableBuffer(data))
+ def apply(executorId: String, taskId: Long, state: TaskState, data: ByteBuffer): StatusUpdate = {
+ StatusUpdate(executorId, taskId, state, new SerializableBuffer(data))
}
}
-// Internal messages in master
+// Internal messages in driver
private[spark] case object ReviveOffers extends StandaloneClusterMessage
-private[spark] case object StopMaster extends StandaloneClusterMessage
+private[spark] case object StopDriver extends StandaloneClusterMessage
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
index eeaae23dc8..082022be1c 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -23,13 +23,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
var totalCoreCount = new AtomicInteger(0)
- class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor {
- val slaveActor = new HashMap[String, ActorRef]
- val slaveAddress = new HashMap[String, Address]
- val slaveHost = new HashMap[String, String]
+ class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor {
+ val executorActor = new HashMap[String, ActorRef]
+ val executorAddress = new HashMap[String, Address]
+ val executorHost = new HashMap[String, String]
val freeCores = new HashMap[String, Int]
- val actorToSlaveId = new HashMap[ActorRef, String]
- val addressToSlaveId = new HashMap[Address, String]
+ val actorToExecutorId = new HashMap[ActorRef, String]
+ val addressToExecutorId = new HashMap[Address, String]
override def preStart() {
// Listen for remote client disconnection events, since they don't go through Akka's watch()
@@ -37,86 +37,86 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
def receive = {
- case RegisterSlave(slaveId, host, cores) =>
- if (slaveActor.contains(slaveId)) {
- sender ! RegisterSlaveFailed("Duplicate slave ID: " + slaveId)
+ case RegisterExecutor(executorId, host, cores) =>
+ if (executorActor.contains(executorId)) {
+ sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId)
} else {
- logInfo("Registered slave: " + sender + " with ID " + slaveId)
- sender ! RegisteredSlave(sparkProperties)
+ logInfo("Registered executor: " + sender + " with ID " + executorId)
+ sender ! RegisteredExecutor(sparkProperties)
context.watch(sender)
- slaveActor(slaveId) = sender
- slaveHost(slaveId) = host
- freeCores(slaveId) = cores
- slaveAddress(slaveId) = sender.path.address
- actorToSlaveId(sender) = slaveId
- addressToSlaveId(sender.path.address) = slaveId
+ executorActor(executorId) = sender
+ executorHost(executorId) = host
+ freeCores(executorId) = cores
+ executorAddress(executorId) = sender.path.address
+ actorToExecutorId(sender) = executorId
+ addressToExecutorId(sender.path.address) = executorId
totalCoreCount.addAndGet(cores)
makeOffers()
}
- case StatusUpdate(slaveId, taskId, state, data) =>
+ case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
- freeCores(slaveId) += 1
- makeOffers(slaveId)
+ freeCores(executorId) += 1
+ makeOffers(executorId)
}
case ReviveOffers =>
makeOffers()
- case StopMaster =>
+ case StopDriver =>
sender ! true
context.stop(self)
case Terminated(actor) =>
- actorToSlaveId.get(actor).foreach(removeSlave(_, "Akka actor terminated"))
+ actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated"))
case RemoteClientDisconnected(transport, address) =>
- addressToSlaveId.get(address).foreach(removeSlave(_, "remote Akka client disconnected"))
+ addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected"))
case RemoteClientShutdown(transport, address) =>
- addressToSlaveId.get(address).foreach(removeSlave(_, "remote Akka client shutdown"))
+ addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown"))
}
- // Make fake resource offers on all slaves
+ // Make fake resource offers on all executors
def makeOffers() {
launchTasks(scheduler.resourceOffers(
- slaveHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
+ executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
}
- // Make fake resource offers on just one slave
- def makeOffers(slaveId: String) {
+ // Make fake resource offers on just one executor
+ def makeOffers(executorId: String) {
launchTasks(scheduler.resourceOffers(
- Seq(new WorkerOffer(slaveId, slaveHost(slaveId), freeCores(slaveId)))))
+ Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId)))))
}
// Launch tasks returned by a set of resource offers
def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
for (task <- tasks.flatten) {
- freeCores(task.slaveId) -= 1
- slaveActor(task.slaveId) ! LaunchTask(task)
+ freeCores(task.executorId) -= 1
+ executorActor(task.executorId) ! LaunchTask(task)
}
}
// Remove a disconnected slave from the cluster
- def removeSlave(slaveId: String, reason: String) {
- logInfo("Slave " + slaveId + " disconnected, so removing it")
- val numCores = freeCores(slaveId)
- actorToSlaveId -= slaveActor(slaveId)
- addressToSlaveId -= slaveAddress(slaveId)
- slaveActor -= slaveId
- slaveHost -= slaveId
- freeCores -= slaveId
- slaveHost -= slaveId
+ def removeExecutor(executorId: String, reason: String) {
+ logInfo("Slave " + executorId + " disconnected, so removing it")
+ val numCores = freeCores(executorId)
+ actorToExecutorId -= executorActor(executorId)
+ addressToExecutorId -= executorAddress(executorId)
+ executorActor -= executorId
+ executorHost -= executorId
+ freeCores -= executorId
+ executorHost -= executorId
totalCoreCount.addAndGet(-numCores)
- scheduler.slaveLost(slaveId, SlaveLost(reason))
+ scheduler.executorLost(executorId, SlaveLost(reason))
}
}
- var masterActor: ActorRef = null
+ var driverActor: ActorRef = null
val taskIdsOnSlave = new HashMap[String, HashSet[String]]
- def start() {
+ override def start() {
val properties = new ArrayBuffer[(String, String)]
val iterator = System.getProperties.entrySet.iterator
while (iterator.hasNext) {
@@ -126,15 +126,15 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
properties += ((key, value))
}
}
- masterActor = actorSystem.actorOf(
- Props(new MasterActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
+ driverActor = actorSystem.actorOf(
+ Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
}
- def stop() {
+ override def stop() {
try {
- if (masterActor != null) {
+ if (driverActor != null) {
val timeout = 5.seconds
- val future = masterActor.ask(StopMaster)(timeout)
+ val future = driverActor.ask(StopDriver)(timeout)
Await.result(future, timeout)
}
} catch {
@@ -143,11 +143,11 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
}
- def reviveOffers() {
- masterActor ! ReviveOffers
+ override def reviveOffers() {
+ driverActor ! ReviveOffers
}
- def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2)
+ override def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2)
}
private[spark] object StandaloneSchedulerBackend {
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala
index aa097fd3a2..b41e951be9 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala
@@ -5,7 +5,7 @@ import spark.util.SerializableBuffer
private[spark] class TaskDescription(
val taskId: Long,
- val slaveId: String,
+ val executorId: String,
val name: String,
_serializedTask: ByteBuffer)
extends Serializable {
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
index ca84503780..0f975ce1eb 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
@@ -4,7 +4,12 @@ package spark.scheduler.cluster
* Information about a running task attempt inside a TaskSet.
*/
private[spark]
-class TaskInfo(val taskId: Long, val index: Int, val launchTime: Long, val host: String) {
+class TaskInfo(
+ val taskId: Long,
+ val index: Int,
+ val launchTime: Long,
+ val executorId: String,
+ val host: String) {
var finishTime: Long = 0
var failed = false
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index a089b71644..3dabdd76b1 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -17,10 +17,7 @@ import java.nio.ByteBuffer
/**
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
*/
-private[spark] class TaskSetManager(
- sched: ClusterScheduler,
- val taskSet: TaskSet)
- extends Logging {
+private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Logging {
// Maximum time to wait to run a task in a preferred location (in ms)
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
@@ -100,7 +97,7 @@ private[spark] class TaskSetManager(
}
// Add a task to all the pending-task lists that it should be on.
- def addPendingTask(index: Int) {
+ private def addPendingTask(index: Int) {
val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
if (locations.size == 0) {
pendingTasksWithNoPrefs += index
@@ -115,7 +112,7 @@ private[spark] class TaskSetManager(
// Return the pending tasks list for a given host, or an empty list if
// there is no map entry for that host
- def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
+ private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
pendingTasksForHost.getOrElse(host, ArrayBuffer())
}
@@ -123,7 +120,7 @@ private[spark] class TaskSetManager(
// Return None if the list is empty.
// This method also cleans up any tasks in the list that have already
// been launched, since we want that to happen lazily.
- def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
+ private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
@@ -137,11 +134,12 @@ private[spark] class TaskSetManager(
// Return a speculative task for a given host if any are available. The task should not have an
// attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
// task must have a preference for this host (or no preferred locations at all).
- def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
+ private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
+ val hostsAlive = sched.hostsAlive
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
val localTask = speculatableTasks.find {
index =>
- val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
+ val locations = tasks(index).preferredLocations.toSet & hostsAlive
val attemptLocs = taskAttempts(index).map(_.host)
(locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host)
}
@@ -161,7 +159,7 @@ private[spark] class TaskSetManager(
// Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well.
- def findTask(host: String, localOnly: Boolean): Option[Int] = {
+ private def findTask(host: String, localOnly: Boolean): Option[Int] = {
val localTask = findTaskFromList(getPendingTasksForHost(host))
if (localTask != None) {
return localTask
@@ -183,13 +181,13 @@ private[spark] class TaskSetManager(
// Does a host count as a preferred location for a task? This is true if
// either the task has preferred locations and this host is one, or it has
// no preferred locations (in which we still count the launch as preferred).
- def isPreferredLocation(task: Task[_], host: String): Boolean = {
+ private def isPreferredLocation(task: Task[_], host: String): Boolean = {
val locs = task.preferredLocations
return (locs.contains(host) || locs.isEmpty)
}
// Respond to an offer of a single slave from the scheduler by finding a task
- def slaveOffer(slaveId: String, host: String, availableCpus: Double): Option[TaskDescription] = {
+ def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = {
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
val time = System.currentTimeMillis
val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
@@ -206,11 +204,11 @@ private[spark] class TaskSetManager(
} else {
"non-preferred, not one of " + task.preferredLocations.mkString(", ")
}
- logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
- taskSet.id, index, taskId, slaveId, host, prefStr))
+ logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
+ taskSet.id, index, taskId, execId, host, prefStr))
// Do various bookkeeping
copiesRunning(index) += 1
- val info = new TaskInfo(taskId, index, time, host)
+ val info = new TaskInfo(taskId, index, time, execId, host)
taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index)
if (preferred) {
@@ -224,7 +222,7 @@ private[spark] class TaskSetManager(
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index)
- return Some(new TaskDescription(taskId, slaveId, taskName, serializedTask))
+ return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
}
case _ =>
}
@@ -334,7 +332,7 @@ private[spark] class TaskSetManager(
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format(
taskSet.id, index, MAX_TASK_FAILURES))
- abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES))
+ abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
}
}
} else {
@@ -356,19 +354,22 @@ private[spark] class TaskSetManager(
sched.taskSetFinished(this)
}
- def hostLost(hostname: String) {
- logInfo("Re-queueing tasks for " + hostname + " from TaskSet " + taskSet.id)
- // If some task has preferred locations only on hostname, put it in the no-prefs list
- // to avoid the wait from delay scheduling
- for (index <- getPendingTasksForHost(hostname)) {
- val newLocs = tasks(index).preferredLocations.toSet & sched.hostsAlive
- if (newLocs.isEmpty) {
- pendingTasksWithNoPrefs += index
+ def executorLost(execId: String, hostname: String) {
+ logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
+ val newHostsAlive = sched.hostsAlive
+ // If some task has preferred locations only on hostname, and there are no more executors there,
+ // put it in the no-prefs list to avoid the wait from delay scheduling
+ if (!newHostsAlive.contains(hostname)) {
+ for (index <- getPendingTasksForHost(hostname)) {
+ val newLocs = tasks(index).preferredLocations.toSet & newHostsAlive
+ if (newLocs.isEmpty) {
+ pendingTasksWithNoPrefs += index
+ }
}
}
- // Re-enqueue any tasks that ran on the failed host if this is a shuffle map stage
+ // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
if (tasks(0).isInstanceOf[ShuffleMapTask]) {
- for ((tid, info) <- taskInfos if info.host == hostname) {
+ for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
if (finished(index)) {
finished(index) = false
@@ -382,7 +383,7 @@ private[spark] class TaskSetManager(
}
}
// Also re-enqueue any tasks that were running on the node
- for ((tid, info) <- taskInfos if info.running && info.host == hostname) {
+ for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
taskLost(tid, TaskState.KILLED, null)
}
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
index 6b919d68b2..3c3afcbb14 100644
--- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
@@ -1,8 +1,8 @@
package spark.scheduler.cluster
/**
- * Represents free resources available on a worker node.
+ * Represents free resources available on an executor.
*/
private[spark]
-class WorkerOffer(val slaveId: String, val hostname: String, val cores: Int) {
+class WorkerOffer(val executorId: String, val hostname: String, val cores: Int) {
}
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 2593c0e3a0..482d1cc853 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -20,7 +20,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
with Logging {
var attemptId = new AtomicInteger(0)
- var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
+ var threadPool = Utils.newDaemonFixedThreadPool(threads)
val env = SparkEnv.get
var listener: TaskSchedulerListener = null
@@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
}
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
- logInfo("Running task " + idInJob)
+ logInfo("Running " + task)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
try {
@@ -80,8 +80,11 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
val resultToReturn = ser.deserialize[Any](ser.serialize(result))
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
- logInfo("Finished task " + idInJob)
- listener.taskEnded(task, Success, resultToReturn, accumUpdates)
+ logInfo("Finished " + task)
+
+ // If the threadpool has not already been shutdown, notify DAGScheduler
+ if (!Thread.currentThread().isInterrupted)
+ listener.taskEnded(task, Success, resultToReturn, accumUpdates)
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
@@ -91,7 +94,8 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
submitTask(task, idInJob)
} else {
// TODO: Do something nicer here to return all the way to the user
- listener.taskEnded(task, new ExceptionFailure(t), null, null)
+ if (!Thread.currentThread().isInterrupted)
+ listener.taskEnded(task, new ExceptionFailure(t), null, null)
}
}
}
@@ -112,16 +116,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
- val url = new File(".", localName).toURI.toURL
+ val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
if (!classLoader.getURLs.contains(url)) {
logInfo("Adding " + url + " to class loader")
classLoader.addURL(url)
diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
index c45c7df69c..7bf56a05d6 100644
--- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
@@ -64,13 +64,9 @@ private[spark] class CoarseMesosSchedulerBackend(
val taskIdToSlaveId = new HashMap[Int, String]
val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed
- val sparkHome = sc.getSparkHome() match {
- case Some(path) =>
- path
- case None =>
- throw new SparkException("Spark home is not set; set it through the spark.home system " +
- "property, the SPARK_HOME environment variable or the SparkContext constructor")
- }
+ val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException(
+ "Spark home is not set; set it through the spark.home system " +
+ "property, the SPARK_HOME environment variable or the SparkContext constructor"))
val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt
@@ -108,11 +104,11 @@ private[spark] class CoarseMesosSchedulerBackend(
def createCommand(offer: Offer, numCores: Int): CommandInfo = {
val runScript = new File(sparkHome, "run").getCanonicalPath
- val masterUrl = "akka://spark@%s:%s/user/%s".format(
- System.getProperty("spark.master.host"), System.getProperty("spark.master.port"),
+ val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
StandaloneSchedulerBackend.ACTOR_NAME)
val command = "\"%s\" spark.executor.StandaloneExecutorBackend %s %s %s %d".format(
- runScript, masterUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)
+ runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)
val environment = Environment.newBuilder()
sc.executorEnvs.foreach { case (key, value) =>
environment.addVariables(Environment.Variable.newBuilder()
@@ -184,7 +180,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/** Helper function to pull out a resource from a Mesos Resources protobuf */
- def getResource(res: JList[Resource], name: String): Double = {
+ private def getResource(res: JList[Resource], name: String): Double = {
for (r <- res if r.getName == name) {
return r.getScalar.getValue
}
@@ -193,7 +189,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/** Build a Mesos resource protobuf object */
- def createResource(resourceName: String, quantity: Double): Protos.Resource = {
+ private def createResource(resourceName: String, quantity: Double): Protos.Resource = {
Resource.newBuilder()
.setName(resourceName)
.setType(Value.Type.SCALAR)
@@ -202,7 +198,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/** Check whether a Mesos task state represents a finished task */
- def isFinished(state: MesosTaskState) = {
+ private def isFinished(state: MesosTaskState) = {
state == MesosTaskState.TASK_FINISHED ||
state == MesosTaskState.TASK_FAILED ||
state == MesosTaskState.TASK_KILLED ||
diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
index 8c7a1dfbc0..eab1c60e0b 100644
--- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
@@ -51,7 +51,7 @@ private[spark] class MesosSchedulerBackend(
val taskIdToSlaveId = new HashMap[Long, String]
// An ExecutorInfo for our tasks
- var executorInfo: ExecutorInfo = null
+ var execArgs: Array[Byte] = null
override def start() {
synchronized {
@@ -70,19 +70,14 @@ private[spark] class MesosSchedulerBackend(
}
}.start()
- executorInfo = createExecutorInfo()
waitForRegister()
}
}
- def createExecutorInfo(): ExecutorInfo = {
- val sparkHome = sc.getSparkHome() match {
- case Some(path) =>
- path
- case None =>
- throw new SparkException("Spark home is not set; set it through the spark.home system " +
- "property, the SPARK_HOME environment variable or the SparkContext constructor")
- }
+ def createExecutorInfo(execId: String): ExecutorInfo = {
+ val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException(
+ "Spark home is not set; set it through the spark.home system " +
+ "property, the SPARK_HOME environment variable or the SparkContext constructor"))
val execScript = new File(sparkHome, "spark-executor").getCanonicalPath
val environment = Environment.newBuilder()
sc.executorEnvs.foreach { case (key, value) =>
@@ -101,7 +96,7 @@ private[spark] class MesosSchedulerBackend(
.setEnvironment(environment)
.build()
ExecutorInfo.newBuilder()
- .setExecutorId(ExecutorID.newBuilder().setValue("default").build())
+ .setExecutorId(ExecutorID.newBuilder().setValue(execId).build())
.setCommand(command)
.setData(ByteString.copyFrom(createExecArg()))
.addResources(memory)
@@ -113,17 +108,20 @@ private[spark] class MesosSchedulerBackend(
* containing all the spark.* system properties in the form of (String, String) pairs.
*/
private def createExecArg(): Array[Byte] = {
- val props = new HashMap[String, String]
- val iterator = System.getProperties.entrySet.iterator
- while (iterator.hasNext) {
- val entry = iterator.next
- val (key, value) = (entry.getKey.toString, entry.getValue.toString)
- if (key.startsWith("spark.")) {
- props(key) = value
+ if (execArgs == null) {
+ val props = new HashMap[String, String]
+ val iterator = System.getProperties.entrySet.iterator
+ while (iterator.hasNext) {
+ val entry = iterator.next
+ val (key, value) = (entry.getKey.toString, entry.getValue.toString)
+ if (key.startsWith("spark.")) {
+ props(key) = value
+ }
}
+ // Serialize the map as an array of (String, String) pairs
+ execArgs = Utils.serialize(props.toArray)
}
- // Serialize the map as an array of (String, String) pairs
- return Utils.serialize(props.toArray)
+ return execArgs
}
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
@@ -220,7 +218,7 @@ private[spark] class MesosSchedulerBackend(
return MesosTaskInfo.newBuilder()
.setTaskId(taskId)
.setSlaveId(SlaveID.newBuilder().setValue(slaveId).build())
- .setExecutor(executorInfo)
+ .setExecutor(createExecutorInfo(slaveId))
.setName(task.name)
.addResources(cpuResource)
.setData(ByteString.copyFrom(task.serializedTask))
@@ -272,7 +270,7 @@ private[spark] class MesosSchedulerBackend(
synchronized {
slaveIdsWithExecutors -= slaveId.getValue
}
- scheduler.slaveLost(slaveId.getValue, reason)
+ scheduler.executorLost(slaveId.getValue, reason)
}
override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) {
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 7a8ac10cdd..9893e9625d 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -16,7 +16,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-import spark.{CacheTracker, Logging, SizeEstimator, SparkEnv, SparkException, Utils}
+import spark.{Logging, SizeEstimator, SparkEnv, SparkException, Utils}
import spark.network._
import spark.serializer.Serializer
import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap}
@@ -30,6 +30,7 @@ extends Exception(message)
private[spark]
class BlockManager(
+ executorId: String,
actorSystem: ActorSystem,
val master: BlockManagerMaster,
val serializer: Serializer,
@@ -68,11 +69,8 @@ class BlockManager(
val connectionManager = new ConnectionManager(0)
implicit val futureExecContext = connectionManager.futureExecContext
- val connectionManagerId = connectionManager.id
- val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port)
-
- // TODO: This will be removed after cacheTracker is removed from the code base.
- var cacheTracker: CacheTracker = null
+ val blockManagerId = BlockManagerId(
+ executorId, connectionManager.id.host, connectionManager.id.port)
// Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
// for receiving shuffle outputs)
@@ -93,7 +91,10 @@ class BlockManager(
val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
- @volatile private var shuttingDown = false
+ // Pending reregistration action being executed asynchronously or null if none
+ // is pending. Accesses should synchronize on asyncReregisterLock.
+ var asyncReregisterTask: Future[Unit] = null
+ val asyncReregisterLock = new Object
private def heartBeat() {
if (!master.sendHeartBeat(blockManagerId)) {
@@ -109,8 +110,9 @@ class BlockManager(
/**
* Construct a BlockManager with a memory limit set based on system properties.
*/
- def this(actorSystem: ActorSystem, master: BlockManagerMaster, serializer: Serializer) = {
- this(actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties)
+ def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster,
+ serializer: Serializer) = {
+ this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties)
}
/**
@@ -150,6 +152,8 @@ class BlockManager(
/**
* Reregister with the master and report all blocks to it. This will be called by the heart beat
* thread if our heartbeat to the block amnager indicates that we were not registered.
+ *
+ * Note that this method must be called without any BlockInfo locks held.
*/
def reregister() {
// TODO: We might need to rate limit reregistering.
@@ -159,6 +163,32 @@ class BlockManager(
}
/**
+ * Reregister with the master sometime soon.
+ */
+ def asyncReregister() {
+ asyncReregisterLock.synchronized {
+ if (asyncReregisterTask == null) {
+ asyncReregisterTask = Future[Unit] {
+ reregister()
+ asyncReregisterLock.synchronized {
+ asyncReregisterTask = null
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * For testing. Wait for any pending asynchronous reregistration; otherwise, do nothing.
+ */
+ def waitForAsyncReregister() {
+ val task = asyncReregisterTask
+ if (task != null) {
+ Await.ready(task, Duration.Inf)
+ }
+ }
+
+ /**
* Get storage level of local block. If no info exists for the block, then returns null.
*/
def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
@@ -173,7 +203,7 @@ class BlockManager(
if (needReregister) {
logInfo("Got told to reregister updating block " + blockId)
// Reregistering will report our new block for free.
- reregister()
+ asyncReregister()
}
logDebug("Told master about block " + blockId)
}
@@ -191,7 +221,7 @@ class BlockManager(
case level =>
val inMem = level.useMemory && memoryStore.contains(blockId)
val onDisk = level.useDisk && diskStore.contains(blockId)
- val storageLevel = new StorageLevel(onDisk, inMem, level.deserialized, level.replication)
+ val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication)
val memSize = if (inMem) memoryStore.getSize(blockId) else 0L
val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L
(storageLevel, memSize, diskSize, info.tellMaster)
@@ -213,7 +243,7 @@ class BlockManager(
val startTimeMs = System.currentTimeMillis
var managers = master.getLocations(blockId)
val locations = managers.map(_.ip)
- logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs))
+ logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs))
return locations
}
@@ -223,7 +253,7 @@ class BlockManager(
def getLocations(blockIds: Array[String]): Array[Seq[String]] = {
val startTimeMs = System.currentTimeMillis
val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray
- logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
+ logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
return locations
}
@@ -615,7 +645,7 @@ class BlockManager(
var size = 0L
myInfo.synchronized {
- logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
if (level.useMemory) {
@@ -647,8 +677,10 @@ class BlockManager(
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
+
// Replicate block if required
if (level.replication > 1) {
+ val remoteStartTime = System.currentTimeMillis
// Serialize the block if not already done
if (bytesAfterPut == null) {
if (valuesAfterPut == null) {
@@ -658,16 +690,10 @@ class BlockManager(
bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
}
replicate(blockId, bytesAfterPut, level)
+ logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime))
}
-
BlockManager.dispose(bytesAfterPut)
- // TODO: This code will be removed when CacheTracker is gone.
- if (blockId.startsWith("rdd")) {
- notifyCacheTracker(blockId)
- }
- logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs))
-
return size
}
@@ -733,11 +759,6 @@ class BlockManager(
}
}
- // TODO: This code will be removed when CacheTracker is gone.
- if (blockId.startsWith("rdd")) {
- notifyCacheTracker(blockId)
- }
-
// If replication had started, then wait for it to finish
if (level.replication > 1) {
if (replicationFuture == null) {
@@ -760,8 +781,7 @@ class BlockManager(
*/
var cachedPeers: Seq[BlockManagerId] = null
private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
- val tLevel: StorageLevel =
- new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
+ val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
if (cachedPeers == null) {
cachedPeers = master.getPeers(blockManagerId, level.replication - 1)
}
@@ -780,16 +800,6 @@ class BlockManager(
}
}
- // TODO: This code will be removed when CacheTracker is gone.
- private def notifyCacheTracker(key: String) {
- if (cacheTracker != null) {
- val rddInfo = key.split("_")
- val rddId: Int = rddInfo(1).toInt
- val partition: Int = rddInfo(2).toInt
- cacheTracker.notifyFromBlockManager(spark.AddedToCache(rddId, partition, host))
- }
- }
-
/**
* Read a block consisting of a single object.
*/
@@ -940,6 +950,7 @@ class BlockManager(
blockInfo.clear()
memoryStore.clear()
diskStore.clear()
+ metadataCleaner.cancel()
logInfo("BlockManager stopped")
}
}
@@ -968,7 +979,7 @@ object BlockManager extends Logging {
*/
def dispose(buffer: ByteBuffer) {
if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) {
- logDebug("Unmapping " + buffer)
+ logTrace("Unmapping " + buffer)
if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) {
buffer.asInstanceOf[DirectBuffer].cleaner().clean()
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala
index 488679f049..f2f1e77d41 100644
--- a/core/src/main/scala/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerId.scala
@@ -3,38 +3,67 @@ package spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.ConcurrentHashMap
+/**
+ * This class represent an unique identifier for a BlockManager.
+ * The first 2 constructors of this class is made private to ensure that
+ * BlockManagerId objects can be created only using the factory method in
+ * [[spark.storage.BlockManager$]]. This allows de-duplication of ID objects.
+ * Also, constructor parameters are private to ensure that parameters cannot
+ * be modified from outside this class.
+ */
+private[spark] class BlockManagerId private (
+ private var executorId_ : String,
+ private var ip_ : String,
+ private var port_ : Int
+ ) extends Externalizable {
-private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
- def this() = this(null, 0) // For deserialization only
+ private def this() = this(null, null, 0) // For deserialization only
- def this(in: ObjectInput) = this(in.readUTF(), in.readInt())
+ def executorId: String = executorId_
+
+ def ip: String = ip_
+
+ def port: Int = port_
override def writeExternal(out: ObjectOutput) {
- out.writeUTF(ip)
- out.writeInt(port)
+ out.writeUTF(executorId_)
+ out.writeUTF(ip_)
+ out.writeInt(port_)
}
override def readExternal(in: ObjectInput) {
- ip = in.readUTF()
- port = in.readInt()
+ executorId_ = in.readUTF()
+ ip_ = in.readUTF()
+ port_ = in.readInt()
}
@throws(classOf[IOException])
private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
- override def toString = "BlockManagerId(" + ip + ", " + port + ")"
+ override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, ip, port)
- override def hashCode = ip.hashCode * 41 + port
+ override def hashCode: Int = (executorId.hashCode * 41 + ip.hashCode) * 41 + port
override def equals(that: Any) = that match {
- case id: BlockManagerId => port == id.port && ip == id.ip
- case _ => false
+ case id: BlockManagerId =>
+ executorId == id.executorId && port == id.port && ip == id.ip
+ case _ =>
+ false
}
}
private[spark] object BlockManagerId {
+ def apply(execId: String, ip: String, port: Int) =
+ getCachedBlockManagerId(new BlockManagerId(execId, ip, port))
+
+ def apply(in: ObjectInput) = {
+ val obj = new BlockManagerId()
+ obj.readExternal(in)
+ getCachedBlockManagerId(obj)
+ }
+
val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()
def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
index a3d8671834..36398095a2 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -1,6 +1,10 @@
package spark.storage
-import scala.collection.mutable.ArrayBuffer
+import java.io._
+import java.util.{HashMap => JHashMap}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.util.Random
import akka.actor.{Actor, ActorRef, ActorSystem, Props}
@@ -11,52 +15,51 @@ import akka.util.duration._
import spark.{Logging, SparkException, Utils}
-
private[spark] class BlockManagerMaster(
val actorSystem: ActorSystem,
- isMaster: Boolean,
+ isDriver: Boolean,
isLocal: Boolean,
- masterIp: String,
- masterPort: Int)
+ driverIp: String,
+ driverPort: Int)
extends Logging {
- val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
+ val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
- val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager"
+ val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager"
val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager"
val DEFAULT_MANAGER_IP: String = Utils.localHostName()
val timeout = 10.seconds
- var masterActor: ActorRef = {
- if (isMaster) {
- val masterActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)),
- name = MASTER_AKKA_ACTOR_NAME)
+ var driverActor: ActorRef = {
+ if (isDriver) {
+ val driverActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)),
+ name = DRIVER_AKKA_ACTOR_NAME)
logInfo("Registered BlockManagerMaster Actor")
- masterActor
+ driverActor
} else {
- val url = "akka://spark@%s:%s/user/%s".format(masterIp, masterPort, MASTER_AKKA_ACTOR_NAME)
+ val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, DRIVER_AKKA_ACTOR_NAME)
logInfo("Connecting to BlockManagerMaster: " + url)
actorSystem.actorFor(url)
}
}
- /** Remove a dead host from the master actor. This is only called on the master side. */
- def notifyADeadHost(host: String) {
- tell(RemoveHost(host))
- logInfo("Removed " + host + " successfully in notifyADeadHost")
+ /** Remove a dead executor from the driver actor. This is only called on the driver side. */
+ def removeExecutor(execId: String) {
+ tell(RemoveExecutor(execId))
+ logInfo("Removed " + execId + " successfully in removeExecutor")
}
/**
- * Send the master actor a heart beat from the slave. Returns true if everything works out,
- * false if the master does not know about the given block manager, which means the block
+ * Send the driver actor a heart beat from the slave. Returns true if everything works out,
+ * false if the driver does not know about the given block manager, which means the block
* manager should re-register.
*/
def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = {
- askMasterWithRetry[Boolean](HeartBeat(blockManagerId))
+ askDriverWithReply[Boolean](HeartBeat(blockManagerId))
}
- /** Register the BlockManager's id with the master. */
+ /** Register the BlockManager's id with the driver. */
def registerBlockManager(
blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
logInfo("Trying to register BlockManager")
@@ -70,25 +73,25 @@ private[spark] class BlockManagerMaster(
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long): Boolean = {
- val res = askMasterWithRetry[Boolean](
+ val res = askDriverWithReply[Boolean](
UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize))
logInfo("Updated info of block " + blockId)
res
}
- /** Get locations of the blockId from the master */
+ /** Get locations of the blockId from the driver */
def getLocations(blockId: String): Seq[BlockManagerId] = {
- askMasterWithRetry[Seq[BlockManagerId]](GetLocations(blockId))
+ askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId))
}
- /** Get locations of multiple blockIds from the master */
+ /** Get locations of multiple blockIds from the driver */
def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
- askMasterWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
+ askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
}
- /** Get ids of other nodes in the cluster from the master */
+ /** Get ids of other nodes in the cluster from the driver */
def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = {
- val result = askMasterWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
+ val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
if (result.length != numPeers) {
throw new SparkException(
"Error getting peers, only got " + result.size + " instead of " + numPeers)
@@ -98,10 +101,10 @@ private[spark] class BlockManagerMaster(
/**
* Remove a block from the slaves that have it. This can only be used to remove
- * blocks that the master knows about.
+ * blocks that the driver knows about.
*/
def removeBlock(blockId: String) {
- askMasterWithRetry(RemoveBlock(blockId))
+ askDriverWithReply(RemoveBlock(blockId))
}
/**
@@ -111,41 +114,41 @@ private[spark] class BlockManagerMaster(
* amount of remaining memory.
*/
def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = {
- askMasterWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
+ askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
}
- /** Stop the master actor, called only on the Spark master node */
+ /** Stop the driver actor, called only on the Spark driver node */
def stop() {
- if (masterActor != null) {
+ if (driverActor != null) {
tell(StopBlockManagerMaster)
- masterActor = null
+ driverActor = null
logInfo("BlockManagerMaster stopped")
}
}
/** Send a one-way message to the master actor, to which we expect it to reply with true. */
private def tell(message: Any) {
- if (!askMasterWithRetry[Boolean](message)) {
+ if (!askDriverWithReply[Boolean](message)) {
throw new SparkException("BlockManagerMasterActor returned false, expected true.")
}
}
/**
- * Send a message to the master actor and get its result within a default timeout, or
+ * Send a message to the driver actor and get its result within a default timeout, or
* throw a SparkException if this fails.
*/
- private def askMasterWithRetry[T](message: Any): T = {
+ private def askDriverWithReply[T](message: Any): T = {
// TODO: Consider removing multiple attempts
- if (masterActor == null) {
- throw new SparkException("Error sending message to BlockManager as masterActor is null " +
+ if (driverActor == null) {
+ throw new SparkException("Error sending message to BlockManager as driverActor is null " +
"[message = " + message + "]")
}
var attempts = 0
var lastException: Exception = null
- while (attempts < AKKA_RETRY_ATTEMPS) {
+ while (attempts < AKKA_RETRY_ATTEMPTS) {
attempts += 1
try {
- val future = masterActor.ask(message)(timeout)
+ val future = driverActor.ask(message)(timeout)
val result = Await.result(future, timeout)
if (result == null) {
throw new Exception("BlockManagerMaster returned null")
diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
index f4d026da33..2830bc6297 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
@@ -23,9 +23,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
private val blockManagerInfo =
new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo]
- // Mapping from host name to block manager id. We allow multiple block managers
- // on the same host name (ip).
- private val blockManagerIdByHost = new HashMap[String, ArrayBuffer[BlockManagerId]]
+ // Mapping from executor ID to block manager ID.
+ private val blockManagerIdByExecutor = new HashMap[String, BlockManagerId]
// Mapping from block id to the set of block managers that have the block.
private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]]
@@ -68,11 +67,14 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
case GetMemoryStatus =>
getMemoryStatus
+ case GetStorageStatus =>
+ getStorageStatus
+
case RemoveBlock(blockId) =>
removeBlock(blockId)
- case RemoveHost(host) =>
- removeHost(host)
+ case RemoveExecutor(execId) =>
+ removeExecutor(execId)
sender ! true
case StopBlockManagerMaster =>
@@ -96,16 +98,12 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
def removeBlockManager(blockManagerId: BlockManagerId) {
val info = blockManagerInfo(blockManagerId)
- // Remove the block manager from blockManagerIdByHost. If the list of block
- // managers belonging to the IP is empty, remove the entry from the hash map.
- blockManagerIdByHost.get(blockManagerId.ip).foreach { managers: ArrayBuffer[BlockManagerId] =>
- managers -= blockManagerId
- if (managers.size == 0) blockManagerIdByHost.remove(blockManagerId.ip)
- }
+ // Remove the block manager from blockManagerIdByExecutor.
+ blockManagerIdByExecutor -= blockManagerId.executorId
// Remove it from blockManagerInfo and remove all the blocks.
blockManagerInfo.remove(blockManagerId)
- var iterator = info.blocks.keySet.iterator
+ val iterator = info.blocks.keySet.iterator
while (iterator.hasNext) {
val blockId = iterator.next
val locations = blockLocations.get(blockId)._2
@@ -117,7 +115,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
def expireDeadHosts() {
- logDebug("Checking for hosts with no recent heart beats in BlockManagerMaster.")
+ logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.")
val now = System.currentTimeMillis()
val minSeenTime = now - slaveTimeout
val toRemove = new HashSet[BlockManagerId]
@@ -130,17 +128,15 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
toRemove.foreach(removeBlockManager)
}
- def removeHost(host: String) {
- logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.")
- logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq)
- blockManagerIdByHost.get(host).foreach(_.foreach(removeBlockManager))
- logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq)
+ def removeExecutor(execId: String) {
+ logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.")
+ blockManagerIdByExecutor.get(execId).foreach(removeBlockManager)
sender ! true
}
def heartBeat(blockManagerId: BlockManagerId) {
if (!blockManagerInfo.contains(blockManagerId)) {
- if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
+ if (blockManagerId.executorId == "<driver>" && !isLocal) {
sender ! true
} else {
sender ! false
@@ -177,24 +173,28 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
sender ! res
}
- private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " " + blockManagerId + " "
+ private def getStorageStatus() {
+ val res = blockManagerInfo.map { case(blockManagerId, info) =>
+ import collection.JavaConverters._
+ StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap)
+ }
+ sender ! res
+ }
- if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
- logInfo("Got Register Msg from master node, don't register it")
- } else {
- blockManagerIdByHost.get(blockManagerId.ip) match {
- case Some(managers) =>
- // A block manager of the same host name already exists.
- logInfo("Got another registration for host " + blockManagerId)
- managers += blockManagerId
+ private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
+ if (id.executorId == "<driver>" && !isLocal) {
+ // Got a register message from the master node; don't register it
+ } else if (!blockManagerInfo.contains(id)) {
+ blockManagerIdByExecutor.get(id.executorId) match {
+ case Some(manager) =>
+ // A block manager of the same host name already exists
+ logError("Got two different block manager registrations on " + id.executorId)
+ System.exit(1)
case None =>
- blockManagerIdByHost += (blockManagerId.ip -> ArrayBuffer(blockManagerId))
+ blockManagerIdByExecutor(id.executorId) = id
}
-
- blockManagerInfo += (blockManagerId -> new BlockManagerMasterActor.BlockManagerInfo(
- blockManagerId, System.currentTimeMillis(), maxMemSize, slaveActor))
+ blockManagerInfo(id) = new BlockManagerMasterActor.BlockManagerInfo(
+ id, System.currentTimeMillis(), maxMemSize, slaveActor)
}
sender ! true
}
@@ -206,11 +206,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
memSize: Long,
diskSize: Long) {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " " + blockManagerId + " " + blockId + " "
-
if (!blockManagerInfo.contains(blockManagerId)) {
- if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
+ if (blockManagerId.executorId == "<driver>" && !isLocal) {
// We intentionally do not register the master (except in local mode),
// so we should not indicate failure.
sender ! true
@@ -342,8 +339,8 @@ object BlockManagerMasterActor {
_lastSeenMs = System.currentTimeMillis()
}
- def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long)
- : Unit = synchronized {
+ def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long,
+ diskSize: Long) {
updateLastSeenMs()
diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala
index d73a9b790f..1494f90103 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala
@@ -54,11 +54,9 @@ class UpdateBlockInfo(
}
override def readExternal(in: ObjectInput) {
- blockManagerId = new BlockManagerId()
- blockManagerId.readExternal(in)
+ blockManagerId = BlockManagerId(in)
blockId = in.readUTF()
- storageLevel = new StorageLevel()
- storageLevel.readExternal(in)
+ storageLevel = StorageLevel(in)
memSize = in.readInt()
diskSize = in.readInt()
}
@@ -90,7 +88,7 @@ private[spark]
case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster
private[spark]
-case class RemoveHost(host: String) extends ToBlockManagerMaster
+case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
private[spark]
case object StopBlockManagerMaster extends ToBlockManagerMaster
@@ -100,3 +98,6 @@ case object GetMemoryStatus extends ToBlockManagerMaster
private[spark]
case object ExpireDeadHosts extends ToBlockManagerMaster
+
+private[spark]
+case object GetStorageStatus extends ToBlockManagerMaster
diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala
new file mode 100644
index 0000000000..eda320fa47
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala
@@ -0,0 +1,85 @@
+package spark.storage
+
+import akka.actor.{ActorRef, ActorSystem}
+import akka.pattern.ask
+import akka.util.Timeout
+import akka.util.duration._
+import cc.spray.directives._
+import cc.spray.typeconversion.TwirlSupport._
+import cc.spray.Directives
+import scala.collection.mutable.ArrayBuffer
+import spark.{Logging, SparkContext}
+import spark.util.AkkaUtils
+import spark.Utils
+
+
+/**
+ * Web UI server for the BlockManager inside each SparkContext.
+ */
+private[spark]
+class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, sc: SparkContext)
+ extends Directives with Logging {
+
+ val STATIC_RESOURCE_DIR = "spark/deploy/static"
+
+ implicit val timeout = Timeout(10 seconds)
+
+ /** Start a HTTP server to run the Web interface */
+ def start() {
+ try {
+ val port = if (System.getProperty("spark.ui.port") != null) {
+ System.getProperty("spark.ui.port").toInt
+ } else {
+ // TODO: Unfortunately, it's not possible to pass port 0 to spray and figure out which
+ // random port it bound to, so we have to try to find a local one by creating a socket.
+ Utils.findFreePort()
+ }
+ AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, handler, "BlockManagerHTTPServer")
+ logInfo("Started BlockManager web UI at http://%s:%d".format(Utils.localHostName(), port))
+ } catch {
+ case e: Exception =>
+ logError("Failed to create BlockManager WebUI", e)
+ System.exit(1)
+ }
+ }
+
+ val handler = {
+ get {
+ path("") {
+ completeWith {
+ // Request the current storage status from the Master
+ val future = blockManagerMaster ? GetStorageStatus
+ future.map { status =>
+ // Calculate macro-level statistics
+ val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray
+ val maxMem = storageStatusList.map(_.maxMem).reduce(_+_)
+ val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_)
+ val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize))
+ .reduceOption(_+_).getOrElse(0L)
+ val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc)
+ spark.storage.html.index.
+ render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList)
+ }
+ }
+ } ~
+ path("rdd") {
+ parameter("id") { id =>
+ completeWith {
+ val future = blockManagerMaster ? GetStorageStatus
+ future.map { status =>
+ val prefix = "rdd_" + id.toString
+ val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray
+ val filteredStorageStatusList = StorageUtils.
+ filterStorageStatusByPrefix(storageStatusList, prefix)
+ val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
+ spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList)
+ }
+ }
+ }
+ } ~
+ pathPrefix("static") {
+ getFromResourceDirectory(STATIC_RESOURCE_DIR)
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala
index 3f234df654..30d7500e01 100644
--- a/core/src/main/scala/spark/storage/BlockMessage.scala
+++ b/core/src/main/scala/spark/storage/BlockMessage.scala
@@ -64,7 +64,7 @@ private[spark] class BlockMessage() {
val booleanInt = buffer.getInt()
val replication = buffer.getInt()
- level = new StorageLevel(booleanInt, replication)
+ level = StorageLevel(booleanInt, replication)
val dataLength = buffer.getInt()
data = ByteBuffer.allocate(dataLength)
diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala
index 00e32f753c..ae88ff0bb1 100644
--- a/core/src/main/scala/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/spark/storage/MemoryStore.scala
@@ -17,7 +17,6 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true)
private var currentMemory = 0L
-
// Object used to ensure that only one thread is putting blocks and if necessary, dropping
// blocks from the memory store.
private val putLock = new Object()
diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala
index e3544e5aae..3b5a77ab22 100644
--- a/core/src/main/scala/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/spark/storage/StorageLevel.scala
@@ -7,25 +7,30 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
* whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory
* in a serialized format, and whether to replicate the RDD partitions on multiple nodes.
* The [[spark.storage.StorageLevel$]] singleton object contains some static constants for
- * commonly useful storage levels.
+ * commonly useful storage levels. To create your own storage level object, use the factor method
+ * of the singleton object (`StorageLevel(...)`).
*/
-class StorageLevel(
- var useDisk: Boolean,
- var useMemory: Boolean,
- var deserialized: Boolean,
- var replication: Int = 1)
+class StorageLevel private(
+ private var useDisk_ : Boolean,
+ private var useMemory_ : Boolean,
+ private var deserialized_ : Boolean,
+ private var replication_ : Int = 1)
extends Externalizable {
// TODO: Also add fields for caching priority, dataset ID, and flushing.
-
- assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes")
-
- def this(flags: Int, replication: Int) {
+ private def this(flags: Int, replication: Int) {
this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication)
}
def this() = this(false, true, false) // For deserialization
+ def useDisk = useDisk_
+ def useMemory = useMemory_
+ def deserialized = deserialized_
+ def replication = replication_
+
+ assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes")
+
override def clone(): StorageLevel = new StorageLevel(
this.useDisk, this.useMemory, this.deserialized, this.replication)
@@ -43,13 +48,13 @@ class StorageLevel(
def toInt: Int = {
var ret = 0
- if (useDisk) {
+ if (useDisk_) {
ret |= 4
}
- if (useMemory) {
+ if (useMemory_) {
ret |= 2
}
- if (deserialized) {
+ if (deserialized_) {
ret |= 1
}
return ret
@@ -57,15 +62,15 @@ class StorageLevel(
override def writeExternal(out: ObjectOutput) {
out.writeByte(toInt)
- out.writeByte(replication)
+ out.writeByte(replication_)
}
override def readExternal(in: ObjectInput) {
val flags = in.readByte()
- useDisk = (flags & 4) != 0
- useMemory = (flags & 2) != 0
- deserialized = (flags & 1) != 0
- replication = in.readByte()
+ useDisk_ = (flags & 4) != 0
+ useMemory_ = (flags & 2) != 0
+ deserialized_ = (flags & 1) != 0
+ replication_ = in.readByte()
}
@throws(classOf[IOException])
@@ -75,6 +80,14 @@ class StorageLevel(
"StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
override def hashCode(): Int = toInt * 41 + replication
+ def description : String = {
+ var result = ""
+ result += (if (useDisk) "Disk " else "")
+ result += (if (useMemory) "Memory " else "")
+ result += (if (deserialized) "Deserialized " else "Serialized")
+ result += "%sx Replicated".format(replication)
+ result
+ }
}
@@ -91,6 +104,21 @@ object StorageLevel {
val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false)
val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2)
+ /** Create a new StorageLevel object */
+ def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, replication: Int = 1) =
+ getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication))
+
+ /** Create a new StorageLevel object from its integer representation */
+ def apply(flags: Int, replication: Int) =
+ getCachedStorageLevel(new StorageLevel(flags, replication))
+
+ /** Read StorageLevel object from ObjectInput stream */
+ def apply(in: ObjectInput) = {
+ val obj = new StorageLevel()
+ obj.readExternal(in)
+ getCachedStorageLevel(obj)
+ }
+
private[spark]
val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]()
diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala
new file mode 100644
index 0000000000..a10e3a95c6
--- /dev/null
+++ b/core/src/main/scala/spark/storage/StorageUtils.scala
@@ -0,0 +1,78 @@
+package spark.storage
+
+import spark.SparkContext
+import BlockManagerMasterActor.BlockStatus
+
+private[spark]
+case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
+ blocks: Map[String, BlockStatus]) {
+
+ def memUsed(blockPrefix: String = "") = {
+ blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize).
+ reduceOption(_+_).getOrElse(0l)
+ }
+
+ def diskUsed(blockPrefix: String = "") = {
+ blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize).
+ reduceOption(_+_).getOrElse(0l)
+ }
+
+ def memRemaining : Long = maxMem - memUsed()
+
+}
+
+case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
+ numPartitions: Int, memSize: Long, diskSize: Long)
+
+
+/* Helper methods for storage-related objects */
+private[spark]
+object StorageUtils {
+
+ /* Given the current storage status of the BlockManager, returns information for each RDD */
+ def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus],
+ sc: SparkContext) : Array[RDDInfo] = {
+ rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc)
+ }
+
+ /* Given a list of BlockStatus objets, returns information for each RDD */
+ def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
+ sc: SparkContext) : Array[RDDInfo] = {
+ // Find all RDD Blocks (ignore broadcast variables)
+ val rddBlocks = infos.filterKeys(_.startsWith("rdd"))
+
+ // Group by rddId, ignore the partition name
+ val groupedRddBlocks = infos.groupBy { case(k, v) =>
+ k.substring(0,k.lastIndexOf('_'))
+ }.mapValues(_.values.toArray)
+
+ // For each RDD, generate an RDDInfo object
+ groupedRddBlocks.map { case(rddKey, rddBlocks) =>
+
+ // Add up memory and disk sizes
+ val memSize = rddBlocks.map(_.memSize).reduce(_ + _)
+ val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _)
+
+ // Find the id of the RDD, e.g. rdd_1 => 1
+ val rddId = rddKey.split("_").last.toInt
+ // Get the friendly name for the rdd, if available.
+ val rddName = Option(sc.persistentRdds(rddId).name).getOrElse(rddKey)
+ val rddStorageLevel = sc.persistentRdds(rddId).getStorageLevel
+
+ RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize)
+ }.toArray
+ }
+
+ /* Removes all BlockStatus object that are not part of a block prefix */
+ def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus],
+ prefix: String) : Array[StorageStatus] = {
+
+ storageStatusList.map { status =>
+ val newBlocks = status.blocks.filterKeys(_.startsWith(prefix))
+ //val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _)
+ StorageStatus(status.blockManagerId, status.maxMem, newBlocks)
+ }
+
+ }
+
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala
index 689f07b969..a70d1c8e78 100644
--- a/core/src/main/scala/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/spark/storage/ThreadingTest.scala
@@ -75,10 +75,11 @@ private[spark] object ThreadingTest {
System.setProperty("spark.kryoserializer.buffer.mb", "1")
val actorSystem = ActorSystem("test")
val serializer = new KryoSerializer
- val masterIp: String = System.getProperty("spark.master.host", "localhost")
- val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt
- val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, masterIp, masterPort)
- val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer, 1024 * 1024)
+ val driverIp: String = System.getProperty("spark.driver.host", "localhost")
+ val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
+ val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, driverIp, driverPort)
+ val blockManager = new BlockManager(
+ "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024)
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
producers.foreach(_.start)
diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala
index e67cb0336d..e43fbd6b1c 100644
--- a/core/src/main/scala/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/spark/util/AkkaUtils.scala
@@ -1,6 +1,6 @@
package spark.util
-import akka.actor.{Props, ActorSystemImpl, ActorSystem}
+import akka.actor.{ActorRef, Props, ActorSystemImpl, ActorSystem}
import com.typesafe.config.ConfigFactory
import akka.util.duration._
import akka.pattern.ask
@@ -30,8 +30,10 @@ private[spark] object AkkaUtils {
val akkaConf = ConfigFactory.parseString("""
akka.daemonic = on
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
+ akka.stdout-loglevel = "ERROR"
akka.actor.provider = "akka.remote.RemoteActorRefProvider"
akka.remote.transport = "akka.remote.netty.NettyRemoteTransport"
+ akka.remote.log-remote-lifecycle-events = on
akka.remote.netty.hostname = "%s"
akka.remote.netty.port = %d
akka.remote.netty.connection-timeout = %ds
@@ -51,21 +53,22 @@ private[spark] object AkkaUtils {
/**
* Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to
- * handle requests. Throws a SparkException if this fails.
+ * handle requests. Returns the bound port or throws a SparkException on failure.
*/
- def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route) {
+ def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route,
+ name: String = "HttpServer"): ActorRef = {
val ioWorker = new IoWorker(actorSystem).start()
val httpService = actorSystem.actorOf(Props(new HttpService(route)))
val rootService = actorSystem.actorOf(Props(new SprayCanRootService(httpService)))
val server = actorSystem.actorOf(
- Props(new HttpServer(ioWorker, SingletonHandler(rootService))), name = "HttpServer")
+ Props(new HttpServer(ioWorker, SingletonHandler(rootService))), name = name)
actorSystem.registerOnTermination { ioWorker.stop() }
val timeout = 3.seconds
val future = server.ask(HttpServer.Bind(ip, port))(timeout)
try {
Await.result(future, timeout) match {
case bound: HttpServer.Bound =>
- return
+ return server
case other: Any =>
throw new SparkException("Failed to bind web UI to port " + port + ": " + other)
}
diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala
index 19e67acd0c..a342d378ff 100644
--- a/core/src/main/scala/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/spark/util/MetadataCleaner.scala
@@ -4,28 +4,30 @@ import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors}
import java.util.{TimerTask, Timer}
import spark.Logging
-class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging {
- val delaySeconds = (System.getProperty("spark.cleanup.delay", "-100").toDouble * 60).toInt
- val periodSeconds = math.max(10, delaySeconds / 10)
- val timer = new Timer(name + " cleanup timer", true)
+/**
+ * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries)
+ */
+class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging {
+ private val delaySeconds = MetadataCleaner.getDelaySeconds
+ private val periodSeconds = math.max(10, delaySeconds / 10)
+ private val timer = new Timer(name + " cleanup timer", true)
- val task = new TimerTask {
- def run() {
+ private val task = new TimerTask {
+ override def run() {
try {
- if (delaySeconds > 0) {
- cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000))
- logInfo("Ran metadata cleaner for " + name)
- }
+ cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000))
+ logInfo("Ran metadata cleaner for " + name)
} catch {
case e: Exception => logError("Error running cleanup task for " + name, e)
}
}
}
- if (periodSeconds > 0) {
- logInfo(
- "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and "
- + "period of " + periodSeconds + " secs")
+
+ if (delaySeconds > 0) {
+ logDebug(
+ "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " +
+ "and period of " + periodSeconds + " secs")
timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000)
}
@@ -33,3 +35,10 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging
timer.cancel()
}
}
+
+
+object MetadataCleaner {
+ def getDelaySeconds = System.getProperty("spark.cleaner.delay", "-1").toInt
+ def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.delay", delay.toString) }
+}
+
diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala
new file mode 100644
index 0000000000..e3f00ea8c7
--- /dev/null
+++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala
@@ -0,0 +1,62 @@
+package spark.util
+
+import scala.annotation.tailrec
+
+import java.io.OutputStream
+import java.util.concurrent.TimeUnit._
+
+class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream {
+ val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS)
+ val CHUNK_SIZE = 8192
+ var lastSyncTime = System.nanoTime
+ var bytesWrittenSinceSync: Long = 0
+
+ override def write(b: Int) {
+ waitToWrite(1)
+ out.write(b)
+ }
+
+ override def write(bytes: Array[Byte]) {
+ write(bytes, 0, bytes.length)
+ }
+
+ @tailrec
+ override final def write(bytes: Array[Byte], offset: Int, length: Int) {
+ val writeSize = math.min(length - offset, CHUNK_SIZE)
+ if (writeSize > 0) {
+ waitToWrite(writeSize)
+ out.write(bytes, offset, writeSize)
+ write(bytes, offset + writeSize, length)
+ }
+ }
+
+ override def flush() {
+ out.flush()
+ }
+
+ override def close() {
+ out.close()
+ }
+
+ @tailrec
+ private def waitToWrite(numBytes: Int) {
+ val now = System.nanoTime
+ val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS)
+ val rate = bytesWrittenSinceSync.toDouble / elapsedSecs
+ if (rate < bytesPerSec) {
+ // It's okay to write; just update some variables and return
+ bytesWrittenSinceSync += numBytes
+ if (now > lastSyncTime + SYNC_INTERVAL) {
+ // Sync interval has passed; let's resync
+ lastSyncTime = now
+ bytesWrittenSinceSync = numBytes
+ }
+ } else {
+ // Calculate how much time we should sleep to bring ourselves to the desired rate.
+ // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala)
+ val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS)
+ if (sleepTime > 0) Thread.sleep(sleepTime)
+ waitToWrite(numBytes)
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala
index 070ee19ac0..188f8910da 100644
--- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala
@@ -1,16 +1,16 @@
package spark.util
import java.util.concurrent.ConcurrentHashMap
-import scala.collection.JavaConversions._
-import scala.collection.mutable.{HashMap, Map}
+import scala.collection.JavaConversions
+import scala.collection.mutable.Map
/**
* This is a custom implementation of scala.collection.mutable.Map which stores the insertion
* time stamp along with each key-value pair. Key-value pairs that are older than a particular
- * threshold time can them be removed using the cleanup method. This is intended to be a drop-in
+ * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in
* replacement of scala.collection.mutable.HashMap.
*/
-class TimeStampedHashMap[A, B] extends Map[A, B]() {
+class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging {
val internalMap = new ConcurrentHashMap[A, (B, Long)]()
def get(key: A): Option[B] = {
@@ -20,7 +20,7 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() {
def iterator: Iterator[(A, B)] = {
val jIterator = internalMap.entrySet().iterator()
- jIterator.map(kv => (kv.getKey, kv.getValue._1))
+ JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1))
}
override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = {
@@ -31,8 +31,10 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() {
}
override def - (key: A): Map[A, B] = {
- internalMap.remove(key)
- this
+ val newMap = new TimeStampedHashMap[A, B]
+ newMap.internalMap.putAll(this.internalMap)
+ newMap.internalMap.remove(key)
+ newMap
}
override def += (kv: (A, B)): this.type = {
@@ -56,14 +58,14 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() {
}
override def filter(p: ((A, B)) => Boolean): Map[A, B] = {
- internalMap.map(kv => (kv._1, kv._2._1)).filter(p)
+ JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p)
}
override def empty: Map[A, B] = new TimeStampedHashMap[A, B]()
- override def size(): Int = internalMap.size()
+ override def size: Int = internalMap.size
- override def foreach[U](f: ((A, B)) => U): Unit = {
+ override def foreach[U](f: ((A, B)) => U) {
val iterator = internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
@@ -72,11 +74,15 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() {
}
}
- def cleanup(threshTime: Long) {
+ /**
+ * Removes old key-value pairs that have timestamp earlier than `threshTime`
+ */
+ def clearOldValues(threshTime: Long) {
val iterator = internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
if (entry.getValue._2 < threshTime) {
+ logDebug("Removing key " + entry.getKey)
iterator.remove()
}
}
diff --git a/core/src/main/scala/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/spark/util/TimeStampedHashSet.scala
new file mode 100644
index 0000000000..5f1cc93752
--- /dev/null
+++ b/core/src/main/scala/spark/util/TimeStampedHashSet.scala
@@ -0,0 +1,69 @@
+package spark.util
+
+import scala.collection.mutable.Set
+import scala.collection.JavaConversions
+import java.util.concurrent.ConcurrentHashMap
+
+
+class TimeStampedHashSet[A] extends Set[A] {
+ val internalMap = new ConcurrentHashMap[A, Long]()
+
+ def contains(key: A): Boolean = {
+ internalMap.contains(key)
+ }
+
+ def iterator: Iterator[A] = {
+ val jIterator = internalMap.entrySet().iterator()
+ JavaConversions.asScalaIterator(jIterator).map(_.getKey)
+ }
+
+ override def + (elem: A): Set[A] = {
+ val newSet = new TimeStampedHashSet[A]
+ newSet ++= this
+ newSet += elem
+ newSet
+ }
+
+ override def - (elem: A): Set[A] = {
+ val newSet = new TimeStampedHashSet[A]
+ newSet ++= this
+ newSet -= elem
+ newSet
+ }
+
+ override def += (key: A): this.type = {
+ internalMap.put(key, currentTime)
+ this
+ }
+
+ override def -= (key: A): this.type = {
+ internalMap.remove(key)
+ this
+ }
+
+ override def empty: Set[A] = new TimeStampedHashSet[A]()
+
+ override def size(): Int = internalMap.size()
+
+ override def foreach[U](f: (A) => U): Unit = {
+ val iterator = internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ f(iterator.next.getKey)
+ }
+ }
+
+ /**
+ * Removes old values that have timestamp earlier than `threshTime`
+ */
+ def clearOldValues(threshTime: Long) {
+ val iterator = internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ val entry = iterator.next()
+ if (entry.getValue < threshTime) {
+ iterator.remove()
+ }
+ }
+ }
+
+ private def currentTime: Long = System.currentTimeMillis()
+}
diff --git a/core/src/main/twirl/spark/deploy/common/layout.scala.html b/core/src/main/twirl/spark/common/layout.scala.html
index b9192060aa..b9192060aa 100644
--- a/core/src/main/twirl/spark/deploy/common/layout.scala.html
+++ b/core/src/main/twirl/spark/common/layout.scala.html
diff --git a/core/src/main/twirl/spark/deploy/master/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html
index 18c32e5a1f..285645c389 100644
--- a/core/src/main/twirl/spark/deploy/master/index.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/index.scala.html
@@ -2,7 +2,7 @@
@import spark.deploy.master._
@import spark.Utils
-@spark.deploy.common.html.layout(title = "Spark Master on " + state.uri) {
+@spark.common.html.layout(title = "Spark Master on " + state.uri) {
<!-- Cluster Details -->
<div class="row">
diff --git a/core/src/main/twirl/spark/deploy/master/job_details.scala.html b/core/src/main/twirl/spark/deploy/master/job_details.scala.html
index dcf41c28f2..d02a51b214 100644
--- a/core/src/main/twirl/spark/deploy/master/job_details.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/job_details.scala.html
@@ -1,6 +1,6 @@
@(job: spark.deploy.master.JobInfo)
-@spark.deploy.common.html.layout(title = "Job Details") {
+@spark.common.html.layout(title = "Job Details") {
<!-- Job Details -->
<div class="row">
diff --git a/core/src/main/twirl/spark/deploy/worker/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html
index b247307dab..1d703dae58 100644
--- a/core/src/main/twirl/spark/deploy/worker/index.scala.html
+++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html
@@ -1,8 +1,7 @@
@(worker: spark.deploy.WorkerState)
-
@import spark.Utils
-@spark.deploy.common.html.layout(title = "Spark Worker on " + worker.uri) {
+@spark.common.html.layout(title = "Spark Worker on " + worker.uri) {
<!-- Worker Details -->
<div class="row">
diff --git a/core/src/main/twirl/spark/storage/index.scala.html b/core/src/main/twirl/spark/storage/index.scala.html
new file mode 100644
index 0000000000..2b337f6133
--- /dev/null
+++ b/core/src/main/twirl/spark/storage/index.scala.html
@@ -0,0 +1,40 @@
+@(maxMem: Long, remainingMem: Long, diskSpaceUsed: Long, rdds: Array[spark.storage.RDDInfo], storageStatusList: Array[spark.storage.StorageStatus])
+@import spark.Utils
+
+@spark.common.html.layout(title = "Storage Dashboard") {
+
+ <!-- High-Level Information -->
+ <div class="row">
+ <div class="span12">
+ <ul class="unstyled">
+ <li><strong>Memory:</strong>
+ @{Utils.memoryBytesToString(maxMem - remainingMem)} Used
+ (@{Utils.memoryBytesToString(remainingMem)} Available) </li>
+ <li><strong>Disk:</strong> @{Utils.memoryBytesToString(diskSpaceUsed)} Used </li>
+ </ul>
+ </div>
+ </div>
+
+ <hr/>
+
+ <!-- RDD Summary -->
+ <div class="row">
+ <div class="span12">
+ <h3> RDD Summary </h3>
+ <br/>
+ @rdd_table(rdds)
+ </div>
+ </div>
+
+ <hr/>
+
+ <!-- Worker Summary -->
+ <div class="row">
+ <div class="span12">
+ <h3> Worker Summary </h3>
+ <br/>
+ @worker_table(storageStatusList)
+ </div>
+ </div>
+
+} \ No newline at end of file
diff --git a/core/src/main/twirl/spark/storage/rdd.scala.html b/core/src/main/twirl/spark/storage/rdd.scala.html
new file mode 100644
index 0000000000..ac7f8c981f
--- /dev/null
+++ b/core/src/main/twirl/spark/storage/rdd.scala.html
@@ -0,0 +1,77 @@
+@(rddInfo: spark.storage.RDDInfo, storageStatusList: Array[spark.storage.StorageStatus])
+@import spark.Utils
+
+@spark.common.html.layout(title = "RDD Info ") {
+
+ <!-- High-Level Information -->
+ <div class="row">
+ <div class="span12">
+ <ul class="unstyled">
+ <li>
+ <strong>Storage Level:</strong>
+ @(rddInfo.storageLevel.description)
+ <li>
+ <strong>Partitions:</strong>
+ @(rddInfo.numPartitions)
+ </li>
+ <li>
+ <strong>Memory Size:</strong>
+ @{Utils.memoryBytesToString(rddInfo.memSize)}
+ </li>
+ <li>
+ <strong>Disk Size:</strong>
+ @{Utils.memoryBytesToString(rddInfo.diskSize)}
+ </li>
+ </ul>
+ </div>
+ </div>
+
+ <hr/>
+
+ <!-- RDD Summary -->
+ <div class="row">
+ <div class="span12">
+ <h3> RDD Summary </h3>
+ <br/>
+
+
+ <!-- Block Table Summary -->
+ <table class="table table-bordered table-striped table-condensed sortable">
+ <thead>
+ <tr>
+ <th>Block Name</th>
+ <th>Storage Level</th>
+ <th>Size in Memory</th>
+ <th>Size on Disk</th>
+ </tr>
+ </thead>
+ <tbody>
+ @storageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1).map { case (k,v) =>
+ <tr>
+ <td>@k</td>
+ <td>
+ @(v.storageLevel.description)
+ </td>
+ <td>@{Utils.memoryBytesToString(v.memSize)}</td>
+ <td>@{Utils.memoryBytesToString(v.diskSize)}</td>
+ </tr>
+ }
+ </tbody>
+ </table>
+
+
+ </div>
+ </div>
+
+ <hr/>
+
+ <!-- Worker Table -->
+ <div class="row">
+ <div class="span12">
+ <h3> Worker Summary </h3>
+ <br/>
+ @worker_table(storageStatusList, "rdd_" + rddInfo.id )
+ </div>
+ </div>
+
+} \ No newline at end of file
diff --git a/core/src/main/twirl/spark/storage/rdd_table.scala.html b/core/src/main/twirl/spark/storage/rdd_table.scala.html
new file mode 100644
index 0000000000..af801cf229
--- /dev/null
+++ b/core/src/main/twirl/spark/storage/rdd_table.scala.html
@@ -0,0 +1,30 @@
+@(rdds: Array[spark.storage.RDDInfo])
+@import spark.Utils
+
+<table class="table table-bordered table-striped table-condensed sortable">
+ <thead>
+ <tr>
+ <th>RDD Name</th>
+ <th>Storage Level</th>
+ <th>Partitions</th>
+ <th>Size in Memory</th>
+ <th>Size on Disk</th>
+ </tr>
+ </thead>
+ <tbody>
+ @for(rdd <- rdds) {
+ <tr>
+ <td>
+ <a href="rdd?id=@(rdd.id)">
+ @rdd.name
+ </a>
+ </td>
+ <td>@(rdd.storageLevel.description)
+ </td>
+ <td>@rdd.numPartitions</td>
+ <td>@{Utils.memoryBytesToString(rdd.memSize)}</td>
+ <td>@{Utils.memoryBytesToString(rdd.diskSize)}</td>
+ </tr>
+ }
+ </tbody>
+</table> \ No newline at end of file
diff --git a/core/src/main/twirl/spark/storage/worker_table.scala.html b/core/src/main/twirl/spark/storage/worker_table.scala.html
new file mode 100644
index 0000000000..d54b8de4cc
--- /dev/null
+++ b/core/src/main/twirl/spark/storage/worker_table.scala.html
@@ -0,0 +1,24 @@
+@(workersStatusList: Array[spark.storage.StorageStatus], prefix: String = "")
+@import spark.Utils
+
+<table class="table table-bordered table-striped table-condensed sortable">
+ <thead>
+ <tr>
+ <th>Host</th>
+ <th>Memory Usage</th>
+ <th>Disk Usage</th>
+ </tr>
+ </thead>
+ <tbody>
+ @for(status <- workersStatusList) {
+ <tr>
+ <td>@(status.blockManagerId.ip + ":" + status.blockManagerId.port)</td>
+ <td>
+ @(Utils.memoryBytesToString(status.memUsed(prefix)))
+ (@(Utils.memoryBytesToString(status.memRemaining)) Total Available)
+ </td>
+ <td>@(Utils.memoryBytesToString(status.diskUsed(prefix)))</td>
+ </tr>
+ }
+ </tbody>
+</table> \ No newline at end of file
diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties
index 4c99e450bc..6ec89c0184 100644
--- a/core/src/test/resources/log4j.properties
+++ b/core/src/test/resources/log4j.properties
@@ -1,8 +1,8 @@
-# Set everything to be logged to the console
+# Set everything to be logged to the file core/target/unit-tests.log
log4j.rootCategory=INFO, file
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
-log4j.appender.file.file=spark-tests.log
+log4j.appender.file.file=core/target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala
index d8be99dde7..ac8ae7d308 100644
--- a/core/src/test/scala/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/spark/AccumulatorSuite.scala
@@ -1,6 +1,5 @@
package spark
-import org.scalatest.BeforeAndAfter
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import collection.mutable
@@ -9,18 +8,7 @@ import scala.math.exp
import scala.math.signum
import spark.SparkContext._
-class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
-
- var sc: SparkContext = null
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
test ("basic accumulation"){
sc = new SparkContext("local", "test")
@@ -29,6 +17,12 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
val d = sc.parallelize(1 to 20)
d.foreach{x => acc += x}
acc.value should be (210)
+
+
+ val longAcc = sc.accumulator(0l)
+ val maxInt = Integer.MAX_VALUE.toLong
+ d.foreach{x => longAcc += maxInt + x}
+ longAcc.value should be (210l + maxInt * 20)
}
test ("value not assignable from tasks") {
@@ -53,10 +47,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
for (i <- 1 to maxI) {
v should contain(i)
}
- sc.stop()
- sc = null
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ resetSparkContext()
}
}
@@ -86,10 +77,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
x => acc.value += x
}
} should produce [SparkException]
- sc.stop()
- sc = null
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ resetSparkContext()
}
}
@@ -115,10 +103,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
bufferAcc.value should contain(i)
mapAcc.value should contain (i -> i.toString)
}
- sc.stop()
- sc = null
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
+ resetSparkContext()
}
}
@@ -134,8 +119,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
x => acc.localValue ++= x
}
acc.value should be ( (0 to maxI).toSet)
- sc.stop()
- sc = null
+ resetSparkContext()
}
}
diff --git a/core/src/test/scala/spark/BroadcastSuite.scala b/core/src/test/scala/spark/BroadcastSuite.scala
index 2d3302f0aa..362a31fb0d 100644
--- a/core/src/test/scala/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/spark/BroadcastSuite.scala
@@ -1,20 +1,8 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
-class BroadcastSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class BroadcastSuite extends FunSuite with LocalSparkContext {
test("basic broadcast") {
sc = new SparkContext("local", "test")
diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala
deleted file mode 100644
index 467605981b..0000000000
--- a/core/src/test/scala/spark/CacheTrackerSuite.scala
+++ /dev/null
@@ -1,131 +0,0 @@
-package spark
-
-import org.scalatest.FunSuite
-
-import scala.collection.mutable.HashMap
-
-import akka.actor._
-import akka.dispatch._
-import akka.pattern.ask
-import akka.remote._
-import akka.util.Duration
-import akka.util.Timeout
-import akka.util.duration._
-
-class CacheTrackerSuite extends FunSuite {
- // Send a message to an actor and wait for a reply, in a blocking manner
- private def ask(actor: ActorRef, message: Any): Any = {
- try {
- val timeout = 10.seconds
- val future = actor.ask(message)(timeout)
- return Await.result(future, timeout)
- } catch {
- case e: Exception =>
- throw new SparkException("Error communicating with actor", e)
- }
- }
-
- test("CacheTrackerActor slave initialization & cache status") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 0L)))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("RegisterRDD") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 3)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(getCacheLocations(tracker) === Map(1 -> List(Nil, Nil, Nil), 2 -> List(Nil)))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("AddedToCache") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 2)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true)
- assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true)
- assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L)))
-
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("DroppedFromCache") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 2)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true)
- assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true)
- assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L)))
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
-
- assert(ask(tracker, DroppedFromCache(1, 1, "host001", 2L << 11)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 68608L)))
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- /**
- * Helper function to get cacheLocations from CacheTracker
- */
- def getCacheLocations(tracker: ActorRef): HashMap[Int, List[List[String]]] = {
- val answer = ask(tracker, GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
- answer.map { case (i, arr) => (i, arr.toList) }
- }
-}
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
new file mode 100644
index 0000000000..0b74607fb8
--- /dev/null
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -0,0 +1,355 @@
+package spark
+
+import org.scalatest.FunSuite
+import java.io.File
+import spark.rdd._
+import spark.SparkContext._
+import storage.StorageLevel
+
+class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
+ initLogging()
+
+ var checkpointDir: File = _
+ val partitioner = new HashPartitioner(2)
+
+ override def beforeEach() {
+ super.beforeEach()
+ checkpointDir = File.createTempFile("temp", "")
+ checkpointDir.delete()
+ sc = new SparkContext("local", "test")
+ sc.setCheckpointDir(checkpointDir.toString)
+ }
+
+ override def afterEach() {
+ super.afterEach()
+ if (checkpointDir != null) {
+ checkpointDir.delete()
+ }
+ }
+
+ test("RDDs with one-to-one dependencies") {
+ testCheckpointing(_.map(x => x.toString))
+ testCheckpointing(_.flatMap(x => 1 to x))
+ testCheckpointing(_.filter(_ % 2 == 0))
+ testCheckpointing(_.sample(false, 0.5, 0))
+ testCheckpointing(_.glom())
+ testCheckpointing(_.mapPartitions(_.map(_.toString)))
+ testCheckpointing(r => new MapPartitionsWithSplitRDD(r,
+ (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false ))
+ testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
+ testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
+ testCheckpointing(_.pipe(Seq("cat")))
+ }
+
+ test("ParallelCollection") {
+ val parCollection = sc.makeRDD(1 to 4, 2)
+ val numSplits = parCollection.splits.size
+ parCollection.checkpoint()
+ assert(parCollection.dependencies === Nil)
+ val result = parCollection.collect()
+ assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result)
+ assert(parCollection.dependencies != Nil)
+ assert(parCollection.splits.length === numSplits)
+ assert(parCollection.splits.toList === parCollection.checkpointData.get.getSplits.toList)
+ assert(parCollection.collect() === result)
+ }
+
+ test("BlockRDD") {
+ val blockId = "id"
+ val blockManager = SparkEnv.get.blockManager
+ blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
+ val blockRDD = new BlockRDD[String](sc, Array(blockId))
+ val numSplits = blockRDD.splits.size
+ blockRDD.checkpoint()
+ val result = blockRDD.collect()
+ assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result)
+ assert(blockRDD.dependencies != Nil)
+ assert(blockRDD.splits.length === numSplits)
+ assert(blockRDD.splits.toList === blockRDD.checkpointData.get.getSplits.toList)
+ assert(blockRDD.collect() === result)
+ }
+
+ test("ShuffledRDD") {
+ testCheckpointing(rdd => {
+ // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
+ new ShuffledRDD(rdd.map(x => (x % 2, 1)), partitioner)
+ })
+ }
+
+ test("UnionRDD") {
+ def otherRDD = sc.makeRDD(1 to 10, 1)
+
+ // Test whether the size of UnionRDDSplits reduce in size after parent RDD is checkpointed.
+ // Current implementation of UnionRDD has transient reference to parent RDDs,
+ // so only the splits will reduce in serialized size, not the RDD.
+ testCheckpointing(_.union(otherRDD), false, true)
+ testParentCheckpointing(_.union(otherRDD), false, true)
+ }
+
+ test("CartesianRDD") {
+ def otherRDD = sc.makeRDD(1 to 10, 1)
+ testCheckpointing(new CartesianRDD(sc, _, otherRDD))
+
+ // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of CoalescedRDDSplit has transient reference to parent RDD,
+ // so only the RDD will reduce in serialized size, not the splits.
+ testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false)
+
+ // Test that the CartesianRDD updates parent splits (CartesianRDD.s1/s2) after
+ // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
+ // Note that this test is very specific to the current implementation of CartesianRDD.
+ val ones = sc.makeRDD(1 to 100, 10).map(x => x)
+ ones.checkpoint() // checkpoint that MappedRDD
+ val cartesian = new CartesianRDD(sc, ones, ones)
+ val splitBeforeCheckpoint =
+ serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
+ cartesian.count() // do the checkpointing
+ val splitAfterCheckpoint =
+ serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
+ assert(
+ (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) &&
+ (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2),
+ "CartesianRDD.parents not updated after parent RDD checkpointed"
+ )
+ }
+
+ test("CoalescedRDD") {
+ testCheckpointing(new CoalescedRDD(_, 2))
+
+ // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of CoalescedRDDSplit has transient reference to parent RDD,
+ // so only the RDD will reduce in serialized size, not the splits.
+ testParentCheckpointing(new CoalescedRDD(_, 2), true, false)
+
+ // Test that the CoalescedRDDSplit updates parent splits (CoalescedRDDSplit.parents) after
+ // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
+ // Note that this test is very specific to the current implementation of CoalescedRDDSplits
+ val ones = sc.makeRDD(1 to 100, 10).map(x => x)
+ ones.checkpoint() // checkpoint that MappedRDD
+ val coalesced = new CoalescedRDD(ones, 2)
+ val splitBeforeCheckpoint =
+ serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
+ coalesced.count() // do the checkpointing
+ val splitAfterCheckpoint =
+ serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
+ assert(
+ splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head,
+ "CoalescedRDDSplit.parents not updated after parent RDD checkpointed"
+ )
+ }
+
+ test("CoGroupedRDD") {
+ val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD()
+ testCheckpointing(rdd => {
+ CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner)
+ }, false, true)
+
+ val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD()
+ testParentCheckpointing(rdd => {
+ CheckpointSuite.cogroup(
+ longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner)
+ }, false, true)
+ }
+
+ test("ZippedRDD") {
+ testCheckpointing(
+ rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+
+ // Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of ZippedRDDSplit has transient references to parent RDDs,
+ // so only the RDD will reduce in serialized size, not the splits.
+ testParentCheckpointing(
+ rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+ }
+
+ /**
+ * Test checkpointing of the final RDD generated by the given operation. By default,
+ * this method tests whether the size of serialized RDD has reduced after checkpointing or not.
+ * It can also test whether the size of serialized RDD splits has reduced after checkpointing or
+ * not, but this is not done by default as usually the splits do not refer to any RDD and
+ * therefore never store the lineage.
+ */
+ def testCheckpointing[U: ClassManifest](
+ op: (RDD[Int]) => RDD[U],
+ testRDDSize: Boolean = true,
+ testRDDSplitSize: Boolean = false
+ ) {
+ // Generate the final RDD using given RDD operation
+ val baseRDD = generateLongLineageRDD()
+ val operatedRDD = op(baseRDD)
+ val parentRDD = operatedRDD.dependencies.headOption.orNull
+ val rddType = operatedRDD.getClass.getSimpleName
+ val numSplits = operatedRDD.splits.length
+
+ // Find serialized sizes before and after the checkpoint
+ val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+ operatedRDD.checkpoint()
+ val result = operatedRDD.collect()
+ val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+
+ // Test whether the checkpoint file has been created
+ assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result)
+
+ // Test whether dependencies have been changed from its earlier parent RDD
+ assert(operatedRDD.dependencies.head.rdd != parentRDD)
+
+ // Test whether the splits have been changed to the new Hadoop splits
+ assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.getSplits.toList)
+
+ // Test whether the number of splits is same as before
+ assert(operatedRDD.splits.length === numSplits)
+
+ // Test whether the data in the checkpointed RDD is same as original
+ assert(operatedRDD.collect() === result)
+
+ // Test whether serialized size of the RDD has reduced. If the RDD
+ // does not have any dependency to another RDD (e.g., ParallelCollection,
+ // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing.
+ if (testRDDSize) {
+ logInfo("Size of " + rddType +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]")
+ assert(
+ rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+ "Size of " + rddType + " did not reduce after checkpointing " +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+ )
+ }
+
+ // Test whether serialized size of the splits has reduced. If the splits
+ // do not have any non-transient reference to another RDD or another RDD's splits, it
+ // does not refer to a lineage and therefore may not reduce in size after checkpointing.
+ // However, if the original splits before checkpointing do refer to a parent RDD, the splits
+ // must be forgotten after checkpointing (to remove all reference to parent RDDs) and
+ // replaced with the HadoopSplits of the checkpointed RDD.
+ if (testRDDSplitSize) {
+ logInfo("Size of " + rddType + " splits "
+ + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]")
+ assert(
+ splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
+ "Size of " + rddType + " splits did not reduce after checkpointing " +
+ "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
+ )
+ }
+ }
+
+ /**
+ * Test whether checkpointing of the parent of the generated RDD also
+ * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent
+ * RDDs splits. So even if the parent RDD is checkpointed and its splits changed,
+ * this RDD will remember the splits and therefore potentially the whole lineage.
+ */
+ def testParentCheckpointing[U: ClassManifest](
+ op: (RDD[Int]) => RDD[U],
+ testRDDSize: Boolean,
+ testRDDSplitSize: Boolean
+ ) {
+ // Generate the final RDD using given RDD operation
+ val baseRDD = generateLongLineageRDD()
+ val operatedRDD = op(baseRDD)
+ val parentRDD = operatedRDD.dependencies.head.rdd
+ val rddType = operatedRDD.getClass.getSimpleName
+ val parentRDDType = parentRDD.getClass.getSimpleName
+
+ // Get the splits and dependencies of the parent in case they're lazily computed
+ parentRDD.dependencies
+ parentRDD.splits
+
+ // Find serialized sizes before and after the checkpoint
+ val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+ parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one
+ val result = operatedRDD.collect()
+ val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+
+ // Test whether the data in the checkpointed RDD is same as original
+ assert(operatedRDD.collect() === result)
+
+ // Test whether serialized size of the RDD has reduced because of its parent being
+ // checkpointed. If this RDD or its parent RDD do not have any dependency
+ // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may
+ // not reduce in size after checkpointing.
+ if (testRDDSize) {
+ assert(
+ rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+ "Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+ )
+ }
+
+ // Test whether serialized size of the splits has reduced because of its parent being
+ // checkpointed. If the splits do not have any non-transient reference to another RDD
+ // or another RDD's splits, it does not refer to a lineage and therefore may not reduce
+ // in size after checkpointing. However, if the splits do refer to the *splits* of a parent
+ // RDD, then these splits must update reference to the parent RDD splits as the parent RDD's
+ // splits must have changed after checkpointing.
+ if (testRDDSplitSize) {
+ assert(
+ splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
+ "Size of " + rddType + " splits did not reduce after checkpointing parent " + parentRDDType +
+ "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
+ )
+ }
+
+ }
+
+ /**
+ * Generate an RDD with a long lineage of one-to-one dependencies.
+ */
+ def generateLongLineageRDD(): RDD[Int] = {
+ var rdd = sc.makeRDD(1 to 100, 4)
+ for (i <- 1 to 50) {
+ rdd = rdd.map(x => x + 1)
+ }
+ rdd
+ }
+
+ /**
+ * Generate an RDD with a long lineage specifically for CoGroupedRDD.
+ * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage
+ * and narrow dependency with this RDD. This method generate such an RDD by a sequence
+ * of cogroups and mapValues which creates a long lineage of narrow dependencies.
+ */
+ def generateLongLineageRDDForCoGroupedRDD() = {
+ val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _)
+
+ def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
+
+ var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones)
+ for(i <- 1 to 10) {
+ cogrouped = cogrouped.mapValues(add).cogroup(ones)
+ }
+ cogrouped.mapValues(add)
+ }
+
+ /**
+ * Get serialized sizes of the RDD and its splits, in order to test whether the size shrinks
+ * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint.
+ */
+ def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
+ (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length,
+ Utils.serialize(rdd.splits).length)
+ }
+
+ /**
+ * Serialize and deserialize an object. This is useful to verify the objects
+ * contents after deserialization (e.g., the contents of an RDD split after
+ * it is sent to a slave along with a task)
+ */
+ def serializeDeserialize[T](obj: T): T = {
+ val bytes = Utils.serialize(obj)
+ Utils.deserialize[T](bytes)
+ }
+}
+
+
+object CheckpointSuite {
+ // This is a custom cogroup function that does not use mapValues like
+ // the PairRDDFunctions.cogroup()
+ def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = {
+ //println("First = " + first + ", second = " + second)
+ new CoGroupedRDD[K](
+ Seq(first.asInstanceOf[RDD[(_, _)]], second.asInstanceOf[RDD[(_, _)]]),
+ part
+ ).asInstanceOf[RDD[(K, Seq[Seq[V]])]]
+ }
+
+}
diff --git a/core/src/test/scala/spark/ClosureCleanerSuite.scala b/core/src/test/scala/spark/ClosureCleanerSuite.scala
index 7c0334d957..b2d0dd4627 100644
--- a/core/src/test/scala/spark/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/spark/ClosureCleanerSuite.scala
@@ -3,6 +3,7 @@ package spark
import java.io.NotSerializableException
import org.scalatest.FunSuite
+import spark.LocalSparkContext._
import SparkContext._
class ClosureCleanerSuite extends FunSuite {
@@ -43,11 +44,10 @@ object TestObject {
def run(): Int = {
var nonSer = new NonSerializable
var x = 5
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- val answer = nums.map(_ + x).reduce(_ + _)
- sc.stop()
- return answer
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + x).reduce(_ + _)
+ }
}
}
@@ -58,11 +58,10 @@ class TestClass extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- val answer = nums.map(_ + getX).reduce(_ + _)
- sc.stop()
- return answer
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + getX).reduce(_ + _)
+ }
}
}
@@ -71,11 +70,10 @@ class TestClassWithoutDefaultConstructor(x: Int) extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- val answer = nums.map(_ + getX).reduce(_ + _)
- sc.stop()
- return answer
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + getX).reduce(_ + _)
+ }
}
}
@@ -87,11 +85,10 @@ class TestClassWithoutFieldAccess {
def run(): Int = {
var nonSer2 = new NonSerializable
var x = 5
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- val answer = nums.map(_ + x).reduce(_ + _)
- sc.stop()
- return answer
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ nums.map(_ + x).reduce(_ + _)
+ }
}
}
@@ -100,16 +97,16 @@ object TestObjectWithNesting {
def run(): Int = {
var nonSer = new NonSerializable
var answer = 0
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- var y = 1
- for (i <- 1 to 4) {
- var nonSer2 = new NonSerializable
- var x = i
- answer += nums.map(_ + x + y).reduce(_ + _)
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ var y = 1
+ for (i <- 1 to 4) {
+ var nonSer2 = new NonSerializable
+ var x = i
+ answer += nums.map(_ + x + y).reduce(_ + _)
+ }
+ answer
}
- sc.stop()
- return answer
}
}
@@ -119,14 +116,14 @@ class TestClassWithNesting(val y: Int) extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
var answer = 0
- val sc = new SparkContext("local", "test")
- val nums = sc.parallelize(Array(1, 2, 3, 4))
- for (i <- 1 to 4) {
- var nonSer2 = new NonSerializable
- var x = i
- answer += nums.map(_ + x + getY).reduce(_ + _)
+ return withSpark(new SparkContext("local", "test")) { sc =>
+ val nums = sc.parallelize(Array(1, 2, 3, 4))
+ for (i <- 1 to 4) {
+ var nonSer2 = new NonSerializable
+ var x = i
+ answer += nums.map(_ + x + getY).reduce(_ + _)
+ }
+ answer
}
- sc.stop()
- return answer
}
}
diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala
index cacc2796b6..0e2585daa4 100644
--- a/core/src/test/scala/spark/DistributedSuite.scala
+++ b/core/src/test/scala/spark/DistributedSuite.scala
@@ -15,41 +15,28 @@ import scala.collection.mutable.ArrayBuffer
import SparkContext._
import storage.StorageLevel
-class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
+class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext {
val clusterUrl = "local-cluster[2,1,512]"
- @transient var sc: SparkContext = _
-
after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
System.clearProperty("spark.reducer.maxMbInFlight")
System.clearProperty("spark.storage.memoryFraction")
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
}
test("local-cluster format") {
sc = new SparkContext("local-cluster[2,1,512]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
- sc.stop()
- System.clearProperty("spark.master.port")
+ resetSparkContext()
sc = new SparkContext("local-cluster[2 , 1 , 512]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
- sc.stop()
- System.clearProperty("spark.master.port")
+ resetSparkContext()
sc = new SparkContext("local-cluster[2, 1, 512]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
- sc.stop()
- System.clearProperty("spark.master.port")
+ resetSparkContext()
sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
- sc.stop()
- System.clearProperty("spark.master.port")
- sc = null
+ resetSparkContext()
}
test("simple groupByKey") {
@@ -188,4 +175,73 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
val values = sc.parallelize(1 to 2, 2).map(x => System.getenv("TEST_VAR")).collect()
assert(values.toSeq === Seq("TEST_VALUE", "TEST_VALUE"))
}
+
+ test("recover from node failures") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ val data = sc.parallelize(Seq(true, true), 2)
+ assert(data.count === 2) // force executors to start
+ val masterId = SparkEnv.get.blockManager.blockManagerId
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ assert(data.map(failOnMarkedIdentity).collect.size === 2)
+ }
+
+ test("recover from repeated node failures during shuffle-map") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, false), 2)
+ assert(data.count === 2)
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2)
+ }
+ }
+
+ test("recover from repeated node failures during shuffle-reduce") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, true), 2)
+ assert(data.count === 2)
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ // This relies on mergeCombiners being used to perform the actual reduce for this
+ // test to actually be testing what it claims.
+ val grouped = data.map(x => x -> x).combineByKey(
+ x => x,
+ (x: Boolean, y: Boolean) => x,
+ (x: Boolean, y: Boolean) => failOnMarkedIdentity(x)
+ )
+ assert(grouped.collect.size === 1)
+ }
+ }
+}
+
+object DistributedSuite {
+ // Indicates whether this JVM is marked for failure.
+ var mark = false
+
+ // Set by test to remember if we are in the driver program so we can assert
+ // that we are not.
+ var amMaster = false
+
+ // Act like an identity function, but if the argument is true, set mark to true.
+ def markNodeIfIdentity(item: Boolean): Boolean = {
+ if (item) {
+ assert(!amMaster)
+ mark = true
+ }
+ item
+ }
+
+ // Act like an identity function, but if mark was set to true previously, fail,
+ // crashing the entire JVM.
+ def failOnMarkedIdentity(item: Boolean): Boolean = {
+ if (mark) {
+ System.exit(42)
+ }
+ item
+ }
}
diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala
new file mode 100644
index 0000000000..342610e1dd
--- /dev/null
+++ b/core/src/test/scala/spark/DriverSuite.scala
@@ -0,0 +1,32 @@
+package spark
+
+import java.io.File
+
+import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.prop.TableDrivenPropertyChecks._
+import org.scalatest.time.SpanSugar._
+
+class DriverSuite extends FunSuite with Timeouts {
+ test("driver should exit after finishing") {
+ // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing"
+ val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]"))
+ forAll(masters) { (master: String) =>
+ failAfter(10 seconds) {
+ Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master),
+ new File(System.getenv("SPARK_HOME")))
+ }
+ }
+ }
+}
+
+/**
+ * Program that creates a Spark driver but doesn't call SparkContext.stop() or
+ * Sys.exit() after finishing.
+ */
+object DriverWithoutCleanup {
+ def main(args: Array[String]) {
+ val sc = new SparkContext(args(0), "DriverWithoutCleanup")
+ sc.parallelize(1 to 100, 4).count()
+ }
+}
diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala
index a3454f25f6..8c1445a465 100644
--- a/core/src/test/scala/spark/FailureSuite.scala
+++ b/core/src/test/scala/spark/FailureSuite.scala
@@ -1,7 +1,6 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import org.scalatest.prop.Checkers
import scala.collection.mutable.ArrayBuffer
@@ -23,18 +22,7 @@ object FailureSuiteState {
}
}
-class FailureSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class FailureSuite extends FunSuite with LocalSparkContext {
// Run a 3-task map job in which task 1 deterministically fails once, and check
// whether the job completes successfully and we ran 4 tasks in total.
diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala
index b4283d9604..f1a35bced3 100644
--- a/core/src/test/scala/spark/FileServerSuite.scala
+++ b/core/src/test/scala/spark/FileServerSuite.scala
@@ -2,17 +2,16 @@ package spark
import com.google.common.io.Files
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import java.io.{File, PrintWriter, FileReader, BufferedReader}
import SparkContext._
-class FileServerSuite extends FunSuite with BeforeAndAfter {
+class FileServerSuite extends FunSuite with LocalSparkContext {
- @transient var sc: SparkContext = _
- @transient var tmpFile : File = _
- @transient var testJarFile : File = _
+ @transient var tmpFile: File = _
+ @transient var testJarFile: File = _
- before {
+ override def beforeEach() {
+ super.beforeEach()
// Create a sample text file
val tmpdir = new File(Files.createTempDir(), "test")
tmpdir.mkdir()
@@ -22,17 +21,12 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
pw.close()
}
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
+ override def afterEach() {
+ super.afterEach()
// Clean up downloaded file
if (tmpFile.exists) {
tmpFile.delete()
}
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
}
test("Distributing files locally") {
@@ -40,7 +34,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
@@ -54,7 +49,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile((new File(tmpFile.toString)).toURL.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
@@ -83,7 +79,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala
index 554bea53a9..91b48c7456 100644
--- a/core/src/test/scala/spark/FileSuite.scala
+++ b/core/src/test/scala/spark/FileSuite.scala
@@ -6,24 +6,12 @@ import scala.io.Source
import com.google.common.io.Files
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import org.apache.hadoop.io._
import SparkContext._
-class FileSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
-
+class FileSuite extends FunSuite with LocalSparkContext {
+
test("text files") {
sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index c61913fc82..934e4c2f67 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -46,7 +46,7 @@ public class JavaAPISuite implements Serializable {
sc.stop();
sc = null;
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port");
+ System.clearProperty("spark.driver.port");
}
static class ReverseIntComparator implements Comparator<Integer>, Serializable {
@@ -356,6 +356,34 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void mapsFromPairsToPairs() {
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> pairRDD = sc.parallelizePairs(pairs);
+
+ // Regression test for SPARK-668:
+ JavaPairRDD<String, Integer> swapped = pairRDD.flatMap(
+ new PairFlatMapFunction<Tuple2<Integer, String>, String, Integer>() {
+ @Override
+ public Iterable<Tuple2<String, Integer>> call(Tuple2<Integer, String> item) throws Exception {
+ return Collections.singletonList(item.swap());
+ }
+ });
+ swapped.collect();
+
+ // There was never a bug here, but it's worth testing:
+ pairRDD.map(new PairFunction<Tuple2<Integer, String>, String, Integer>() {
+ @Override
+ public Tuple2<String, Integer> call(Tuple2<Integer, String> item) throws Exception {
+ return item.swap();
+ }
+ }).collect();
+ }
+
+ @Test
public void mapPartitions() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
JavaRDD<Integer> partitionSums = rdd.mapPartitions(
@@ -586,7 +614,7 @@ public class JavaAPISuite implements Serializable {
public void accumulators() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
- final Accumulator<Integer> intAccum = sc.accumulator(10);
+ final Accumulator<Integer> intAccum = sc.intAccumulator(10);
rdd.foreach(new VoidFunction<Integer>() {
public void call(Integer x) {
intAccum.add(x);
@@ -594,7 +622,7 @@ public class JavaAPISuite implements Serializable {
});
Assert.assertEquals((Integer) 25, intAccum.value());
- final Accumulator<Double> doubleAccum = sc.accumulator(10.0);
+ final Accumulator<Double> doubleAccum = sc.doubleAccumulator(10.0);
rdd.foreach(new VoidFunction<Integer>() {
public void call(Integer x) {
doubleAccum.add((double) x);
@@ -641,4 +669,31 @@ public class JavaAPISuite implements Serializable {
Assert.assertEquals(new Tuple2<String, Integer>("1", 1), s.get(0));
Assert.assertEquals(new Tuple2<String, Integer>("2", 2), s.get(1));
}
+
+ @Test
+ public void checkpointAndComputation() {
+ File tempDir = Files.createTempDir();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ sc.setCheckpointDir(tempDir.getAbsolutePath(), true);
+ Assert.assertEquals(false, rdd.isCheckpointed());
+ rdd.checkpoint();
+ rdd.count(); // Forces the DAG to cause a checkpoint
+ Assert.assertEquals(true, rdd.isCheckpointed());
+ Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect());
+ }
+
+ @Test
+ public void checkpointAndRestore() {
+ File tempDir = Files.createTempDir();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ sc.setCheckpointDir(tempDir.getAbsolutePath(), true);
+ Assert.assertEquals(false, rdd.isCheckpointed());
+ rdd.checkpoint();
+ rdd.count(); // Forces the DAG to cause a checkpoint
+ Assert.assertEquals(true, rdd.isCheckpointed());
+
+ Assert.assertTrue(rdd.getCheckpointFile().isPresent());
+ JavaRDD<Integer> recovered = sc.checkpointFile(rdd.getCheckpointFile().get());
+ Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect());
+ }
}
diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala
new file mode 100644
index 0000000000..ff00dd05dd
--- /dev/null
+++ b/core/src/test/scala/spark/LocalSparkContext.scala
@@ -0,0 +1,41 @@
+package spark
+
+import org.scalatest.Suite
+import org.scalatest.BeforeAndAfterEach
+
+/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */
+trait LocalSparkContext extends BeforeAndAfterEach { self: Suite =>
+
+ @transient var sc: SparkContext = _
+
+ override def afterEach() {
+ resetSparkContext()
+ super.afterEach()
+ }
+
+ def resetSparkContext() = {
+ if (sc != null) {
+ LocalSparkContext.stop(sc)
+ sc = null
+ }
+ }
+
+}
+
+object LocalSparkContext {
+ def stop(sc: SparkContext) {
+ sc.stop()
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.driver.port")
+ }
+
+ /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */
+ def withSpark[T](sc: SparkContext)(f: SparkContext => T) = {
+ try {
+ f(sc)
+ } finally {
+ stop(sc)
+ }
+ }
+
+} \ No newline at end of file
diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index 5b4b198960..f4e7ec39fe 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -5,8 +5,10 @@ import org.scalatest.FunSuite
import akka.actor._
import spark.scheduler.MapStatus
import spark.storage.BlockManagerId
+import spark.util.AkkaUtils
-class MapOutputTrackerSuite extends FunSuite {
+class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
+
test("compressSize") {
assert(MapOutputTracker.compressSize(0L) === 0)
assert(MapOutputTracker.compressSize(1L) === 1)
@@ -41,13 +43,13 @@ class MapOutputTrackerSuite extends FunSuite {
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
Array(compressedSize1000, compressedSize10000)))
- tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000)))
val statuses = tracker.getServerStatuses(10, 0)
- assert(statuses.toSeq === Seq((new BlockManagerId("hostA", 1000), size1000),
- (new BlockManagerId("hostB", 1000), size10000)))
+ assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000),
+ (BlockManagerId("b", "hostB", 1000), size10000)))
tracker.stop()
}
@@ -59,18 +61,52 @@ class MapOutputTrackerSuite extends FunSuite {
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
- tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
// As if we had two simulatenous fetch failures
- tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
- tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
- // The remaining reduce task might try to grab the output dispite the shuffle failure;
+ // The remaining reduce task might try to grab the output despite the shuffle failure;
// this should cause it to fail, and the scheduler will ignore the failure due to the
// stage already being aborted.
- intercept[Exception] { tracker.getServerStatuses(10, 1) }
+ intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) }
+ }
+
+ test("remote fetch") {
+ try {
+ System.clearProperty("spark.driver.host") // In case some previous test had set it
+ val (actorSystem, boundPort) =
+ AkkaUtils.createActorSystem("test", "localhost", 0)
+ System.setProperty("spark.driver.port", boundPort.toString)
+ val masterTracker = new MapOutputTracker(actorSystem, true)
+ val slaveTracker = new MapOutputTracker(actorSystem, false)
+ masterTracker.registerShuffle(10, 1)
+ masterTracker.incrementGeneration()
+ slaveTracker.updateGeneration(masterTracker.getGeneration)
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ masterTracker.registerMapOutput(10, 0, new MapStatus(
+ BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
+ masterTracker.incrementGeneration()
+ slaveTracker.updateGeneration(masterTracker.getGeneration)
+ assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+ Seq((BlockManagerId("a", "hostA", 1000), size1000)))
+
+ masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
+ masterTracker.incrementGeneration()
+ slaveTracker.updateGeneration(masterTracker.getGeneration)
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+
+ // failure should be cached
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+ } finally {
+ System.clearProperty("spark.driver.port")
+ }
}
}
diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala
index eb3c8f238f..af1107cd19 100644
--- a/core/src/test/scala/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/spark/PartitioningSuite.scala
@@ -1,25 +1,12 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import scala.collection.mutable.ArrayBuffer
import SparkContext._
-class PartitioningSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if(sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
-
+class PartitioningSuite extends FunSuite with LocalSparkContext {
test("HashPartitioner equality") {
val p2 = new HashPartitioner(2)
diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala
index 9b84b29227..a6344edf8f 100644
--- a/core/src/test/scala/spark/PipedRDDSuite.scala
+++ b/core/src/test/scala/spark/PipedRDDSuite.scala
@@ -1,21 +1,9 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import SparkContext._
-class PipedRDDSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class PipedRDDSuite extends FunSuite with LocalSparkContext {
test("basic pipe") {
sc = new SparkContext("local", "test")
@@ -51,5 +39,3 @@ class PipedRDDSuite extends FunSuite with BeforeAndAfter {
}
}
-
-
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index d74e9786c3..89a3687386 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -2,32 +2,20 @@ package spark
import scala.collection.mutable.HashMap
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
+import spark.SparkContext._
+import spark.rdd.{CoalescedRDD, PartitionPruningRDD}
-import spark.rdd.CoalescedRDD
-import SparkContext._
-
-class RDDSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class RDDSuite extends FunSuite with LocalSparkContext {
test("basic operations") {
sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
- assert(dups.distinct.count === 4)
- assert(dups.distinct().collect === dups.distinct.collect)
- assert(dups.distinct(2).collect === dups.distinct.collect)
+ assert(dups.distinct().count() === 4)
+ assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses?
+ assert(dups.distinct().collect === dups.distinct().collect)
+ assert(dups.distinct(2).collect === dups.distinct().collect)
assert(nums.reduce(_ + _) === 10)
assert(nums.fold(0)(_ + _) === 10)
assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4"))
@@ -44,6 +32,10 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
case(split, iter) => Iterator((split, iter.reduceLeft(_ + _)))
}
assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7)))
+
+ intercept[UnsupportedOperationException] {
+ nums.filter(_ > 5).reduce(_ + _)
+ }
}
test("SparkContext.union") {
@@ -76,10 +68,23 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
}
- test("checkpointing") {
+ test("basic checkpointing") {
+ import java.io.File
+ val checkpointDir = File.createTempFile("temp", "")
+ checkpointDir.delete()
+
sc = new SparkContext("local", "test")
- val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).flatMap(x => 1 to x).checkpoint()
- assert(rdd.collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
+ sc.setCheckpointDir(checkpointDir.toString)
+ val parCollection = sc.makeRDD(1 to 4)
+ val flatMappedRDD = parCollection.flatMap(x => 1 to x)
+ flatMappedRDD.checkpoint()
+ assert(flatMappedRDD.dependencies.head.rdd == parCollection)
+ val result = flatMappedRDD.collect()
+ Thread.sleep(1000)
+ assert(flatMappedRDD.dependencies.head.rdd != parCollection)
+ assert(flatMappedRDD.collect() === result)
+
+ checkpointDir.deleteOnExit()
}
test("basic caching") {
@@ -91,12 +96,12 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
}
test("caching with failures") {
- sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test")
val onlySplit = new Split { override def index: Int = 0 }
var shouldFail = true
- val rdd = new RDD[Int](sc) {
- override def splits: Array[Split] = Array(onlySplit)
- override val dependencies = List[Dependency[_]]()
+ val rdd = new RDD[Int](sc, Nil) {
+ override def getSplits: Array[Split] = Array(onlySplit)
+ override val getDependencies = List[Dependency[_]]()
override def compute(split: Split, context: TaskContext): Iterator[Int] = {
if (shouldFail) {
throw new Exception("injected failure")
@@ -123,8 +128,10 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10)))
// Check that the narrow dependency is also specified correctly
- assert(coalesced1.dependencies.head.getParents(0).toList === List(0, 1, 2, 3, 4))
- assert(coalesced1.dependencies.head.getParents(1).toList === List(5, 6, 7, 8, 9))
+ assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList ===
+ List(0, 1, 2, 3, 4))
+ assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList ===
+ List(5, 6, 7, 8, 9))
val coalesced2 = new CoalescedRDD(data, 3)
assert(coalesced2.collect().toList === (1 to 10).toList)
@@ -155,4 +162,15 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
nums.zip(sc.parallelize(1 to 4, 1)).collect()
}
}
+
+ test("partition pruning") {
+ sc = new SparkContext("local", "test")
+ val data = sc.parallelize(1 to 10, 10)
+ // Note that split number starts from 0, so > 8 means only 10th partition left.
+ val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8)
+ assert(prunedRdd.splits.size === 1)
+ val prunedData = prunedRdd.collect()
+ assert(prunedData.size === 1)
+ assert(prunedData(0) === 10)
+ }
}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index bebb8ebe86..3493b9511f 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -3,7 +3,6 @@ package spark
import scala.collection.mutable.ArrayBuffer
import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
@@ -15,18 +14,7 @@ import com.google.common.io.Files
import spark.rdd.ShuffledRDD
import spark.SparkContext._
-class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
test("groupByKey") {
sc = new SparkContext("local", "test")
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala
index 1ad11ff4c3..edb8c839fc 100644
--- a/core/src/test/scala/spark/SortingSuite.scala
+++ b/core/src/test/scala/spark/SortingSuite.scala
@@ -5,18 +5,7 @@ import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.ShouldMatchers
import SparkContext._
-class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with Logging {
-
- var sc: SparkContext = _
-
- after {
- if (sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
+class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers with Logging {
test("sortByKey") {
sc = new SparkContext("local", "test")
diff --git a/core/src/test/scala/spark/ThreadingSuite.scala b/core/src/test/scala/spark/ThreadingSuite.scala
index e9b1837d89..ff315b6693 100644
--- a/core/src/test/scala/spark/ThreadingSuite.scala
+++ b/core/src/test/scala/spark/ThreadingSuite.scala
@@ -22,19 +22,7 @@ object ThreadingSuiteState {
}
}
-class ThreadingSuite extends FunSuite with BeforeAndAfter {
-
- var sc: SparkContext = _
-
- after {
- if(sc != null) {
- sc.stop()
- sc = null
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.master.port")
- }
-
+class ThreadingSuite extends FunSuite with LocalSparkContext {
test("accessing SparkContext form a different thread") {
sc = new SparkContext("local", "test")
diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
new file mode 100644
index 0000000000..a5db7103f5
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
@@ -0,0 +1,32 @@
+package spark.scheduler
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import spark.TaskContext
+import spark.RDD
+import spark.SparkContext
+import spark.Split
+import spark.LocalSparkContext
+
+class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+ test("Calls executeOnCompleteCallbacks after failure") {
+ var completed = false
+ sc = new SparkContext("local", "test")
+ val rdd = new RDD[String](sc, List()) {
+ override def getSplits = Array[Split](StubSplit(0))
+ override def compute(split: Split, context: TaskContext) = {
+ context.addOnCompleteCallback(() => completed = true)
+ sys.error("failed")
+ }
+ }
+ val func = (c: TaskContext, i: Iterator[String]) => i.next
+ val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0)
+ intercept[RuntimeException] {
+ task.run(0)
+ }
+ assert(completed === true)
+ }
+
+ case class StubSplit(val index: Int) extends Split
+} \ No newline at end of file
diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
index 8f86e3170e..2d177bbf67 100644
--- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
@@ -69,33 +69,41 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("StorageLevel object caching") {
- val level1 = new StorageLevel(false, false, false, 3)
- val level2 = new StorageLevel(false, false, false, 3)
+ val level1 = StorageLevel(false, false, false, 3)
+ val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1
+ val level3 = StorageLevel(false, false, false, 2) // this should return a different object
+ assert(level2 === level1, "level2 is not same as level1")
+ assert(level2.eq(level1), "level2 is not the same object as level1")
+ assert(level3 != level1, "level3 is same as level1")
val bytes1 = spark.Utils.serialize(level1)
val level1_ = spark.Utils.deserialize[StorageLevel](bytes1)
val bytes2 = spark.Utils.serialize(level2)
val level2_ = spark.Utils.deserialize[StorageLevel](bytes2)
assert(level1_ === level1, "Deserialized level1 not same as original level1")
- assert(level2_ === level2, "Deserialized level2 not same as original level1")
- assert(level1_ === level2_, "Deserialized level1 not same as deserialized level2")
- assert(level2_.eq(level1_), "Deserialized level2 not the same object as deserialized level1")
+ assert(level1_.eq(level1), "Deserialized level1 not the same object as original level2")
+ assert(level2_ === level2, "Deserialized level2 not same as original level2")
+ assert(level2_.eq(level1), "Deserialized level2 not the same object as original level1")
}
test("BlockManagerId object caching") {
- val id1 = new StorageLevel(false, false, false, 3)
- val id2 = new StorageLevel(false, false, false, 3)
+ val id1 = BlockManagerId("e1", "XXX", 1)
+ val id2 = BlockManagerId("e1", "XXX", 1) // this should return the same object as id1
+ val id3 = BlockManagerId("e1", "XXX", 2) // this should return a different object
+ assert(id2 === id1, "id2 is not same as id1")
+ assert(id2.eq(id1), "id2 is not the same object as id1")
+ assert(id3 != id1, "id3 is same as id1")
val bytes1 = spark.Utils.serialize(id1)
- val id1_ = spark.Utils.deserialize[StorageLevel](bytes1)
+ val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1)
val bytes2 = spark.Utils.serialize(id2)
- val id2_ = spark.Utils.deserialize[StorageLevel](bytes2)
- assert(id1_ === id1, "Deserialized id1 not same as original id1")
- assert(id2_ === id2, "Deserialized id2 not same as original id1")
- assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2")
- assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1")
+ val id2_ = spark.Utils.deserialize[BlockManagerId](bytes2)
+ assert(id1_ === id1, "Deserialized id1 is not same as original id1")
+ assert(id1_.eq(id1), "Deserialized id1 is not the same object as original id1")
+ assert(id2_ === id2, "Deserialized id2 is not same as original id2")
+ assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1")
}
test("master + 1 manager interaction") {
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -125,8 +133,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("master + 2 managers interaction") {
- store = new BlockManager(actorSystem, master, serializer, 2000)
- store2 = new BlockManager(actorSystem, master, new KryoSerializer, 2000)
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000)
+ store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer, 2000)
val peers = master.getPeers(store.blockManagerId, 1)
assert(peers.size === 1, "master did not return the other manager as a peer")
@@ -141,7 +149,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("removing block") {
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -190,7 +198,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("reregistration on heart beat") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
@@ -198,7 +206,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
assert(store.getSingle("a1") != None, "a1 was not in store")
assert(master.getLocations("a1").size > 0, "master was not told about a1")
- master.notifyADeadHost(store.blockManagerId.ip)
+ master.removeExecutor(store.blockManagerId.executorId)
assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
store invokePrivate heartBeat()
@@ -206,25 +214,63 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("reregistration on block update") {
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
-
assert(master.getLocations("a1").size > 0, "master was not told about a1")
- master.notifyADeadHost(store.blockManagerId.ip)
+ master.removeExecutor(store.blockManagerId.executorId)
assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY)
+ store.waitForAsyncReregister()
assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master")
assert(master.getLocations("a2").size > 0, "master was not told about a2")
}
+ test("reregistration doesn't dead lock") {
+ val heartBeat = PrivateMethod[Unit]('heartBeat)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
+ val a1 = new Array[Byte](400)
+ val a2 = List(new Array[Byte](400))
+
+ // try many times to trigger any deadlocks
+ for (i <- 1 to 100) {
+ master.removeExecutor(store.blockManagerId.executorId)
+ val t1 = new Thread {
+ override def run() {
+ store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, true)
+ }
+ }
+ val t2 = new Thread {
+ override def run() {
+ store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
+ }
+ }
+ val t3 = new Thread {
+ override def run() {
+ store invokePrivate heartBeat()
+ }
+ }
+
+ t1.start()
+ t2.start()
+ t3.start()
+ t1.join()
+ t2.join()
+ t3.join()
+
+ store.dropFromMemory("a1", null)
+ store.dropFromMemory("a2", null)
+ store.waitForAsyncReregister()
+ }
+ }
+
test("in-memory LRU storage") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -243,7 +289,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU storage with serialization") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -262,14 +308,14 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of same RDD") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY)
- // Even though we accessed rdd_0_3 last, it should not have replaced partitiosn 1 and 2
+ // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2
// from the same RDD
assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store")
@@ -281,7 +327,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of multiple RDDs") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
@@ -304,7 +350,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("on-disk storage") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -317,7 +363,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -332,7 +378,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with getLocalBytes") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -347,7 +393,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -362,7 +408,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization and getLocalBytes") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -377,7 +423,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -402,7 +448,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU with streams") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@@ -426,7 +472,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels and streams") {
- store = new BlockManager(actorSystem, master, serializer, 1200)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@@ -472,7 +518,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("overly large block") {
- store = new BlockManager(actorSystem, master, serializer, 500)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 500)
store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.getSingle("a1") === None, "a1 was in store")
store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK)
@@ -483,49 +529,49 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("block compression") {
try {
System.setProperty("spark.shuffle.compress", "true")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000)
store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.shuffle.compress", "false")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec2", actorSystem, master, serializer, 2000)
store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed")
store.stop()
store = null
System.setProperty("spark.broadcast.compress", "true")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec3", actorSystem, master, serializer, 2000)
store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.broadcast.compress", "false")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec4", actorSystem, master, serializer, 2000)
store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed")
store.stop()
store = null
System.setProperty("spark.rdd.compress", "true")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec5", actorSystem, master, serializer, 2000)
store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.rdd.compress", "false")
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec6", actorSystem, master, serializer, 2000)
store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed")
store.stop()
store = null
// Check that any other block types are also kept uncompressed
- store = new BlockManager(actorSystem, master, serializer, 2000)
+ store = new BlockManager("exec7", actorSystem, master, serializer, 2000)
store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed")
store.stop()
diff --git a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala
new file mode 100644
index 0000000000..794063fb6d
--- /dev/null
+++ b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala
@@ -0,0 +1,23 @@
+package spark.util
+
+import org.scalatest.FunSuite
+import java.io.ByteArrayOutputStream
+import java.util.concurrent.TimeUnit._
+
+class RateLimitedOutputStreamSuite extends FunSuite {
+
+ private def benchmark[U](f: => U): Long = {
+ val start = System.nanoTime
+ f
+ System.nanoTime - start
+ }
+
+ test("write") {
+ val underlying = new ByteArrayOutputStream
+ val data = "X" * 41000
+ val stream = new RateLimitedOutputStream(underlying, 10000)
+ val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) }
+ assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4)
+ assert(underlying.toString("UTF-8") == data)
+ }
+}