aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/spark/Accumulators.scala41
-rw-r--r--core/src/main/scala/spark/BoundedMemoryCache.scala118
-rw-r--r--core/src/main/scala/spark/CacheManager.scala65
-rw-r--r--core/src/main/scala/spark/CacheTracker.scala238
-rw-r--r--core/src/main/scala/spark/DaemonThreadFactory.scala18
-rw-r--r--core/src/main/scala/spark/HttpFileServer.scala8
-rw-r--r--core/src/main/scala/spark/HttpServer.scala9
-rw-r--r--core/src/main/scala/spark/KryoSerializer.scala210
-rw-r--r--core/src/main/scala/spark/Logging.scala3
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala60
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala86
-rw-r--r--core/src/main/scala/spark/ParallelCollection.scala24
-rw-r--r--core/src/main/scala/spark/Partitioner.scala4
-rw-r--r--core/src/main/scala/spark/RDD.scala197
-rw-r--r--core/src/main/scala/spark/RDDCheckpointData.scala105
-rw-r--r--core/src/main/scala/spark/SequenceFileRDDFunctions.scala8
-rw-r--r--core/src/main/scala/spark/SizeEstimator.scala13
-rw-r--r--core/src/main/scala/spark/SparkContext.scala157
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala39
-rw-r--r--core/src/main/scala/spark/SparkFiles.java25
-rw-r--r--core/src/main/scala/spark/TaskContext.scala3
-rw-r--r--core/src/main/scala/spark/Utils.scala125
-rw-r--r--core/src/main/scala/spark/api/java/JavaPairRDD.scala10
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala33
-rw-r--r--core/src/main/scala/spark/api/java/JavaSparkContext.scala105
-rw-r--r--core/src/main/scala/spark/api/java/StorageLevels.java11
-rw-r--r--core/src/main/scala/spark/api/python/PythonPartitioner.scala48
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala293
-rw-r--r--core/src/main/scala/spark/broadcast/Broadcast.scala2
-rw-r--r--core/src/main/scala/spark/broadcast/HttpBroadcast.scala26
-rw-r--r--core/src/main/scala/spark/deploy/DeployMessage.scala4
-rw-r--r--core/src/main/scala/spark/deploy/JobDescription.scala3
-rw-r--r--core/src/main/scala/spark/deploy/JsonProtocol.scala78
-rw-r--r--core/src/main/scala/spark/deploy/client/TestClient.scala2
-rw-r--r--core/src/main/scala/spark/deploy/master/Master.scala15
-rw-r--r--core/src/main/scala/spark/deploy/master/MasterWebUI.scala58
-rw-r--r--core/src/main/scala/spark/deploy/master/WorkerInfo.scala6
-rw-r--r--core/src/main/scala/spark/deploy/master/WorkerState.scala7
-rw-r--r--core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala5
-rw-r--r--core/src/main/scala/spark/deploy/worker/Worker.scala4
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerArguments.scala22
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala19
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala34
-rw-r--r--core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala3
-rw-r--r--core/src/main/scala/spark/network/Connection.scala7
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala17
-rw-r--r--core/src/main/scala/spark/network/ConnectionManagerTest.scala24
-rw-r--r--core/src/main/scala/spark/rdd/BlockRDD.scala20
-rw-r--r--core/src/main/scala/spark/rdd/CartesianRDD.scala47
-rw-r--r--core/src/main/scala/spark/rdd/CheckpointRDD.scala128
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala70
-rw-r--r--core/src/main/scala/spark/rdd/CoalescedRDD.scala47
-rw-r--r--core/src/main/scala/spark/rdd/FilteredRDD.scala17
-rw-r--r--core/src/main/scala/spark/rdd/FlatMappedRDD.scala10
-rw-r--r--core/src/main/scala/spark/rdd/GlommedRDD.scala14
-rw-r--r--core/src/main/scala/spark/rdd/HadoopRDD.scala15
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsRDD.scala14
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala14
-rw-r--r--core/src/main/scala/spark/rdd/MappedRDD.scala11
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala19
-rw-r--r--core/src/main/scala/spark/rdd/PipedRDD.scala18
-rw-r--r--core/src/main/scala/spark/rdd/SampledRDD.scala29
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala28
-rw-r--r--core/src/main/scala/spark/rdd/UnionRDD.scala45
-rw-r--r--core/src/main/scala/spark/rdd/ZippedRDD.scala60
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala90
-rw-r--r--core/src/main/scala/spark/scheduler/MapStatus.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/ResultTask.scala102
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala24
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala31
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala3
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala44
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala16
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala10
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala231
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerId.scala70
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala727
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMasterActor.scala401
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMessages.scala100
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala16
-rw-r--r--core/src/main/scala/spark/storage/BlockMessage.scala2
-rw-r--r--core/src/main/scala/spark/storage/BlockStore.scala7
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala5
-rw-r--r--core/src/main/scala/spark/storage/MemoryStore.scala6
-rw-r--r--core/src/main/scala/spark/storage/StorageLevel.scala80
-rw-r--r--core/src/main/scala/spark/storage/ThreadingTest.scala13
-rw-r--r--core/src/main/scala/spark/util/AkkaUtils.scala1
-rw-r--r--core/src/main/scala/spark/util/IdGenerator.scala14
-rw-r--r--core/src/main/scala/spark/util/MetadataCleaner.scala44
-rw-r--r--core/src/main/scala/spark/util/RateLimitedOutputStream.scala62
-rw-r--r--core/src/main/scala/spark/util/TimeStampedHashMap.scala93
-rw-r--r--core/src/main/scala/spark/util/TimeStampedHashSet.scala69
-rw-r--r--core/src/main/twirl/spark/deploy/master/worker_row.scala.html1
-rw-r--r--core/src/main/twirl/spark/deploy/master/worker_table.scala.html1
-rw-r--r--core/src/test/resources/log4j.properties4
-rw-r--r--core/src/test/scala/spark/BoundedMemoryCacheSuite.scala58
-rw-r--r--core/src/test/scala/spark/CacheTrackerSuite.scala131
-rw-r--r--core/src/test/scala/spark/CheckpointSuite.scala357
-rw-r--r--core/src/test/scala/spark/ClosureCleanerSuite.scala2
-rw-r--r--core/src/test/scala/spark/DistributedSuite.scala69
-rw-r--r--core/src/test/scala/spark/DriverSuite.scala31
-rw-r--r--core/src/test/scala/spark/FileServerSuite.scala13
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java98
-rw-r--r--core/src/test/scala/spark/MapOutputTrackerSuite.scala56
-rw-r--r--core/src/test/scala/spark/PartitioningSuite.scala26
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala60
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala7
-rw-r--r--core/src/test/scala/spark/SizeEstimatorSuite.scala48
-rw-r--r--core/src/test/scala/spark/scheduler/TaskContextSuite.scala42
-rw-r--r--core/src/test/scala/spark/storage/BlockManagerSuite.scala168
-rw-r--r--core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala23
112 files changed, 4287 insertions, 2175 deletions
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala
index bacd0ace37..57c6df35be 100644
--- a/core/src/main/scala/spark/Accumulators.scala
+++ b/core/src/main/scala/spark/Accumulators.scala
@@ -25,8 +25,7 @@ class Accumulable[R, T] (
extends Serializable {
val id = Accumulators.newId
- @transient
- private var value_ = initialValue // Current value on master
+ @transient private var value_ = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
var deserialized = false
@@ -39,19 +38,36 @@ class Accumulable[R, T] (
def += (term: T) { value_ = param.addAccumulator(value_, term) }
/**
+ * Add more data to this accumulator / accumulable
+ * @param term the data to add
+ */
+ def add(term: T) { value_ = param.addAccumulator(value_, term) }
+
+ /**
* Merge two accumulable objects together
- *
+ *
* Normally, a user will not want to use this version, but will instead call `+=`.
- * @param term the other Accumulable that will get merged with this
+ * @param term the other `R` that will get merged with this
*/
def ++= (term: R) { value_ = param.addInPlace(value_, term)}
/**
+ * Merge two accumulable objects together
+ *
+ * Normally, a user will not want to use this version, but will instead call `add`.
+ * @param term the other `R` that will get merged with this
+ */
+ def merge(term: R) { value_ = param.addInPlace(value_, term)}
+
+ /**
* Access the accumulator's current value; only allowed on master.
*/
- def value = {
- if (!deserialized) value_
- else throw new UnsupportedOperationException("Can't read accumulator value in task")
+ def value: R = {
+ if (!deserialized) {
+ value_
+ } else {
+ throw new UnsupportedOperationException("Can't read accumulator value in task")
+ }
}
/**
@@ -68,10 +84,17 @@ class Accumulable[R, T] (
/**
* Set the accumulator's value; only allowed on master.
*/
- def value_= (r: R) {
- if (!deserialized) value_ = r
+ def value_= (newValue: R) {
+ if (!deserialized) value_ = newValue
else throw new UnsupportedOperationException("Can't assign accumulator value in task")
}
+
+ /**
+ * Set the accumulator's value; only allowed on master
+ */
+ def setValue(newValue: R) {
+ this.value = newValue
+ }
// Called by Java when deserializing an object
private def readObject(in: ObjectInputStream) {
diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala
deleted file mode 100644
index e8392a194f..0000000000
--- a/core/src/main/scala/spark/BoundedMemoryCache.scala
+++ /dev/null
@@ -1,118 +0,0 @@
-package spark
-
-import java.util.LinkedHashMap
-
-/**
- * An implementation of Cache that estimates the sizes of its entries and attempts to limit its
- * total memory usage to a fraction of the JVM heap. Objects' sizes are estimated using
- * SizeEstimator, which has limitations; most notably, we will overestimate total memory used if
- * some cache entries have pointers to a shared object. Nonetheless, this Cache should work well
- * when most of the space is used by arrays of primitives or of simple classes.
- */
-private[spark] class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
- logInfo("BoundedMemoryCache.maxBytes = " + maxBytes)
-
- def this() {
- this(BoundedMemoryCache.getMaxBytes)
- }
-
- private var currentBytes = 0L
- private val map = new LinkedHashMap[(Any, Int), Entry](32, 0.75f, true)
-
- override def get(datasetId: Any, partition: Int): Any = {
- synchronized {
- val entry = map.get((datasetId, partition))
- if (entry != null) {
- entry.value
- } else {
- null
- }
- }
- }
-
- override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
- val key = (datasetId, partition)
- logInfo("Asked to add key " + key)
- val size = estimateValueSize(key, value)
- synchronized {
- if (size > getCapacity) {
- return CachePutFailure()
- } else if (ensureFreeSpace(datasetId, size)) {
- logInfo("Adding key " + key)
- map.put(key, new Entry(value, size))
- currentBytes += size
- logInfo("Number of entries is now " + map.size)
- return CachePutSuccess(size)
- } else {
- logInfo("Didn't add key " + key + " because we would have evicted part of same dataset")
- return CachePutFailure()
- }
- }
- }
-
- override def getCapacity: Long = maxBytes
-
- /**
- * Estimate sizeOf 'value'
- */
- private def estimateValueSize(key: (Any, Int), value: Any) = {
- val startTime = System.currentTimeMillis
- val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef])
- val timeTaken = System.currentTimeMillis - startTime
- logInfo("Estimated size for key %s is %d".format(key, size))
- logInfo("Size estimation for key %s took %d ms".format(key, timeTaken))
- size
- }
-
- /**
- * Remove least recently used entries from the map until at least space bytes are free, in order
- * to make space for a partition from the given dataset ID. If this cannot be done without
- * evicting other data from the same dataset, returns false; otherwise, returns true. Assumes
- * that a lock is held on the BoundedMemoryCache.
- */
- private def ensureFreeSpace(datasetId: Any, space: Long): Boolean = {
- logInfo("ensureFreeSpace(%s, %d) called with curBytes=%d, maxBytes=%d".format(
- datasetId, space, currentBytes, maxBytes))
- val iter = map.entrySet.iterator // Will give entries in LRU order
- while (maxBytes - currentBytes < space && iter.hasNext) {
- val mapEntry = iter.next()
- val (entryDatasetId, entryPartition) = mapEntry.getKey
- if (entryDatasetId == datasetId) {
- // Cannot make space without removing part of the same dataset, or a more recently used one
- return false
- }
- reportEntryDropped(entryDatasetId, entryPartition, mapEntry.getValue)
- currentBytes -= mapEntry.getValue.size
- iter.remove()
- }
- return true
- }
-
- protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
- logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
- // TODO: remove BoundedMemoryCache
-
- val (keySpaceId, innerDatasetId) = datasetId.asInstanceOf[(Any, Any)]
- innerDatasetId match {
- case rddId: Int =>
- SparkEnv.get.cacheTracker.dropEntry(rddId, partition)
- case broadcastUUID: java.util.UUID =>
- // TODO: Maybe something should be done if the broadcasted variable falls out of cache
- case _ =>
- }
- }
-}
-
-// An entry in our map; stores a cached object and its size in bytes
-private[spark] case class Entry(value: Any, size: Long)
-
-private[spark] object BoundedMemoryCache {
- /**
- * Get maximum cache capacity from system configuration
- */
- def getMaxBytes: Long = {
- val memoryFractionToUse = System.getProperty("spark.boundedMemoryCache.memoryFraction", "0.66").toDouble
- (Runtime.getRuntime.maxMemory * memoryFractionToUse).toLong
- }
-}
-
diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala
new file mode 100644
index 0000000000..a0b53fd9d6
--- /dev/null
+++ b/core/src/main/scala/spark/CacheManager.scala
@@ -0,0 +1,65 @@
+package spark
+
+import scala.collection.mutable.{ArrayBuffer, HashSet}
+import spark.storage.{BlockManager, StorageLevel}
+
+
+/** Spark class responsible for passing RDDs split contents to the BlockManager and making
+ sure a node doesn't load two copies of an RDD at once.
+ */
+private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
+ private val loading = new HashSet[String]
+
+ /** Gets or computes an RDD split. Used by RDD.iterator() when a RDD is cached. */
+ def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
+ : Iterator[T] = {
+ val key = "rdd_%d_%d".format(rdd.id, split.index)
+ logInfo("Cache key is " + key)
+ blockManager.get(key) match {
+ case Some(cachedValues) =>
+ // Split is in cache, so just return its values
+ logInfo("Found partition in cache!")
+ return cachedValues.asInstanceOf[Iterator[T]]
+
+ case None =>
+ // Mark the split as loading (unless someone else marks it first)
+ loading.synchronized {
+ if (loading.contains(key)) {
+ logInfo("Loading contains " + key + ", waiting...")
+ while (loading.contains(key)) {
+ try {loading.wait()} catch {case _ =>}
+ }
+ logInfo("Loading no longer contains " + key + ", so returning cached result")
+ // See whether someone else has successfully loaded it. The main way this would fail
+ // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
+ // partition but we didn't want to make space for it. However, that case is unlikely
+ // because it's unlikely that two threads would work on the same RDD partition. One
+ // downside of the current code is that threads wait serially if this does happen.
+ blockManager.get(key) match {
+ case Some(values) =>
+ return values.asInstanceOf[Iterator[T]]
+ case None =>
+ logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
+ loading.add(key)
+ }
+ } else {
+ loading.add(key)
+ }
+ }
+ try {
+ // If we got here, we have to load the split
+ val elements = new ArrayBuffer[Any]
+ logInfo("Computing partition " + split)
+ elements ++= rdd.compute(split, context)
+ // Try to put this block in the blockManager
+ blockManager.put(key, elements, storageLevel, true)
+ return elements.iterator.asInstanceOf[Iterator[T]]
+ } finally {
+ loading.synchronized {
+ loading.remove(key)
+ loading.notifyAll()
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
deleted file mode 100644
index 3d79078733..0000000000
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ /dev/null
@@ -1,238 +0,0 @@
-package spark
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-
-import akka.actor._
-import akka.dispatch._
-import akka.pattern.ask
-import akka.remote._
-import akka.util.Duration
-import akka.util.Timeout
-import akka.util.duration._
-
-import spark.storage.BlockManager
-import spark.storage.StorageLevel
-
-private[spark] sealed trait CacheTrackerMessage
-
-private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
- extends CacheTrackerMessage
-private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
- extends CacheTrackerMessage
-private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage
-private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
-private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
-private[spark] case object GetCacheStatus extends CacheTrackerMessage
-private[spark] case object GetCacheLocations extends CacheTrackerMessage
-private[spark] case object StopCacheTracker extends CacheTrackerMessage
-
-private[spark] class CacheTrackerActor extends Actor with Logging {
- // TODO: Should probably store (String, CacheType) tuples
- private val locs = new HashMap[Int, Array[List[String]]]
-
- /**
- * A map from the slave's host name to its cache size.
- */
- private val slaveCapacity = new HashMap[String, Long]
- private val slaveUsage = new HashMap[String, Long]
-
- private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
- private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
- private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
-
- def receive = {
- case SlaveCacheStarted(host: String, size: Long) =>
- slaveCapacity.put(host, size)
- slaveUsage.put(host, 0)
- sender ! true
-
- case RegisterRDD(rddId: Int, numPartitions: Int) =>
- logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
- locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
- sender ! true
-
- case AddedToCache(rddId, partition, host, size) =>
- slaveUsage.put(host, getCacheUsage(host) + size)
- locs(rddId)(partition) = host :: locs(rddId)(partition)
- sender ! true
-
- case DroppedFromCache(rddId, partition, host, size) =>
- slaveUsage.put(host, getCacheUsage(host) - size)
- // Do a sanity check to make sure usage is greater than 0.
- locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
- sender ! true
-
- case MemoryCacheLost(host) =>
- logInfo("Memory cache lost on " + host)
- for ((id, locations) <- locs) {
- for (i <- 0 until locations.length) {
- locations(i) = locations(i).filterNot(_ == host)
- }
- }
- sender ! true
-
- case GetCacheLocations =>
- logInfo("Asked for current cache locations")
- sender ! locs.map{case (rrdId, array) => (rrdId -> array.clone())}
-
- case GetCacheStatus =>
- val status = slaveCapacity.map { case (host, capacity) =>
- (host, capacity, getCacheUsage(host))
- }.toSeq
- sender ! status
-
- case StopCacheTracker =>
- logInfo("Stopping CacheTrackerActor")
- sender ! true
- context.stop(self)
- }
-}
-
-private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
- extends Logging {
-
- // Tracker actor on the master, or remote reference to it on workers
- val ip: String = System.getProperty("spark.master.host", "localhost")
- val port: Int = System.getProperty("spark.master.port", "7077").toInt
- val actorName: String = "CacheTracker"
-
- val timeout = 10.seconds
-
- var trackerActor: ActorRef = if (isMaster) {
- val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName)
- logInfo("Registered CacheTrackerActor actor")
- actor
- } else {
- val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
- actorSystem.actorFor(url)
- }
-
- val registeredRddIds = new HashSet[Int]
-
- // Remembers which splits are currently being loaded (on worker nodes)
- val loading = new HashSet[String]
-
- // Send a message to the trackerActor and get its result within a default timeout, or
- // throw a SparkException if this fails.
- def askTracker(message: Any): Any = {
- try {
- val future = trackerActor.ask(message)(timeout)
- return Await.result(future, timeout)
- } catch {
- case e: Exception =>
- throw new SparkException("Error communicating with CacheTracker", e)
- }
- }
-
- // Send a one-way message to the trackerActor, to which we expect it to reply with true.
- def communicate(message: Any) {
- if (askTracker(message) != true) {
- throw new SparkException("Error reply received from CacheTracker")
- }
- }
-
- // Registers an RDD (on master only)
- def registerRDD(rddId: Int, numPartitions: Int) {
- registeredRddIds.synchronized {
- if (!registeredRddIds.contains(rddId)) {
- logInfo("Registering RDD ID " + rddId + " with cache")
- registeredRddIds += rddId
- communicate(RegisterRDD(rddId, numPartitions))
- }
- }
- }
-
- // For BlockManager.scala only
- def cacheLost(host: String) {
- communicate(MemoryCacheLost(host))
- logInfo("CacheTracker successfully removed entries on " + host)
- }
-
- // Get the usage status of slave caches. Each tuple in the returned sequence
- // is in the form of (host name, capacity, usage).
- def getCacheStatus(): Seq[(String, Long, Long)] = {
- askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]]
- }
-
- // For BlockManager.scala only
- def notifyFromBlockManager(t: AddedToCache) {
- communicate(t)
- }
-
- // Get a snapshot of the currently known locations
- def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
- askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
- }
-
- // Gets or computes an RDD split
- def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
- : Iterator[T] = {
- val key = "rdd_%d_%d".format(rdd.id, split.index)
- logInfo("Cache key is " + key)
- blockManager.get(key) match {
- case Some(cachedValues) =>
- // Split is in cache, so just return its values
- logInfo("Found partition in cache!")
- return cachedValues.asInstanceOf[Iterator[T]]
-
- case None =>
- // Mark the split as loading (unless someone else marks it first)
- loading.synchronized {
- if (loading.contains(key)) {
- logInfo("Loading contains " + key + ", waiting...")
- while (loading.contains(key)) {
- try {loading.wait()} catch {case _ =>}
- }
- logInfo("Loading no longer contains " + key + ", so returning cached result")
- // See whether someone else has successfully loaded it. The main way this would fail
- // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
- // partition but we didn't want to make space for it. However, that case is unlikely
- // because it's unlikely that two threads would work on the same RDD partition. One
- // downside of the current code is that threads wait serially if this does happen.
- blockManager.get(key) match {
- case Some(values) =>
- return values.asInstanceOf[Iterator[T]]
- case None =>
- logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
- loading.add(key)
- }
- } else {
- loading.add(key)
- }
- }
- // If we got here, we have to load the split
- // Tell the master that we're doing so
- //val host = System.getProperty("spark.hostname", Utils.localHostName)
- //val future = trackerActor !! AddedToCache(rdd.id, split.index, host)
- // TODO: fetch any remote copy of the split that may be available
- // TODO: also register a listener for when it unloads
- logInfo("Computing partition " + split)
- val elements = new ArrayBuffer[Any]
- elements ++= rdd.compute(split, context)
- try {
- // Try to put this block in the blockManager
- blockManager.put(key, elements, storageLevel, true)
- //future.apply() // Wait for the reply from the cache tracker
- } finally {
- loading.synchronized {
- loading.remove(key)
- loading.notifyAll()
- }
- }
- return elements.iterator.asInstanceOf[Iterator[T]]
- }
- }
-
- // Called by the Cache to report that an entry has been dropped from it
- def dropEntry(rddId: Int, partition: Int) {
- communicate(DroppedFromCache(rddId, partition, Utils.localHostName()))
- }
-
- def stop() {
- communicate(StopCacheTracker)
- registeredRddIds.clear()
- trackerActor = null
- }
-}
diff --git a/core/src/main/scala/spark/DaemonThreadFactory.scala b/core/src/main/scala/spark/DaemonThreadFactory.scala
deleted file mode 100644
index 56e59adeb7..0000000000
--- a/core/src/main/scala/spark/DaemonThreadFactory.scala
+++ /dev/null
@@ -1,18 +0,0 @@
-package spark
-
-import java.util.concurrent.ThreadFactory
-
-/**
- * A ThreadFactory that creates daemon threads
- */
-private object DaemonThreadFactory extends ThreadFactory {
- override def newThread(r: Runnable): Thread = new DaemonThread(r)
-}
-
-private class DaemonThread(r: Runnable = null) extends Thread {
- override def run() {
- if (r != null) {
- r.run()
- }
- }
-} \ No newline at end of file
diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala
index 659d17718f..00901d95e2 100644
--- a/core/src/main/scala/spark/HttpFileServer.scala
+++ b/core/src/main/scala/spark/HttpFileServer.scala
@@ -1,9 +1,7 @@
package spark
-import java.io.{File, PrintWriter}
-import java.net.URL
-import scala.collection.mutable.HashMap
-import org.apache.hadoop.fs.FileUtil
+import java.io.{File}
+import com.google.common.io.Files
private[spark] class HttpFileServer extends Logging {
@@ -40,7 +38,7 @@ private[spark] class HttpFileServer extends Logging {
}
def addFileToDir(file: File, dir: File) : String = {
- Utils.copyFile(file, new File(dir, file.getName))
+ Files.copy(file, new File(dir, file.getName))
return dir + "/" + file.getName
}
diff --git a/core/src/main/scala/spark/HttpServer.scala b/core/src/main/scala/spark/HttpServer.scala
index 0196595ba1..4e0507c080 100644
--- a/core/src/main/scala/spark/HttpServer.scala
+++ b/core/src/main/scala/spark/HttpServer.scala
@@ -4,6 +4,7 @@ import java.io.File
import java.net.InetAddress
import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.server.bio.SocketConnector
import org.eclipse.jetty.server.handler.DefaultHandler
import org.eclipse.jetty.server.handler.HandlerList
import org.eclipse.jetty.server.handler.ResourceHandler
@@ -27,7 +28,13 @@ private[spark] class HttpServer(resourceBase: File) extends Logging {
if (server != null) {
throw new ServerStateException("Server is already started")
} else {
- server = new Server(0)
+ server = new Server()
+ val connector = new SocketConnector
+ connector.setMaxIdleTime(60*1000)
+ connector.setSoLingerTime(-1)
+ connector.setPort(0)
+ server.addConnector(connector)
+
val threadPool = new QueuedThreadPool
threadPool.setDaemon(true)
server.setThreadPool(threadPool)
diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala
index 44b630e478..0bd73e936b 100644
--- a/core/src/main/scala/spark/KryoSerializer.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -9,153 +9,80 @@ import scala.collection.mutable
import com.esotericsoftware.kryo._
import com.esotericsoftware.kryo.{Serializer => KSerializer}
-import com.esotericsoftware.kryo.serialize.ClassSerializer
-import com.esotericsoftware.kryo.serialize.SerializableSerializer
+import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
+import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer}
import de.javakaffee.kryoserializers.KryoReflectionFactorySupport
import serializer.{SerializerInstance, DeserializationStream, SerializationStream}
import spark.broadcast._
import spark.storage._
-/**
- * Zig-zag encoder used to write object sizes to serialization streams.
- * Based on Kryo's integer encoder.
- */
-private[spark] object ZigZag {
- def writeInt(n: Int, out: OutputStream) {
- var value = n
- if ((value & ~0x7F) == 0) {
- out.write(value)
- return
- }
- out.write(((value & 0x7F) | 0x80))
- value >>>= 7
- if ((value & ~0x7F) == 0) {
- out.write(value)
- return
- }
- out.write(((value & 0x7F) | 0x80))
- value >>>= 7
- if ((value & ~0x7F) == 0) {
- out.write(value)
- return
- }
- out.write(((value & 0x7F) | 0x80))
- value >>>= 7
- if ((value & ~0x7F) == 0) {
- out.write(value)
- return
- }
- out.write(((value & 0x7F) | 0x80))
- value >>>= 7
- out.write(value)
- }
+private[spark]
+class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream {
- def readInt(in: InputStream): Int = {
- var offset = 0
- var result = 0
- while (offset < 32) {
- val b = in.read()
- if (b == -1) {
- throw new EOFException("End of stream")
- }
- result |= ((b & 0x7F) << offset)
- if ((b & 0x80) == 0) {
- return result
- }
- offset += 7
- }
- throw new SparkException("Malformed zigzag-encoded integer")
- }
-}
-
-private[spark]
-class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream)
-extends SerializationStream {
- val channel = Channels.newChannel(out)
+ val output = new KryoOutput(outStream)
def writeObject[T](t: T): SerializationStream = {
- kryo.writeClassAndObject(threadBuffer, t)
- ZigZag.writeInt(threadBuffer.position(), out)
- threadBuffer.flip()
- channel.write(threadBuffer)
- threadBuffer.clear()
+ kryo.writeClassAndObject(output, t)
this
}
- def flush() { out.flush() }
- def close() { out.close() }
+ def flush() { output.flush() }
+ def close() { output.close() }
}
-private[spark]
-class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream)
-extends DeserializationStream {
+private[spark]
+class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream {
+
+ val input = new KryoInput(inStream)
+
def readObject[T](): T = {
- val len = ZigZag.readInt(in)
- objectBuffer.readClassAndObject(in, len).asInstanceOf[T]
+ try {
+ kryo.readClassAndObject(input).asInstanceOf[T]
+ } catch {
+ // DeserializationStream uses the EOF exception to indicate stopping condition.
+ case e: com.esotericsoftware.kryo.KryoException => throw new java.io.EOFException
+ }
}
- def close() { in.close() }
+ def close() {
+ // Kryo's Input automatically closes the input stream it is using.
+ input.close()
+ }
}
private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
- val kryo = ks.kryo
- val threadBuffer = ks.threadBuffer.get()
- val objectBuffer = ks.objectBuffer.get()
+
+ val kryo = ks.kryo.get()
+ val output = ks.output.get()
+ val input = ks.input.get()
def serialize[T](t: T): ByteBuffer = {
- // Write it to our thread-local scratch buffer first to figure out the size, then return a new
- // ByteBuffer of the appropriate size
- threadBuffer.clear()
- kryo.writeClassAndObject(threadBuffer, t)
- val newBuf = ByteBuffer.allocate(threadBuffer.position)
- threadBuffer.flip()
- newBuf.put(threadBuffer)
- newBuf.flip()
- newBuf
+ output.clear()
+ kryo.writeClassAndObject(output, t)
+ ByteBuffer.wrap(output.toBytes)
}
def deserialize[T](bytes: ByteBuffer): T = {
- kryo.readClassAndObject(bytes).asInstanceOf[T]
+ input.setBuffer(bytes.array)
+ kryo.readClassAndObject(input).asInstanceOf[T]
}
def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
val oldClassLoader = kryo.getClassLoader
kryo.setClassLoader(loader)
- val obj = kryo.readClassAndObject(bytes).asInstanceOf[T]
+ input.setBuffer(bytes.array)
+ val obj = kryo.readClassAndObject(input).asInstanceOf[T]
kryo.setClassLoader(oldClassLoader)
obj
}
def serializeStream(s: OutputStream): SerializationStream = {
- threadBuffer.clear()
- new KryoSerializationStream(kryo, threadBuffer, s)
+ new KryoSerializationStream(kryo, s)
}
def deserializeStream(s: InputStream): DeserializationStream = {
- new KryoDeserializationStream(objectBuffer, s)
- }
-
- override def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
- threadBuffer.clear()
- while (iterator.hasNext) {
- val element = iterator.next()
- // TODO: Do we also want to write the object's size? Doesn't seem necessary.
- kryo.writeClassAndObject(threadBuffer, element)
- }
- val newBuf = ByteBuffer.allocate(threadBuffer.position)
- threadBuffer.flip()
- newBuf.put(threadBuffer)
- newBuf.flip()
- newBuf
- }
-
- override def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
- buffer.rewind()
- new Iterator[Any] {
- override def hasNext: Boolean = buffer.remaining > 0
- override def next(): Any = kryo.readClassAndObject(buffer)
- }
+ new KryoDeserializationStream(kryo, s)
}
}
@@ -171,18 +98,19 @@ trait KryoRegistrator {
* A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
*/
class KryoSerializer extends spark.serializer.Serializer with Logging {
- // Make this lazy so that it only gets called once we receive our first task on each executor,
- // so we can pull out any custom Kryo registrator from the user's JARs.
- lazy val kryo = createKryo()
- val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024
+ val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
- val objectBuffer = new ThreadLocal[ObjectBuffer] {
- override def initialValue = new ObjectBuffer(kryo, bufferSize)
+ val kryo = new ThreadLocal[Kryo] {
+ override def initialValue = createKryo()
}
- val threadBuffer = new ThreadLocal[ByteBuffer] {
- override def initialValue = ByteBuffer.allocate(bufferSize)
+ val output = new ThreadLocal[KryoOutput] {
+ override def initialValue = new KryoOutput(bufferSize)
+ }
+
+ val input = new ThreadLocal[KryoInput] {
+ override def initialValue = new KryoInput(bufferSize)
}
def createKryo(): Kryo = {
@@ -213,41 +141,44 @@ class KryoSerializer extends spark.serializer.Serializer with Logging {
kryo.register(obj.getClass)
}
- // Register the following classes for passing closures.
- kryo.register(classOf[Class[_]], new ClassSerializer(kryo))
- kryo.setRegistrationOptional(true)
-
// Allow sending SerializableWritable
- kryo.register(classOf[SerializableWritable[_]], new SerializableSerializer())
- kryo.register(classOf[HttpBroadcast[_]], new SerializableSerializer())
+ kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
+ kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())
// Register some commonly used Scala singleton objects. Because these
// are singletons, we must return the exact same local object when we
// deserialize rather than returning a clone as FieldSerializer would.
- class SingletonSerializer(obj: AnyRef) extends KSerializer {
- override def writeObjectData(buf: ByteBuffer, obj: AnyRef) {}
- override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = obj.asInstanceOf[T]
+ class SingletonSerializer[T](obj: T) extends KSerializer[T] {
+ override def write(kryo: Kryo, output: KryoOutput, obj: T) {}
+ override def read(kryo: Kryo, input: KryoInput, cls: java.lang.Class[T]): T = obj
}
- kryo.register(None.getClass, new SingletonSerializer(None))
- kryo.register(Nil.getClass, new SingletonSerializer(Nil))
+ kryo.register(None.getClass, new SingletonSerializer[AnyRef](None))
+ kryo.register(Nil.getClass, new SingletonSerializer[AnyRef](Nil))
// Register maps with a special serializer since they have complex internal structure
class ScalaMapSerializer(buildMap: Array[(Any, Any)] => scala.collection.Map[Any, Any])
- extends KSerializer {
- override def writeObjectData(buf: ByteBuffer, obj: AnyRef) {
+ extends KSerializer[Array[(Any, Any)] => scala.collection.Map[Any, Any]] {
+ override def write(
+ kryo: Kryo,
+ output: KryoOutput,
+ obj: Array[(Any, Any)] => scala.collection.Map[Any, Any]) {
val map = obj.asInstanceOf[scala.collection.Map[Any, Any]]
- kryo.writeObject(buf, map.size.asInstanceOf[java.lang.Integer])
+ kryo.writeObject(output, map.size.asInstanceOf[java.lang.Integer])
for ((k, v) <- map) {
- kryo.writeClassAndObject(buf, k)
- kryo.writeClassAndObject(buf, v)
+ kryo.writeClassAndObject(output, k)
+ kryo.writeClassAndObject(output, v)
}
}
- override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = {
- val size = kryo.readObject(buf, classOf[java.lang.Integer]).intValue
+ override def read (
+ kryo: Kryo,
+ input: KryoInput,
+ cls: Class[Array[(Any, Any)] => scala.collection.Map[Any, Any]])
+ : Array[(Any, Any)] => scala.collection.Map[Any, Any] = {
+ val size = kryo.readObject(input, classOf[java.lang.Integer]).intValue
val elems = new Array[(Any, Any)](size)
for (i <- 0 until size)
- elems(i) = (kryo.readClassAndObject(buf), kryo.readClassAndObject(buf))
- buildMap(elems).asInstanceOf[T]
+ elems(i) = (kryo.readClassAndObject(input), kryo.readClassAndObject(input))
+ buildMap(elems).asInstanceOf[Array[(Any, Any)] => scala.collection.Map[Any, Any]]
}
}
kryo.register(mutable.HashMap().getClass, new ScalaMapSerializer(mutable.HashMap() ++ _))
@@ -275,5 +206,8 @@ class KryoSerializer extends spark.serializer.Serializer with Logging {
kryo
}
- def newInstance(): SerializerInstance = new KryoSerializerInstance(this)
+ def newInstance(): SerializerInstance = {
+ this.kryo.get().setClassLoader(Thread.currentThread().getContextClassLoader)
+ new KryoSerializerInstance(this)
+ }
}
diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala
index 90bae26202..7c1c1bb144 100644
--- a/core/src/main/scala/spark/Logging.scala
+++ b/core/src/main/scala/spark/Logging.scala
@@ -11,8 +11,7 @@ import org.slf4j.LoggerFactory
trait Logging {
// Make the log field transient so that objects with Logging can
// be serialized and used on another machine
- @transient
- private var log_ : Logger = null
+ @transient private var log_ : Logger = null
// Method to get or create the logger for this object
protected def log: Logger = {
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 70eb9f702e..ac02f3363a 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -17,6 +17,7 @@ import akka.util.duration._
import spark.scheduler.MapStatus
import spark.storage.BlockManagerId
+import spark.util.{MetadataCleaner, TimeStampedHashMap}
private[spark] sealed trait MapOutputTrackerMessage
@@ -44,7 +45,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
val timeout = 10.seconds
- var mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
+ var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
@@ -53,7 +54,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheGeneration = generation
- val cachedSerializedStatuses = new HashMap[Int, Array[Byte]]
+ val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
var trackerActor: ActorRef = if (isMaster) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
@@ -64,6 +65,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
actorSystem.actorFor(url)
}
+ val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
+
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
@@ -84,14 +87,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (mapStatuses.get(shuffleId) != null) {
+ if (mapStatuses.get(shuffleId) != None) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
- var array = mapStatuses.get(shuffleId)
+ var array = mapStatuses(shuffleId)
array.synchronized {
array(mapId) = status
}
@@ -108,7 +111,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var array = mapStatuses.get(shuffleId)
+ var array = mapStatuses(shuffleId)
if (array != null) {
array.synchronized {
if (array(mapId) != null && array(mapId).address == bmAddress) {
@@ -126,7 +129,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
- val statuses = mapStatuses.get(shuffleId)
+ val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
fetching.synchronized {
@@ -139,8 +142,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
case e: InterruptedException =>
}
}
- return mapStatuses.get(shuffleId).map(status =>
- (status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId))))
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId))
} else {
fetching += shuffleId
}
@@ -156,27 +158,27 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
fetchedStatuses = deserializeStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
- if (fetchedStatuses.contains(null)) {
- throw new FetchFailedException(null, shuffleId, -1, reduceId,
- new Exception("Missing an output location for shuffle " + shuffleId))
- }
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
- return fetchedStatuses.map(s =>
- (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
} else {
- return statuses.map(s =>
- (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
}
}
+ def cleanup(cleanupTime: Long) {
+ mapStatuses.clearOldValues(cleanupTime)
+ cachedSerializedStatuses.clearOldValues(cleanupTime)
+ }
+
def stop() {
communicate(StopMapOutputTracker)
mapStatuses.clear()
+ metadataCleaner.cancel()
trackerActor = null
}
@@ -202,7 +204,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
- mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
+ mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
generation = newGen
}
}
@@ -220,7 +222,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
case Some(bytes) =>
return bytes
case None =>
- statuses = mapStatuses.get(shuffleId)
+ statuses = mapStatuses(shuffleId)
generationGotten = generation
}
}
@@ -258,6 +260,28 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
private[spark] object MapOutputTracker {
private val LOG_BASE = 1.1
+ // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
+ // any of the statuses is null (indicating a missing location due to a failed mapper),
+ // throw a FetchFailedException.
+ def convertMapStatuses(
+ shuffleId: Int,
+ reduceId: Int,
+ statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
+ if (statuses == null) {
+ throw new FetchFailedException(null, shuffleId, -1, reduceId,
+ new Exception("Missing all output locations for shuffle " + shuffleId))
+ }
+ statuses.map {
+ status =>
+ if (status == null) {
+ throw new FetchFailedException(null, shuffleId, -1, reduceId,
+ new Exception("Missing an output location for shuffle " + shuffleId))
+ } else {
+ (status.address, decompressSize(status.compressedSizes(reduceId)))
+ }
+ }
+ }
+
/**
* Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
* We do this by encoding the log base 1.1 of the size as an integer, which can support
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 08ae06e865..53b051f1c5 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -52,6 +52,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
mergeCombiners: (C, C) => C,
partitioner: Partitioner,
mapSideCombine: Boolean = true): RDD[(K, C)] = {
+ if (getKeyClass().isArray) {
+ if (mapSideCombine) {
+ throw new SparkException("Cannot use map-side combining with array keys.")
+ }
+ if (partitioner.isInstanceOf[HashPartitioner]) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
+ }
val aggregator =
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (mapSideCombine) {
@@ -92,6 +100,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* before sending results to a reducer, similarly to a "combiner" in MapReduce.
*/
def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = {
+
+ if (getKeyClass().isArray) {
+ throw new SparkException("reduceByKeyLocally() does not support array keys")
+ }
+
def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = {
val map = new JHashMap[K, V]
for ((k, v) <- iter) {
@@ -165,6 +178,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* be set to true.
*/
def partitionBy(partitioner: Partitioner, mapSideCombine: Boolean = false): RDD[(K, V)] = {
+ if (getKeyClass().isArray) {
+ if (mapSideCombine) {
+ throw new SparkException("Cannot use map-side combining with array keys.")
+ }
+ if (partitioner.isInstanceOf[HashPartitioner]) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
+ }
if (mapSideCombine) {
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
@@ -178,9 +199,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
/**
- * Merge the values for each key using an associative reduce function. This will also perform
- * the merging locally on each mapper before sending results to a reducer, similarly to a
- * "combiner" in MapReduce.
+ * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
+ * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
+ * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD.
*/
def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = {
this.cogroup(other, partitioner).flatMapValues {
@@ -336,6 +357,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* list of values for that key in `this` as well as `other`.
*/
def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = {
+ if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
val cg = new CoGroupedRDD[K](
Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]),
partitioner)
@@ -352,6 +376,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
*/
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner)
: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
+ if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
val cg = new CoGroupedRDD[K](
Seq(self.asInstanceOf[RDD[(_, _)]],
other1.asInstanceOf[RDD[(_, _)]],
@@ -438,7 +465,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
val res = self.context.runJob(self, process _, Array(index), false)
res(0)
case None =>
- throw new UnsupportedOperationException("lookup() called on an RDD without a partitioner")
+ self.filter(_._1 == key).map(_._2).collect
}
}
@@ -466,20 +493,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
path: String,
keyClass: Class[_],
valueClass: Class[_],
- outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) {
- saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, new Configuration)
- }
-
- /**
- * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
- * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
- */
- def saveAsNewAPIHadoopFile(
- path: String,
- keyClass: Class[_],
- valueClass: Class[_],
outputFormatClass: Class[_ <: NewOutputFormat[_, _]],
- conf: Configuration) {
+ conf: Configuration = self.context.hadoopConfiguration) {
val job = new NewAPIHadoopJob(conf)
job.setOutputKeyClass(keyClass)
job.setOutputValueClass(valueClass)
@@ -530,7 +545,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: OutputFormat[_, _]],
- conf: JobConf = new JobConf) {
+ conf: JobConf = new JobConf(self.context.hadoopConfiguration)) {
conf.setOutputKeyClass(keyClass)
conf.setOutputValueClass(valueClass)
// conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug
@@ -588,6 +603,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
writer.cleanup()
}
+ /**
+ * Return an RDD with the keys of each tuple.
+ */
+ def keys: RDD[K] = self.map(_._1)
+
+ /**
+ * Return an RDD with the values of each tuple.
+ */
+ def values: RDD[V] = self.map(_._2)
+
private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure
private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure
@@ -624,24 +649,23 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
}
private[spark]
-class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev.context) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override val partitioner = prev.partitioner
- override def compute(split: Split, taskContext: TaskContext) =
- prev.iterator(split, taskContext).map{case (k, v) => (k, f(v))}
+class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U)
+ extends RDD[(K, U)](prev) {
+
+ override def getSplits = firstParent[(K, V)].splits
+ override val partitioner = firstParent[(K, V)].partitioner
+ override def compute(split: Split, context: TaskContext) =
+ firstParent[(K, V)].iterator(split, context).map{ case (k, v) => (k, f(v)) }
}
private[spark]
class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U])
- extends RDD[(K, U)](prev.context) {
-
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override val partitioner = prev.partitioner
+ extends RDD[(K, U)](prev) {
- override def compute(split: Split, taskContext: TaskContext) = {
- prev.iterator(split, taskContext).flatMap { case (k, v) => f(v).map(x => (k, x)) }
+ override def getSplits = firstParent[(K, V)].splits
+ override val partitioner = firstParent[(K, V)].partitioner
+ override def compute(split: Split, context: TaskContext) = {
+ firstParent[(K, V)].iterator(split, context).flatMap { case (k, v) => f(v).map(x => (k, x)) }
}
}
diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala
index a27f766e31..10adcd53ec 100644
--- a/core/src/main/scala/spark/ParallelCollection.scala
+++ b/core/src/main/scala/spark/ParallelCollection.scala
@@ -2,6 +2,7 @@ package spark
import scala.collection.immutable.NumericRange
import scala.collection.mutable.ArrayBuffer
+import scala.collection.Map
private[spark] class ParallelCollectionSplit[T: ClassManifest](
val rddId: Long,
@@ -22,28 +23,33 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest](
}
private[spark] class ParallelCollection[T: ClassManifest](
- sc: SparkContext,
+ @transient sc: SparkContext,
@transient data: Seq[T],
- numSlices: Int)
- extends RDD[T](sc) {
+ numSlices: Int,
+ locationPrefs: Map[Int,Seq[String]])
+ extends RDD[T](sc, Nil) {
// TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
// cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
// instead.
+ // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.
- @transient
- val splits_ = {
+ @transient var splits_ : Array[Split] = {
val slices = ParallelCollection.slice(data, numSlices).toArray
slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray
}
- override def splits = splits_.asInstanceOf[Array[Split]]
+ override def getSplits = splits_
- override def compute(s: Split, taskContext: TaskContext) =
+ override def compute(s: Split, context: TaskContext) =
s.asInstanceOf[ParallelCollectionSplit[T]].iterator
- override def preferredLocations(s: Split): Seq[String] = Nil
+ override def getPreferredLocations(s: Split): Seq[String] = {
+ locationPrefs.getOrElse(s.index, Nil)
+ }
- override val dependencies: List[Dependency[_]] = Nil
+ override def clearDependencies() {
+ splits_ = null
+ }
}
private object ParallelCollection {
diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala
index b71021a082..9d5b966e1e 100644
--- a/core/src/main/scala/spark/Partitioner.scala
+++ b/core/src/main/scala/spark/Partitioner.scala
@@ -11,6 +11,10 @@ abstract class Partitioner extends Serializable {
/**
* A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`.
+ *
+ * Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
+ * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
+ * produce an unexpected or incorrect result.
*/
class HashPartitioner(partitions: Int) extends Partitioner {
def numPartitions = partitions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index bb4c13c494..c79f34342f 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -1,10 +1,8 @@
package spark
-import java.io.EOFException
-import java.io.ObjectInputStream
+import java.io.{ObjectOutputStream, IOException, EOFException, ObjectInputStream}
import java.net.URL
-import java.util.Random
-import java.util.Date
+import java.util.{Date, Random}
import java.util.{HashMap => JHashMap}
import java.util.concurrent.atomic.AtomicLong
@@ -13,6 +11,7 @@ import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
@@ -73,41 +72,42 @@ import SparkContext._
* [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details
* on RDD internals.
*/
-abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serializable {
+abstract class RDD[T: ClassManifest](
+ @transient var sc: SparkContext,
+ var dependencies_ : List[Dependency[_]]
+ ) extends Serializable with Logging {
- // Methods that must be implemented by subclasses:
- /** Set of partitions in this RDD. */
- def splits: Array[Split]
+ def this(@transient oneParent: RDD[_]) =
+ this(oneParent.context , List(new OneToOneDependency(oneParent)))
+
+ // =======================================================================
+ // Methods that should be implemented by subclasses of RDD
+ // =======================================================================
/** Function for computing a given partition. */
def compute(split: Split, context: TaskContext): Iterator[T]
- /** How this RDD depends on any parent RDDs. */
- @transient val dependencies: List[Dependency[_]]
+ /** Set of partitions in this RDD. */
+ protected def getSplits(): Array[Split]
- // Methods available on all RDDs:
+ /** How this RDD depends on any parent RDDs. */
+ protected def getDependencies(): List[Dependency[_]] = dependencies_
- /** Record user function generating this RDD. */
- private[spark] val origin = Utils.getSparkCallSite
+ /** Optionally overridden by subclasses to specify placement preferences. */
+ protected def getPreferredLocations(split: Split): Seq[String] = Nil
/** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None
- /** Optionally overridden by subclasses to specify placement preferences. */
- def preferredLocations(split: Split): Seq[String] = Nil
-
- /** The [[spark.SparkContext]] that this RDD was created on. */
- def context = sc
- private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
+ // =======================================================================
+ // Methods and fields available on all RDDs
+ // =======================================================================
/** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId()
- // Variables relating to persistence
- private var storageLevel: StorageLevel = StorageLevel.NONE
-
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
@@ -131,22 +131,39 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
- private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
- if (!level.useDisk && level.replication < 2) {
- throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
+ /**
+ * Get the preferred location of a split, taking into account whether the
+ * RDD is checkpointed or not.
+ */
+ final def preferredLocations(split: Split): Seq[String] = {
+ if (isCheckpointed) {
+ checkpointData.get.getPreferredLocations(split)
+ } else {
+ getPreferredLocations(split)
}
+ }
- // This is a hack. Ideally this should re-use the code used by the CacheTracker
- // to generate the key.
- def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index)
-
- persist(level)
- sc.runJob(this, (iter: Iterator[T]) => {} )
-
- val p = this.partitioner
+ /**
+ * Get the array of splits of this RDD, taking into account whether the
+ * RDD is checkpointed or not.
+ */
+ final def splits: Array[Split] = {
+ if (isCheckpointed) {
+ checkpointData.get.getSplits
+ } else {
+ getSplits
+ }
+ }
- new BlockRDD[T](sc, splits.map(getSplitKey).toArray) {
- override val partitioner = p
+ /**
+ * Get the list of dependencies of this RDD, taking into account whether the
+ * RDD is checkpointed or not.
+ */
+ final def dependencies: List[Dependency[_]] = {
+ if (isCheckpointed) {
+ dependencies_
+ } else {
+ getDependencies
}
}
@@ -156,8 +173,10 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
* subclasses of RDD.
*/
final def iterator(split: Split, context: TaskContext): Iterator[T] = {
- if (storageLevel != StorageLevel.NONE) {
- SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel)
+ if (isCheckpointed) {
+ checkpointData.get.iterator(split, context)
+ } else if (storageLevel != StorageLevel.NONE) {
+ SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
} else {
compute(split, context)
}
@@ -185,9 +204,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
- def distinct(numSplits: Int = splits.size): RDD[T] =
+ def distinct(numSplits: Int): RDD[T] =
map(x => (x, null)).reduceByKey((x, y) => x, numSplits).map(_._1)
+ def distinct(): RDD[T] = distinct(splits.size)
+
/**
* Return a sampled subset of this RDD.
*/
@@ -328,6 +349,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def toArray(): Array[T] = collect()
/**
+ * Return an RDD that contains all matching values by applying `f`.
+ */
+ def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = {
+ filter(f.isDefinedAt).map(f)
+ }
+
+ /**
* Reduces the elements of this RDD using the specified associative binary operator.
*/
def reduce(f: (T, T) => T): T = {
@@ -415,6 +443,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): Map[T, Long] = {
+ if (elementClassManifest.erasure.isArray) {
+ throw new SparkException("countByValue() does not support arrays")
+ }
// TODO: This should perhaps be distributed by default.
def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = {
val map = new OLMap[T]
@@ -443,6 +474,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
timeout: Long,
confidence: Double = 0.95
): PartialResult[Map[T, BoundedDouble]] = {
+ if (elementClassManifest.erasure.isArray) {
+ throw new SparkException("countByValueApprox() does not support arrays")
+ }
val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) =>
val map = new OLMap[T]
while (iter.hasNext) {
@@ -502,8 +536,95 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
.saveAsSequenceFile(path)
}
+ /**
+ * Creates tuples of the elements in this RDD by applying `f`.
+ */
+ def keyBy[K](f: T => K): RDD[(K, T)] = {
+ map(x => (f(x), x))
+ }
+
/** A private method for tests, to look at the contents of each partition */
private[spark] def collectPartitions(): Array[Array[T]] = {
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
}
+
+ /**
+ * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
+ * directory set with SparkContext.setCheckpointDir() and all references to its parent
+ * RDDs will be removed. This function must be called before any job has been
+ * executed on this RDD. It is strongly recommended that this RDD is persisted in
+ * memory, otherwise saving it on a file will require recomputation.
+ */
+ def checkpoint() {
+ if (context.checkpointDir.isEmpty) {
+ throw new Exception("Checkpoint directory has not been set in the SparkContext")
+ } else if (checkpointData.isEmpty) {
+ checkpointData = Some(new RDDCheckpointData(this))
+ checkpointData.get.markForCheckpoint()
+ }
+ }
+
+ /**
+ * Return whether this RDD has been checkpointed or not
+ */
+ def isCheckpointed(): Boolean = {
+ if (checkpointData.isDefined) checkpointData.get.isCheckpointed() else false
+ }
+
+ /**
+ * Gets the name of the file to which this RDD was checkpointed
+ */
+ def getCheckpointFile(): Option[String] = {
+ if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None
+ }
+
+ // =======================================================================
+ // Other internal methods and fields
+ // =======================================================================
+
+ private var storageLevel: StorageLevel = StorageLevel.NONE
+
+ /** Record user function generating this RDD. */
+ private[spark] val origin = Utils.getSparkCallSite
+
+ private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
+
+ private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
+
+ /** Returns the first parent RDD */
+ protected[spark] def firstParent[U: ClassManifest] = {
+ dependencies.head.rdd.asInstanceOf[RDD[U]]
+ }
+
+ /** The [[spark.SparkContext]] that this RDD was created on. */
+ def context = sc
+
+ /**
+ * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler
+ * after a job using this RDD has completed (therefore the RDD has been materialized and
+ * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs.
+ */
+ protected[spark] def doCheckpoint() {
+ if (checkpointData.isDefined) checkpointData.get.doCheckpoint()
+ dependencies.foreach(_.rdd.doCheckpoint())
+ }
+
+ /**
+ * Changes the dependencies of this RDD from its original parents to the new RDD
+ * (`newRDD`) created from the checkpoint file.
+ */
+ protected[spark] def changeDependencies(newRDD: RDD[_]) {
+ clearDependencies()
+ dependencies_ = List(new OneToOneDependency(newRDD))
+ }
+
+ /**
+ * Clears the dependencies of this RDD. This method must ensure that all references
+ * to the original parent RDDs is removed to enable the parent RDDs to be garbage
+ * collected. Subclasses of RDD may override this method for implementing their own cleaning
+ * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea.
+ */
+ protected[spark] def clearDependencies() {
+ dependencies_ = null
+ }
}
diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala
new file mode 100644
index 0000000000..18df530b7d
--- /dev/null
+++ b/core/src/main/scala/spark/RDDCheckpointData.scala
@@ -0,0 +1,105 @@
+package spark
+
+import org.apache.hadoop.fs.Path
+import rdd.{CheckpointRDD, CoalescedRDD}
+import scheduler.{ResultTask, ShuffleMapTask}
+
+/**
+ * Enumeration to manage state transitions of an RDD through checkpointing
+ * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ]
+ */
+private[spark] object CheckpointState extends Enumeration {
+ type CheckpointState = Value
+ val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value
+}
+
+/**
+ * This class contains all the information related to RDD checkpointing. Each instance of this class
+ * is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as,
+ * manages the post-checkpoint state by providing the updated splits, iterator and preferred locations
+ * of the checkpointed RDD.
+ */
+private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
+extends Logging with Serializable {
+
+ import CheckpointState._
+
+ // The checkpoint state of the associated RDD.
+ var cpState = Initialized
+
+ // The file to which the associated RDD has been checkpointed to
+ @transient var cpFile: Option[String] = None
+
+ // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
+ @transient var cpRDD: Option[RDD[T]] = None
+
+ // Mark the RDD for checkpointing
+ def markForCheckpoint() {
+ RDDCheckpointData.synchronized {
+ if (cpState == Initialized) cpState = MarkedForCheckpoint
+ }
+ }
+
+ // Is the RDD already checkpointed
+ def isCheckpointed(): Boolean = {
+ RDDCheckpointData.synchronized { cpState == Checkpointed }
+ }
+
+ // Get the file to which this RDD was checkpointed to as an Option
+ def getCheckpointFile(): Option[String] = {
+ RDDCheckpointData.synchronized { cpFile }
+ }
+
+ // Do the checkpointing of the RDD. Called after the first job using that RDD is over.
+ def doCheckpoint() {
+ // If it is marked for checkpointing AND checkpointing is not already in progress,
+ // then set it to be in progress, else return
+ RDDCheckpointData.synchronized {
+ if (cpState == MarkedForCheckpoint) {
+ cpState = CheckpointingInProgress
+ } else {
+ return
+ }
+ }
+
+ // Save to file, and reload it as an RDD
+ val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString
+ rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _)
+ val newRDD = new CheckpointRDD[T](rdd.context, path)
+
+ // Change the dependencies and splits of the RDD
+ RDDCheckpointData.synchronized {
+ cpFile = Some(path)
+ cpRDD = Some(newRDD)
+ rdd.changeDependencies(newRDD)
+ cpState = Checkpointed
+ RDDCheckpointData.clearTaskCaches()
+ logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
+ }
+ }
+
+ // Get preferred location of a split after checkpointing
+ def getPreferredLocations(split: Split) = {
+ RDDCheckpointData.synchronized {
+ cpRDD.get.preferredLocations(split)
+ }
+ }
+
+ def getSplits: Array[Split] = {
+ RDDCheckpointData.synchronized {
+ cpRDD.get.splits
+ }
+ }
+
+ // Get iterator. This is called at the worker nodes.
+ def iterator(split: Split, context: TaskContext): Iterator[T] = {
+ rdd.firstParent[T].iterator(split, context)
+ }
+}
+
+private[spark] object RDDCheckpointData {
+ def clearTaskCaches() {
+ ShuffleMapTask.clearCache()
+ ResultTask.clearCache()
+ }
+}
diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
index a34aee69c1..6b4a11d6d3 100644
--- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
@@ -42,7 +42,13 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) {
classManifest[T].erasure
} else {
- implicitly[T => Writable].getClass.getMethods()(0).getReturnType
+ // We get the type of the Writable class by looking at the apply method which converts
+ // from T to Writable. Since we have two apply methods we filter out the one which
+ // is of the form "java.lang.Object apply(java.lang.Object)"
+ implicitly[T => Writable].getClass.getDeclaredMethods().filter(
+ m => m.getReturnType().toString != "java.lang.Object" &&
+ m.getName() == "apply")(0).getReturnType
+
}
// TODO: use something like WritableConverter to avoid reflection
}
diff --git a/core/src/main/scala/spark/SizeEstimator.scala b/core/src/main/scala/spark/SizeEstimator.scala
index 7c3e8640e9..d4e1157250 100644
--- a/core/src/main/scala/spark/SizeEstimator.scala
+++ b/core/src/main/scala/spark/SizeEstimator.scala
@@ -9,7 +9,6 @@ import java.util.Random
import javax.management.MBeanServer
import java.lang.management.ManagementFactory
-import com.sun.management.HotSpotDiagnosticMXBean
import scala.collection.mutable.ArrayBuffer
@@ -76,12 +75,20 @@ private[spark] object SizeEstimator extends Logging {
if (System.getProperty("spark.test.useCompressedOops") != null) {
return System.getProperty("spark.test.useCompressedOops").toBoolean
}
+
try {
val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic"
val server = ManagementFactory.getPlatformMBeanServer()
+
+ // NOTE: This should throw an exception in non-Sun JVMs
+ val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean")
+ val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption",
+ Class.forName("java.lang.String"))
+
val bean = ManagementFactory.newPlatformMXBeanProxy(server,
- hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean])
- return bean.getVMOption("UseCompressedOops").getValue.toBoolean
+ hotSpotMBeanName, hotSpotMBeanClass)
+ // TODO: We could use reflection on the VMOption returned ?
+ return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true")
} catch {
case e: Exception => {
// Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 0afab522af..66bdbe7cda 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -3,10 +3,12 @@ package spark
import java.io._
import java.util.concurrent.atomic.AtomicInteger
import java.net.{URI, URLClassLoader}
+import java.lang.ref.WeakReference
import scala.collection.Map
import scala.collection.generic.Growable
import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.JavaConversions._
import akka.actor.Actor
import akka.actor.Actor._
@@ -36,12 +38,8 @@ import spark.broadcast._
import spark.deploy.LocalSparkCluster
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
-import spark.rdd.HadoopRDD
-import spark.rdd.NewHadoopRDD
-import spark.rdd.UnionRDD
-import spark.scheduler.ShuffleMapTask
-import spark.scheduler.DAGScheduler
-import spark.scheduler.TaskScheduler
+import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD}
+import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler}
import spark.scheduler.local.LocalScheduler
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
@@ -58,29 +56,13 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend
* @param environment Environment variables to set on worker nodes.
*/
class SparkContext(
- master: String,
- jobName: String,
- val sparkHome: String,
- jars: Seq[String],
- environment: Map[String, String])
+ val master: String,
+ val jobName: String,
+ val sparkHome: String = null,
+ val jars: Seq[String] = Nil,
+ environment: Map[String, String] = Map())
extends Logging {
- /**
- * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI
- * @param sparkHome Location where Spark is installed on cluster nodes.
- * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
- * system or HDFS, HTTP, HTTPS, or FTP URLs.
- */
- def this(master: String, jobName: String, sparkHome: String, jars: Seq[String]) =
- this(master, jobName, sparkHome, jars, Map())
-
- /**
- * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI
- */
- def this(master: String, jobName: String) = this(master, jobName, null, Nil, Map())
-
// Ensure logging is initialized before we spawn any threads
initLogging()
@@ -187,11 +169,32 @@ class SparkContext(
private var dagScheduler = new DAGScheduler(taskScheduler)
+ /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
+ val hadoopConfiguration = {
+ val conf = new Configuration()
+ // Explicitly check for S3 environment variables
+ if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
+ conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
+ conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
+ conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ }
+ // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
+ for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) {
+ conf.set(key.substring("spark.hadoop.".length), System.getProperty(key))
+ }
+ val bufferSize = System.getProperty("spark.buffer.size", "65536")
+ conf.set("io.file.buffer.size", bufferSize)
+ conf
+ }
+
+ private[spark] var checkpointDir: Option[String] = None
+
// Methods for creating RDDs
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
- new ParallelCollection[T](this, seq, numSlices)
+ new ParallelCollection[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
/** Distribute a local Scala collection to form an RDD. */
@@ -199,6 +202,14 @@ class SparkContext(
parallelize(seq, numSlices)
}
+ /** Distribute a local Scala collection to form an RDD, with one or more
+ * location preferences (hostnames of Spark nodes) for each object.
+ * Create a new partition for each collection item. */
+ def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = {
+ val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
+ new ParallelCollection[T](this, seq.map(_._1), seq.size, indexToPrefs)
+ }
+
/**
* Read a text file from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI, and return it as an RDD of Strings.
@@ -231,10 +242,8 @@ class SparkContext(
valueClass: Class[V],
minSplits: Int = defaultMinSplits
) : RDD[(K, V)] = {
- val conf = new JobConf()
+ val conf = new JobConf(hadoopConfiguration)
FileInputFormat.setInputPaths(conf, path)
- val bufferSize = System.getProperty("spark.buffer.size", "65536")
- conf.set("io.file.buffer.size", bufferSize)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
}
@@ -275,8 +284,7 @@ class SparkContext(
path,
fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]],
- vm.erasure.asInstanceOf[Class[V]],
- new Configuration)
+ vm.erasure.asInstanceOf[Class[V]])
}
/**
@@ -288,7 +296,7 @@ class SparkContext(
fClass: Class[F],
kClass: Class[K],
vClass: Class[V],
- conf: Configuration): RDD[(K, V)] = {
+ conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
@@ -300,7 +308,7 @@ class SparkContext(
* and extra configuration options to pass to the input format.
*/
def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]](
- conf: Configuration,
+ conf: Configuration = hadoopConfiguration,
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = {
@@ -365,6 +373,13 @@ class SparkContext(
.flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes))
}
+
+ protected[spark] def checkpointFile[T: ClassManifest](
+ path: String
+ ): RDD[T] = {
+ new CheckpointRDD[T](this, path)
+ }
+
/** Build the union of a list of RDDs. */
def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
@@ -382,11 +397,12 @@ class SparkContext(
new Accumulator(initialValue, param)
/**
- * Create an [[spark.Accumulable]] shared variable, with a `+=` method
+ * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`.
+ * Only the master can access the accumuable's `value`.
* @tparam T accumulator type
* @tparam R type that can be added to the accumulator
*/
- def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
+ def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) =
new Accumulable(initialValue, param)
/**
@@ -404,12 +420,13 @@ class SparkContext(
* Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once.
*/
- def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal)
+ def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
/**
- * Add a file to be downloaded into the working directory of this Spark job on every node.
+ * Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
- * filesystems), or an HTTP, HTTPS or FTP URI.
+ * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
+ * use `SparkFiles.get(path)` to find its download location.
*/
def addFile(path: String) {
val uri = new URI(path)
@@ -419,9 +436,10 @@ class SparkContext(
}
addedFiles(key) = System.currentTimeMillis
- // Fetch the file locally in case the task is executed locally
- val filename = new File(path.split("/").last)
- Utils.fetchFile(path, new File("."))
+ // Fetch the file locally in case a job is executed locally.
+ // Jobs that run through LocalScheduler will already fetch the required dependencies,
+ // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
+ Utils.fetchFile(path, new File(SparkFiles.getRootDirectory))
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
@@ -437,11 +455,10 @@ class SparkContext(
}
/**
- * Clear the job's list of files added by `addFile` so that they do not get donwloaded to
+ * Clear the job's list of files added by `addFile` so that they do not get downloaded to
* any new nodes.
*/
def clearFiles() {
- addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedFiles.clear()
}
@@ -465,23 +482,27 @@ class SparkContext(
* any new nodes.
*/
def clearJars() {
- addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedJars.clear()
}
/** Shut down the SparkContext. */
def stop() {
- dagScheduler.stop()
- dagScheduler = null
- taskScheduler = null
- // TODO: Cache.stop()?
- env.stop()
- // Clean up locally linked files
- clearFiles()
- clearJars()
- SparkEnv.set(null)
- ShuffleMapTask.clearCache()
- logInfo("Successfully stopped SparkContext")
+ if (dagScheduler != null) {
+ dagScheduler.stop()
+ dagScheduler = null
+ taskScheduler = null
+ // TODO: Cache.stop()?
+ env.stop()
+ // Clean up locally linked files
+ clearFiles()
+ clearJars()
+ SparkEnv.set(null)
+ ShuffleMapTask.clearCache()
+ ResultTask.clearCache()
+ logInfo("Successfully stopped SparkContext")
+ } else {
+ logInfo("SparkContext already stopped")
+ }
}
/**
@@ -518,6 +539,7 @@ class SparkContext(
val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
+ rdd.doCheckpoint()
result
}
@@ -574,6 +596,26 @@ class SparkContext(
return f
}
+ /**
+ * Set the directory under which RDDs are going to be checkpointed. The directory must
+ * be a HDFS path if running on a cluster. If the directory does not exist, it will
+ * be created. If the directory exists and useExisting is set to true, then the
+ * exisiting directory will be used. Otherwise an exception will be thrown to
+ * prevent accidental overriding of checkpoint files in the existing directory.
+ */
+ def setCheckpointDir(dir: String, useExisting: Boolean = false) {
+ val path = new Path(dir)
+ val fs = path.getFileSystem(new Configuration())
+ if (!useExisting) {
+ if (fs.exists(path)) {
+ throw new Exception("Checkpoint directory '" + path + "' already exists.")
+ } else {
+ fs.mkdirs(path)
+ }
+ }
+ checkpointDir = Some(dir)
+ }
+
/** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */
def defaultParallelism: Int = taskScheduler.defaultParallelism
@@ -595,6 +637,7 @@ class SparkContext(
* various Spark features.
*/
object SparkContext {
+
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 272d7cdad3..2a7a8af83d 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -22,24 +22,19 @@ class SparkEnv (
val actorSystem: ActorSystem,
val serializer: Serializer,
val closureSerializer: Serializer,
- val cacheTracker: CacheTracker,
+ val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
- val httpFileServer: HttpFileServer
+ val httpFileServer: HttpFileServer,
+ val sparkFilesDir: String
) {
- /** No-parameter constructor for unit tests. */
- def this() = {
- this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
- }
-
def stop() {
httpFileServer.stop()
mapOutputTracker.stop()
- cacheTracker.stop()
shuffleFetcher.stop()
broadcastManager.stop()
blockManager.stop()
@@ -86,10 +81,13 @@ object SparkEnv extends Logging {
}
val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
-
- val blockManagerMaster = new BlockManagerMaster(actorSystem, isMaster, isLocal)
+
+ val masterIp: String = System.getProperty("spark.master.host", "localhost")
+ val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt
+ val blockManagerMaster = new BlockManagerMaster(
+ actorSystem, isMaster, isLocal, masterIp, masterPort)
val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer)
-
+
val connectionManager = blockManager.connectionManager
val broadcastManager = new BroadcastManager(isMaster)
@@ -97,18 +95,26 @@ object SparkEnv extends Logging {
val closureSerializer = instantiateClass[Serializer](
"spark.closure.serializer", "spark.JavaSerializer")
- val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager)
- blockManager.cacheTracker = cacheTracker
+ val cacheManager = new CacheManager(blockManager)
val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster)
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
-
+
val httpFileServer = new HttpFileServer()
httpFileServer.initialize()
System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
+ // Set the sparkFiles directory, used when downloading dependencies. In local mode,
+ // this is a temporary directory; in distributed mode, this is the executor's current working
+ // directory.
+ val sparkFilesDir: String = if (isMaster) {
+ Utils.createTempDir().getAbsolutePath
+ } else {
+ "."
+ }
+
// Warn about deprecated spark.cache.class property
if (System.getProperty("spark.cache.class") != null) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
@@ -119,12 +125,13 @@ object SparkEnv extends Logging {
actorSystem,
serializer,
closureSerializer,
- cacheTracker,
+ cacheManager,
mapOutputTracker,
shuffleFetcher,
broadcastManager,
blockManager,
connectionManager,
- httpFileServer)
+ httpFileServer,
+ sparkFilesDir)
}
}
diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java
new file mode 100644
index 0000000000..566aec622c
--- /dev/null
+++ b/core/src/main/scala/spark/SparkFiles.java
@@ -0,0 +1,25 @@
+package spark;
+
+import java.io.File;
+
+/**
+ * Resolves paths to files added through `SparkContext.addFile()`.
+ */
+public class SparkFiles {
+
+ private SparkFiles() {}
+
+ /**
+ * Get the absolute path of a file added through `SparkContext.addFile()`.
+ */
+ public static String get(String filename) {
+ return new File(getRootDirectory(), filename).getAbsolutePath();
+ }
+
+ /**
+ * Get the root directory that contains files added through `SparkContext.addFile()`.
+ */
+ public static String getRootDirectory() {
+ return SparkEnv.get().sparkFilesDir();
+ }
+}
diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala
index d2746b26b3..eab85f85a2 100644
--- a/core/src/main/scala/spark/TaskContext.scala
+++ b/core/src/main/scala/spark/TaskContext.scala
@@ -5,8 +5,7 @@ import scala.collection.mutable.ArrayBuffer
class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable {
- @transient
- val onCompleteCallbacks = new ArrayBuffer[() => Unit]
+ @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
// Add a callback function to be executed on task completion. An example use
// is for HadoopRDD to register a callback to close the input stream.
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 6d64b32174..ae77264372 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -1,7 +1,7 @@
package spark
import java.io._
-import java.net.{NetworkInterface, InetAddress, URL, URI}
+import java.net.{NetworkInterface, InetAddress, Inet4Address, URL, URI}
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.conf.Configuration
@@ -9,6 +9,8 @@ import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import scala.io.Source
+import com.google.common.io.Files
+import com.google.common.util.concurrent.ThreadFactoryBuilder
/**
* Various utility methods used by Spark.
@@ -110,48 +112,56 @@ private object Utils extends Logging {
}
}
- /** Copy a file on the local file system */
- def copyFile(source: File, dest: File) {
- val in = new FileInputStream(source)
- val out = new FileOutputStream(dest)
- copyStream(in, out, true)
- }
-
- /** Download a file from a given URL to the local filesystem */
- def downloadFile(url: URL, localPath: String) {
- val in = url.openStream()
- val out = new FileOutputStream(localPath)
- Utils.copyStream(in, out, true)
- }
-
/**
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
+ *
+ * Throws SparkException if the target file already exists and has different contents than
+ * the requested file.
*/
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
+ val tempDir = getLocalDir
+ val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
val targetFile = new File(targetDir, filename)
val uri = new URI(url)
uri.getScheme match {
case "http" | "https" | "ftp" =>
- logInfo("Fetching " + url + " to " + targetFile)
+ logInfo("Fetching " + url + " to " + tempFile)
val in = new URL(url).openStream()
- val out = new FileOutputStream(targetFile)
+ val out = new FileOutputStream(tempFile)
Utils.copyStream(in, out, true)
+ if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
+ tempFile.delete()
+ throw new SparkException("File " + targetFile + " exists and does not match contents of" +
+ " " + url)
+ } else {
+ Files.move(tempFile, targetFile)
+ }
case "file" | null =>
- // Remove the file if it already exists
- targetFile.delete()
- // Symlink the file locally.
- if (uri.isAbsolute) {
- // url is absolute, i.e. it starts with "file:///". Extract the source
- // file's absolute path from the url.
- val sourceFile = new File(uri)
- logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
- FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath)
+ val sourceFile = if (uri.isAbsolute) {
+ new File(uri)
+ } else {
+ new File(url)
+ }
+ if (targetFile.exists && !Files.equal(sourceFile, targetFile)) {
+ throw new SparkException("File " + targetFile + " exists and does not match contents of" +
+ " " + url)
} else {
- // url is not absolute, i.e. itself is the path to the source file.
- logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath)
- FileUtil.symLink(url, targetFile.getAbsolutePath)
+ // Remove the file if it already exists
+ targetFile.delete()
+ // Symlink the file locally.
+ if (uri.isAbsolute) {
+ // url is absolute, i.e. it starts with "file:///". Extract the source
+ // file's absolute path from the url.
+ val sourceFile = new File(uri)
+ logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
+ FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath)
+ } else {
+ // url is not absolute, i.e. itself is the path to the source file.
+ logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath)
+ FileUtil.symLink(url, targetFile.getAbsolutePath)
+ }
}
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
@@ -159,8 +169,15 @@ private object Utils extends Logging {
val conf = new Configuration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
- val out = new FileOutputStream(targetFile)
+ val out = new FileOutputStream(tempFile)
Utils.copyStream(in, out, true)
+ if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
+ tempFile.delete()
+ throw new SparkException("File " + targetFile + " exists and does not match contents of" +
+ " " + url)
+ } else {
+ Files.move(tempFile, targetFile)
+ }
}
// Decompress the file if it's a .tar or .tar.gz
if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
@@ -171,7 +188,16 @@ private object Utils extends Logging {
Utils.execute(Seq("tar", "-xf", filename), targetDir)
}
// Make the file executable - That's necessary for scripts
- FileUtil.chmod(filename, "a+x")
+ FileUtil.chmod(targetFile.getAbsolutePath, "a+x")
+ }
+
+ /**
+ * Get a temporary directory using Spark's spark.local.dir property, if set. This will always
+ * return a single directory, even though the spark.local.dir property might be a list of
+ * multiple paths.
+ */
+ def getLocalDir: String = {
+ System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0)
}
/**
@@ -212,7 +238,8 @@ private object Utils extends Logging {
// Address resolves to something like 127.0.1.1, which happens on Debian; try to find
// a better address using the local network interfaces
for (ni <- NetworkInterface.getNetworkInterfaces) {
- for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && !addr.isLoopbackAddress) {
+ for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress &&
+ !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) {
// We've found an address that looks reasonable!
logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" +
" a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress +
@@ -247,48 +274,28 @@ private object Utils extends Logging {
customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
}
- /**
- * Returns a standard ThreadFactory except all threads are daemons.
- */
- private def newDaemonThreadFactory: ThreadFactory = {
- new ThreadFactory {
- def newThread(r: Runnable): Thread = {
- var t = Executors.defaultThreadFactory.newThread (r)
- t.setDaemon (true)
- return t
- }
- }
- }
+ private[spark] val daemonThreadFactory: ThreadFactory =
+ new ThreadFactoryBuilder().setDaemon(true).build()
/**
* Wrapper over newCachedThreadPool.
*/
- def newDaemonCachedThreadPool(): ThreadPoolExecutor = {
- var threadPool = Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory (newDaemonThreadFactory)
-
- return threadPool
- }
+ def newDaemonCachedThreadPool(): ThreadPoolExecutor =
+ Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
/**
* Return the string to tell how long has passed in seconds. The passing parameter should be in
* millisecond.
*/
def getUsedTimeMs(startTimeMs: Long): String = {
- return " " + (System.currentTimeMillis - startTimeMs) + " ms "
+ return " " + (System.currentTimeMillis - startTimeMs) + " ms"
}
/**
* Wrapper over newFixedThreadPool.
*/
- def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
- var threadPool = Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory(newDaemonThreadFactory)
-
- return threadPool
- }
+ def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor =
+ Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
/**
* Delete a file or directory and its contents recursively.
diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
index 5c2be534ff..8ce32e0e2f 100644
--- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -471,6 +471,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x)
fromRDD(new OrderedRDDFunctions(rdd).sortByKey(ascending))
}
+
+ /**
+ * Return an RDD with the keys of each tuple.
+ */
+ def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1))
+
+ /**
+ * Return an RDD with the values of each tuple.
+ */
+ def values(): JavaRDD[V] = JavaRDD.fromRDD[V](rdd.map(_._2))
}
object JavaPairRDD {
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index 81d3a94466..b3698ffa44 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -9,6 +9,7 @@ import spark.api.java.JavaPairRDD._
import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
import spark.partial.{PartialResult, BoundedDouble}
import spark.storage.StorageLevel
+import com.google.common.base.Optional
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
@@ -298,4 +299,36 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Save this RDD as a SequenceFile of serialized objects.
*/
def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path)
+
+ /**
+ * Creates tuples of the elements in this RDD by applying `f`.
+ */
+ def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = {
+ implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
+ JavaPairRDD.fromRDD(rdd.keyBy(f))
+ }
+
+ /**
+ * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
+ * directory set with SparkContext.setCheckpointDir() and all references to its parent
+ * RDDs will be removed. This function must be called before any job has been
+ * executed on this RDD. It is strongly recommended that this RDD is persisted in
+ * memory, otherwise saving it on a file will require recomputation.
+ */
+ def checkpoint() = rdd.checkpoint()
+
+ /**
+ * Return whether this RDD has been checkpointed or not
+ */
+ def isCheckpointed(): Boolean = rdd.isCheckpointed()
+
+ /**
+ * Gets the name of the file to which this RDD was checkpointed
+ */
+ def getCheckpointFile(): Optional[String] = {
+ rdd.getCheckpointFile match {
+ case Some(file) => Optional.of(file)
+ case _ => Optional.absent()
+ }
+ }
}
diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
index edbb187b1b..50b8970cd8 100644
--- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
@@ -10,7 +10,7 @@ import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
-import spark.{Accumulator, AccumulatorParam, RDD, SparkContext}
+import spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext}
import spark.SparkContext.IntAccumulatorParam
import spark.SparkContext.DoubleAccumulatorParam
import spark.broadcast.Broadcast
@@ -265,26 +265,46 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
/**
* Create an [[spark.Accumulator]] integer variable, which tasks can "add" values
- * to using the `+=` method. Only the master can access the accumulator's `value`.
+ * to using the `add` method. Only the master can access the accumulator's `value`.
*/
- def intAccumulator(initialValue: Int): Accumulator[Int] =
- sc.accumulator(initialValue)(IntAccumulatorParam)
+ def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] =
+ sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]]
/**
* Create an [[spark.Accumulator]] double variable, which tasks can "add" values
- * to using the `+=` method. Only the master can access the accumulator's `value`.
+ * to using the `add` method. Only the master can access the accumulator's `value`.
*/
- def doubleAccumulator(initialValue: Double): Accumulator[Double] =
- sc.accumulator(initialValue)(DoubleAccumulatorParam)
+ def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] =
+ sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]]
+
+ /**
+ * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values
+ * to using the `add` method. Only the master can access the accumulator's `value`.
+ */
+ def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue)
+
+ /**
+ * Create an [[spark.Accumulator]] double variable, which tasks can "add" values
+ * to using the `add` method. Only the master can access the accumulator's `value`.
+ */
+ def accumulator(initialValue: Double): Accumulator[java.lang.Double] =
+ doubleAccumulator(initialValue)
/**
* Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values
- * to using the `+=` method. Only the master can access the accumulator's `value`.
+ * to using the `add` method. Only the master can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] =
sc.accumulator(initialValue)(accumulatorParam)
/**
+ * Create an [[spark.Accumulable]] shared variable of the given type, to which tasks can
+ * "add" values with `add`. Only the master can access the accumuable's `value`.
+ */
+ def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] =
+ sc.accumulable(initialValue)(param)
+
+ /**
* Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once.
*/
@@ -301,6 +321,75 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* (in that order of preference). If neither of these is set, return None.
*/
def getSparkHome(): Option[String] = sc.getSparkHome()
+
+ /**
+ * Add a file to be downloaded with this Spark job on every node.
+ * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
+ * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
+ * use `SparkFiles.get(path)` to find its download location.
+ */
+ def addFile(path: String) {
+ sc.addFile(path)
+ }
+
+ /**
+ * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
+ * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
+ * filesystems), or an HTTP, HTTPS or FTP URI.
+ */
+ def addJar(path: String) {
+ sc.addJar(path)
+ }
+
+ /**
+ * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to
+ * any new nodes.
+ */
+ def clearJars() {
+ sc.clearJars()
+ }
+
+ /**
+ * Clear the job's list of files added by `addFile` so that they do not get downloaded to
+ * any new nodes.
+ */
+ def clearFiles() {
+ sc.clearFiles()
+ }
+
+ /**
+ * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse.
+ */
+ def hadoopConfiguration(): Configuration = {
+ sc.hadoopConfiguration
+ }
+
+ /**
+ * Set the directory under which RDDs are going to be checkpointed. The directory must
+ * be a HDFS path if running on a cluster. If the directory does not exist, it will
+ * be created. If the directory exists and useExisting is set to true, then the
+ * exisiting directory will be used. Otherwise an exception will be thrown to
+ * prevent accidental overriding of checkpoint files in the existing directory.
+ */
+ def setCheckpointDir(dir: String, useExisting: Boolean) {
+ sc.setCheckpointDir(dir, useExisting)
+ }
+
+ /**
+ * Set the directory under which RDDs are going to be checkpointed. The directory must
+ * be a HDFS path if running on a cluster. If the directory does not exist, it will
+ * be created. If the directory exists, an exception will be thrown to prevent accidental
+ * overriding of checkpoint files.
+ */
+ def setCheckpointDir(dir: String) {
+ sc.setCheckpointDir(dir)
+ }
+
+ protected def checkpointFile[T](path: String): JavaRDD[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ new JavaRDD(sc.checkpointFile(path))
+ }
}
object JavaSparkContext {
diff --git a/core/src/main/scala/spark/api/java/StorageLevels.java b/core/src/main/scala/spark/api/java/StorageLevels.java
index 722af3c06c..5e5845ac3a 100644
--- a/core/src/main/scala/spark/api/java/StorageLevels.java
+++ b/core/src/main/scala/spark/api/java/StorageLevels.java
@@ -17,4 +17,15 @@ public class StorageLevels {
public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2);
public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1);
public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2);
+
+ /**
+ * Create a new StorageLevel object.
+ * @param useDisk saved to disk, if true
+ * @param useMemory saved to memory, if true
+ * @param deserialized saved as deserialized objects, if true
+ * @param replication replication factor
+ */
+ public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) {
+ return StorageLevel.apply(useDisk, useMemory, deserialized, replication);
+ }
}
diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
new file mode 100644
index 0000000000..519e310323
--- /dev/null
+++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
@@ -0,0 +1,48 @@
+package spark.api.python
+
+import spark.Partitioner
+
+import java.util.Arrays
+
+/**
+ * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
+ *
+ * Stores the unique id() of the Python-side partitioning function so that it is incorporated into
+ * equality comparisons. Correctness requires that the id is a unique identifier for the
+ * lifetime of the job (i.e. that it is not re-used as the id of a different partitioning
+ * function). This can be ensured by using the Python id() function and maintaining a reference
+ * to the Python partitioning function so that its id() is not reused.
+ */
+private[spark] class PythonPartitioner(
+ override val numPartitions: Int,
+ val pyPartitionFunctionId: Long)
+ extends Partitioner {
+
+ override def getPartition(key: Any): Int = {
+ if (key == null) {
+ return 0
+ }
+ else {
+ val hashCode = {
+ if (key.isInstanceOf[Array[Byte]]) {
+ Arrays.hashCode(key.asInstanceOf[Array[Byte]])
+ } else {
+ key.hashCode()
+ }
+ }
+ val mod = hashCode % numPartitions
+ if (mod < 0) {
+ mod + numPartitions
+ } else {
+ mod // Guard against negative hash codes
+ }
+ }
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case h: PythonPartitioner =>
+ h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId
+ case _ =>
+ false
+ }
+}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
new file mode 100644
index 0000000000..f43a152ca7
--- /dev/null
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -0,0 +1,293 @@
+package spark.api.python
+
+import java.io._
+import java.net._
+import java.util.{List => JList, ArrayList => JArrayList, Collections}
+
+import scala.collection.JavaConversions._
+import scala.io.Source
+
+import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
+import spark.broadcast.Broadcast
+import spark._
+import spark.rdd.PipedRDD
+
+
+private[spark] class PythonRDD[T: ClassManifest](
+ parent: RDD[T],
+ command: Seq[String],
+ envVars: java.util.Map[String, String],
+ preservePartitoning: Boolean,
+ pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
+ accumulator: Accumulator[JList[Array[Byte]]])
+ extends RDD[Array[Byte]](parent) {
+
+ // Similar to Runtime.exec(), if we are given a single string, split it into words
+ // using a standard StringTokenizer (i.e. by spaces)
+ def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
+ preservePartitoning: Boolean, pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
+ accumulator: Accumulator[JList[Array[Byte]]]) =
+ this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
+ broadcastVars, accumulator)
+
+ override def getSplits = parent.splits
+
+ override val partitioner = if (preservePartitoning) parent.partitioner else None
+
+ override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = {
+ val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
+
+ val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
+ // Add the environmental variables to the process.
+ val currentEnvVars = pb.environment()
+
+ for ((variable, value) <- envVars) {
+ currentEnvVars.put(variable, value)
+ }
+
+ val proc = pb.start()
+ val env = SparkEnv.get
+
+ // Start a thread to print the process's stderr to ours
+ new Thread("stderr reader for " + command) {
+ override def run() {
+ for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
+ System.err.println(line)
+ }
+ }
+ }.start()
+
+ // Start a thread to feed the process input from our parent's iterator
+ new Thread("stdin writer for " + command) {
+ override def run() {
+ SparkEnv.set(env)
+ val out = new PrintWriter(proc.getOutputStream)
+ val dOut = new DataOutputStream(proc.getOutputStream)
+ // Split index
+ dOut.writeInt(split.index)
+ // sparkFilesDir
+ PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut)
+ // Broadcast variables
+ dOut.writeInt(broadcastVars.length)
+ for (broadcast <- broadcastVars) {
+ dOut.writeLong(broadcast.id)
+ dOut.writeInt(broadcast.value.length)
+ dOut.write(broadcast.value)
+ dOut.flush()
+ }
+ // Serialized user code
+ for (elem <- command) {
+ out.println(elem)
+ }
+ out.flush()
+ // Data values
+ for (elem <- parent.iterator(split, context)) {
+ PythonRDD.writeAsPickle(elem, dOut)
+ }
+ dOut.flush()
+ out.flush()
+ proc.getOutputStream.close()
+ }
+ }.start()
+
+ // Return an iterator that read lines from the process's stdout
+ val stream = new DataInputStream(proc.getInputStream)
+ return new Iterator[Array[Byte]] {
+ def next(): Array[Byte] = {
+ val obj = _nextObj
+ _nextObj = read()
+ obj
+ }
+
+ private def read(): Array[Byte] = {
+ try {
+ val length = stream.readInt()
+ if (length != -1) {
+ val obj = new Array[Byte](length)
+ stream.readFully(obj)
+ obj
+ } else {
+ // We've finished the data section of the output, but we can still read some
+ // accumulator updates; let's do that, breaking when we get EOFException
+ while (true) {
+ val len2 = stream.readInt()
+ val update = new Array[Byte](len2)
+ stream.readFully(update)
+ accumulator += Collections.singletonList(update)
+ }
+ new Array[Byte](0)
+ }
+ } catch {
+ case eof: EOFException => {
+ val exitStatus = proc.waitFor()
+ if (exitStatus != 0) {
+ throw new Exception("Subprocess exited with status " + exitStatus)
+ }
+ new Array[Byte](0)
+ }
+ case e => throw e
+ }
+ }
+
+ var _nextObj = read()
+
+ def hasNext = _nextObj.length != 0
+ }
+ }
+
+ val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
+}
+
+/**
+ * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
+ * This is used by PySpark's shuffle operations.
+ */
+private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
+ RDD[(Array[Byte], Array[Byte])](prev) {
+ override def getSplits = prev.splits
+ override def compute(split: Split, context: TaskContext) =
+ prev.iterator(split, context).grouped(2).map {
+ case Seq(a, b) => (a, b)
+ case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
+ }
+ val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
+}
+
+private[spark] object PythonRDD {
+
+ /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
+ def stripPickle(arr: Array[Byte]) : Array[Byte] = {
+ arr.slice(2, arr.length - 1)
+ }
+
+ /**
+ * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
+ * The data format is a 32-bit integer representing the pickled object's length (in bytes),
+ * followed by the pickled data.
+ *
+ * Pickle module:
+ *
+ * http://docs.python.org/2/library/pickle.html
+ *
+ * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
+ *
+ * http://hg.python.org/cpython/file/2.6/Lib/pickle.py
+ * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
+ *
+ * @param elem the object to write
+ * @param dOut a data output stream
+ */
+ def writeAsPickle(elem: Any, dOut: DataOutputStream) {
+ if (elem.isInstanceOf[Array[Byte]]) {
+ val arr = elem.asInstanceOf[Array[Byte]]
+ dOut.writeInt(arr.length)
+ dOut.write(arr)
+ } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
+ val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
+ val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes
+ dOut.writeInt(length)
+ dOut.writeByte(Pickle.PROTO)
+ dOut.writeByte(Pickle.TWO)
+ dOut.write(PythonRDD.stripPickle(t._1))
+ dOut.write(PythonRDD.stripPickle(t._2))
+ dOut.writeByte(Pickle.TUPLE2)
+ dOut.writeByte(Pickle.STOP)
+ } else if (elem.isInstanceOf[String]) {
+ // For uniformity, strings are wrapped into Pickles.
+ val s = elem.asInstanceOf[String].getBytes("UTF-8")
+ val length = 2 + 1 + 4 + s.length + 1
+ dOut.writeInt(length)
+ dOut.writeByte(Pickle.PROTO)
+ dOut.writeByte(Pickle.TWO)
+ dOut.write(Pickle.BINUNICODE)
+ dOut.writeInt(Integer.reverseBytes(s.length))
+ dOut.write(s)
+ dOut.writeByte(Pickle.STOP)
+ } else {
+ throw new Exception("Unexpected RDD type")
+ }
+ }
+
+ def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
+ JavaRDD[Array[Byte]] = {
+ val file = new DataInputStream(new FileInputStream(filename))
+ val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
+ try {
+ while (true) {
+ val length = file.readInt()
+ val obj = new Array[Byte](length)
+ file.readFully(obj)
+ objs.append(obj)
+ }
+ } catch {
+ case eof: EOFException => {}
+ case e => throw e
+ }
+ JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
+ }
+
+ def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+ val file = new DataOutputStream(new FileOutputStream(filename))
+ for (item <- items) {
+ writeAsPickle(item, file)
+ }
+ file.close()
+ }
+
+ def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] =
+ rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head
+}
+
+private object Pickle {
+ val PROTO: Byte = 0x80.toByte
+ val TWO: Byte = 0x02.toByte
+ val BINUNICODE: Byte = 'X'
+ val STOP: Byte = '.'
+ val TUPLE2: Byte = 0x86.toByte
+ val EMPTY_LIST: Byte = ']'
+ val MARK: Byte = '('
+ val APPENDS: Byte = 'e'
+}
+
+private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
+ override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
+}
+
+/**
+ * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
+ * collects a list of pickled strings that we pass to Python through a socket.
+ */
+class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
+ extends AccumulatorParam[JList[Array[Byte]]] {
+
+ override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
+
+ override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
+ : JList[Array[Byte]] = {
+ if (serverHost == null) {
+ // This happens on the worker node, where we just want to remember all the updates
+ val1.addAll(val2)
+ val1
+ } else {
+ // This happens on the master, where we pass the updates to Python through a socket
+ val socket = new Socket(serverHost, serverPort)
+ val in = socket.getInputStream
+ val out = new DataOutputStream(socket.getOutputStream)
+ out.writeInt(val2.size)
+ for (array <- val2) {
+ out.writeInt(array.length)
+ out.write(array)
+ }
+ out.flush()
+ // Wait for a byte from the Python side as an acknowledgement
+ val byteRead = in.read()
+ if (byteRead == -1) {
+ throw new SparkException("EOF reached before Python server acknowledged")
+ }
+ socket.close()
+ null
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala
index 6055bfd045..2ffe7f741d 100644
--- a/core/src/main/scala/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/spark/broadcast/Broadcast.scala
@@ -5,7 +5,7 @@ import java.util.concurrent.atomic.AtomicLong
import spark._
-abstract class Broadcast[T](id: Long) extends Serializable {
+abstract class Broadcast[T](private[spark] val id: Long) extends Serializable {
def value: T
// We cannot have an abstract readObject here due to some weird issues with
diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
index 7eb4ddb74f..8e490e6bad 100644
--- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
@@ -11,6 +11,7 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark._
import spark.storage.StorageLevel
+import util.{MetadataCleaner, TimeStampedHashSet}
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
@@ -64,6 +65,10 @@ private object HttpBroadcast extends Logging {
private var serverUri: String = null
private var server: HttpServer = null
+ private val files = new TimeStampedHashSet[String]
+ private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup)
+
+
def initialize(isMaster: Boolean) {
synchronized {
if (!initialized) {
@@ -85,11 +90,12 @@ private object HttpBroadcast extends Logging {
server = null
}
initialized = false
+ cleaner.cancel()
}
}
private def createServer() {
- broadcastDir = Utils.createTempDir()
+ broadcastDir = Utils.createTempDir(Utils.getLocalDir)
server = new HttpServer(broadcastDir)
server.start()
serverUri = server.uri
@@ -108,6 +114,7 @@ private object HttpBroadcast extends Logging {
val serOut = ser.serializeStream(out)
serOut.writeObject(value)
serOut.close()
+ files += file.getAbsolutePath
}
def read[T](id: Long): T = {
@@ -123,4 +130,21 @@ private object HttpBroadcast extends Logging {
serIn.close()
obj
}
+
+ def cleanup(cleanupTime: Long) {
+ val iterator = files.internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ val entry = iterator.next()
+ val (file, time) = (entry.getKey, entry.getValue)
+ if (time < cleanupTime) {
+ try {
+ iterator.remove()
+ new File(file.toString).delete()
+ logInfo("Deleted broadcast file '" + file + "'")
+ } catch {
+ case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e)
+ }
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala
index 457122745b..35f40c6e91 100644
--- a/core/src/main/scala/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/spark/deploy/DeployMessage.scala
@@ -4,7 +4,6 @@ import spark.deploy.ExecutorState.ExecutorState
import spark.deploy.master.{WorkerInfo, JobInfo}
import spark.deploy.worker.ExecutorRunner
import scala.collection.immutable.List
-import scala.collection.mutable.HashMap
private[spark] sealed trait DeployMessage extends Serializable
@@ -42,7 +41,8 @@ private[spark] case class LaunchExecutor(
execId: Int,
jobDesc: JobDescription,
cores: Int,
- memory: Int)
+ memory: Int,
+ sparkHome: String)
extends DeployMessage
diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/JobDescription.scala
index 20879c5f11..7160fc05fc 100644
--- a/core/src/main/scala/spark/deploy/JobDescription.scala
+++ b/core/src/main/scala/spark/deploy/JobDescription.scala
@@ -4,7 +4,8 @@ private[spark] class JobDescription(
val name: String,
val cores: Int,
val memoryPerSlave: Int,
- val command: Command)
+ val command: Command,
+ val sparkHome: String)
extends Serializable {
val user = System.getProperty("user.name", "<unknown>")
diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala
new file mode 100644
index 0000000000..732fa08064
--- /dev/null
+++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala
@@ -0,0 +1,78 @@
+package spark.deploy
+
+import master.{JobInfo, WorkerInfo}
+import worker.ExecutorRunner
+import cc.spray.json._
+
+/**
+ * spray-json helper class containing implicit conversion to json for marshalling responses
+ */
+private[spark] object JsonProtocol extends DefaultJsonProtocol {
+ implicit object WorkerInfoJsonFormat extends RootJsonWriter[WorkerInfo] {
+ def write(obj: WorkerInfo) = JsObject(
+ "id" -> JsString(obj.id),
+ "host" -> JsString(obj.host),
+ "webuiaddress" -> JsString(obj.webUiAddress),
+ "cores" -> JsNumber(obj.cores),
+ "coresused" -> JsNumber(obj.coresUsed),
+ "memory" -> JsNumber(obj.memory),
+ "memoryused" -> JsNumber(obj.memoryUsed)
+ )
+ }
+
+ implicit object JobInfoJsonFormat extends RootJsonWriter[JobInfo] {
+ def write(obj: JobInfo) = JsObject(
+ "starttime" -> JsNumber(obj.startTime),
+ "id" -> JsString(obj.id),
+ "name" -> JsString(obj.desc.name),
+ "cores" -> JsNumber(obj.desc.cores),
+ "user" -> JsString(obj.desc.user),
+ "memoryperslave" -> JsNumber(obj.desc.memoryPerSlave),
+ "submitdate" -> JsString(obj.submitDate.toString))
+ }
+
+ implicit object JobDescriptionJsonFormat extends RootJsonWriter[JobDescription] {
+ def write(obj: JobDescription) = JsObject(
+ "name" -> JsString(obj.name),
+ "cores" -> JsNumber(obj.cores),
+ "memoryperslave" -> JsNumber(obj.memoryPerSlave),
+ "user" -> JsString(obj.user)
+ )
+ }
+
+ implicit object ExecutorRunnerJsonFormat extends RootJsonWriter[ExecutorRunner] {
+ def write(obj: ExecutorRunner) = JsObject(
+ "id" -> JsNumber(obj.execId),
+ "memory" -> JsNumber(obj.memory),
+ "jobid" -> JsString(obj.jobId),
+ "jobdesc" -> obj.jobDesc.toJson.asJsObject
+ )
+ }
+
+ implicit object MasterStateJsonFormat extends RootJsonWriter[MasterState] {
+ def write(obj: MasterState) = JsObject(
+ "url" -> JsString("spark://" + obj.uri),
+ "workers" -> JsArray(obj.workers.toList.map(_.toJson)),
+ "cores" -> JsNumber(obj.workers.map(_.cores).sum),
+ "coresused" -> JsNumber(obj.workers.map(_.coresUsed).sum),
+ "memory" -> JsNumber(obj.workers.map(_.memory).sum),
+ "memoryused" -> JsNumber(obj.workers.map(_.memoryUsed).sum),
+ "activejobs" -> JsArray(obj.activeJobs.toList.map(_.toJson)),
+ "completedjobs" -> JsArray(obj.completedJobs.toList.map(_.toJson))
+ )
+ }
+
+ implicit object WorkerStateJsonFormat extends RootJsonWriter[WorkerState] {
+ def write(obj: WorkerState) = JsObject(
+ "id" -> JsString(obj.workerId),
+ "masterurl" -> JsString(obj.masterUrl),
+ "masterwebuiurl" -> JsString(obj.masterWebUiUrl),
+ "cores" -> JsNumber(obj.cores),
+ "coresused" -> JsNumber(obj.coresUsed),
+ "memory" -> JsNumber(obj.memory),
+ "memoryused" -> JsNumber(obj.memoryUsed),
+ "executors" -> JsArray(obj.executors.toList.map(_.toJson)),
+ "finishedexecutors" -> JsArray(obj.finishedExecutors.toList.map(_.toJson))
+ )
+ }
+}
diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala
index 57a7e123b7..8764c400e2 100644
--- a/core/src/main/scala/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/spark/deploy/client/TestClient.scala
@@ -25,7 +25,7 @@ private[spark] object TestClient {
val url = args(0)
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0)
val desc = new JobDescription(
- "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()))
+ "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home")
val listener = new TestListener
val client = new Client(actorSystem, url, desc, listener)
client.start()
diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala
index b30c8e99b5..2c2cd0231b 100644
--- a/core/src/main/scala/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/spark/deploy/master/Master.scala
@@ -156,7 +156,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
if (spreadOutJobs) {
// Try to spread out each job among all the nodes, until it has all its cores
for (job <- waitingJobs if job.coresLeft > 0) {
- val usableWorkers = workers.toArray.filter(canUse(job, _)).sortBy(_.coresFree).reverse
+ val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE)
+ .filter(canUse(job, _)).sortBy(_.coresFree).reverse
val numUsable = usableWorkers.length
val assigned = new Array[Int](numUsable) // Number of cores to give on each node
var toAssign = math.min(job.coresLeft, usableWorkers.map(_.coresFree).sum)
@@ -172,7 +173,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
for (pos <- 0 until numUsable) {
if (assigned(pos) > 0) {
val exec = job.addExecutor(usableWorkers(pos), assigned(pos))
- launchExecutor(usableWorkers(pos), exec)
+ launchExecutor(usableWorkers(pos), exec, job.desc.sparkHome)
job.state = JobState.RUNNING
}
}
@@ -185,7 +186,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
val coresToUse = math.min(worker.coresFree, job.coresLeft)
if (coresToUse > 0) {
val exec = job.addExecutor(worker, coresToUse)
- launchExecutor(worker, exec)
+ launchExecutor(worker, exec, job.desc.sparkHome)
job.state = JobState.RUNNING
}
}
@@ -194,15 +195,17 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
}
- def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) {
+ def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
- worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory)
+ worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome)
exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
}
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
publicAddress: String): WorkerInfo = {
+ // There may be one or more refs to dead workers on this same node (w/ different ID's), remove them.
+ workers.filter(w => (w.host == host) && (w.state == WorkerState.DEAD)).foreach(workers -= _)
val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
workers += worker
idToWorker(worker.id) = worker
@@ -213,7 +216,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
def removeWorker(worker: WorkerInfo) {
logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port)
- workers -= worker
+ worker.setState(WorkerState.DEAD)
idToWorker -= worker.id
actorToWorker -= worker.actor
addressToWorker -= worker.actor.path.address
diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
index 3cdd3721f5..458ee2d665 100644
--- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
+++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
@@ -8,7 +8,11 @@ import akka.util.duration._
import cc.spray.Directives
import cc.spray.directives._
import cc.spray.typeconversion.TwirlSupport._
+import cc.spray.http.MediaTypes
+import cc.spray.typeconversion.SprayJsonSupport._
+
import spark.deploy._
+import spark.deploy.JsonProtocol._
private[spark]
class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives {
@@ -19,29 +23,51 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
val handler = {
get {
- path("") {
- completeWith {
+ (path("") & parameters('format ?)) {
+ case Some(js) if js.equalsIgnoreCase("json") =>
val future = master ? RequestMasterState
- future.map {
- masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState])
+ respondWithMediaType(MediaTypes.`application/json`) { ctx =>
+ ctx.complete(future.mapTo[MasterState])
+ }
+ case _ =>
+ completeWith {
+ val future = master ? RequestMasterState
+ future.map {
+ masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState])
+ }
}
- }
} ~
path("job") {
- parameter("jobId") { jobId =>
- completeWith {
+ parameters("jobId", 'format ?) {
+ case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) =>
val future = master ? RequestMasterState
- future.map { state =>
- val masterState = state.asInstanceOf[MasterState]
-
- // A bit ugly an inefficient, but we won't have a number of jobs
- // so large that it will make a significant difference.
- (masterState.activeJobs ++ masterState.completedJobs).find(_.id == jobId) match {
- case Some(job) => spark.deploy.master.html.job_details.render(job)
- case _ => null
+ val jobInfo = for (masterState <- future.mapTo[MasterState]) yield {
+ masterState.activeJobs.find(_.id == jobId) match {
+ case Some(job) => job
+ case _ => masterState.completedJobs.find(_.id == jobId) match {
+ case Some(job) => job
+ case _ => null
+ }
+ }
+ }
+ respondWithMediaType(MediaTypes.`application/json`) { ctx =>
+ ctx.complete(jobInfo.mapTo[JobInfo])
+ }
+ case (jobId, _) =>
+ completeWith {
+ val future = master ? RequestMasterState
+ future.map { state =>
+ val masterState = state.asInstanceOf[MasterState]
+
+ masterState.activeJobs.find(_.id == jobId) match {
+ case Some(job) => spark.deploy.master.html.job_details.render(job)
+ case _ => masterState.completedJobs.find(_.id == jobId) match {
+ case Some(job) => spark.deploy.master.html.job_details.render(job)
+ case _ => null
+ }
+ }
}
}
- }
}
} ~
pathPrefix("static") {
diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
index a0a698ef04..5a7f5fef8a 100644
--- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
@@ -14,7 +14,7 @@ private[spark] class WorkerInfo(
val publicAddress: String) {
var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info
-
+ var state: WorkerState.Value = WorkerState.ALIVE
var coresUsed = 0
var memoryUsed = 0
@@ -42,4 +42,8 @@ private[spark] class WorkerInfo(
def webUiAddress : String = {
"http://" + this.publicAddress + ":" + this.webUiPort
}
+
+ def setState(state: WorkerState.Value) = {
+ this.state = state
+ }
}
diff --git a/core/src/main/scala/spark/deploy/master/WorkerState.scala b/core/src/main/scala/spark/deploy/master/WorkerState.scala
new file mode 100644
index 0000000000..0bf35014c8
--- /dev/null
+++ b/core/src/main/scala/spark/deploy/master/WorkerState.scala
@@ -0,0 +1,7 @@
+package spark.deploy.master
+
+private[spark] object WorkerState extends Enumeration("ALIVE", "DEAD", "DECOMMISSIONED") {
+ type WorkerState = Value
+
+ val ALIVE, DEAD, DECOMMISSIONED = Value
+}
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
index beceb55ecd..0d1fe2a6b4 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
@@ -106,11 +106,6 @@ private[spark] class ExecutorRunner(
throw new IOException("Failed to create directory " + executorDir)
}
- // Download the files it depends on into it (disabled for now)
- //for (url <- jobDesc.fileUrls) {
- // fetchFile(url, executorDir)
- //}
-
// Launch the process
val command = buildCommandSeq()
val builder = new ProcessBuilder(command: _*).directory(executorDir)
diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala
index 7c9e588ea2..19bf2be118 100644
--- a/core/src/main/scala/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/spark/deploy/worker/Worker.scala
@@ -119,10 +119,10 @@ private[spark] class Worker(
logError("Worker registration failed: " + message)
System.exit(1)
- case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_) =>
+ case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) =>
logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name))
val manager = new ExecutorRunner(
- jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, sparkHome, workDir)
+ jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
executors(jobId + "/" + execId) = manager
manager.start()
coresUsed += cores_
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
index 340920025b..37524a7c82 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
@@ -104,9 +104,25 @@ private[spark] class WorkerArguments(args: Array[String]) {
}
def inferDefaultMemory(): Int = {
- val bean = ManagementFactory.getOperatingSystemMXBean
- .asInstanceOf[com.sun.management.OperatingSystemMXBean]
- val totalMb = (bean.getTotalPhysicalMemorySize / 1024 / 1024).toInt
+ val ibmVendor = System.getProperty("java.vendor").contains("IBM")
+ var totalMb = 0
+ try {
+ val bean = ManagementFactory.getOperatingSystemMXBean()
+ if (ibmVendor) {
+ val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean")
+ val method = beanClass.getDeclaredMethod("getTotalPhysicalMemory")
+ totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt
+ } else {
+ val beanClass = Class.forName("com.sun.management.OperatingSystemMXBean")
+ val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize")
+ totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt
+ }
+ } catch {
+ case e: Exception => {
+ totalMb = 2*1024
+ System.out.println("Failed to get total physical memory. Using " + totalMb + " MB")
+ }
+ }
// Leave out 1 GB for the operating system, but don't return a negative memory size
math.max(totalMb - 1024, 512)
}
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
index d06f4884ee..f9489d99fc 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
@@ -7,7 +7,11 @@ import akka.util.Timeout
import akka.util.duration._
import cc.spray.Directives
import cc.spray.typeconversion.TwirlSupport._
+import cc.spray.http.MediaTypes
+import cc.spray.typeconversion.SprayJsonSupport._
+
import spark.deploy.{WorkerState, RequestWorkerState}
+import spark.deploy.JsonProtocol._
private[spark]
class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives {
@@ -18,13 +22,20 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct
val handler = {
get {
- path("") {
- completeWith{
+ (path("") & parameters('format ?)) {
+ case Some(js) if js.equalsIgnoreCase("json") => {
val future = worker ? RequestWorkerState
- future.map { workerState =>
- spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState])
+ respondWithMediaType(MediaTypes.`application/json`) { ctx =>
+ ctx.complete(future.mapTo[WorkerState])
}
}
+ case _ =>
+ completeWith{
+ val future = worker ? RequestWorkerState
+ future.map { workerState =>
+ spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState])
+ }
+ }
} ~
path("log") {
parameters("jobId", "executorId", "logType") { (jobId, executorId, logType) =>
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index 2552958d27..28d9d40d43 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -159,22 +159,24 @@ private[spark] class Executor extends Logging {
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
- // Fetch missing dependencies
- for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
- currentFiles(name) = timestamp
- }
- for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
- currentJars(name) = timestamp
- // Add it to our class loader
- val localName = name.split("/").last
- val url = new File(".", localName).toURI.toURL
- if (!urlClassLoader.getURLs.contains(url)) {
- logInfo("Adding " + url + " to class loader")
- urlClassLoader.addURL(url)
+ synchronized {
+ // Fetch missing dependencies
+ for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ currentFiles(name) = timestamp
+ }
+ for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ currentJars(name) = timestamp
+ // Add it to our class loader
+ val localName = name.split("/").last
+ val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
+ if (!urlClassLoader.getURLs.contains(url)) {
+ logInfo("Adding " + url + " to class loader")
+ urlClassLoader.addURL(url)
+ }
}
}
}
diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
index 915f71ba9f..a29bf974d2 100644
--- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
@@ -24,9 +24,6 @@ private[spark] class StandaloneExecutorBackend(
with ExecutorBackend
with Logging {
- val threadPool = new ThreadPoolExecutor(
- 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
-
var master: ActorRef = null
override def preStart() {
diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
index 80262ab7b4..c193bf7c8d 100644
--- a/core/src/main/scala/spark/network/Connection.scala
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -135,8 +135,11 @@ extends Connection(SocketChannel.open, selector_) {
val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) {
messages += message // this is probably incorrect, it wont work as fifo
- if (!message.started) logDebug("Starting to send [" + message + "]")
- message.started = true
+ if (!message.started) {
+ logDebug("Starting to send [" + message + "]")
+ message.started = true
+ message.startTime = System.currentTimeMillis
+ }
return chunk
} else {
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index 642fa4b525..2ecd14f536 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -43,18 +43,17 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
val selector = SelectorProvider.provider.openSelector()
- val handleMessageExecutor = Executors.newFixedThreadPool(4)
+ val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt)
val serverChannel = ServerSocketChannel.open()
val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val messageStatuses = new HashMap[Int, MessageStatus]
- val connectionRequests = new SynchronizedQueue[SendingConnection]
+ val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
val sendMessageRequests = new Queue[(Message, SendingConnection)]
- implicit val futureExecContext = ExecutionContext.fromExecutor(
- Executors.newCachedThreadPool(DaemonThreadFactory))
-
+ implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
+
var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
serverChannel.configureBlocking(false)
@@ -79,10 +78,10 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
def run() {
try {
while(!selectorThread.isInterrupted) {
- while(!connectionRequests.isEmpty) {
- val sendingConnection = connectionRequests.dequeue
+ for( (connectionManagerId, sendingConnection) <- connectionRequests) {
sendingConnection.connect()
addConnection(sendingConnection)
+ connectionRequests -= connectionManagerId
}
sendMessageRequests.synchronized {
while(!sendMessageRequests.isEmpty) {
@@ -300,8 +299,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
- val newConnection = new SendingConnection(inetSocketAddress, selector)
- connectionRequests += newConnection
+ val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, new SendingConnection(inetSocketAddress, selector))
newConnection
}
val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
@@ -473,6 +471,7 @@ private[spark] object ConnectionManager {
val mb = size * count / 1024.0 / 1024.0
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
+ println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
println("--------------------------")
println()
}
diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala
index 47ceaf3c07..533e4610f3 100644
--- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala
+++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala
@@ -13,8 +13,14 @@ import akka.util.duration._
private[spark] object ConnectionManagerTest extends Logging{
def main(args: Array[String]) {
+ //<mesos cluster> - the master URL
+ //<slaves file> - a list slaves to run connectionTest on
+ //[num of tasks] - the number of parallel tasks to be initiated default is number of slave hosts
+ //[size of msg in MB (integer)] - the size of messages to be sent in each task, default is 10
+ //[count] - how many times to run, default is 3
+ //[await time in seconds] : await time (in seconds), default is 600
if (args.length < 2) {
- println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>")
+ println("Usage: ConnectionManagerTest <mesos cluster> <slaves file> [num of tasks] [size of msg in MB (integer)] [count] [await time in seconds)] ")
System.exit(1)
}
@@ -29,16 +35,19 @@ private[spark] object ConnectionManagerTest extends Logging{
/*println("Slaves")*/
/*slaves.foreach(println)*/
-
- val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map(
+ val tasknum = if (args.length > 2) args(2).toInt else slaves.length
+ val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024
+ val count = if (args.length > 4) args(4).toInt else 3
+ val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second
+ println("Running "+count+" rounds of test: " + "parallel tasks = " + tasknum + ", msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime)
+ val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map(
i => SparkEnv.get.connectionManager.id).collect()
println("\nSlave ConnectionManagerIds")
slaveConnManagerIds.foreach(println)
println
- val count = 10
(0 until count).foreach(i => {
- val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => {
+ val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => {
val connManager = SparkEnv.get.connectionManager
val thisConnManagerId = connManager.id
connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
@@ -46,7 +55,6 @@ private[spark] object ConnectionManagerTest extends Logging{
None
})
- val size = 100 * 1024 * 1024
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
@@ -56,13 +64,13 @@ private[spark] object ConnectionManagerTest extends Logging{
logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
})
- val results = futures.map(f => Await.result(f, 1.second))
+ val results = futures.map(f => Await.result(f, awaitTime))
val finishTime = System.currentTimeMillis
Thread.sleep(5000)
val mb = size * results.size / 1024.0 / 1024.0
val ms = finishTime - startTime
- val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
+ val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
logInfo(resultStr)
resultStr
}).collect()
diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index f98528a183..2c022f88e0 100644
--- a/core/src/main/scala/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -1,9 +1,7 @@
package spark.rdd
import scala.collection.mutable.HashMap
-
-import spark.{Dependency, RDD, SparkContext, SparkEnv, Split, TaskContext}
-
+import spark.{RDD, SparkContext, SparkEnv, Split, TaskContext}
private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
val index = idx
@@ -11,22 +9,20 @@ private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split
private[spark]
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
- extends RDD[T](sc) {
+ extends RDD[T](sc, Nil) {
- @transient
- val splits_ = (0 until blockIds.size).map(i => {
+ @transient var splits_ : Array[Split] = (0 until blockIds.size).map(i => {
new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split]
}).toArray
- @transient
- lazy val locations_ = {
+ @transient lazy val locations_ = {
val blockManager = SparkEnv.get.blockManager
/*val locations = blockIds.map(id => blockManager.getLocations(id))*/
val locations = blockManager.getLocations(blockIds)
HashMap(blockIds.zip(locations):_*)
}
- override def splits = splits_
+ override def getSplits = splits_
override def compute(split: Split, context: TaskContext): Iterator[T] = {
val blockManager = SparkEnv.get.blockManager
@@ -38,9 +34,11 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
}
}
- override def preferredLocations(split: Split) =
+ override def getPreferredLocations(split: Split) =
locations_(split.asInstanceOf[BlockRDDSplit].blockId)
- override val dependencies: List[Dependency[_]] = Nil
+ override def clearDependencies() {
+ splits_ = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
index 4a7e5f3d06..453d410ad4 100644
--- a/core/src/main/scala/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -1,37 +1,53 @@
package spark.rdd
-import spark.{NarrowDependency, RDD, SparkContext, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+import spark.{OneToOneDependency, NarrowDependency, RDD, SparkContext, Split, TaskContext}
private[spark]
-class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable {
+class CartesianSplit(
+ idx: Int,
+ @transient rdd1: RDD[_],
+ @transient rdd2: RDD[_],
+ s1Index: Int,
+ s2Index: Int
+ ) extends Split {
+ var s1 = rdd1.splits(s1Index)
+ var s2 = rdd2.splits(s2Index)
override val index: Int = idx
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ s1 = rdd1.splits(s1Index)
+ s2 = rdd2.splits(s2Index)
+ oos.defaultWriteObject()
+ }
}
private[spark]
class CartesianRDD[T: ClassManifest, U:ClassManifest](
sc: SparkContext,
- rdd1: RDD[T],
- rdd2: RDD[U])
- extends RDD[Pair[T, U]](sc)
+ var rdd1 : RDD[T],
+ var rdd2 : RDD[U])
+ extends RDD[Pair[T, U]](sc, Nil)
with Serializable {
val numSplitsInRdd2 = rdd2.splits.size
- @transient
- val splits_ = {
+ @transient var splits_ = {
// create the cross product split
val array = new Array[Split](rdd1.splits.size * rdd2.splits.size)
for (s1 <- rdd1.splits; s2 <- rdd2.splits) {
val idx = s1.index * numSplitsInRdd2 + s2.index
- array(idx) = new CartesianSplit(idx, s1, s2)
+ array(idx) = new CartesianSplit(idx, rdd1, rdd2, s1.index, s2.index)
}
array
}
- override def splits = splits_
+ override def getSplits = splits_
- override def preferredLocations(split: Split) = {
+ override def getPreferredLocations(split: Split) = {
val currSplit = split.asInstanceOf[CartesianSplit]
rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
}
@@ -42,7 +58,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
}
- override val dependencies = List(
+ var deps_ = List(
new NarrowDependency(rdd1) {
def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2)
},
@@ -50,4 +66,13 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
def getParents(id: Int): Seq[Int] = List(id % numSplitsInRdd2)
}
)
+
+ override def getDependencies = deps_
+
+ override def clearDependencies() {
+ deps_ = Nil
+ splits_ = null
+ rdd1 = null
+ rdd2 = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
new file mode 100644
index 0000000000..6f00f6ac73
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
@@ -0,0 +1,128 @@
+package spark.rdd
+
+import spark._
+import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.io.{NullWritable, BytesWritable}
+import org.apache.hadoop.util.ReflectionUtils
+import org.apache.hadoop.fs.Path
+import java.io.{File, IOException, EOFException}
+import java.text.NumberFormat
+
+private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split {
+ override val index: Int = idx
+}
+
+/**
+ * This RDD represents a RDD checkpoint file (similar to HadoopRDD).
+ */
+private[spark]
+class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String)
+ extends RDD[T](sc, Nil) {
+
+ @transient val path = new Path(checkpointPath)
+ @transient val fs = path.getFileSystem(new Configuration())
+
+ @transient val splits_ : Array[Split] = {
+ val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted
+ splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray
+ }
+
+ checkpointData = Some(new RDDCheckpointData[T](this))
+ checkpointData.get.cpFile = Some(checkpointPath)
+
+ override def getSplits = splits_
+
+ override def getPreferredLocations(split: Split): Seq[String] = {
+ val status = fs.getFileStatus(path)
+ val locations = fs.getFileBlockLocations(status, 0, status.getLen)
+ locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
+ }
+
+ override def compute(split: Split, context: TaskContext): Iterator[T] = {
+ CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile, context)
+ }
+
+ override def checkpoint() {
+ // Do nothing. Hadoop RDD should not be checkpointed.
+ }
+}
+
+private[spark] object CheckpointRDD extends Logging {
+
+ def splitIdToFileName(splitId: Int): String = {
+ val numfmt = NumberFormat.getInstance()
+ numfmt.setMinimumIntegerDigits(5)
+ numfmt.setGroupingUsed(false)
+ "part-" + numfmt.format(splitId)
+ }
+
+ def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) {
+ val outputDir = new Path(path)
+ val fs = outputDir.getFileSystem(new Configuration())
+
+ val finalOutputName = splitIdToFileName(context.splitId)
+ val finalOutputPath = new Path(outputDir, finalOutputName)
+ val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId)
+
+ if (fs.exists(tempOutputPath)) {
+ throw new IOException("Checkpoint failed: temporary path " +
+ tempOutputPath + " already exists")
+ }
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+
+ val fileOutputStream = if (blockSize < 0) {
+ fs.create(tempOutputPath, false, bufferSize)
+ } else {
+ // This is mainly for testing purpose
+ fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
+ }
+ val serializer = SparkEnv.get.serializer.newInstance()
+ val serializeStream = serializer.serializeStream(fileOutputStream)
+ serializeStream.writeAll(iterator)
+ serializeStream.close()
+
+ if (!fs.rename(tempOutputPath, finalOutputPath)) {
+ if (!fs.delete(finalOutputPath, true)) {
+ throw new IOException("Checkpoint failed: failed to delete earlier output of task "
+ + context.attemptId)
+ }
+ if (!fs.rename(tempOutputPath, finalOutputPath)) {
+ throw new IOException("Checkpoint failed: failed to save output of task: "
+ + context.attemptId)
+ }
+ }
+ }
+
+ def readFromFile[T](path: String, context: TaskContext): Iterator[T] = {
+ val inputPath = new Path(path)
+ val fs = inputPath.getFileSystem(new Configuration())
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ val fileInputStream = fs.open(inputPath, bufferSize)
+ val serializer = SparkEnv.get.serializer.newInstance()
+ val deserializeStream = serializer.deserializeStream(fileInputStream)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback(() => deserializeStream.close())
+
+ deserializeStream.asIterator.asInstanceOf[Iterator[T]]
+ }
+
+ // Test whether CheckpointRDD generate expected number of splits despite
+ // each split file having multiple blocks. This needs to be run on a
+ // cluster (mesos or standalone) using HDFS.
+ def main(args: Array[String]) {
+ import spark._
+
+ val Array(cluster, hdfsPath) = args
+ val sc = new SparkContext(cluster, "CheckpointRDD Test")
+ val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
+ val path = new Path(hdfsPath, "temp")
+ val fs = path.getFileSystem(new Configuration())
+ sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
+ val cpRDD = new CheckpointRDD[Int](sc, path.toString)
+ assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same")
+ assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same")
+ fs.delete(path)
+ }
+}
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index de0d9fad88..8fafd27bb6 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -1,14 +1,30 @@
package spark.rdd
+import java.io.{ObjectOutputStream, IOException}
+import java.util.{HashMap => JHashMap}
+import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
private[spark] sealed trait CoGroupSplitDep extends Serializable
-private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
+
+private[spark] case class NarrowCoGroupSplitDep(
+ rdd: RDD[_],
+ splitIndex: Int,
+ var split: Split
+ ) extends CoGroupSplitDep {
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ split = rdd.splits(splitIndex)
+ oos.defaultWriteObject()
+ }
+}
+
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
private[spark]
@@ -24,30 +40,29 @@ private[spark] class CoGroupAggregator
{ (b1, b2) => b1 ++ b2 })
with Serializable
-class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
- extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging {
+class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
+ extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging {
val aggr = new CoGroupAggregator
- @transient
- override val dependencies = {
+ @transient var deps_ = {
val deps = new ArrayBuffer[Dependency[_]]
for ((rdd, index) <- rdds.zipWithIndex) {
- val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
- if (mapSideCombinedRDD.partitioner == Some(part)) {
- logInfo("Adding one-to-one dependency with " + mapSideCombinedRDD)
- deps += new OneToOneDependency(mapSideCombinedRDD)
+ if (rdd.partitioner == Some(part)) {
+ logInfo("Adding one-to-one dependency with " + rdd)
+ deps += new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
+ val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
deps += new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
}
}
deps.toList
}
- @transient
- val splits_ : Array[Split] = {
- val firstRdd = rdds.head
+ override def getDependencies = deps_
+
+ @transient var splits_ : Array[Split] = {
val array = new Array[Split](part.numPartitions)
for (i <- 0 until array.size) {
array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) =>
@@ -55,28 +70,33 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
case s: ShuffleDependency[_, _] =>
new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep
case _ =>
- new NarrowCoGroupSplitDep(r, r.splits(i)): CoGroupSplitDep
+ new NarrowCoGroupSplitDep(r, i, r.splits(i)): CoGroupSplitDep
}
}.toList)
}
array
}
- override def splits = splits_
-
+ override def getSplits = splits_
+
override val partitioner = Some(part)
- override def preferredLocations(s: Split) = Nil
-
override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit]
val numRdds = split.deps.size
- val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
+ val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
- map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any]))
+ val seq = map.get(k)
+ if (seq != null) {
+ seq
+ } else {
+ val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
+ map.put(k, seq)
+ seq
+ }
}
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
- case NarrowCoGroupSplitDep(rdd, itsSplit) => {
+ case NarrowCoGroupSplitDep(rdd, itsSplitIndex, itsSplit) => {
// Read them from the parent
for ((k, v) <- rdd.iterator(itsSplit, context)) {
getSeq(k.asInstanceOf[K])(depNum) += v
@@ -93,6 +113,12 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair)
}
}
- map.iterator
+ JavaConversions.mapAsScalaMap(map).iterator
+ }
+
+ override def clearDependencies() {
+ deps_ = null
+ splits_ = null
+ rdds = null
}
}
diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
index 1affe0e0ef..167755bbba 100644
--- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -1,9 +1,22 @@
package spark.rdd
-import spark.{NarrowDependency, RDD, Split, TaskContext}
+import spark.{Dependency, OneToOneDependency, NarrowDependency, RDD, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+private[spark] case class CoalescedRDDSplit(
+ index: Int,
+ @transient rdd: RDD[_],
+ parentsIndices: Array[Int]
+ ) extends Split {
+ var parents: Seq[Split] = parentsIndices.map(rdd.splits(_))
-private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ parents = parentsIndices.map(rdd.splits(_))
+ oos.defaultWriteObject()
+ }
+}
/**
* Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of
@@ -13,34 +26,44 @@ private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) exten
* This transformation is useful when an RDD with many partitions gets filtered into a smaller one,
* or to avoid having a large number of small tasks when processing a directory with many files.
*/
-class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int)
- extends RDD[T](prev.context) {
+class CoalescedRDD[T: ClassManifest](
+ var prev: RDD[T],
+ maxPartitions: Int)
+ extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs
- @transient val splits_ : Array[Split] = {
+ @transient var splits_ : Array[Split] = {
val prevSplits = prev.splits
if (prevSplits.length < maxPartitions) {
- prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) }
+ prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) }
} else {
(0 until maxPartitions).map { i =>
val rangeStart = (i * prevSplits.length) / maxPartitions
val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions
- new CoalescedRDDSplit(i, prevSplits.slice(rangeStart, rangeEnd))
+ new CoalescedRDDSplit(i, prev, (rangeStart until rangeEnd).toArray)
}.toArray
}
}
- override def splits = splits_
+ override def getSplits = splits_
override def compute(split: Split, context: TaskContext): Iterator[T] = {
- split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap {
- parentSplit => prev.iterator(parentSplit, context)
+ split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit =>
+ firstParent[T].iterator(parentSplit, context)
}
}
- val dependencies = List(
+ var deps_ : List[Dependency[_]] = List(
new NarrowDependency(prev) {
def getParents(id: Int): Seq[Int] =
- splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index)
+ splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices
}
)
+
+ override def getDependencies() = deps_
+
+ override def clearDependencies() {
+ deps_ = Nil
+ splits_ = null
+ prev = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala
index b148da28de..6dbe235bd9 100644
--- a/core/src/main/scala/spark/rdd/FilteredRDD.scala
+++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala
@@ -2,10 +2,15 @@ package spark.rdd
import spark.{OneToOneDependency, RDD, Split, TaskContext}
+private[spark] class FilteredRDD[T: ClassManifest](
+ prev: RDD[T],
+ f: T => Boolean)
+ extends RDD[T](prev) {
-private[spark]
-class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).filter(f)
-} \ No newline at end of file
+ override def getSplits = firstParent[T].splits
+
+ override val partitioner = prev.partitioner // Since filter cannot change a partition's keys
+
+ override def compute(split: Split, context: TaskContext) =
+ firstParent[T].iterator(split, context).filter(f)
+}
diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
index 785662b2da..1b604c66e2 100644
--- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
@@ -1,16 +1,16 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{RDD, Split, TaskContext}
+
private[spark]
class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: T => TraversableOnce[U])
- extends RDD[U](prev.context) {
+ extends RDD[U](prev) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
+ override def getSplits = firstParent[T].splits
override def compute(split: Split, context: TaskContext) =
- prev.iterator(split, context).flatMap(f)
+ firstParent[T].iterator(split, context).flatMap(f)
}
diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala
index fac8ffb4cb..051bffed19 100644
--- a/core/src/main/scala/spark/rdd/GlommedRDD.scala
+++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala
@@ -1,12 +1,12 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{RDD, Split, TaskContext}
+private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T])
+ extends RDD[Array[T]](prev) {
+
+ override def getSplits = firstParent[T].splits
-private[spark]
-class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split, context: TaskContext) =
- Array(prev.iterator(split, context).toArray).iterator
-} \ No newline at end of file
+ Array(firstParent[T].iterator(split, context).toArray).iterator
+}
diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index ab163f569b..f547f53812 100644
--- a/core/src/main/scala/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -22,9 +22,8 @@ import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskCo
* A Spark split class that wraps around a Hadoop InputSplit.
*/
private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
- extends Split
- with Serializable {
-
+ extends Split {
+
val inputSplit = new SerializableWritable[InputSplit](s)
override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt
@@ -43,7 +42,7 @@ class HadoopRDD[K, V](
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int)
- extends RDD[(K, V)](sc) {
+ extends RDD[(K, V)](sc, Nil) {
// A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it
val confBroadcast = sc.broadcast(new SerializableWritable(conf))
@@ -64,7 +63,7 @@ class HadoopRDD[K, V](
.asInstanceOf[InputFormat[K, V]]
}
- override def splits = splits_
+ override def getSplits = splits_
override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[HadoopSplit]
@@ -110,11 +109,13 @@ class HadoopRDD[K, V](
}
}
- override def preferredLocations(split: Split) = {
+ override def getPreferredLocations(split: Split) = {
// TODO: Filtering out "localhost" in case of file:// URLs
val hadoopSplit = split.asInstanceOf[HadoopSplit]
hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
}
- override val dependencies: List[Dependency[_]] = Nil
+ override def checkpoint() {
+ // Do nothing. Hadoop RDD should not be checkpointed.
+ }
}
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
index c764505345..073f7d7d2a 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
@@ -1,6 +1,6 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{RDD, Split, TaskContext}
private[spark]
@@ -8,11 +8,13 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: Iterator[T] => Iterator[U],
preservesPartitioning: Boolean = false)
- extends RDD[U](prev.context) {
+ extends RDD[U](prev) {
- override val partitioner = if (preservesPartitioning) prev.partitioner else None
+ override val partitioner =
+ if (preservesPartitioning) firstParent[T].partitioner else None
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split, context: TaskContext) = f(prev.iterator(split, context))
+ override def getSplits = firstParent[T].splits
+
+ override def compute(split: Split, context: TaskContext) =
+ f(firstParent[T].iterator(split, context))
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
index 3d9888bd34..2ddc3d01b6 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
@@ -1,6 +1,7 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{RDD, Split, TaskContext}
+
/**
* A variant of the MapPartitionsRDD that passes the split index into the
@@ -11,12 +12,13 @@ private[spark]
class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: (Int, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean)
- extends RDD[U](prev.context) {
+ preservesPartitioning: Boolean
+ ) extends RDD[U](prev) {
+
+ override def getSplits = firstParent[T].splits
override val partitioner = if (preservesPartitioning) prev.partitioner else None
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
+
override def compute(split: Split, context: TaskContext) =
- f(split.index, prev.iterator(split, context))
+ f(split.index, firstParent[T].iterator(split, context))
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala
index 70fa8f4497..c6ceb272cd 100644
--- a/core/src/main/scala/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/MappedRDD.scala
@@ -1,14 +1,15 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{RDD, Split, TaskContext}
private[spark]
class MappedRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: T => U)
- extends RDD[U](prev.context) {
+ extends RDD[U](prev) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).map(f)
+ override def getSplits = firstParent[T].splits
+
+ override def compute(split: Split, context: TaskContext) =
+ firstParent[T].iterator(split, context).map(f)
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index 197ed5ea17..c3b155fcbd 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -20,11 +20,12 @@ class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit
}
class NewHadoopRDD[K, V](
- sc: SparkContext,
+ sc : SparkContext,
inputFormatClass: Class[_ <: InputFormat[K, V]],
- keyClass: Class[K], valueClass: Class[V],
+ keyClass: Class[K],
+ valueClass: Class[V],
@transient conf: Configuration)
- extends RDD[(K, V)](sc)
+ extends RDD[(K, V)](sc, Nil)
with HadoopMapReduceUtil {
// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
@@ -36,11 +37,9 @@ class NewHadoopRDD[K, V](
formatter.format(new Date())
}
- @transient
- private val jobId = new JobID(jobtrackerId, id)
+ @transient private val jobId = new JobID(jobtrackerId, id)
- @transient
- private val splits_ : Array[Split] = {
+ @transient private val splits_ : Array[Split] = {
val inputFormat = inputFormatClass.newInstance
val jobContext = newJobContext(conf, jobId)
val rawSplits = inputFormat.getSplits(jobContext).toArray
@@ -51,7 +50,7 @@ class NewHadoopRDD[K, V](
result
}
- override def splits = splits_
+ override def getSplits = splits_
override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopSplit]
@@ -86,10 +85,8 @@ class NewHadoopRDD[K, V](
}
}
- override def preferredLocations(split: Split) = {
+ override def getPreferredLocations(split: Split) = {
val theSplit = split.asInstanceOf[NewHadoopSplit]
theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
}
-
- override val dependencies: List[Dependency[_]] = Nil
}
diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index 336e193217..6631f83510 100644
--- a/core/src/main/scala/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -8,7 +8,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
-import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext}
+import spark.{RDD, SparkEnv, Split, TaskContext}
/**
@@ -16,18 +16,18 @@ import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext}
* (printing them one per line) and returns the output as a collection of strings.
*/
class PipedRDD[T: ClassManifest](
- parent: RDD[T], command: Seq[String], envVars: Map[String, String])
- extends RDD[String](parent.context) {
+ prev: RDD[T],
+ command: Seq[String],
+ envVars: Map[String, String])
+ extends RDD[String](prev) {
- def this(parent: RDD[T], command: Seq[String]) = this(parent, command, Map())
+ def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map())
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
- def this(parent: RDD[T], command: String) = this(parent, PipedRDD.tokenize(command))
+ def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command))
- override def splits = parent.splits
-
- override val dependencies = List(new OneToOneDependency(parent))
+ override def getSplits = firstParent[T].splits
override def compute(split: Split, context: TaskContext): Iterator[String] = {
val pb = new ProcessBuilder(command)
@@ -52,7 +52,7 @@ class PipedRDD[T: ClassManifest](
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
- for (elem <- parent.iterator(split, context)) {
+ for (elem <- firstParent[T].iterator(split, context)) {
out.println(elem)
}
out.close()
diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala
index 6e4797aabb..e24ad23b21 100644
--- a/core/src/main/scala/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/spark/rdd/SampledRDD.scala
@@ -1,11 +1,11 @@
package spark.rdd
import java.util.Random
+
import cern.jet.random.Poisson
import cern.jet.random.engine.DRand
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
-
+import spark.{RDD, Split, TaskContext}
private[spark]
class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable {
@@ -14,23 +14,20 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali
class SampledRDD[T: ClassManifest](
prev: RDD[T],
- withReplacement: Boolean,
+ withReplacement: Boolean,
frac: Double,
seed: Int)
- extends RDD[T](prev.context) {
+ extends RDD[T](prev) {
- @transient
- val splits_ = {
+ @transient var splits_ : Array[Split] = {
val rg = new Random(seed)
- prev.splits.map(x => new SampledRDDSplit(x, rg.nextInt))
+ firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt))
}
- override def splits = splits_.asInstanceOf[Array[Split]]
-
- override val dependencies = List(new OneToOneDependency(prev))
+ override def getSplits = splits_
- override def preferredLocations(split: Split) =
- prev.preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)
+ override def getPreferredLocations(split: Split) =
+ firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)
override def compute(splitIn: Split, context: TaskContext) = {
val split = splitIn.asInstanceOf[SampledRDDSplit]
@@ -38,7 +35,7 @@ class SampledRDD[T: ClassManifest](
// For large datasets, the expected number of occurrences of each element in a sample with
// replacement is Poisson(frac). We use that to get a count for each element.
val poisson = new Poisson(frac, new DRand(split.seed))
- prev.iterator(split.prev, context).flatMap { element =>
+ firstParent[T].iterator(split.prev, context).flatMap { element =>
val count = poisson.nextInt()
if (count == 0) {
Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
@@ -48,7 +45,11 @@ class SampledRDD[T: ClassManifest](
}
} else { // Sampling without replacement
val rand = new Random(split.seed)
- prev.iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
+ firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
}
}
+
+ override def clearDependencies() {
+ splits_ = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index f832633646..28ff19876d 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -1,7 +1,7 @@
package spark.rdd
-import spark.{OneToOneDependency, Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext}
-
+import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext}
+import spark.SparkContext._
private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
override val index = idx
@@ -10,28 +10,28 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
/**
* The resulting RDD from a shuffle (e.g. repartitioning of data).
- * @param parent the parent RDD.
+ * @param prev the parent RDD.
* @param part the partitioner used to partition the RDD
* @tparam K the key class.
* @tparam V the value class.
*/
class ShuffledRDD[K, V](
- @transient parent: RDD[(K, V)],
- part: Partitioner) extends RDD[(K, V)](parent.context) {
+ prev: RDD[(K, V)],
+ part: Partitioner)
+ extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) {
override val partitioner = Some(part)
- @transient
- val splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
-
- override def splits = splits_
-
- override def preferredLocations(split: Split) = Nil
+ @transient var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
- val dep = new ShuffleDependency(parent, part)
- override val dependencies = List(dep)
+ override def getSplits = splits_
override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = {
- SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index)
+ val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
+ SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index)
+ }
+
+ override def clearDependencies() {
+ splits_ = null
}
}
diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala
index a08473f7be..82f0a44ecd 100644
--- a/core/src/main/scala/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/spark/rdd/UnionRDD.scala
@@ -1,43 +1,46 @@
package spark.rdd
import scala.collection.mutable.ArrayBuffer
-
import spark.{Dependency, RangeDependency, RDD, SparkContext, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int)
+ extends Split {
-private[spark] class UnionSplit[T: ClassManifest](
- idx: Int,
- rdd: RDD[T],
- split: Split)
- extends Split
- with Serializable {
+ var split: Split = rdd.splits(splitIndex)
def iterator(context: TaskContext) = rdd.iterator(split, context)
+
def preferredLocations() = rdd.preferredLocations(split)
+
override val index: Int = idx
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ split = rdd.splits(splitIndex)
+ oos.defaultWriteObject()
+ }
}
class UnionRDD[T: ClassManifest](
sc: SparkContext,
- @transient rdds: Seq[RDD[T]])
- extends RDD[T](sc)
- with Serializable {
+ @transient var rdds: Seq[RDD[T]])
+ extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs
- @transient
- val splits_ : Array[Split] = {
+ @transient var splits_ : Array[Split] = {
val array = new Array[Split](rdds.map(_.splits.size).sum)
var pos = 0
for (rdd <- rdds; split <- rdd.splits) {
- array(pos) = new UnionSplit(pos, rdd, split)
+ array(pos) = new UnionSplit(pos, rdd, split.index)
pos += 1
}
array
}
- override def splits = splits_
+ override def getSplits = splits_
- @transient
- override val dependencies = {
+ @transient var deps_ = {
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for (rdd <- rdds) {
@@ -47,9 +50,17 @@ class UnionRDD[T: ClassManifest](
deps.toList
}
+ override def getDependencies = deps_
+
override def compute(s: Split, context: TaskContext): Iterator[T] =
s.asInstanceOf[UnionSplit[T]].iterator(context)
- override def preferredLocations(s: Split): Seq[String] =
+ override def getPreferredLocations(s: Split): Seq[String] =
s.asInstanceOf[UnionSplit[T]].preferredLocations()
+
+ override def clearDependencies() {
+ deps_ = null
+ splits_ = null
+ rdds = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
index 92d667ff1e..d950b06c85 100644
--- a/core/src/main/scala/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -1,53 +1,65 @@
package spark.rdd
import spark.{OneToOneDependency, RDD, SparkContext, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
idx: Int,
- rdd1: RDD[T],
- rdd2: RDD[U],
- split1: Split,
- split2: Split)
- extends Split
- with Serializable {
+ @transient rdd1: RDD[T],
+ @transient rdd2: RDD[U]
+ ) extends Split {
- def iterator(context: TaskContext): Iterator[(T, U)] =
- rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
+ var split1 = rdd1.splits(idx)
+ var split2 = rdd1.splits(idx)
+ override val index: Int = idx
- def preferredLocations(): Seq[String] =
- rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
+ def splits = (split1, split2)
- override val index: Int = idx
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ split1 = rdd1.splits(idx)
+ split2 = rdd2.splits(idx)
+ oos.defaultWriteObject()
+ }
}
class ZippedRDD[T: ClassManifest, U: ClassManifest](
sc: SparkContext,
- @transient rdd1: RDD[T],
- @transient rdd2: RDD[U])
- extends RDD[(T, U)](sc)
+ var rdd1: RDD[T],
+ var rdd2: RDD[U])
+ extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)))
with Serializable {
- @transient
- val splits_ : Array[Split] = {
+ // TODO: FIX THIS.
+
+ @transient var splits_ : Array[Split] = {
if (rdd1.splits.size != rdd2.splits.size) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
}
val array = new Array[Split](rdd1.splits.size)
for (i <- 0 until rdd1.splits.size) {
- array(i) = new ZippedSplit(i, rdd1, rdd2, rdd1.splits(i), rdd2.splits(i))
+ array(i) = new ZippedSplit(i, rdd1, rdd2)
}
array
}
- override def splits = splits_
+ override def getSplits = splits_
- @transient
- override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))
+ override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = {
+ val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
+ rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
+ }
- override def compute(s: Split, context: TaskContext): Iterator[(T, U)] =
- s.asInstanceOf[ZippedSplit[T, U]].iterator(context)
+ override def getPreferredLocations(s: Split): Seq[String] = {
+ val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
+ rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
+ }
- override def preferredLocations(s: Split): Seq[String] =
- s.asInstanceOf[ZippedSplit[T, U]].preferredLocations()
+ override def clearDependencies() {
+ splits_ = null
+ rdd1 = null
+ rdd2 = null
+ }
}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 29757b1178..b320be8863 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -14,6 +14,7 @@ import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
import spark.storage.BlockManagerMaster
import spark.storage.BlockManagerId
+import util.{MetadataCleaner, TimeStampedHashMap}
/**
* A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
@@ -61,28 +62,35 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val nextStageId = new AtomicInteger(0)
- val idToStage = new HashMap[Int, Stage]
+ val idToStage = new TimeStampedHashMap[Int, Stage]
- val shuffleToMapStage = new HashMap[Int, Stage]
+ val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
var cacheLocs = new HashMap[Int, Array[List[String]]]
val env = SparkEnv.get
- val cacheTracker = env.cacheTracker
val mapOutputTracker = env.mapOutputTracker
+ val blockManagerMaster = env.blockManager.master
- val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back;
- // that's not going to be a realistic assumption in general
+ // For tracking failed nodes, we use the MapOutputTracker's generation number, which is
+ // sent with every task. When we detect a node failing, we note the current generation number
+ // and failed host, increment it for new tasks, and use this to ignore stray ShuffleMapTask
+ // results.
+ // TODO: Garbage collect information about failure generations when we know there are no more
+ // stray messages to detect.
+ val failedGeneration = new HashMap[String, Long]
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
val running = new HashSet[Stage] // Stages we are running right now
val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
- val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage
+ val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage
var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits
val activeJobs = new HashSet[ActiveJob]
val resultStageToJob = new HashMap[Stage, ActiveJob]
+ val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
+
// Start a thread to run the DAGScheduler event loop
new Thread("DAGScheduler") {
setDaemon(true)
@@ -92,11 +100,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}.start()
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
+ if (!cacheLocs.contains(rdd.id)) {
+ val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
+ cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
+ locations => locations.map(_.ip).toList
+ }.toArray
+ }
cacheLocs(rdd.id)
}
- def updateCacheLocs() {
- cacheLocs = cacheTracker.getLocationsSnapshot()
+ def clearCacheLocs() {
+ cacheLocs.clear
}
/**
@@ -123,7 +137,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of splits is unknown
logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
- cacheTracker.registerRDD(rdd.id, rdd.splits.size)
if (shuffleDep != None) {
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
}
@@ -145,8 +158,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visited += r
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown
- logInfo("Registering parent RDD " + r.id + " (" + r.origin + ")")
- cacheTracker.registerRDD(r.id, r.splits.size)
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
@@ -247,7 +258,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val runId = nextRunId.getAndIncrement()
val finalStage = newStage(finalRDD, None, runId)
val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
- updateCacheLocs()
+ clearCacheLocs()
logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
" output partitions")
logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
@@ -290,7 +301,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// on the failed node.
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
logInfo("Resubmitting failed stages")
- updateCacheLocs()
+ clearCacheLocs()
val failed2 = failed.toArray
failed.clear()
for (stage <- failed2.sortBy(_.priority)) {
@@ -426,7 +437,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val status = event.result.asInstanceOf[MapStatus]
val host = status.address.ip
logInfo("ShuffleMapTask finished with host " + host)
- if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos
+ if (failedGeneration.contains(host) && smt.generation <= failedGeneration(host)) {
+ logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + host)
+ } else {
stage.addOutputLoc(smt.partition, status)
}
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
@@ -436,11 +449,18 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
logInfo("waiting: " + waiting)
logInfo("failed: " + failed)
if (stage.shuffleDep != None) {
+ // We supply true to increment the generation number here in case this is a
+ // recomputation of the map outputs. In that case, some nodes may have cached
+ // locations with holes (from when we detected the error) and will need the
+ // generation incremented to refetch them.
+ // TODO: Only increment the generation number if this is not the first time
+ // we registered these map outputs.
mapOutputTracker.registerMapOutputs(
stage.shuffleDep.get.shuffleId,
- stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray)
+ stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray,
+ true)
}
- updateCacheLocs()
+ clearCacheLocs()
if (stage.outputLocs.count(_ == Nil) != 0) {
// Some tasks had failed; let's resubmit this stage
// TODO: Lower-level scheduler should also deal with this
@@ -492,7 +512,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock
// TODO: mark the host as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
- handleHostLost(bmAddress.ip)
+ handleHostLost(bmAddress.ip, Some(task.generation))
}
case other =>
@@ -504,11 +524,15 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
/**
* Responds to a host being lost. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use hostLost() to post a host lost event from outside.
+ *
+ * Optionally the generation during which the failure was caught can be passed to avoid allowing
+ * stray fetch failures from possibly retriggering the detection of a node as lost.
*/
- def handleHostLost(host: String) {
- if (!deadHosts.contains(host)) {
- logInfo("Host lost: " + host)
- deadHosts += host
+ def handleHostLost(host: String, maybeGeneration: Option[Long] = None) {
+ val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration)
+ if (!failedGeneration.contains(host) || failedGeneration(host) < currentGeneration) {
+ failedGeneration(host) = currentGeneration
+ logInfo("Host lost: " + host + " (generation " + currentGeneration + ")")
env.blockManager.master.notifyADeadHost(host)
// TODO: This will be really slow if we keep accumulating shuffle map stages
for ((shuffleId, stage) <- shuffleToMapStage) {
@@ -516,8 +540,13 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
mapOutputTracker.registerMapOutputs(shuffleId, locs, true)
}
- cacheTracker.cacheLost(host)
- updateCacheLocs()
+ if (shuffleToMapStage.isEmpty) {
+ mapOutputTracker.incrementGeneration()
+ }
+ clearCacheLocs()
+ } else {
+ logDebug("Additional host lost message for " + host +
+ "(generation " + currentGeneration + ")")
}
}
@@ -594,8 +623,23 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
return Nil
}
+ def cleanup(cleanupTime: Long) {
+ var sizeBefore = idToStage.size
+ idToStage.clearOldValues(cleanupTime)
+ logInfo("idToStage " + sizeBefore + " --> " + idToStage.size)
+
+ sizeBefore = shuffleToMapStage.size
+ shuffleToMapStage.clearOldValues(cleanupTime)
+ logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size)
+
+ sizeBefore = pendingTasks.size
+ pendingTasks.clearOldValues(cleanupTime)
+ logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size)
+ }
+
def stop() {
eventQueue.put(StopDAGScheduler)
+ metadataCleaner.cancel()
taskSched.stop()
}
}
diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala
index 4532d9497f..fae643f3a8 100644
--- a/core/src/main/scala/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/spark/scheduler/MapStatus.scala
@@ -20,7 +20,7 @@ private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes:
}
def readExternal(in: ObjectInput) {
- address = new BlockManagerId(in)
+ address = BlockManagerId(in)
compressedSizes = new Array[Byte](in.readInt())
in.readFully(compressedSizes)
}
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index e492279b4e..8cd4c661eb 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -1,26 +1,112 @@
package spark.scheduler
import spark._
+import java.io._
+import util.{MetadataCleaner, TimeStampedHashMap}
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+private[spark] object ResultTask {
+
+ // A simple map between the stage id to the serialized byte array of a task.
+ // Served as a cache for task serialization because serialization can be
+ // expensive on the master node if it needs to launch thousands of tasks.
+ val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
+
+ val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues)
+
+ def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = {
+ synchronized {
+ val old = serializedInfoCache.get(stageId).orNull
+ if (old != null) {
+ return old
+ } else {
+ val out = new ByteArrayOutputStream
+ val ser = SparkEnv.get.closureSerializer.newInstance
+ val objOut = ser.serializeStream(new GZIPOutputStream(out))
+ objOut.writeObject(rdd)
+ objOut.writeObject(func)
+ objOut.close()
+ val bytes = out.toByteArray
+ serializedInfoCache.put(stageId, bytes)
+ return bytes
+ }
+ }
+ }
+
+ def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
+ synchronized {
+ val loader = Thread.currentThread.getContextClassLoader
+ val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+ val ser = SparkEnv.get.closureSerializer.newInstance
+ val objIn = ser.deserializeStream(in)
+ val rdd = objIn.readObject().asInstanceOf[RDD[_]]
+ val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
+ return (rdd, func)
+ }
+ }
+
+ def clearCache() {
+ synchronized {
+ serializedInfoCache.clear()
+ }
+ }
+}
+
private[spark] class ResultTask[T, U](
stageId: Int,
- rdd: RDD[T],
- func: (TaskContext, Iterator[T]) => U,
- val partition: Int,
+ var rdd: RDD[T],
+ var func: (TaskContext, Iterator[T]) => U,
+ var partition: Int,
@transient locs: Seq[String],
val outputId: Int)
- extends Task[U](stageId) {
+ extends Task[U](stageId) with Externalizable {
- val split = rdd.splits(partition)
+ def this() = this(0, null, null, 0, null, 0)
+
+ var split = if (rdd == null) {
+ null
+ } else {
+ rdd.splits(partition)
+ }
override def run(attemptId: Long): U = {
val context = new TaskContext(stageId, partition, attemptId)
- val result = func(context, rdd.iterator(split, context))
- context.executeOnCompleteCallbacks()
- result
+ try {
+ func(context, rdd.iterator(split, context))
+ } finally {
+ context.executeOnCompleteCallbacks()
+ }
}
override def preferredLocations: Seq[String] = locs
override def toString = "ResultTask(" + stageId + ", " + partition + ")"
+
+ override def writeExternal(out: ObjectOutput) {
+ RDDCheckpointData.synchronized {
+ split = rdd.splits(partition)
+ out.writeInt(stageId)
+ val bytes = ResultTask.serializeInfo(
+ stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ out.writeInt(partition)
+ out.writeInt(outputId)
+ out.writeObject(split)
+ }
+ }
+
+ override def readExternal(in: ObjectInput) {
+ val stageId = in.readInt()
+ val numBytes = in.readInt()
+ val bytes = new Array[Byte](numBytes)
+ in.readFully(bytes)
+ val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
+ rdd = rdd_.asInstanceOf[RDD[T]]
+ func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
+ partition = in.readInt()
+ val outputId = in.readInt()
+ split = in.readObject().asInstanceOf[Split]
+ }
}
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index bd1911fce2..19f5328eee 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -14,17 +14,20 @@ import com.ning.compress.lzf.LZFOutputStream
import spark._
import spark.storage._
+import util.{TimeStampedHashMap, MetadataCleaner}
private[spark] object ShuffleMapTask {
// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
- val serializedInfoCache = new JHashMap[Int, Array[Byte]]
+ val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
+
+ val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues)
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
synchronized {
- val old = serializedInfoCache.get(stageId)
+ val old = serializedInfoCache.get(stageId).orNull
if (old != null) {
return old
} else {
@@ -87,13 +90,16 @@ private[spark] class ShuffleMapTask(
}
override def writeExternal(out: ObjectOutput) {
- out.writeInt(stageId)
- val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
- out.writeInt(bytes.length)
- out.write(bytes)
- out.writeInt(partition)
- out.writeLong(generation)
- out.writeObject(split)
+ RDDCheckpointData.synchronized {
+ split = rdd.splits(partition)
+ out.writeInt(stageId)
+ val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ out.writeInt(partition)
+ out.writeLong(generation)
+ out.writeObject(split)
+ }
}
override def readExternal(in: ObjectInput) {
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index 20f6e65020..a639b72795 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -252,19 +252,24 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
def slaveLost(slaveId: String, reason: ExecutorLossReason) {
var failedHost: Option[String] = None
synchronized {
- val host = slaveIdToHost(slaveId)
- if (hostsAlive.contains(host)) {
- logError("Lost an executor on " + host + ": " + reason)
- slaveIdsWithExecutors -= slaveId
- hostsAlive -= host
- activeTaskSetsQueue.foreach(_.hostLost(host))
- failedHost = Some(host)
- } else {
- // We may get multiple slaveLost() calls with different loss reasons. For example, one
- // may be triggered by a dropped connection from the slave while another may be a report
- // of executor termination from Mesos. We produce log messages for both so we eventually
- // report the termination reason.
- logError("Lost an executor on " + host + " (already removed): " + reason)
+ slaveIdToHost.get(slaveId) match {
+ case Some(host) =>
+ if (hostsAlive.contains(host)) {
+ logError("Lost an executor on " + host + ": " + reason)
+ slaveIdsWithExecutors -= slaveId
+ hostsAlive -= host
+ activeTaskSetsQueue.foreach(_.hostLost(host))
+ failedHost = Some(host)
+ } else {
+ // We may get multiple slaveLost() calls with different loss reasons. For example, one
+ // may be triggered by a dropped connection from the slave while another may be a report
+ // of executor termination from Mesos. We produce log messages for both so we eventually
+ // report the termination reason.
+ logError("Lost an executor on " + host + " (already removed): " + reason)
+ }
+ case None =>
+ // We were told about a slave being lost before we could even allocate work to it
+ logError("Lost slave " + slaveId + " (no work assigned yet)")
}
}
if (failedHost != None) {
diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index e2301347e5..4f82cd96dd 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -39,7 +39,8 @@ private[spark] class SparkDeploySchedulerBackend(
StandaloneSchedulerBackend.ACTOR_NAME)
val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}")
val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
- val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command)
+ val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone"))
+ val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome)
client = new Client(sc.env.actorSystem, master, jobDesc, this)
client.start()
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index cf4aae03a7..a089b71644 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -201,7 +201,11 @@ private[spark] class TaskSetManager(
val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch
val preferred = isPreferredLocation(task, host)
- val prefStr = if (preferred) "preferred" else "non-preferred"
+ val prefStr = if (preferred) {
+ "preferred"
+ } else {
+ "non-preferred, not one of " + task.preferredLocations.mkString(", ")
+ }
logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
taskSet.id, index, taskId, slaveId, host, prefStr))
// Do various bookkeeping
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index eb20fe41b2..9ff7c02097 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -20,7 +20,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
with Logging {
var attemptId = new AtomicInteger(0)
- var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
+ var threadPool = Utils.newDaemonFixedThreadPool(threads)
val env = SparkEnv.get
var listener: TaskSchedulerListener = null
@@ -81,7 +81,10 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
logInfo("Finished task " + idInJob)
- listener.taskEnded(task, Success, resultToReturn, accumUpdates)
+
+ // If the threadpool has not already been shutdown, notify DAGScheduler
+ if (!Thread.currentThread().isInterrupted)
+ listener.taskEnded(task, Success, resultToReturn, accumUpdates)
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
@@ -91,7 +94,8 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
submitTask(task, idInJob)
} else {
// TODO: Do something nicer here to return all the way to the user
- listener.taskEnded(task, new ExceptionFailure(t), null, null)
+ if (!Thread.currentThread().isInterrupted)
+ listener.taskEnded(task, new ExceptionFailure(t), null, null)
}
}
}
@@ -108,22 +112,24 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
- // Fetch missing dependencies
- for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
- currentFiles(name) = timestamp
- }
- for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
- currentJars(name) = timestamp
- // Add it to our class loader
- val localName = name.split("/").last
- val url = new File(".", localName).toURI.toURL
- if (!classLoader.getURLs.contains(url)) {
- logInfo("Adding " + url + " to class loader")
- classLoader.addURL(url)
+ synchronized {
+ // Fetch missing dependencies
+ for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ currentFiles(name) = timestamp
+ }
+ for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ currentJars(name) = timestamp
+ // Add it to our class loader
+ val localName = name.split("/").last
+ val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
+ if (!classLoader.getURLs.contains(url)) {
+ logInfo("Adding " + url + " to class loader")
+ classLoader.addURL(url)
+ }
}
}
}
diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
index c45c7df69c..014906b028 100644
--- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
@@ -64,13 +64,9 @@ private[spark] class CoarseMesosSchedulerBackend(
val taskIdToSlaveId = new HashMap[Int, String]
val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed
- val sparkHome = sc.getSparkHome() match {
- case Some(path) =>
- path
- case None =>
- throw new SparkException("Spark home is not set; set it through the spark.home system " +
- "property, the SPARK_HOME environment variable or the SparkContext constructor")
- }
+ val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException(
+ "Spark home is not set; set it through the spark.home system " +
+ "property, the SPARK_HOME environment variable or the SparkContext constructor"))
val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt
@@ -184,7 +180,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/** Helper function to pull out a resource from a Mesos Resources protobuf */
- def getResource(res: JList[Resource], name: String): Double = {
+ private def getResource(res: JList[Resource], name: String): Double = {
for (r <- res if r.getName == name) {
return r.getScalar.getValue
}
@@ -193,7 +189,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/** Build a Mesos resource protobuf object */
- def createResource(resourceName: String, quantity: Double): Protos.Resource = {
+ private def createResource(resourceName: String, quantity: Double): Protos.Resource = {
Resource.newBuilder()
.setName(resourceName)
.setType(Value.Type.SCALAR)
@@ -202,7 +198,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/** Check whether a Mesos task state represents a finished task */
- def isFinished(state: MesosTaskState) = {
+ private def isFinished(state: MesosTaskState) = {
state == MesosTaskState.TASK_FINISHED ||
state == MesosTaskState.TASK_FAILED ||
state == MesosTaskState.TASK_KILLED ||
diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
index 8c7a1dfbc0..2989e31f5e 100644
--- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
@@ -76,13 +76,9 @@ private[spark] class MesosSchedulerBackend(
}
def createExecutorInfo(): ExecutorInfo = {
- val sparkHome = sc.getSparkHome() match {
- case Some(path) =>
- path
- case None =>
- throw new SparkException("Spark home is not set; set it through the spark.home system " +
- "property, the SPARK_HOME environment variable or the SparkContext constructor")
- }
+ val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException(
+ "Spark home is not set; set it through the spark.home system " +
+ "property, the SPARK_HOME environment variable or the SparkContext constructor"))
val execScript = new File(sparkHome, "spark-executor").getCanonicalPath
val environment = Environment.newBuilder()
sc.executorEnvs.foreach { case (key, value) =>
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index df295b1820..19cdaaa984 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -1,59 +1,39 @@
package spark.storage
-import akka.actor.{ActorSystem, Cancellable}
-import akka.dispatch.{Await, Future}
-import akka.util.Duration
-import akka.util.duration._
-
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-
-import java.io.{InputStream, OutputStream, Externalizable, ObjectInput, ObjectOutput}
-import java.nio.{MappedByteBuffer, ByteBuffer}
+import java.io.{InputStream, OutputStream}
+import java.nio.{ByteBuffer, MappedByteBuffer}
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
import scala.collection.JavaConversions._
-import spark.{CacheTracker, Logging, SizeEstimator, SparkEnv, SparkException, Utils}
-import spark.network._
-import spark.serializer.Serializer
-import spark.util.ByteBufferInputStream
-import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
-import sun.nio.ch.DirectBuffer
-
-
-private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
- def this() = this(null, 0) // For deserialization only
-
- def this(in: ObjectInput) = this(in.readUTF(), in.readInt())
+import akka.actor.{ActorSystem, Cancellable, Props}
+import akka.dispatch.{Await, Future}
+import akka.util.Duration
+import akka.util.duration._
- override def writeExternal(out: ObjectOutput) {
- out.writeUTF(ip)
- out.writeInt(port)
- }
+import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
- override def readExternal(in: ObjectInput) {
- ip = in.readUTF()
- port = in.readInt()
- }
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
- override def toString = "BlockManagerId(" + ip + ", " + port + ")"
+import spark.{Logging, SizeEstimator, SparkEnv, SparkException, Utils}
+import spark.network._
+import spark.serializer.Serializer
+import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap}
- override def hashCode = ip.hashCode * 41 + port
+import sun.nio.ch.DirectBuffer
- override def equals(that: Any) = that match {
- case id: BlockManagerId => port == id.port && ip == id.ip
- case _ => false
- }
-}
private[spark]
case class BlockException(blockId: String, message: String, ex: Exception = null)
extends Exception(message)
private[spark]
-class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
- val serializer: Serializer, maxMemory: Long)
+class BlockManager(
+ actorSystem: ActorSystem,
+ val master: BlockManagerMaster,
+ val serializer: Serializer,
+ maxMemory: Long)
extends Logging {
class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
@@ -79,7 +59,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
}
}
- private val blockInfo = new ConcurrentHashMap[String, BlockInfo](1000)
+ private val blockInfo = new TimeStampedHashMap[String, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
private[storage] val diskStore: BlockStore =
@@ -89,10 +69,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
implicit val futureExecContext = connectionManager.futureExecContext
val connectionManagerId = connectionManager.id
- val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port)
-
- // TODO: This will be removed after cacheTracker is removed from the code base.
- var cacheTracker: CacheTracker = null
+ val blockManagerId = BlockManagerId(connectionManagerId.host, connectionManagerId.port)
// Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
// for receiving shuffle outputs)
@@ -110,16 +87,20 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
val host = System.getProperty("spark.hostname", Utils.localHostName())
+ val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
+ name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
+
@volatile private var shuttingDown = false
private def heartBeat() {
- if (!master.mustHeartBeat(HeartBeat(blockManagerId))) {
+ if (!master.sendHeartBeat(blockManagerId)) {
reregister()
}
}
var heartBeatTask: Cancellable = null
+ val metadataCleaner = new MetadataCleaner("BlockManager", this.dropOldBlocks)
initialize()
/**
@@ -134,8 +115,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
* BlockManagerWorker actor.
*/
private def initialize() {
- master.mustRegisterBlockManager(
- RegisterBlockManager(blockManagerId, maxMemory))
+ master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
BlockManagerWorker.startBlockManagerWorker(this)
if (!BlockManager.getDisableHeartBeatsForTesting) {
heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) {
@@ -156,8 +136,8 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
*/
private def reportAllBlocks() {
logInfo("Reporting " + blockInfo.size + " blocks to the master.")
- for (blockId <- blockInfo.keys) {
- if (!tryToReportBlockStatus(blockId)) {
+ for ((blockId, info) <- blockInfo) {
+ if (!tryToReportBlockStatus(blockId, info)) {
logError("Failed to report " + blockId + " to master; giving up.")
return
}
@@ -171,26 +151,22 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
def reregister() {
// TODO: We might need to rate limit reregistering.
logInfo("BlockManager reregistering with master")
- master.mustRegisterBlockManager(
- RegisterBlockManager(blockManagerId, maxMemory))
+ master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
reportAllBlocks()
}
/**
* Get storage level of local block. If no info exists for the block, then returns null.
*/
- def getLevel(blockId: String): StorageLevel = {
- val info = blockInfo.get(blockId)
- if (info != null) info.level else null
- }
+ def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
/**
* Tell the master about the current storage status of a block. This will send a block update
* message reflecting the current status, *not* the desired storage level in its block info.
* For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk.
*/
- def reportBlockStatus(blockId: String) {
- val needReregister = !tryToReportBlockStatus(blockId)
+ def reportBlockStatus(blockId: String, info: BlockInfo) {
+ val needReregister = !tryToReportBlockStatus(blockId, info)
if (needReregister) {
logInfo("Got told to reregister updating block " + blockId)
// Reregistering will report our new block for free.
@@ -200,33 +176,27 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
}
/**
- * Actually send a BlockUpdate message. Returns the mater's response, which will be true if the
- * block was successfully recorded and false if the slave needs to re-register.
+ * Actually send a UpdateBlockInfo message. Returns the mater's response,
+ * which will be true if the block was successfully recorded and false if
+ * the slave needs to re-register.
*/
- private def tryToReportBlockStatus(blockId: String): Boolean = {
- val (curLevel, inMemSize, onDiskSize, tellMaster) = blockInfo.get(blockId) match {
- case null =>
- (StorageLevel.NONE, 0L, 0L, false)
- case info =>
- info.synchronized {
- info.level match {
- case null =>
- (StorageLevel.NONE, 0L, 0L, false)
- case level =>
- val inMem = level.useMemory && memoryStore.contains(blockId)
- val onDisk = level.useDisk && diskStore.contains(blockId)
- (
- new StorageLevel(onDisk, inMem, level.deserialized, level.replication),
- if (inMem) memoryStore.getSize(blockId) else 0L,
- if (onDisk) diskStore.getSize(blockId) else 0L,
- info.tellMaster
- )
- }
- }
+ private def tryToReportBlockStatus(blockId: String, info: BlockInfo): Boolean = {
+ val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized {
+ info.level match {
+ case null =>
+ (StorageLevel.NONE, 0L, 0L, false)
+ case level =>
+ val inMem = level.useMemory && memoryStore.contains(blockId)
+ val onDisk = level.useDisk && diskStore.contains(blockId)
+ val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication)
+ val memSize = if (inMem) memoryStore.getSize(blockId) else 0L
+ val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L
+ (storageLevel, memSize, diskSize, info.tellMaster)
+ }
}
if (tellMaster) {
- master.mustBlockUpdate(BlockUpdate(blockManagerId, blockId, curLevel, inMemSize, onDiskSize))
+ master.updateBlockInfo(blockManagerId, blockId, curLevel, inMemSize, onDiskSize)
} else {
true
}
@@ -238,7 +208,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
*/
def getLocations(blockId: String): Seq[String] = {
val startTimeMs = System.currentTimeMillis
- var managers = master.mustGetLocations(GetLocations(blockId))
+ var managers = master.getLocations(blockId)
val locations = managers.map(_.ip)
logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs))
return locations
@@ -249,8 +219,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
*/
def getLocations(blockIds: Array[String]): Array[Seq[String]] = {
val startTimeMs = System.currentTimeMillis
- val locations = master.mustGetLocationsMultipleBlockIds(
- GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray
+ val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray
logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
return locations
}
@@ -272,7 +241,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
}
}
- val info = blockInfo.get(blockId)
+ val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
info.waitForReady() // In case the block is still being put() by another thread
@@ -357,7 +326,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
}
}
- val info = blockInfo.get(blockId)
+ val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
info.waitForReady() // In case the block is still being put() by another thread
@@ -413,7 +382,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
}
logDebug("Getting remote block " + blockId)
// Get locations of block
- val locations = master.mustGetLocations(GetLocations(blockId))
+ val locations = master.getLocations(blockId)
// Get block from remote locations
for (loc <- locations) {
@@ -615,7 +584,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
throw new IllegalArgumentException("Storage level is null or invalid")
}
- val oldBlock = blockInfo.get(blockId)
+ val oldBlock = blockInfo.get(blockId).orNull
if (oldBlock != null) {
logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
oldBlock.waitForReady()
@@ -670,7 +639,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
// and tell the master about it.
myInfo.markReady(size)
if (tellMaster) {
- reportBlockStatus(blockId)
+ reportBlockStatus(blockId, myInfo)
}
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
@@ -690,10 +659,6 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
BlockManager.dispose(bytesAfterPut)
- // TODO: This code will be removed when CacheTracker is gone.
- if (blockId.startsWith("rdd")) {
- notifyCacheTracker(blockId)
- }
logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs))
return size
@@ -716,7 +681,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
throw new IllegalArgumentException("Storage level is null or invalid")
}
- if (blockInfo.containsKey(blockId)) {
+ if (blockInfo.contains(blockId)) {
logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
return
}
@@ -757,15 +722,10 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
// and tell the master about it.
myInfo.markReady(bytes.limit)
if (tellMaster) {
- reportBlockStatus(blockId)
+ reportBlockStatus(blockId, myInfo)
}
}
- // TODO: This code will be removed when CacheTracker is gone.
- if (blockId.startsWith("rdd")) {
- notifyCacheTracker(blockId)
- }
-
// If replication had started, then wait for it to finish
if (level.replication > 1) {
if (replicationFuture == null) {
@@ -788,10 +748,9 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
*/
var cachedPeers: Seq[BlockManagerId] = null
private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
- val tLevel: StorageLevel =
- new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
+ val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
if (cachedPeers == null) {
- cachedPeers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1))
+ cachedPeers = master.getPeers(blockManagerId, level.replication - 1)
}
for (peer: BlockManagerId <- cachedPeers) {
val start = System.nanoTime
@@ -808,16 +767,6 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
}
}
- // TODO: This code will be removed when CacheTracker is gone.
- private def notifyCacheTracker(key: String) {
- if (cacheTracker != null) {
- val rddInfo = key.split("_")
- val rddId: Int = rddInfo(1).toInt
- val partition: Int = rddInfo(2).toInt
- cacheTracker.notifyFromBlockManager(spark.AddedToCache(rddId, partition, host))
- }
- }
-
/**
* Read a block consisting of a single object.
*/
@@ -838,7 +787,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
*/
def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) {
logInfo("Dropping block " + blockId + " from memory")
- val info = blockInfo.get(blockId)
+ val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
val level = info.level
@@ -851,9 +800,12 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
diskStore.putBytes(blockId, bytes, level)
}
}
- memoryStore.remove(blockId)
+ val blockWasRemoved = memoryStore.remove(blockId)
+ if (!blockWasRemoved) {
+ logWarning("Block " + blockId + " could not be dropped from memory as it does not exist")
+ }
if (info.tellMaster) {
- reportBlockStatus(blockId)
+ reportBlockStatus(blockId, info)
}
if (!level.useDisk) {
// The block is completely gone from this node; forget it so we can put() it again later.
@@ -865,6 +817,53 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
}
}
+ /**
+ * Remove a block from both memory and disk.
+ */
+ def removeBlock(blockId: String) {
+ logInfo("Removing block " + blockId)
+ val info = blockInfo.get(blockId).orNull
+ if (info != null) info.synchronized {
+ // Removals are idempotent in disk store and memory store. At worst, we get a warning.
+ val removedFromMemory = memoryStore.remove(blockId)
+ val removedFromDisk = diskStore.remove(blockId)
+ if (!removedFromMemory && !removedFromDisk) {
+ logWarning("Block " + blockId + " could not be removed as it was not found in either " +
+ "the disk or memory store")
+ }
+ blockInfo.remove(blockId)
+ if (info.tellMaster) {
+ reportBlockStatus(blockId, info)
+ }
+ } else {
+ // The block has already been removed; do nothing.
+ logWarning("Asked to remove block " + blockId + ", which does not exist")
+ }
+ }
+
+ def dropOldBlocks(cleanupTime: Long) {
+ logInfo("Dropping blocks older than " + cleanupTime)
+ val iterator = blockInfo.internalMap.entrySet().iterator()
+ while (iterator.hasNext) {
+ val entry = iterator.next()
+ val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
+ if (time < cleanupTime) {
+ info.synchronized {
+ val level = info.level
+ if (level.useMemory) {
+ memoryStore.remove(id)
+ }
+ if (level.useDisk) {
+ diskStore.remove(id)
+ }
+ iterator.remove()
+ logInfo("Dropped block " + id)
+ }
+ reportBlockStatus(id, info)
+ }
+ }
+ }
+
def shouldCompress(blockId: String): Boolean = {
if (blockId.startsWith("shuffle_")) {
compressShuffle
@@ -914,6 +913,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
heartBeatTask.cancel()
}
connectionManager.stop()
+ master.actorSystem.stop(slaveActor)
blockInfo.clear()
memoryStore.clear()
diskStore.clear()
@@ -923,6 +923,9 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
private[spark]
object BlockManager extends Logging {
+
+ val ID_GENERATOR = new IdGenerator
+
def getMaxMemoryFromSystemProperties: Long = {
val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble
(Runtime.getRuntime.maxMemory * memoryFraction).toLong
diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala
new file mode 100644
index 0000000000..abb8b45a1f
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManagerId.scala
@@ -0,0 +1,70 @@
+package spark.storage
+
+import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
+import java.util.concurrent.ConcurrentHashMap
+
+/**
+ * This class represent an unique identifier for a BlockManager.
+ * The first 2 constructors of this class is made private to ensure that
+ * BlockManagerId objects can be created only using the factory method in
+ * [[spark.storage.BlockManager$]]. This allows de-duplication of id objects.
+ * Also, constructor parameters are private to ensure that parameters cannot
+ * be modified from outside this class.
+ */
+private[spark] class BlockManagerId private (
+ private var ip_ : String,
+ private var port_ : Int
+ ) extends Externalizable {
+
+ private def this() = this(null, 0) // For deserialization only
+
+ def ip = ip_
+
+ def port = port_
+
+ override def writeExternal(out: ObjectOutput) {
+ out.writeUTF(ip_)
+ out.writeInt(port_)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ ip_ = in.readUTF()
+ port_ = in.readInt()
+ }
+
+ @throws(classOf[IOException])
+ private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
+
+ override def toString = "BlockManagerId(" + ip + ", " + port + ")"
+
+ override def hashCode = ip.hashCode * 41 + port
+
+ override def equals(that: Any) = that match {
+ case id: BlockManagerId => port == id.port && ip == id.ip
+ case _ => false
+ }
+}
+
+
+private[spark] object BlockManagerId {
+
+ def apply(ip: String, port: Int) =
+ getCachedBlockManagerId(new BlockManagerId(ip, port))
+
+ def apply(in: ObjectInput) = {
+ val obj = new BlockManagerId()
+ obj.readExternal(in)
+ getCachedBlockManagerId(obj)
+ }
+
+ val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()
+
+ def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
+ if (blockManagerIdCache.containsKey(id)) {
+ blockManagerIdCache.get(id)
+ } else {
+ blockManagerIdCache.put(id, id)
+ id
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
index 0a4e68f437..a3d8671834 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -1,676 +1,167 @@
package spark.storage
-import java.io._
-import java.util.{HashMap => JHashMap}
-
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.mutable.ArrayBuffer
import scala.util.Random
-import akka.actor._
-import akka.dispatch._
+import akka.actor.{Actor, ActorRef, ActorSystem, Props}
+import akka.dispatch.Await
import akka.pattern.ask
-import akka.remote._
import akka.util.{Duration, Timeout}
import akka.util.duration._
import spark.{Logging, SparkException, Utils}
-private[spark]
-sealed trait ToBlockManagerMaster
-
-private[spark]
-case class RegisterBlockManager(
- blockManagerId: BlockManagerId,
- maxMemSize: Long)
- extends ToBlockManagerMaster
-
-private[spark]
-case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
-
-private[spark]
-class BlockUpdate(
- var blockManagerId: BlockManagerId,
- var blockId: String,
- var storageLevel: StorageLevel,
- var memSize: Long,
- var diskSize: Long)
- extends ToBlockManagerMaster
- with Externalizable {
-
- def this() = this(null, null, null, 0, 0) // For deserialization only
-
- override def writeExternal(out: ObjectOutput) {
- blockManagerId.writeExternal(out)
- out.writeUTF(blockId)
- storageLevel.writeExternal(out)
- out.writeInt(memSize.toInt)
- out.writeInt(diskSize.toInt)
- }
-
- override def readExternal(in: ObjectInput) {
- blockManagerId = new BlockManagerId()
- blockManagerId.readExternal(in)
- blockId = in.readUTF()
- storageLevel = new StorageLevel()
- storageLevel.readExternal(in)
- memSize = in.readInt()
- diskSize = in.readInt()
- }
-}
-
-private[spark]
-object BlockUpdate {
- def apply(blockManagerId: BlockManagerId,
- blockId: String,
- storageLevel: StorageLevel,
- memSize: Long,
- diskSize: Long): BlockUpdate = {
- new BlockUpdate(blockManagerId, blockId, storageLevel, memSize, diskSize)
- }
-
- // For pattern-matching
- def unapply(h: BlockUpdate): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = {
- Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize))
- }
-}
-
-private[spark]
-case class GetLocations(blockId: String) extends ToBlockManagerMaster
-
-private[spark]
-case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster
-
-private[spark]
-case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster
-
-private[spark]
-case class RemoveHost(host: String) extends ToBlockManagerMaster
-
-private[spark]
-case object StopBlockManagerMaster extends ToBlockManagerMaster
-
-private[spark]
-case object GetMemoryStatus extends ToBlockManagerMaster
-
-private[spark]
-case object ExpireDeadHosts extends ToBlockManagerMaster
-
-
-private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
-
- class BlockManagerInfo(
- val blockManagerId: BlockManagerId,
- timeMs: Long,
- val maxMem: Long) {
- private var _lastSeenMs = timeMs
- private var _remainingMem = maxMem
- private val _blocks = new JHashMap[String, StorageLevel]
-
- logInfo("Registering block manager %s:%d with %s RAM".format(
- blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem)))
-
- def updateLastSeenMs() {
- _lastSeenMs = System.currentTimeMillis()
- }
-
- def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long)
- : Unit = synchronized {
-
- updateLastSeenMs()
-
- if (_blocks.containsKey(blockId)) {
- // The block exists on the slave already.
- val originalLevel: StorageLevel = _blocks.get(blockId)
-
- if (originalLevel.useMemory) {
- _remainingMem += memSize
- }
- }
-
- if (storageLevel.isValid) {
- // isValid means it is either stored in-memory or on-disk.
- _blocks.put(blockId, storageLevel)
- if (storageLevel.useMemory) {
- _remainingMem -= memSize
- logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
- Utils.memoryBytesToString(_remainingMem)))
- }
- if (storageLevel.useDisk) {
- logInfo("Added %s on disk on %s:%d (size: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
- }
- } else if (_blocks.containsKey(blockId)) {
- // If isValid is not true, drop the block.
- val originalLevel: StorageLevel = _blocks.get(blockId)
- _blocks.remove(blockId)
- if (originalLevel.useMemory) {
- _remainingMem += memSize
- logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
- Utils.memoryBytesToString(_remainingMem)))
- }
- if (originalLevel.useDisk) {
- logInfo("Removed %s on %s:%d on disk (size: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
- }
- }
- }
-
- def remainingMem: Long = _remainingMem
-
- def lastSeenMs: Long = _lastSeenMs
-
- def blocks: JHashMap[String, StorageLevel] = _blocks
-
- override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem
-
- def clear() {
- _blocks.clear()
- }
- }
-
- private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerInfo]
- private val blockManagerIdByHost = new HashMap[String, BlockManagerId]
- private val blockInfo = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]]
-
- initLogging()
-
- val slaveTimeout = System.getProperty("spark.storage.blockManagerSlaveTimeoutMs",
- "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong
-
- val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs",
- "5000").toLong
-
- var timeoutCheckingTask: Cancellable = null
+private[spark] class BlockManagerMaster(
+ val actorSystem: ActorSystem,
+ isMaster: Boolean,
+ isLocal: Boolean,
+ masterIp: String,
+ masterPort: Int)
+ extends Logging {
- override def preStart() {
- if (!BlockManager.getDisableHeartBeatsForTesting) {
- timeoutCheckingTask = context.system.scheduler.schedule(
- 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts)
- }
- super.preStart()
- }
+ val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
+ val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
- def removeBlockManager(blockManagerId: BlockManagerId) {
- val info = blockManagerInfo(blockManagerId)
- blockManagerIdByHost.remove(blockManagerId.ip)
- blockManagerInfo.remove(blockManagerId)
- var iterator = info.blocks.keySet.iterator
- while (iterator.hasNext) {
- val blockId = iterator.next
- val locations = blockInfo.get(blockId)._2
- locations -= blockManagerId
- if (locations.size == 0) {
- blockInfo.remove(locations)
- }
- }
- }
-
- def expireDeadHosts() {
- logDebug("Checking for hosts with no recent heart beats in BlockManagerMaster.")
- val now = System.currentTimeMillis()
- val minSeenTime = now - slaveTimeout
- val toRemove = new HashSet[BlockManagerId]
- for (info <- blockManagerInfo.values) {
- if (info.lastSeenMs < minSeenTime) {
- logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats")
- toRemove += info.blockManagerId
- }
- }
- // TODO: Remove corresponding block infos
- toRemove.foreach(removeBlockManager)
- }
-
- def removeHost(host: String) {
- logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.")
- logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq)
- blockManagerIdByHost.get(host).foreach(removeBlockManager)
- logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq)
- sender ! true
- }
+ val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager"
+ val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager"
+ val DEFAULT_MANAGER_IP: String = Utils.localHostName()
- def heartBeat(blockManagerId: BlockManagerId) {
- if (!blockManagerInfo.contains(blockManagerId)) {
- if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
- sender ! true
- } else {
- sender ! false
- }
+ val timeout = 10.seconds
+ var masterActor: ActorRef = {
+ if (isMaster) {
+ val masterActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)),
+ name = MASTER_AKKA_ACTOR_NAME)
+ logInfo("Registered BlockManagerMaster Actor")
+ masterActor
} else {
- blockManagerInfo(blockManagerId).updateLastSeenMs()
- sender ! true
+ val url = "akka://spark@%s:%s/user/%s".format(masterIp, masterPort, MASTER_AKKA_ACTOR_NAME)
+ logInfo("Connecting to BlockManagerMaster: " + url)
+ actorSystem.actorFor(url)
}
}
- def receive = {
- case RegisterBlockManager(blockManagerId, maxMemSize) =>
- register(blockManagerId, maxMemSize)
-
- case BlockUpdate(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
- blockUpdate(blockManagerId, blockId, storageLevel, deserializedSize, size)
-
- case GetLocations(blockId) =>
- getLocations(blockId)
-
- case GetLocationsMultipleBlockIds(blockIds) =>
- getLocationsMultipleBlockIds(blockIds)
-
- case GetPeers(blockManagerId, size) =>
- getPeersDeterministic(blockManagerId, size)
- /*getPeers(blockManagerId, size)*/
-
- case GetMemoryStatus =>
- getMemoryStatus
-
- case RemoveHost(host) =>
- removeHost(host)
- sender ! true
-
- case StopBlockManagerMaster =>
- logInfo("Stopping BlockManagerMaster")
- sender ! true
- if (timeoutCheckingTask != null) {
- timeoutCheckingTask.cancel
- }
- context.stop(self)
-
- case ExpireDeadHosts =>
- expireDeadHosts()
-
- case HeartBeat(blockManagerId) =>
- heartBeat(blockManagerId)
-
- case other =>
- logInfo("Got unknown message: " + other)
+ /** Remove a dead host from the master actor. This is only called on the master side. */
+ def notifyADeadHost(host: String) {
+ tell(RemoveHost(host))
+ logInfo("Removed " + host + " successfully in notifyADeadHost")
}
- // Return a map from the block manager id to max memory and remaining memory.
- private def getMemoryStatus() {
- val res = blockManagerInfo.map { case(blockManagerId, info) =>
- (blockManagerId, (info.maxMem, info.remainingMem))
- }.toMap
- sender ! res
+ /**
+ * Send the master actor a heart beat from the slave. Returns true if everything works out,
+ * false if the master does not know about the given block manager, which means the block
+ * manager should re-register.
+ */
+ def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = {
+ askMasterWithRetry[Boolean](HeartBeat(blockManagerId))
}
- private def register(blockManagerId: BlockManagerId, maxMemSize: Long) {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " " + blockManagerId + " "
- logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
- if (blockManagerIdByHost.contains(blockManagerId.ip) &&
- blockManagerIdByHost(blockManagerId.ip) != blockManagerId) {
- val oldId = blockManagerIdByHost(blockManagerId.ip)
- logInfo("Got second registration for host " + blockManagerId +
- "; removing old slave " + oldId)
- removeBlockManager(oldId)
- }
- if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
- logInfo("Got Register Msg from master node, don't register it")
- } else {
- blockManagerInfo += (blockManagerId -> new BlockManagerInfo(
- blockManagerId, System.currentTimeMillis(), maxMemSize))
- }
- blockManagerIdByHost += (blockManagerId.ip -> blockManagerId)
- logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs))
- sender ! true
+ /** Register the BlockManager's id with the master. */
+ def registerBlockManager(
+ blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
+ logInfo("Trying to register BlockManager")
+ tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor))
+ logInfo("Registered BlockManager")
}
- private def blockUpdate(
+ def updateBlockInfo(
blockManagerId: BlockManagerId,
blockId: String,
storageLevel: StorageLevel,
memSize: Long,
- diskSize: Long) {
-
- val startTimeMs = System.currentTimeMillis()
- val tmp = " " + blockManagerId + " " + blockId + " "
-
- if (!blockManagerInfo.contains(blockManagerId)) {
- if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
- // We intentionally do not register the master (except in local mode),
- // so we should not indicate failure.
- sender ! true
- } else {
- sender ! false
- }
- return
- }
-
- if (blockId == null) {
- blockManagerInfo(blockManagerId).updateLastSeenMs()
- logDebug("Got in block update 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
- sender ! true
- return
- }
-
- blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize)
-
- var locations: HashSet[BlockManagerId] = null
- if (blockInfo.containsKey(blockId)) {
- locations = blockInfo.get(blockId)._2
- } else {
- locations = new HashSet[BlockManagerId]
- blockInfo.put(blockId, (storageLevel.replication, locations))
- }
-
- if (storageLevel.isValid) {
- locations += blockManagerId
- } else {
- locations.remove(blockManagerId)
- }
-
- if (locations.size == 0) {
- blockInfo.remove(blockId)
- }
- sender ! true
+ diskSize: Long): Boolean = {
+ val res = askMasterWithRetry[Boolean](
+ UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize))
+ logInfo("Updated info of block " + blockId)
+ res
}
- private def getLocations(blockId: String) {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " " + blockId + " "
- logDebug("Got in getLocations 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
- if (blockInfo.containsKey(blockId)) {
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- res.appendAll(blockInfo.get(blockId)._2)
- logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at "
- + Utils.getUsedTimeMs(startTimeMs))
- sender ! res.toSeq
- } else {
- logDebug("Got in getLocations 2" + tmp + Utils.getUsedTimeMs(startTimeMs))
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- sender ! res
- }
+ /** Get locations of the blockId from the master */
+ def getLocations(blockId: String): Seq[BlockManagerId] = {
+ askMasterWithRetry[Seq[BlockManagerId]](GetLocations(blockId))
}
- private def getLocationsMultipleBlockIds(blockIds: Array[String]) {
- def getLocations(blockId: String): Seq[BlockManagerId] = {
- val tmp = blockId
- logDebug("Got in getLocationsMultipleBlockIds Sub 0 " + tmp)
- if (blockInfo.containsKey(blockId)) {
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- res.appendAll(blockInfo.get(blockId)._2)
- logDebug("Got in getLocationsMultipleBlockIds Sub 1 " + tmp + " " + res.toSeq)
- return res.toSeq
- } else {
- logDebug("Got in getLocationsMultipleBlockIds Sub 2 " + tmp)
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- return res.toSeq
- }
- }
-
- logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq)
- var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]]
- for (blockId <- blockIds) {
- res.append(getLocations(blockId))
- }
- logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq + " : " + res.toSeq)
- sender ! res.toSeq
+ /** Get locations of multiple blockIds from the master */
+ def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+ askMasterWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
}
- private def getPeers(blockManagerId: BlockManagerId, size: Int) {
- var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- res.appendAll(peers)
- res -= blockManagerId
- val rand = new Random(System.currentTimeMillis())
- while (res.length > size) {
- res.remove(rand.nextInt(res.length))
+ /** Get ids of other nodes in the cluster from the master */
+ def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = {
+ val result = askMasterWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
+ if (result.length != numPeers) {
+ throw new SparkException(
+ "Error getting peers, only got " + result.size + " instead of " + numPeers)
}
- sender ! res.toSeq
+ result
}
- private def getPeersDeterministic(blockManagerId: BlockManagerId, size: Int) {
- var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
-
- val peersWithIndices = peers.zipWithIndex
- val selfIndex = peersWithIndices.find(_._1 == blockManagerId).map(_._2).getOrElse(-1)
- if (selfIndex == -1) {
- throw new Exception("Self index for " + blockManagerId + " not found")
- }
-
- var index = selfIndex
- while (res.size < size) {
- index += 1
- if (index == selfIndex) {
- throw new Exception("More peer expected than available")
- }
- res += peers(index % peers.size)
- }
- sender ! res.toSeq
+ /**
+ * Remove a block from the slaves that have it. This can only be used to remove
+ * blocks that the master knows about.
+ */
+ def removeBlock(blockId: String) {
+ askMasterWithRetry(RemoveBlock(blockId))
}
-}
-
-private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: Boolean)
- extends Logging {
-
- val AKKA_ACTOR_NAME: String = "BlockMasterManager"
- val REQUEST_RETRY_INTERVAL_MS = 100
- val DEFAULT_MASTER_IP: String = System.getProperty("spark.master.host", "localhost")
- val DEFAULT_MASTER_PORT: Int = System.getProperty("spark.master.port", "7077").toInt
- val DEFAULT_MANAGER_IP: String = Utils.localHostName()
- val timeout = 10.seconds
- var masterActor: ActorRef = null
-
- if (isMaster) {
- masterActor = actorSystem.actorOf(
- Props(new BlockManagerMasterActor(isLocal)), name = AKKA_ACTOR_NAME)
- logInfo("Registered BlockManagerMaster Actor")
- } else {
- val url = "akka://spark@%s:%s/user/%s".format(
- DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT, AKKA_ACTOR_NAME)
- logInfo("Connecting to BlockManagerMaster: " + url)
- masterActor = actorSystem.actorFor(url)
+ /**
+ * Return the memory status for each block manager, in the form of a map from
+ * the block manager's id to two long values. The first value is the maximum
+ * amount of memory allocated for the block manager, while the second is the
+ * amount of remaining memory.
+ */
+ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = {
+ askMasterWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
}
+ /** Stop the master actor, called only on the Spark master node */
def stop() {
if (masterActor != null) {
- communicate(StopBlockManagerMaster)
+ tell(StopBlockManagerMaster)
masterActor = null
logInfo("BlockManagerMaster stopped")
}
}
- // Send a message to the master actor and get its result within a default timeout, or
- // throw a SparkException if this fails.
- def askMaster(message: Any): Any = {
- try {
- val future = masterActor.ask(message)(timeout)
- return Await.result(future, timeout)
- } catch {
- case e: Exception =>
- throw new SparkException("Error communicating with BlockManagerMaster", e)
- }
- }
-
- // Send a one-way message to the master actor, to which we expect it to reply with true.
- def communicate(message: Any) {
- if (askMaster(message) != true) {
- throw new SparkException("Error reply received from BlockManagerMaster")
- }
- }
-
- def notifyADeadHost(host: String) {
- communicate(RemoveHost(host))
- logInfo("Removed " + host + " successfully in notifyADeadHost")
- }
-
- def mustRegisterBlockManager(msg: RegisterBlockManager) {
- logInfo("Trying to register BlockManager")
- while (! syncRegisterBlockManager(msg)) {
- logWarning("Failed to register " + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
- }
- logInfo("Done registering BlockManager")
- }
-
- def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = {
- //val masterActor = RemoteActor.select(node, name)
- val startTimeMs = System.currentTimeMillis()
- val tmp = " msg " + msg + " "
- logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
-
- try {
- communicate(msg)
- logInfo("BlockManager registered successfully @ syncRegisterBlockManager")
- logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
- return true
- } catch {
- case e: Exception =>
- logError("Failed in syncRegisterBlockManager", e)
- return false
- }
- }
-
- def mustHeartBeat(msg: HeartBeat): Boolean = {
- var res = syncHeartBeat(msg)
- while (!res.isDefined) {
- logWarning("Failed to send heart beat " + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+ /** Send a one-way message to the master actor, to which we expect it to reply with true. */
+ private def tell(message: Any) {
+ if (!askMasterWithRetry[Boolean](message)) {
+ throw new SparkException("BlockManagerMasterActor returned false, expected true.")
}
- return res.get
}
- def syncHeartBeat(msg: HeartBeat): Option[Boolean] = {
- try {
- val answer = askMaster(msg).asInstanceOf[Boolean]
- return Some(answer)
- } catch {
- case e: Exception =>
- logError("Failed in syncHeartBeat", e)
- return None
+ /**
+ * Send a message to the master actor and get its result within a default timeout, or
+ * throw a SparkException if this fails.
+ */
+ private def askMasterWithRetry[T](message: Any): T = {
+ // TODO: Consider removing multiple attempts
+ if (masterActor == null) {
+ throw new SparkException("Error sending message to BlockManager as masterActor is null " +
+ "[message = " + message + "]")
}
- }
-
- def mustBlockUpdate(msg: BlockUpdate): Boolean = {
- var res = syncBlockUpdate(msg)
- while (!res.isDefined) {
- logWarning("Failed to send block update " + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
- }
- return res.get
- }
-
- def syncBlockUpdate(msg: BlockUpdate): Option[Boolean] = {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " msg " + msg + " "
- logDebug("Got in syncBlockUpdate " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs))
-
- try {
- val answer = askMaster(msg).asInstanceOf[Boolean]
- logDebug("Block update sent successfully")
- logDebug("Got in synbBlockUpdate " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs))
- return Some(answer)
- } catch {
- case e: Exception =>
- logError("Failed in syncBlockUpdate", e)
- return None
- }
- }
-
- def mustGetLocations(msg: GetLocations): Seq[BlockManagerId] = {
- var res = syncGetLocations(msg)
- while (res == null) {
- logInfo("Failed to get locations " + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
- res = syncGetLocations(msg)
- }
- return res
- }
-
- def syncGetLocations(msg: GetLocations): Seq[BlockManagerId] = {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " msg " + msg + " "
- logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
-
- try {
- val answer = askMaster(msg).asInstanceOf[ArrayBuffer[BlockManagerId]]
- if (answer != null) {
- logDebug("GetLocations successful")
- logDebug("Got in syncGetLocations 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
- return answer
- } else {
- logError("Master replied null in response to GetLocations")
- return null
+ var attempts = 0
+ var lastException: Exception = null
+ while (attempts < AKKA_RETRY_ATTEMPS) {
+ attempts += 1
+ try {
+ val future = masterActor.ask(message)(timeout)
+ val result = Await.result(future, timeout)
+ if (result == null) {
+ throw new Exception("BlockManagerMaster returned null")
+ }
+ return result.asInstanceOf[T]
+ } catch {
+ case ie: InterruptedException => throw ie
+ case e: Exception =>
+ lastException = e
+ logWarning("Error sending message to BlockManagerMaster in " + attempts + " attempts", e)
}
- } catch {
- case e: Exception =>
- logError("GetLocations failed", e)
- return null
+ Thread.sleep(AKKA_RETRY_INTERVAL_MS)
}
- }
- def mustGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds):
- Seq[Seq[BlockManagerId]] = {
- var res: Seq[Seq[BlockManagerId]] = syncGetLocationsMultipleBlockIds(msg)
- while (res == null) {
- logWarning("Failed to GetLocationsMultipleBlockIds " + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
- res = syncGetLocationsMultipleBlockIds(msg)
- }
- return res
+ throw new SparkException(
+ "Error sending message to BlockManagerMaster [message = " + message + "]", lastException)
}
- def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds):
- Seq[Seq[BlockManagerId]] = {
- val startTimeMs = System.currentTimeMillis
- val tmp = " msg " + msg + " "
- logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
-
- try {
- val answer = askMaster(msg).asInstanceOf[Seq[Seq[BlockManagerId]]]
- if (answer != null) {
- logDebug("GetLocationsMultipleBlockIds successful")
- logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp +
- Utils.getUsedTimeMs(startTimeMs))
- return answer
- } else {
- logError("Master replied null in response to GetLocationsMultipleBlockIds")
- return null
- }
- } catch {
- case e: Exception =>
- logError("GetLocationsMultipleBlockIds failed", e)
- return null
- }
- }
-
- def mustGetPeers(msg: GetPeers): Seq[BlockManagerId] = {
- var res = syncGetPeers(msg)
- while ((res == null) || (res.length != msg.size)) {
- logInfo("Failed to get peers " + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
- res = syncGetPeers(msg)
- }
-
- return res
- }
-
- def syncGetPeers(msg: GetPeers): Seq[BlockManagerId] = {
- val startTimeMs = System.currentTimeMillis
- val tmp = " msg " + msg + " "
- logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
-
- try {
- val answer = askMaster(msg).asInstanceOf[Seq[BlockManagerId]]
- if (answer != null) {
- logDebug("GetPeers successful")
- logDebug("Got in syncGetPeers 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
- return answer
- } else {
- logError("Master replied null in response to GetPeers")
- return null
- }
- } catch {
- case e: Exception =>
- logError("GetPeers failed", e)
- return null
- }
- }
-
- def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = {
- askMaster(GetMemoryStatus).asInstanceOf[Map[BlockManagerId, (Long, Long)]]
- }
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
new file mode 100644
index 0000000000..f4d026da33
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
@@ -0,0 +1,401 @@
+package spark.storage
+
+import java.util.{HashMap => JHashMap}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.JavaConversions._
+import scala.util.Random
+
+import akka.actor.{Actor, ActorRef, Cancellable}
+import akka.util.{Duration, Timeout}
+import akka.util.duration._
+
+import spark.{Logging, Utils}
+
+/**
+ * BlockManagerMasterActor is an actor on the master node to track statuses of
+ * all slaves' block managers.
+ */
+private[spark]
+class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
+
+ // Mapping from block manager id to the block manager's information.
+ private val blockManagerInfo =
+ new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo]
+
+ // Mapping from host name to block manager id. We allow multiple block managers
+ // on the same host name (ip).
+ private val blockManagerIdByHost = new HashMap[String, ArrayBuffer[BlockManagerId]]
+
+ // Mapping from block id to the set of block managers that have the block.
+ private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]]
+
+ initLogging()
+
+ val slaveTimeout = System.getProperty("spark.storage.blockManagerSlaveTimeoutMs",
+ "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong
+
+ val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs",
+ "5000").toLong
+
+ var timeoutCheckingTask: Cancellable = null
+
+ override def preStart() {
+ if (!BlockManager.getDisableHeartBeatsForTesting) {
+ timeoutCheckingTask = context.system.scheduler.schedule(
+ 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts)
+ }
+ super.preStart()
+ }
+
+ def receive = {
+ case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) =>
+ register(blockManagerId, maxMemSize, slaveActor)
+
+ case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
+ updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size)
+
+ case GetLocations(blockId) =>
+ getLocations(blockId)
+
+ case GetLocationsMultipleBlockIds(blockIds) =>
+ getLocationsMultipleBlockIds(blockIds)
+
+ case GetPeers(blockManagerId, size) =>
+ getPeersDeterministic(blockManagerId, size)
+ /*getPeers(blockManagerId, size)*/
+
+ case GetMemoryStatus =>
+ getMemoryStatus
+
+ case RemoveBlock(blockId) =>
+ removeBlock(blockId)
+
+ case RemoveHost(host) =>
+ removeHost(host)
+ sender ! true
+
+ case StopBlockManagerMaster =>
+ logInfo("Stopping BlockManagerMaster")
+ sender ! true
+ if (timeoutCheckingTask != null) {
+ timeoutCheckingTask.cancel
+ }
+ context.stop(self)
+
+ case ExpireDeadHosts =>
+ expireDeadHosts()
+
+ case HeartBeat(blockManagerId) =>
+ heartBeat(blockManagerId)
+
+ case other =>
+ logInfo("Got unknown message: " + other)
+ }
+
+ def removeBlockManager(blockManagerId: BlockManagerId) {
+ val info = blockManagerInfo(blockManagerId)
+
+ // Remove the block manager from blockManagerIdByHost. If the list of block
+ // managers belonging to the IP is empty, remove the entry from the hash map.
+ blockManagerIdByHost.get(blockManagerId.ip).foreach { managers: ArrayBuffer[BlockManagerId] =>
+ managers -= blockManagerId
+ if (managers.size == 0) blockManagerIdByHost.remove(blockManagerId.ip)
+ }
+
+ // Remove it from blockManagerInfo and remove all the blocks.
+ blockManagerInfo.remove(blockManagerId)
+ var iterator = info.blocks.keySet.iterator
+ while (iterator.hasNext) {
+ val blockId = iterator.next
+ val locations = blockLocations.get(blockId)._2
+ locations -= blockManagerId
+ if (locations.size == 0) {
+ blockLocations.remove(locations)
+ }
+ }
+ }
+
+ def expireDeadHosts() {
+ logDebug("Checking for hosts with no recent heart beats in BlockManagerMaster.")
+ val now = System.currentTimeMillis()
+ val minSeenTime = now - slaveTimeout
+ val toRemove = new HashSet[BlockManagerId]
+ for (info <- blockManagerInfo.values) {
+ if (info.lastSeenMs < minSeenTime) {
+ logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats")
+ toRemove += info.blockManagerId
+ }
+ }
+ toRemove.foreach(removeBlockManager)
+ }
+
+ def removeHost(host: String) {
+ logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.")
+ logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq)
+ blockManagerIdByHost.get(host).foreach(_.foreach(removeBlockManager))
+ logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq)
+ sender ! true
+ }
+
+ def heartBeat(blockManagerId: BlockManagerId) {
+ if (!blockManagerInfo.contains(blockManagerId)) {
+ if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
+ sender ! true
+ } else {
+ sender ! false
+ }
+ } else {
+ blockManagerInfo(blockManagerId).updateLastSeenMs()
+ sender ! true
+ }
+ }
+
+ // Remove a block from the slaves that have it. This can only be used to remove
+ // blocks that the master knows about.
+ private def removeBlock(blockId: String) {
+ val block = blockLocations.get(blockId)
+ if (block != null) {
+ block._2.foreach { blockManagerId: BlockManagerId =>
+ val blockManager = blockManagerInfo.get(blockManagerId)
+ if (blockManager.isDefined) {
+ // Remove the block from the slave's BlockManager.
+ // Doesn't actually wait for a confirmation and the message might get lost.
+ // If message loss becomes frequent, we should add retry logic here.
+ blockManager.get.slaveActor ! RemoveBlock(blockId)
+ }
+ }
+ }
+ sender ! true
+ }
+
+ // Return a map from the block manager id to max memory and remaining memory.
+ private def getMemoryStatus() {
+ val res = blockManagerInfo.map { case(blockManagerId, info) =>
+ (blockManagerId, (info.maxMem, info.remainingMem))
+ }.toMap
+ sender ! res
+ }
+
+ private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " " + blockManagerId + " "
+
+ if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
+ logInfo("Got Register Msg from master node, don't register it")
+ } else {
+ blockManagerIdByHost.get(blockManagerId.ip) match {
+ case Some(managers) =>
+ // A block manager of the same host name already exists.
+ logInfo("Got another registration for host " + blockManagerId)
+ managers += blockManagerId
+ case None =>
+ blockManagerIdByHost += (blockManagerId.ip -> ArrayBuffer(blockManagerId))
+ }
+
+ blockManagerInfo += (blockManagerId -> new BlockManagerMasterActor.BlockManagerInfo(
+ blockManagerId, System.currentTimeMillis(), maxMemSize, slaveActor))
+ }
+ sender ! true
+ }
+
+ private def updateBlockInfo(
+ blockManagerId: BlockManagerId,
+ blockId: String,
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long) {
+
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " " + blockManagerId + " " + blockId + " "
+
+ if (!blockManagerInfo.contains(blockManagerId)) {
+ if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
+ // We intentionally do not register the master (except in local mode),
+ // so we should not indicate failure.
+ sender ! true
+ } else {
+ sender ! false
+ }
+ return
+ }
+
+ if (blockId == null) {
+ blockManagerInfo(blockManagerId).updateLastSeenMs()
+ sender ! true
+ return
+ }
+
+ blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize)
+
+ var locations: HashSet[BlockManagerId] = null
+ if (blockLocations.containsKey(blockId)) {
+ locations = blockLocations.get(blockId)._2
+ } else {
+ locations = new HashSet[BlockManagerId]
+ blockLocations.put(blockId, (storageLevel.replication, locations))
+ }
+
+ if (storageLevel.isValid) {
+ locations.add(blockManagerId)
+ } else {
+ locations.remove(blockManagerId)
+ }
+
+ // Remove the block from master tracking if it has been removed on all slaves.
+ if (locations.size == 0) {
+ blockLocations.remove(blockId)
+ }
+ sender ! true
+ }
+
+ private def getLocations(blockId: String) {
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " " + blockId + " "
+ if (blockLocations.containsKey(blockId)) {
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ res.appendAll(blockLocations.get(blockId)._2)
+ sender ! res.toSeq
+ } else {
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ sender ! res
+ }
+ }
+
+ private def getLocationsMultipleBlockIds(blockIds: Array[String]) {
+ def getLocations(blockId: String): Seq[BlockManagerId] = {
+ val tmp = blockId
+ if (blockLocations.containsKey(blockId)) {
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ res.appendAll(blockLocations.get(blockId)._2)
+ return res.toSeq
+ } else {
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ return res.toSeq
+ }
+ }
+
+ var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]]
+ for (blockId <- blockIds) {
+ res.append(getLocations(blockId))
+ }
+ sender ! res.toSeq
+ }
+
+ private def getPeers(blockManagerId: BlockManagerId, size: Int) {
+ var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ res.appendAll(peers)
+ res -= blockManagerId
+ val rand = new Random(System.currentTimeMillis())
+ while (res.length > size) {
+ res.remove(rand.nextInt(res.length))
+ }
+ sender ! res.toSeq
+ }
+
+ private def getPeersDeterministic(blockManagerId: BlockManagerId, size: Int) {
+ var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+
+ val selfIndex = peers.indexOf(blockManagerId)
+ if (selfIndex == -1) {
+ throw new Exception("Self index for " + blockManagerId + " not found")
+ }
+
+ // Note that this logic will select the same node multiple times if there aren't enough peers
+ var index = selfIndex
+ while (res.size < size) {
+ index += 1
+ if (index == selfIndex) {
+ throw new Exception("More peer expected than available")
+ }
+ res += peers(index % peers.size)
+ }
+ sender ! res.toSeq
+ }
+}
+
+
+private[spark]
+object BlockManagerMasterActor {
+
+ case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long)
+
+ class BlockManagerInfo(
+ val blockManagerId: BlockManagerId,
+ timeMs: Long,
+ val maxMem: Long,
+ val slaveActor: ActorRef)
+ extends Logging {
+
+ private var _lastSeenMs: Long = timeMs
+ private var _remainingMem: Long = maxMem
+
+ // Mapping from block id to its status.
+ private val _blocks = new JHashMap[String, BlockStatus]
+
+ logInfo("Registering block manager %s:%d with %s RAM".format(
+ blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem)))
+
+ def updateLastSeenMs() {
+ _lastSeenMs = System.currentTimeMillis()
+ }
+
+ def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long)
+ : Unit = synchronized {
+
+ updateLastSeenMs()
+
+ if (_blocks.containsKey(blockId)) {
+ // The block exists on the slave already.
+ val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel
+
+ if (originalLevel.useMemory) {
+ _remainingMem += memSize
+ }
+ }
+
+ if (storageLevel.isValid) {
+ // isValid means it is either stored in-memory or on-disk.
+ _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize))
+ if (storageLevel.useMemory) {
+ _remainingMem -= memSize
+ logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format(
+ blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
+ Utils.memoryBytesToString(_remainingMem)))
+ }
+ if (storageLevel.useDisk) {
+ logInfo("Added %s on disk on %s:%d (size: %s)".format(
+ blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
+ }
+ } else if (_blocks.containsKey(blockId)) {
+ // If isValid is not true, drop the block.
+ val blockStatus: BlockStatus = _blocks.get(blockId)
+ _blocks.remove(blockId)
+ if (blockStatus.storageLevel.useMemory) {
+ _remainingMem += blockStatus.memSize
+ logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format(
+ blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
+ Utils.memoryBytesToString(_remainingMem)))
+ }
+ if (blockStatus.storageLevel.useDisk) {
+ logInfo("Removed %s on %s:%d on disk (size: %s)".format(
+ blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
+ }
+ }
+ }
+
+ def remainingMem: Long = _remainingMem
+
+ def lastSeenMs: Long = _lastSeenMs
+
+ def blocks: JHashMap[String, BlockStatus] = _blocks
+
+ override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem
+
+ def clear() {
+ _blocks.clear()
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala
new file mode 100644
index 0000000000..30483b0b37
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala
@@ -0,0 +1,100 @@
+package spark.storage
+
+import java.io.{Externalizable, ObjectInput, ObjectOutput}
+
+import akka.actor.ActorRef
+
+
+//////////////////////////////////////////////////////////////////////////////////
+// Messages from the master to slaves.
+//////////////////////////////////////////////////////////////////////////////////
+private[spark]
+sealed trait ToBlockManagerSlave
+
+// Remove a block from the slaves that have it. This can only be used to remove
+// blocks that the master knows about.
+private[spark]
+case class RemoveBlock(blockId: String) extends ToBlockManagerSlave
+
+
+//////////////////////////////////////////////////////////////////////////////////
+// Messages from slaves to the master.
+//////////////////////////////////////////////////////////////////////////////////
+private[spark]
+sealed trait ToBlockManagerMaster
+
+private[spark]
+case class RegisterBlockManager(
+ blockManagerId: BlockManagerId,
+ maxMemSize: Long,
+ sender: ActorRef)
+ extends ToBlockManagerMaster
+
+private[spark]
+case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
+
+private[spark]
+class UpdateBlockInfo(
+ var blockManagerId: BlockManagerId,
+ var blockId: String,
+ var storageLevel: StorageLevel,
+ var memSize: Long,
+ var diskSize: Long)
+ extends ToBlockManagerMaster
+ with Externalizable {
+
+ def this() = this(null, null, null, 0, 0) // For deserialization only
+
+ override def writeExternal(out: ObjectOutput) {
+ blockManagerId.writeExternal(out)
+ out.writeUTF(blockId)
+ storageLevel.writeExternal(out)
+ out.writeInt(memSize.toInt)
+ out.writeInt(diskSize.toInt)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ blockManagerId = BlockManagerId(in)
+ blockId = in.readUTF()
+ storageLevel = StorageLevel(in)
+ memSize = in.readInt()
+ diskSize = in.readInt()
+ }
+}
+
+private[spark]
+object UpdateBlockInfo {
+ def apply(blockManagerId: BlockManagerId,
+ blockId: String,
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long): UpdateBlockInfo = {
+ new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)
+ }
+
+ // For pattern-matching
+ def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = {
+ Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize))
+ }
+}
+
+private[spark]
+case class GetLocations(blockId: String) extends ToBlockManagerMaster
+
+private[spark]
+case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster
+
+private[spark]
+case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster
+
+private[spark]
+case class RemoveHost(host: String) extends ToBlockManagerMaster
+
+private[spark]
+case object StopBlockManagerMaster extends ToBlockManagerMaster
+
+private[spark]
+case object GetMemoryStatus extends ToBlockManagerMaster
+
+private[spark]
+case object ExpireDeadHosts extends ToBlockManagerMaster
diff --git a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala
new file mode 100644
index 0000000000..f570cdc52d
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala
@@ -0,0 +1,16 @@
+package spark.storage
+
+import akka.actor.Actor
+
+import spark.{Logging, SparkException, Utils}
+
+
+/**
+ * An actor to take commands from the master to execute options. For example,
+ * this is used to remove blocks from the slave's BlockManager.
+ */
+class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
+ override def receive = {
+ case RemoveBlock(blockId) => blockManager.removeBlock(blockId)
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala
index 3f234df654..30d7500e01 100644
--- a/core/src/main/scala/spark/storage/BlockMessage.scala
+++ b/core/src/main/scala/spark/storage/BlockMessage.scala
@@ -64,7 +64,7 @@ private[spark] class BlockMessage() {
val booleanInt = buffer.getInt()
val replication = buffer.getInt()
- level = new StorageLevel(booleanInt, replication)
+ level = StorageLevel(booleanInt, replication)
val dataLength = buffer.getInt()
data = ByteBuffer.allocate(dataLength)
diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala
index 096bf8bdd9..8188d3595e 100644
--- a/core/src/main/scala/spark/storage/BlockStore.scala
+++ b/core/src/main/scala/spark/storage/BlockStore.scala
@@ -31,7 +31,12 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging {
def getValues(blockId: String): Option[Iterator[Any]]
- def remove(blockId: String)
+ /**
+ * Remove a block, if it exists.
+ * @param blockId the block to remove.
+ * @return True if the block was found and removed, False otherwise.
+ */
+ def remove(blockId: String): Boolean
def contains(blockId: String): Boolean
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index b5561479db..7e5b820cbb 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -92,10 +92,13 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
}
- override def remove(blockId: String) {
+ override def remove(blockId: String): Boolean = {
val file = getFile(blockId)
if (file.exists()) {
file.delete()
+ true
+ } else {
+ false
}
}
diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala
index 02098b82fe..ae88ff0bb1 100644
--- a/core/src/main/scala/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/spark/storage/MemoryStore.scala
@@ -17,7 +17,6 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true)
private var currentMemory = 0L
-
// Object used to ensure that only one thread is putting blocks and if necessary, dropping
// blocks from the memory store.
private val putLock = new Object()
@@ -90,7 +89,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def remove(blockId: String) {
+ override def remove(blockId: String): Boolean = {
entries.synchronized {
val entry = entries.get(blockId)
if (entry != null) {
@@ -98,8 +97,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
currentMemory -= entry.size
logInfo("Block %s of size %d dropped from memory (free %d)".format(
blockId, entry.size, freeMemory))
+ true
} else {
- logWarning("Block " + blockId + " could not be removed as it does not exist")
+ false
}
}
}
diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala
index c497f03e0c..d1d1c61c1c 100644
--- a/core/src/main/scala/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/spark/storage/StorageLevel.scala
@@ -1,53 +1,60 @@
package spark.storage
-import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
/**
* Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory,
* whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory
* in a serialized format, and whether to replicate the RDD partitions on multiple nodes.
* The [[spark.storage.StorageLevel$]] singleton object contains some static constants for
- * commonly useful storage levels.
+ * commonly useful storage levels. To create your own storage level object, use the factor method
+ * of the singleton object (`StorageLevel(...)`).
*/
-class StorageLevel(
- var useDisk: Boolean,
- var useMemory: Boolean,
- var deserialized: Boolean,
- var replication: Int = 1)
+class StorageLevel private(
+ private var useDisk_ : Boolean,
+ private var useMemory_ : Boolean,
+ private var deserialized_ : Boolean,
+ private var replication_ : Int = 1)
extends Externalizable {
// TODO: Also add fields for caching priority, dataset ID, and flushing.
-
- def this(flags: Int, replication: Int) {
+ private def this(flags: Int, replication: Int) {
this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication)
}
def this() = this(false, true, false) // For deserialization
+ def useDisk = useDisk_
+ def useMemory = useMemory_
+ def deserialized = deserialized_
+ def replication = replication_
+
+ assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes")
+
override def clone(): StorageLevel = new StorageLevel(
this.useDisk, this.useMemory, this.deserialized, this.replication)
override def equals(other: Any): Boolean = other match {
case s: StorageLevel =>
- s.useDisk == useDisk &&
+ s.useDisk == useDisk &&
s.useMemory == useMemory &&
s.deserialized == deserialized &&
- s.replication == replication
+ s.replication == replication
case _ =>
false
}
-
+
def isValid = ((useMemory || useDisk) && (replication > 0))
def toInt: Int = {
var ret = 0
- if (useDisk) {
+ if (useDisk_) {
ret |= 4
}
- if (useMemory) {
+ if (useMemory_) {
ret |= 2
}
- if (deserialized) {
+ if (deserialized_) {
ret |= 1
}
return ret
@@ -55,21 +62,27 @@ class StorageLevel(
override def writeExternal(out: ObjectOutput) {
out.writeByte(toInt)
- out.writeByte(replication)
+ out.writeByte(replication_)
}
override def readExternal(in: ObjectInput) {
val flags = in.readByte()
- useDisk = (flags & 4) != 0
- useMemory = (flags & 2) != 0
- deserialized = (flags & 1) != 0
- replication = in.readByte()
+ useDisk_ = (flags & 4) != 0
+ useMemory_ = (flags & 2) != 0
+ deserialized_ = (flags & 1) != 0
+ replication_ = in.readByte()
}
+ @throws(classOf[IOException])
+ private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this)
+
override def toString: String =
"StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
+
+ override def hashCode(): Int = toInt * 41 + replication
}
+
object StorageLevel {
val NONE = new StorageLevel(false, false, false)
val DISK_ONLY = new StorageLevel(true, false, false)
@@ -82,4 +95,31 @@ object StorageLevel {
val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2)
val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false)
val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2)
+
+ /** Create a new StorageLevel object */
+ def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, replication: Int = 1) =
+ getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication))
+
+ /** Create a new StorageLevel object from its integer representation */
+ def apply(flags: Int, replication: Int) =
+ getCachedStorageLevel(new StorageLevel(flags, replication))
+
+ /** Read StorageLevel object from ObjectInput stream */
+ def apply(in: ObjectInput) = {
+ val obj = new StorageLevel()
+ obj.readExternal(in)
+ getCachedStorageLevel(obj)
+ }
+
+ private[spark]
+ val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]()
+
+ private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = {
+ if (storageLevelCache.containsKey(level)) {
+ storageLevelCache.get(level)
+ } else {
+ storageLevelCache.put(level, level)
+ level
+ }
+ }
}
diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala
index 5bb5a29cc4..689f07b969 100644
--- a/core/src/main/scala/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/spark/storage/ThreadingTest.scala
@@ -58,8 +58,10 @@ private[spark] object ThreadingTest {
val startTime = System.currentTimeMillis()
manager.get(blockId) match {
case Some(retrievedBlock) =>
- assert(retrievedBlock.toList.asInstanceOf[List[Int]] == block.toList, "Block " + blockId + " did not match")
- println("Got block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms")
+ assert(retrievedBlock.toList.asInstanceOf[List[Int]] == block.toList,
+ "Block " + blockId + " did not match")
+ println("Got block " + blockId + " in " +
+ (System.currentTimeMillis - startTime) + " ms")
case None =>
assert(false, "Block " + blockId + " could not be retrieved")
}
@@ -73,7 +75,9 @@ private[spark] object ThreadingTest {
System.setProperty("spark.kryoserializer.buffer.mb", "1")
val actorSystem = ActorSystem("test")
val serializer = new KryoSerializer
- val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true)
+ val masterIp: String = System.getProperty("spark.master.host", "localhost")
+ val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt
+ val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, masterIp, masterPort)
val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer, 1024 * 1024)
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
@@ -86,6 +90,7 @@ private[spark] object ThreadingTest {
actorSystem.shutdown()
actorSystem.awaitTermination()
println("Everything stopped.")
- println("It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.")
+ println(
+ "It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.")
}
}
diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala
index e67cb0336d..fbd0ff46bf 100644
--- a/core/src/main/scala/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/spark/util/AkkaUtils.scala
@@ -32,6 +32,7 @@ private[spark] object AkkaUtils {
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
akka.actor.provider = "akka.remote.RemoteActorRefProvider"
akka.remote.transport = "akka.remote.netty.NettyRemoteTransport"
+ akka.remote.log-remote-lifecycle-events = on
akka.remote.netty.hostname = "%s"
akka.remote.netty.port = %d
akka.remote.netty.connection-timeout = %ds
diff --git a/core/src/main/scala/spark/util/IdGenerator.scala b/core/src/main/scala/spark/util/IdGenerator.scala
new file mode 100644
index 0000000000..b6e309fe1a
--- /dev/null
+++ b/core/src/main/scala/spark/util/IdGenerator.scala
@@ -0,0 +1,14 @@
+package spark.util
+
+import java.util.concurrent.atomic.AtomicInteger
+
+/**
+ * A util used to get a unique generation ID. This is a wrapper around Java's
+ * AtomicInteger. An example usage is in BlockManager, where each BlockManager
+ * instance would start an Akka actor and we use this utility to assign the Akka
+ * actors unique names.
+ */
+private[spark] class IdGenerator {
+ private var id = new AtomicInteger
+ def next: Int = id.incrementAndGet
+}
diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala
new file mode 100644
index 0000000000..139e21d09e
--- /dev/null
+++ b/core/src/main/scala/spark/util/MetadataCleaner.scala
@@ -0,0 +1,44 @@
+package spark.util
+
+import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors}
+import java.util.{TimerTask, Timer}
+import spark.Logging
+
+
+class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging {
+
+ val delaySeconds = MetadataCleaner.getDelaySeconds
+ val periodSeconds = math.max(10, delaySeconds / 10)
+ val timer = new Timer(name + " cleanup timer", true)
+
+ val task = new TimerTask {
+ def run() {
+ try {
+ if (delaySeconds > 0) {
+ cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000))
+ logInfo("Ran metadata cleaner for " + name)
+ }
+ } catch {
+ case e: Exception => logError("Error running cleanup task for " + name, e)
+ }
+ }
+ }
+
+ if (periodSeconds > 0) {
+ logInfo(
+ "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and "
+ + "period of " + periodSeconds + " secs")
+ timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000)
+ }
+
+ def cancel() {
+ timer.cancel()
+ }
+}
+
+
+object MetadataCleaner {
+ def getDelaySeconds = (System.getProperty("spark.cleaner.delay", "-100").toDouble * 60).toInt
+ def setDelaySeconds(delay: Long) { System.setProperty("spark.cleaner.delay", delay.toString) }
+}
+
diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala
new file mode 100644
index 0000000000..e3f00ea8c7
--- /dev/null
+++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala
@@ -0,0 +1,62 @@
+package spark.util
+
+import scala.annotation.tailrec
+
+import java.io.OutputStream
+import java.util.concurrent.TimeUnit._
+
+class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream {
+ val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS)
+ val CHUNK_SIZE = 8192
+ var lastSyncTime = System.nanoTime
+ var bytesWrittenSinceSync: Long = 0
+
+ override def write(b: Int) {
+ waitToWrite(1)
+ out.write(b)
+ }
+
+ override def write(bytes: Array[Byte]) {
+ write(bytes, 0, bytes.length)
+ }
+
+ @tailrec
+ override final def write(bytes: Array[Byte], offset: Int, length: Int) {
+ val writeSize = math.min(length - offset, CHUNK_SIZE)
+ if (writeSize > 0) {
+ waitToWrite(writeSize)
+ out.write(bytes, offset, writeSize)
+ write(bytes, offset + writeSize, length)
+ }
+ }
+
+ override def flush() {
+ out.flush()
+ }
+
+ override def close() {
+ out.close()
+ }
+
+ @tailrec
+ private def waitToWrite(numBytes: Int) {
+ val now = System.nanoTime
+ val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS)
+ val rate = bytesWrittenSinceSync.toDouble / elapsedSecs
+ if (rate < bytesPerSec) {
+ // It's okay to write; just update some variables and return
+ bytesWrittenSinceSync += numBytes
+ if (now > lastSyncTime + SYNC_INTERVAL) {
+ // Sync interval has passed; let's resync
+ lastSyncTime = now
+ bytesWrittenSinceSync = numBytes
+ }
+ } else {
+ // Calculate how much time we should sleep to bring ourselves to the desired rate.
+ // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala)
+ val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS)
+ if (sleepTime > 0) Thread.sleep(sleepTime)
+ waitToWrite(numBytes)
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala
new file mode 100644
index 0000000000..bb7c5c01c8
--- /dev/null
+++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala
@@ -0,0 +1,93 @@
+package spark.util
+
+import java.util.concurrent.ConcurrentHashMap
+import scala.collection.JavaConversions
+import scala.collection.mutable.Map
+
+/**
+ * This is a custom implementation of scala.collection.mutable.Map which stores the insertion
+ * time stamp along with each key-value pair. Key-value pairs that are older than a particular
+ * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in
+ * replacement of scala.collection.mutable.HashMap.
+ */
+class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging {
+ val internalMap = new ConcurrentHashMap[A, (B, Long)]()
+
+ def get(key: A): Option[B] = {
+ val value = internalMap.get(key)
+ if (value != null) Some(value._1) else None
+ }
+
+ def iterator: Iterator[(A, B)] = {
+ val jIterator = internalMap.entrySet().iterator()
+ JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1))
+ }
+
+ override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = {
+ val newMap = new TimeStampedHashMap[A, B1]
+ newMap.internalMap.putAll(this.internalMap)
+ newMap.internalMap.put(kv._1, (kv._2, currentTime))
+ newMap
+ }
+
+ override def - (key: A): Map[A, B] = {
+ val newMap = new TimeStampedHashMap[A, B]
+ newMap.internalMap.putAll(this.internalMap)
+ newMap.internalMap.remove(key)
+ newMap
+ }
+
+ override def += (kv: (A, B)): this.type = {
+ internalMap.put(kv._1, (kv._2, currentTime))
+ this
+ }
+
+ override def -= (key: A): this.type = {
+ internalMap.remove(key)
+ this
+ }
+
+ override def update(key: A, value: B) {
+ this += ((key, value))
+ }
+
+ override def apply(key: A): B = {
+ val value = internalMap.get(key)
+ if (value == null) throw new NoSuchElementException()
+ value._1
+ }
+
+ override def filter(p: ((A, B)) => Boolean): Map[A, B] = {
+ JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p)
+ }
+
+ override def empty: Map[A, B] = new TimeStampedHashMap[A, B]()
+
+ override def size(): Int = internalMap.size()
+
+ override def foreach[U](f: ((A, B)) => U): Unit = {
+ val iterator = internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ val entry = iterator.next()
+ val kv = (entry.getKey, entry.getValue._1)
+ f(kv)
+ }
+ }
+
+ /**
+ * Removes old key-value pairs that have timestamp earlier than `threshTime`
+ */
+ def clearOldValues(threshTime: Long) {
+ val iterator = internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ val entry = iterator.next()
+ if (entry.getValue._2 < threshTime) {
+ logDebug("Removing key " + entry.getKey)
+ iterator.remove()
+ }
+ }
+ }
+
+ private def currentTime: Long = System.currentTimeMillis()
+
+}
diff --git a/core/src/main/scala/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/spark/util/TimeStampedHashSet.scala
new file mode 100644
index 0000000000..5f1cc93752
--- /dev/null
+++ b/core/src/main/scala/spark/util/TimeStampedHashSet.scala
@@ -0,0 +1,69 @@
+package spark.util
+
+import scala.collection.mutable.Set
+import scala.collection.JavaConversions
+import java.util.concurrent.ConcurrentHashMap
+
+
+class TimeStampedHashSet[A] extends Set[A] {
+ val internalMap = new ConcurrentHashMap[A, Long]()
+
+ def contains(key: A): Boolean = {
+ internalMap.contains(key)
+ }
+
+ def iterator: Iterator[A] = {
+ val jIterator = internalMap.entrySet().iterator()
+ JavaConversions.asScalaIterator(jIterator).map(_.getKey)
+ }
+
+ override def + (elem: A): Set[A] = {
+ val newSet = new TimeStampedHashSet[A]
+ newSet ++= this
+ newSet += elem
+ newSet
+ }
+
+ override def - (elem: A): Set[A] = {
+ val newSet = new TimeStampedHashSet[A]
+ newSet ++= this
+ newSet -= elem
+ newSet
+ }
+
+ override def += (key: A): this.type = {
+ internalMap.put(key, currentTime)
+ this
+ }
+
+ override def -= (key: A): this.type = {
+ internalMap.remove(key)
+ this
+ }
+
+ override def empty: Set[A] = new TimeStampedHashSet[A]()
+
+ override def size(): Int = internalMap.size()
+
+ override def foreach[U](f: (A) => U): Unit = {
+ val iterator = internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ f(iterator.next.getKey)
+ }
+ }
+
+ /**
+ * Removes old values that have timestamp earlier than `threshTime`
+ */
+ def clearOldValues(threshTime: Long) {
+ val iterator = internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ val entry = iterator.next()
+ if (entry.getValue < threshTime) {
+ iterator.remove()
+ }
+ }
+ }
+
+ private def currentTime: Long = System.currentTimeMillis()
+}
diff --git a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
index c32ab30401..be69e9bf02 100644
--- a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
@@ -7,6 +7,7 @@
<a href="@worker.webUiAddress">@worker.id</href>
</td>
<td>@{worker.host}:@{worker.port}</td>
+ <td>@worker.state</td>
<td>@worker.cores (@worker.coresUsed Used)</td>
<td>@{Utils.memoryMegabytesToString(worker.memory)}
(@{Utils.memoryMegabytesToString(worker.memoryUsed)} Used)</td>
diff --git a/core/src/main/twirl/spark/deploy/master/worker_table.scala.html b/core/src/main/twirl/spark/deploy/master/worker_table.scala.html
index fad1af41dc..b249411a62 100644
--- a/core/src/main/twirl/spark/deploy/master/worker_table.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/worker_table.scala.html
@@ -5,6 +5,7 @@
<tr>
<th>ID</th>
<th>Address</th>
+ <th>State</th>
<th>Cores</th>
<th>Memory</th>
</tr>
diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties
index 4c99e450bc..6ec89c0184 100644
--- a/core/src/test/resources/log4j.properties
+++ b/core/src/test/resources/log4j.properties
@@ -1,8 +1,8 @@
-# Set everything to be logged to the console
+# Set everything to be logged to the file core/target/unit-tests.log
log4j.rootCategory=INFO, file
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
-log4j.appender.file.file=spark-tests.log
+log4j.appender.file.file=core/target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
diff --git a/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala b/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala
deleted file mode 100644
index 37cafd1e8e..0000000000
--- a/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala
+++ /dev/null
@@ -1,58 +0,0 @@
-package spark
-
-import org.scalatest.FunSuite
-import org.scalatest.PrivateMethodTester
-import org.scalatest.matchers.ShouldMatchers
-
-// TODO: Replace this with a test of MemoryStore
-class BoundedMemoryCacheSuite extends FunSuite with PrivateMethodTester with ShouldMatchers {
- test("constructor test") {
- val cache = new BoundedMemoryCache(60)
- expect(60)(cache.getCapacity)
- }
-
- test("caching") {
- // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
- val oldArch = System.setProperty("os.arch", "amd64")
- val oldOops = System.setProperty("spark.test.useCompressedOops", "true")
- val initialize = PrivateMethod[Unit]('initialize)
- SizeEstimator invokePrivate initialize()
-
- val cache = new BoundedMemoryCache(60) {
- //TODO sorry about this, but there is not better way how to skip 'cacheTracker.dropEntry'
- override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
- logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
- }
- }
-
- // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length
- // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6.
- // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html
- // Work around to check for either.
-
- //should be OK
- cache.put("1", 0, "Meh") should (equal (CachePutSuccess(56)) or equal (CachePutSuccess(48)))
-
- //we cannot add this to cache (there is not enough space in cache) & we cannot evict the only value from
- //cache because it's from the same dataset
- expect(CachePutFailure())(cache.put("1", 1, "Meh"))
-
- //should be OK, dataset '1' can be evicted from cache
- cache.put("2", 0, "Meh") should (equal (CachePutSuccess(56)) or equal (CachePutSuccess(48)))
-
- //should fail, cache should obey it's capacity
- expect(CachePutFailure())(cache.put("3", 0, "Very_long_and_useless_string"))
-
- if (oldArch != null) {
- System.setProperty("os.arch", oldArch)
- } else {
- System.clearProperty("os.arch")
- }
-
- if (oldOops != null) {
- System.setProperty("spark.test.useCompressedOops", oldOops)
- } else {
- System.clearProperty("spark.test.useCompressedOops")
- }
- }
-}
diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala
deleted file mode 100644
index 467605981b..0000000000
--- a/core/src/test/scala/spark/CacheTrackerSuite.scala
+++ /dev/null
@@ -1,131 +0,0 @@
-package spark
-
-import org.scalatest.FunSuite
-
-import scala.collection.mutable.HashMap
-
-import akka.actor._
-import akka.dispatch._
-import akka.pattern.ask
-import akka.remote._
-import akka.util.Duration
-import akka.util.Timeout
-import akka.util.duration._
-
-class CacheTrackerSuite extends FunSuite {
- // Send a message to an actor and wait for a reply, in a blocking manner
- private def ask(actor: ActorRef, message: Any): Any = {
- try {
- val timeout = 10.seconds
- val future = actor.ask(message)(timeout)
- return Await.result(future, timeout)
- } catch {
- case e: Exception =>
- throw new SparkException("Error communicating with actor", e)
- }
- }
-
- test("CacheTrackerActor slave initialization & cache status") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 0L)))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("RegisterRDD") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 3)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(getCacheLocations(tracker) === Map(1 -> List(Nil, Nil, Nil), 2 -> List(Nil)))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("AddedToCache") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 2)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true)
- assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true)
- assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L)))
-
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("DroppedFromCache") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 2)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true)
- assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true)
- assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L)))
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
-
- assert(ask(tracker, DroppedFromCache(1, 1, "host001", 2L << 11)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 68608L)))
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- /**
- * Helper function to get cacheLocations from CacheTracker
- */
- def getCacheLocations(tracker: ActorRef): HashMap[Int, List[List[String]]] = {
- val answer = ask(tracker, GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
- answer.map { case (i, arr) => (i, arr.toList) }
- }
-}
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
new file mode 100644
index 0000000000..51573254ca
--- /dev/null
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -0,0 +1,357 @@
+package spark
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import java.io.File
+import spark.rdd._
+import spark.SparkContext._
+import storage.StorageLevel
+
+class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
+ initLogging()
+
+ var sc: SparkContext = _
+ var checkpointDir: File = _
+ val partitioner = new HashPartitioner(2)
+
+ before {
+ checkpointDir = File.createTempFile("temp", "")
+ checkpointDir.delete()
+
+ sc = new SparkContext("local", "test")
+ sc.setCheckpointDir(checkpointDir.toString)
+ }
+
+ after {
+ if (sc != null) {
+ sc.stop()
+ sc = null
+ }
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.master.port")
+
+ if (checkpointDir != null) {
+ checkpointDir.delete()
+ }
+ }
+
+ test("RDDs with one-to-one dependencies") {
+ testCheckpointing(_.map(x => x.toString))
+ testCheckpointing(_.flatMap(x => 1 to x))
+ testCheckpointing(_.filter(_ % 2 == 0))
+ testCheckpointing(_.sample(false, 0.5, 0))
+ testCheckpointing(_.glom())
+ testCheckpointing(_.mapPartitions(_.map(_.toString)))
+ testCheckpointing(r => new MapPartitionsWithSplitRDD(r,
+ (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false ))
+ testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
+ testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
+ testCheckpointing(_.pipe(Seq("cat")))
+ }
+
+ test("ParallelCollection") {
+ val parCollection = sc.makeRDD(1 to 4, 2)
+ val numSplits = parCollection.splits.size
+ parCollection.checkpoint()
+ assert(parCollection.dependencies === Nil)
+ val result = parCollection.collect()
+ assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result)
+ assert(parCollection.dependencies != Nil)
+ assert(parCollection.splits.length === numSplits)
+ assert(parCollection.splits.toList === parCollection.checkpointData.get.getSplits.toList)
+ assert(parCollection.collect() === result)
+ }
+
+ test("BlockRDD") {
+ val blockId = "id"
+ val blockManager = SparkEnv.get.blockManager
+ blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
+ val blockRDD = new BlockRDD[String](sc, Array(blockId))
+ val numSplits = blockRDD.splits.size
+ blockRDD.checkpoint()
+ val result = blockRDD.collect()
+ assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result)
+ assert(blockRDD.dependencies != Nil)
+ assert(blockRDD.splits.length === numSplits)
+ assert(blockRDD.splits.toList === blockRDD.checkpointData.get.getSplits.toList)
+ assert(blockRDD.collect() === result)
+ }
+
+ test("ShuffledRDD") {
+ testCheckpointing(rdd => {
+ // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
+ new ShuffledRDD(rdd.map(x => (x % 2, 1)), partitioner)
+ })
+ }
+
+ test("UnionRDD") {
+ def otherRDD = sc.makeRDD(1 to 10, 1)
+
+ // Test whether the size of UnionRDDSplits reduce in size after parent RDD is checkpointed.
+ // Current implementation of UnionRDD has transient reference to parent RDDs,
+ // so only the splits will reduce in serialized size, not the RDD.
+ testCheckpointing(_.union(otherRDD), false, true)
+ testParentCheckpointing(_.union(otherRDD), false, true)
+ }
+
+ test("CartesianRDD") {
+ def otherRDD = sc.makeRDD(1 to 10, 1)
+ testCheckpointing(new CartesianRDD(sc, _, otherRDD))
+
+ // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of CoalescedRDDSplit has transient reference to parent RDD,
+ // so only the RDD will reduce in serialized size, not the splits.
+ testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false)
+
+ // Test that the CartesianRDD updates parent splits (CartesianRDD.s1/s2) after
+ // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
+ // Note that this test is very specific to the current implementation of CartesianRDD.
+ val ones = sc.makeRDD(1 to 100, 10).map(x => x)
+ ones.checkpoint // checkpoint that MappedRDD
+ val cartesian = new CartesianRDD(sc, ones, ones)
+ val splitBeforeCheckpoint =
+ serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
+ cartesian.count() // do the checkpointing
+ val splitAfterCheckpoint =
+ serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
+ assert(
+ (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) &&
+ (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2),
+ "CartesianRDD.parents not updated after parent RDD checkpointed"
+ )
+ }
+
+ test("CoalescedRDD") {
+ testCheckpointing(new CoalescedRDD(_, 2))
+
+ // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of CoalescedRDDSplit has transient reference to parent RDD,
+ // so only the RDD will reduce in serialized size, not the splits.
+ testParentCheckpointing(new CoalescedRDD(_, 2), true, false)
+
+ // Test that the CoalescedRDDSplit updates parent splits (CoalescedRDDSplit.parents) after
+ // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
+ // Note that this test is very specific to the current implementation of CoalescedRDDSplits
+ val ones = sc.makeRDD(1 to 100, 10).map(x => x)
+ ones.checkpoint // checkpoint that MappedRDD
+ val coalesced = new CoalescedRDD(ones, 2)
+ val splitBeforeCheckpoint =
+ serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
+ coalesced.count() // do the checkpointing
+ val splitAfterCheckpoint =
+ serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
+ assert(
+ splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head,
+ "CoalescedRDDSplit.parents not updated after parent RDD checkpointed"
+ )
+ }
+
+ test("CoGroupedRDD") {
+ val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD()
+ testCheckpointing(rdd => {
+ CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner)
+ }, false, true)
+
+ val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD()
+ testParentCheckpointing(rdd => {
+ CheckpointSuite.cogroup(
+ longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner)
+ }, false, true)
+ }
+
+ test("ZippedRDD") {
+ testCheckpointing(
+ rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+
+ // Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of ZippedRDDSplit has transient references to parent RDDs,
+ // so only the RDD will reduce in serialized size, not the splits.
+ testParentCheckpointing(
+ rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+
+ }
+
+ /**
+ * Test checkpointing of the final RDD generated by the given operation. By default,
+ * this method tests whether the size of serialized RDD has reduced after checkpointing or not.
+ * It can also test whether the size of serialized RDD splits has reduced after checkpointing or
+ * not, but this is not done by default as usually the splits do not refer to any RDD and
+ * therefore never store the lineage.
+ */
+ def testCheckpointing[U: ClassManifest](
+ op: (RDD[Int]) => RDD[U],
+ testRDDSize: Boolean = true,
+ testRDDSplitSize: Boolean = false
+ ) {
+ // Generate the final RDD using given RDD operation
+ val baseRDD = generateLongLineageRDD
+ val operatedRDD = op(baseRDD)
+ val parentRDD = operatedRDD.dependencies.headOption.orNull
+ val rddType = operatedRDD.getClass.getSimpleName
+ val numSplits = operatedRDD.splits.length
+
+ // Find serialized sizes before and after the checkpoint
+ val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+ operatedRDD.checkpoint()
+ val result = operatedRDD.collect()
+ val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+
+ // Test whether the checkpoint file has been created
+ assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result)
+
+ // Test whether dependencies have been changed from its earlier parent RDD
+ assert(operatedRDD.dependencies.head.rdd != parentRDD)
+
+ // Test whether the splits have been changed to the new Hadoop splits
+ assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.getSplits.toList)
+
+ // Test whether the number of splits is same as before
+ assert(operatedRDD.splits.length === numSplits)
+
+ // Test whether the data in the checkpointed RDD is same as original
+ assert(operatedRDD.collect() === result)
+
+ // Test whether serialized size of the RDD has reduced. If the RDD
+ // does not have any dependency to another RDD (e.g., ParallelCollection,
+ // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing.
+ if (testRDDSize) {
+ logInfo("Size of " + rddType +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]")
+ assert(
+ rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+ "Size of " + rddType + " did not reduce after checkpointing " +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+ )
+ }
+
+ // Test whether serialized size of the splits has reduced. If the splits
+ // do not have any non-transient reference to another RDD or another RDD's splits, it
+ // does not refer to a lineage and therefore may not reduce in size after checkpointing.
+ // However, if the original splits before checkpointing do refer to a parent RDD, the splits
+ // must be forgotten after checkpointing (to remove all reference to parent RDDs) and
+ // replaced with the HadoopSplits of the checkpointed RDD.
+ if (testRDDSplitSize) {
+ logInfo("Size of " + rddType + " splits "
+ + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]")
+ assert(
+ splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
+ "Size of " + rddType + " splits did not reduce after checkpointing " +
+ "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
+ )
+ }
+ }
+
+ /**
+ * Test whether checkpointing of the parent of the generated RDD also
+ * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent
+ * RDDs splits. So even if the parent RDD is checkpointed and its splits changed,
+ * this RDD will remember the splits and therefore potentially the whole lineage.
+ */
+ def testParentCheckpointing[U: ClassManifest](
+ op: (RDD[Int]) => RDD[U],
+ testRDDSize: Boolean,
+ testRDDSplitSize: Boolean
+ ) {
+ // Generate the final RDD using given RDD operation
+ val baseRDD = generateLongLineageRDD
+ val operatedRDD = op(baseRDD)
+ val parentRDD = operatedRDD.dependencies.head.rdd
+ val rddType = operatedRDD.getClass.getSimpleName
+ val parentRDDType = parentRDD.getClass.getSimpleName
+
+ // Find serialized sizes before and after the checkpoint
+ val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+ parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one
+ val result = operatedRDD.collect()
+ val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+
+ // Test whether the data in the checkpointed RDD is same as original
+ assert(operatedRDD.collect() === result)
+
+ // Test whether serialized size of the RDD has reduced because of its parent being
+ // checkpointed. If this RDD or its parent RDD do not have any dependency
+ // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may
+ // not reduce in size after checkpointing.
+ if (testRDDSize) {
+ assert(
+ rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+ "Size of " + rddType + " did not reduce after parent checkpointing parent " + parentRDDType +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+ )
+ }
+
+ // Test whether serialized size of the splits has reduced because of its parent being
+ // checkpointed. If the splits do not have any non-transient reference to another RDD
+ // or another RDD's splits, it does not refer to a lineage and therefore may not reduce
+ // in size after checkpointing. However, if the splits do refer to the *splits* of a parent
+ // RDD, then these splits must update reference to the parent RDD splits as the parent RDD's
+ // splits must have changed after checkpointing.
+ if (testRDDSplitSize) {
+ assert(
+ splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
+ "Size of " + rddType + " splits did not reduce after checkpointing parent " + parentRDDType +
+ "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
+ )
+ }
+
+ }
+
+ /**
+ * Generate an RDD with a long lineage of one-to-one dependencies.
+ */
+ def generateLongLineageRDD(): RDD[Int] = {
+ var rdd = sc.makeRDD(1 to 100, 4)
+ for (i <- 1 to 50) {
+ rdd = rdd.map(x => x + 1)
+ }
+ rdd
+ }
+
+ /**
+ * Generate an RDD with a long lineage specifically for CoGroupedRDD.
+ * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage
+ * and narrow dependency with this RDD. This method generate such an RDD by a sequence
+ * of cogroups and mapValues which creates a long lineage of narrow dependencies.
+ */
+ def generateLongLineageRDDForCoGroupedRDD() = {
+ val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _)
+
+ def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
+
+ var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones)
+ for(i <- 1 to 10) {
+ cogrouped = cogrouped.mapValues(add).cogroup(ones)
+ }
+ cogrouped.mapValues(add)
+ }
+
+ /**
+ * Get serialized sizes of the RDD and its splits
+ */
+ def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
+ (Utils.serialize(rdd).size, Utils.serialize(rdd.splits).size)
+ }
+
+ /**
+ * Serialize and deserialize an object. This is useful to verify the objects
+ * contents after deserialization (e.g., the contents of an RDD split after
+ * it is sent to a slave along with a task)
+ */
+ def serializeDeserialize[T](obj: T): T = {
+ val bytes = Utils.serialize(obj)
+ Utils.deserialize[T](bytes)
+ }
+}
+
+
+object CheckpointSuite {
+ // This is a custom cogroup function that does not use mapValues like
+ // the PairRDDFunctions.cogroup()
+ def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = {
+ //println("First = " + first + ", second = " + second)
+ new CoGroupedRDD[K](
+ Seq(first.asInstanceOf[RDD[(_, _)]], second.asInstanceOf[RDD[(_, _)]]),
+ part
+ ).asInstanceOf[RDD[(K, Seq[Seq[V]])]]
+ }
+
+}
diff --git a/core/src/test/scala/spark/ClosureCleanerSuite.scala b/core/src/test/scala/spark/ClosureCleanerSuite.scala
index 7c0334d957..dfa2de80e6 100644
--- a/core/src/test/scala/spark/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/spark/ClosureCleanerSuite.scala
@@ -47,6 +47,8 @@ object TestObject {
val nums = sc.parallelize(Array(1, 2, 3, 4))
val answer = nums.map(_ + x).reduce(_ + _)
sc.stop()
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.master.port")
return answer
}
}
diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala
index cacc2796b6..0487e06d12 100644
--- a/core/src/test/scala/spark/DistributedSuite.scala
+++ b/core/src/test/scala/spark/DistributedSuite.scala
@@ -188,4 +188,73 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
val values = sc.parallelize(1 to 2, 2).map(x => System.getenv("TEST_VAR")).collect()
assert(values.toSeq === Seq("TEST_VALUE", "TEST_VALUE"))
}
+
+ test("recover from node failures") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ val data = sc.parallelize(Seq(true, true), 2)
+ assert(data.count === 2) // force executors to start
+ val masterId = SparkEnv.get.blockManager.blockManagerId
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ assert(data.map(failOnMarkedIdentity).collect.size === 2)
+ }
+
+ test("recover from repeated node failures during shuffle-map") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, false), 2)
+ assert(data.count === 2)
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2)
+ }
+ }
+
+ test("recover from repeated node failures during shuffle-reduce") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, true), 2)
+ assert(data.count === 2)
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ // This relies on mergeCombiners being used to perform the actual reduce for this
+ // test to actually be testing what it claims.
+ val grouped = data.map(x => x -> x).combineByKey(
+ x => x,
+ (x: Boolean, y: Boolean) => x,
+ (x: Boolean, y: Boolean) => failOnMarkedIdentity(x)
+ )
+ assert(grouped.collect.size === 1)
+ }
+ }
+}
+
+object DistributedSuite {
+ // Indicates whether this JVM is marked for failure.
+ var mark = false
+
+ // Set by test to remember if we are in the driver program so we can assert
+ // that we are not.
+ var amMaster = false
+
+ // Act like an identity function, but if the argument is true, set mark to true.
+ def markNodeIfIdentity(item: Boolean): Boolean = {
+ if (item) {
+ assert(!amMaster)
+ mark = true
+ }
+ item
+ }
+
+ // Act like an identity function, but if mark was set to true previously, fail,
+ // crashing the entire JVM.
+ def failOnMarkedIdentity(item: Boolean): Boolean = {
+ if (mark) {
+ System.exit(42)
+ }
+ item
+ }
}
diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala
new file mode 100644
index 0000000000..70a7c8bc2f
--- /dev/null
+++ b/core/src/test/scala/spark/DriverSuite.scala
@@ -0,0 +1,31 @@
+package spark
+
+import java.io.File
+
+import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.prop.TableDrivenPropertyChecks._
+import org.scalatest.time.SpanSugar._
+
+class DriverSuite extends FunSuite with Timeouts {
+ test("driver should exit after finishing") {
+ // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing"
+ val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]"))
+ forAll(masters) { (master: String) =>
+ failAfter(10 seconds) {
+ Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master), new File(System.getenv("SPARK_HOME")))
+ }
+ }
+ }
+}
+
+/**
+ * Program that creates a Spark driver but doesn't call SparkContext.stop() or
+ * Sys.exit() after finishing.
+ */
+object DriverWithoutCleanup {
+ def main(args: Array[String]) {
+ val sc = new SparkContext(args(0), "DriverWithoutCleanup")
+ sc.parallelize(1 to 100, 4).count()
+ }
+} \ No newline at end of file
diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala
index b4283d9604..b9e1248829 100644
--- a/core/src/test/scala/spark/FileServerSuite.scala
+++ b/core/src/test/scala/spark/FileServerSuite.scala
@@ -9,8 +9,8 @@ import SparkContext._
class FileServerSuite extends FunSuite with BeforeAndAfter {
@transient var sc: SparkContext = _
- @transient var tmpFile : File = _
- @transient var testJarFile : File = _
+ @transient var tmpFile: File = _
+ @transient var testJarFile: File = _
before {
// Create a sample text file
@@ -40,7 +40,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
@@ -54,7 +55,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile((new File(tmpFile.toString)).toURL.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
@@ -83,7 +85,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index 46a0b68f89..01351de4ae 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -131,6 +131,17 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void lookup() {
+ JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, String>("Apples", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Citrus")
+ ));
+ Assert.assertEquals(2, categories.lookup("Oranges").size());
+ Assert.assertEquals(2, categories.groupByKey().lookup("Oranges").get(0).size());
+ }
+
+ @Test
public void groupBy() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
Function<Integer, Boolean> isOdd = new Function<Integer, Boolean>() {
@@ -570,4 +581,91 @@ public class JavaAPISuite implements Serializable {
JavaPairRDD<Integer, Double> zipped = rdd.zip(doubles);
zipped.count();
}
+
+ @Test
+ public void accumulators() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+
+ final Accumulator<Integer> intAccum = sc.intAccumulator(10);
+ rdd.foreach(new VoidFunction<Integer>() {
+ public void call(Integer x) {
+ intAccum.add(x);
+ }
+ });
+ Assert.assertEquals((Integer) 25, intAccum.value());
+
+ final Accumulator<Double> doubleAccum = sc.doubleAccumulator(10.0);
+ rdd.foreach(new VoidFunction<Integer>() {
+ public void call(Integer x) {
+ doubleAccum.add((double) x);
+ }
+ });
+ Assert.assertEquals((Double) 25.0, doubleAccum.value());
+
+ // Try a custom accumulator type
+ AccumulatorParam<Float> floatAccumulatorParam = new AccumulatorParam<Float>() {
+ public Float addInPlace(Float r, Float t) {
+ return r + t;
+ }
+
+ public Float addAccumulator(Float r, Float t) {
+ return r + t;
+ }
+
+ public Float zero(Float initialValue) {
+ return 0.0f;
+ }
+ };
+
+ final Accumulator<Float> floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam);
+ rdd.foreach(new VoidFunction<Integer>() {
+ public void call(Integer x) {
+ floatAccum.add((float) x);
+ }
+ });
+ Assert.assertEquals((Float) 25.0f, floatAccum.value());
+
+ // Test the setValue method
+ floatAccum.setValue(5.0f);
+ Assert.assertEquals((Float) 5.0f, floatAccum.value());
+ }
+
+ @Test
+ public void keyBy() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2));
+ List<Tuple2<String, Integer>> s = rdd.keyBy(new Function<Integer, String>() {
+ public String call(Integer t) throws Exception {
+ return t.toString();
+ }
+ }).collect();
+ Assert.assertEquals(new Tuple2<String, Integer>("1", 1), s.get(0));
+ Assert.assertEquals(new Tuple2<String, Integer>("2", 2), s.get(1));
+ }
+
+ @Test
+ public void checkpointAndComputation() {
+ File tempDir = Files.createTempDir();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ sc.setCheckpointDir(tempDir.getAbsolutePath(), true);
+ Assert.assertEquals(false, rdd.isCheckpointed());
+ rdd.checkpoint();
+ rdd.count(); // Forces the DAG to cause a checkpoint
+ Assert.assertEquals(true, rdd.isCheckpointed());
+ Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect());
+ }
+
+ @Test
+ public void checkpointAndRestore() {
+ File tempDir = Files.createTempDir();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ sc.setCheckpointDir(tempDir.getAbsolutePath(), true);
+ Assert.assertEquals(false, rdd.isCheckpointed());
+ rdd.checkpoint();
+ rdd.count(); // Forces the DAG to cause a checkpoint
+ Assert.assertEquals(true, rdd.isCheckpointed());
+
+ Assert.assertTrue(rdd.getCheckpointFile().isPresent());
+ JavaRDD<Integer> recovered = sc.checkpointFile(rdd.getCheckpointFile().get());
+ Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect());
+ }
}
diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index 5b4b198960..095f415978 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -1,12 +1,18 @@
package spark
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import akka.actor._
import spark.scheduler.MapStatus
import spark.storage.BlockManagerId
+import spark.util.AkkaUtils
-class MapOutputTrackerSuite extends FunSuite {
+class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter {
+ after {
+ System.clearProperty("spark.master.port")
+ }
+
test("compressSize") {
assert(MapOutputTracker.compressSize(0L) === 0)
assert(MapOutputTracker.compressSize(1L) === 1)
@@ -41,13 +47,13 @@ class MapOutputTrackerSuite extends FunSuite {
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000),
Array(compressedSize1000, compressedSize10000)))
- tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000),
Array(compressedSize10000, compressedSize1000)))
val statuses = tracker.getServerStatuses(10, 0)
- assert(statuses.toSeq === Seq((new BlockManagerId("hostA", 1000), size1000),
- (new BlockManagerId("hostB", 1000), size10000)))
+ assert(statuses.toSeq === Seq((BlockManagerId("hostA", 1000), size1000),
+ (BlockManagerId("hostB", 1000), size10000)))
tracker.stop()
}
@@ -59,18 +65,48 @@ class MapOutputTrackerSuite extends FunSuite {
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
- tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
// As if we had two simulatenous fetch failures
- tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
- tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000))
// The remaining reduce task might try to grab the output dispite the shuffle failure;
// this should cause it to fail, and the scheduler will ignore the failure due to the
// stage already being aborted.
- intercept[Exception] { tracker.getServerStatuses(10, 1) }
+ intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) }
+ }
+
+ test("remote fetch") {
+ System.clearProperty("spark.master.host")
+ val (actorSystem, boundPort) =
+ AkkaUtils.createActorSystem("test", "localhost", 0)
+ System.setProperty("spark.master.port", boundPort.toString)
+ val masterTracker = new MapOutputTracker(actorSystem, true)
+ val slaveTracker = new MapOutputTracker(actorSystem, false)
+ masterTracker.registerShuffle(10, 1)
+ masterTracker.incrementGeneration()
+ slaveTracker.updateGeneration(masterTracker.getGeneration)
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ masterTracker.registerMapOutput(10, 0, new MapStatus(
+ BlockManagerId("hostA", 1000), Array(compressedSize1000)))
+ masterTracker.incrementGeneration()
+ slaveTracker.updateGeneration(masterTracker.getGeneration)
+ assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+ Seq((BlockManagerId("hostA", 1000), size1000)))
+
+ masterTracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000))
+ masterTracker.incrementGeneration()
+ slaveTracker.updateGeneration(masterTracker.getGeneration)
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+
+ // failure should be cached
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
}
}
diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala
index 3dadc7acec..eb3c8f238f 100644
--- a/core/src/test/scala/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/spark/PartitioningSuite.scala
@@ -106,5 +106,31 @@ class PartitioningSuite extends FunSuite with BeforeAndAfter {
assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner)
+
+ assert(grouped2.map(_ => 1).partitioner === None)
+ assert(grouped2.mapValues(_ => 1).partitioner === grouped2.partitioner)
+ assert(grouped2.flatMapValues(_ => Seq(1)).partitioner === grouped2.partitioner)
+ assert(grouped2.filter(_._1 > 4).partitioner === grouped2.partitioner)
+ }
+
+ test("partitioning Java arrays should fail") {
+ sc = new SparkContext("local", "test")
+ val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x))
+ val arrPairs: RDD[(Array[Int], Int)] =
+ sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x))
+
+ assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array"))
+ // We can't catch all usages of arrays, since they might occur inside other collections:
+ //assert(fails { arrPairs.distinct() })
+ assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array"))
}
}
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index b3c820ed94..db217f8482 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -8,9 +8,9 @@ import spark.rdd.CoalescedRDD
import SparkContext._
class RDDSuite extends FunSuite with BeforeAndAfter {
-
+
var sc: SparkContext = _
-
+
after {
if (sc != null) {
sc.stop()
@@ -19,11 +19,15 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.master.port")
}
-
+
test("basic operations") {
sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
+ val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
+ assert(dups.distinct.count === 4)
+ assert(dups.distinct().collect === dups.distinct.collect)
+ assert(dups.distinct(2).collect === dups.distinct.collect)
assert(nums.reduce(_ + _) === 10)
assert(nums.fold(0)(_ + _) === 10)
assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4"))
@@ -31,6 +35,8 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4)))
+ assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4"))
+ assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4)))
val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _)))
assert(partitionSums.collect().toList === List(3, 7))
@@ -70,10 +76,23 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
}
- test("checkpointing") {
+ test("basic checkpointing") {
+ import java.io.File
+ val checkpointDir = File.createTempFile("temp", "")
+ checkpointDir.delete()
+
sc = new SparkContext("local", "test")
- val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).flatMap(x => 1 to x).checkpoint()
- assert(rdd.collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
+ sc.setCheckpointDir(checkpointDir.toString)
+ val parCollection = sc.makeRDD(1 to 4)
+ val flatMappedRDD = parCollection.flatMap(x => 1 to x)
+ flatMappedRDD.checkpoint()
+ assert(flatMappedRDD.dependencies.head.rdd == parCollection)
+ val result = flatMappedRDD.collect()
+ Thread.sleep(1000)
+ assert(flatMappedRDD.dependencies.head.rdd != parCollection)
+ assert(flatMappedRDD.collect() === result)
+
+ checkpointDir.deleteOnExit()
}
test("basic caching") {
@@ -84,6 +103,29 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(rdd.collect().toList === List(1, 2, 3, 4))
}
+ test("caching with failures") {
+ sc = new SparkContext("local", "test")
+ val onlySplit = new Split { override def index: Int = 0 }
+ var shouldFail = true
+ val rdd = new RDD[Int](sc, Nil) {
+ override def getSplits: Array[Split] = Array(onlySplit)
+ override val getDependencies = List[Dependency[_]]()
+ override def compute(split: Split, context: TaskContext): Iterator[Int] = {
+ if (shouldFail) {
+ throw new Exception("injected failure")
+ } else {
+ return Array(1, 2, 3, 4).iterator
+ }
+ }
+ }.cache()
+ val thrown = intercept[Exception]{
+ rdd.collect()
+ }
+ assert(thrown.getMessage.contains("injected failure"))
+ shouldFail = false
+ assert(rdd.collect().toList === List(1, 2, 3, 4))
+ }
+
test("coalesced RDDs") {
sc = new SparkContext("local", "test")
val data = sc.parallelize(1 to 10, 10)
@@ -94,8 +136,8 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10)))
// Check that the narrow dependency is also specified correctly
- assert(coalesced1.dependencies.head.getParents(0).toList === List(0, 1, 2, 3, 4))
- assert(coalesced1.dependencies.head.getParents(1).toList === List(5, 6, 7, 8, 9))
+ assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === List(0, 1, 2, 3, 4))
+ assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === List(5, 6, 7, 8, 9))
val coalesced2 = new CoalescedRDD(data, 3)
assert(coalesced2.collect().toList === (1 to 10).toList)
@@ -121,7 +163,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
val zipped = nums.zip(nums.map(_ + 1.0))
assert(zipped.glom().map(_.toList).collect().toList ===
List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0))))
-
+
intercept[IllegalArgumentException] {
nums.zip(sc.parallelize(1 to 4, 1)).collect()
}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 8170100f1d..bebb8ebe86 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -216,6 +216,13 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
// Test that a shuffle on the file works, because this used to be a bug
assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
}
+
+ test("keys and values") {
+ sc = new SparkContext("local", "test")
+ val rdd = sc.parallelize(Array((1, "a"), (2, "b")))
+ assert(rdd.keys.collect().toList === List(1, 2))
+ assert(rdd.values.collect().toList === List("a", "b"))
+ }
}
object ShuffleSuite {
diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala
index 17f366212b..e235ef2f67 100644
--- a/core/src/test/scala/spark/SizeEstimatorSuite.scala
+++ b/core/src/test/scala/spark/SizeEstimatorSuite.scala
@@ -3,7 +3,6 @@ package spark
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfterAll
import org.scalatest.PrivateMethodTester
-import org.scalatest.matchers.ShouldMatchers
class DummyClass1 {}
@@ -20,8 +19,17 @@ class DummyClass4(val d: DummyClass3) {
val x: Int = 0
}
+object DummyString {
+ def apply(str: String) : DummyString = new DummyString(str.toArray)
+}
+class DummyString(val arr: Array[Char]) {
+ override val hashCode: Int = 0
+ // JDK-7 has an extra hash32 field http://hg.openjdk.java.net/jdk7u/jdk7u6/jdk/rev/11987e85555f
+ @transient val hash32: Int = 0
+}
+
class SizeEstimatorSuite
- extends FunSuite with BeforeAndAfterAll with PrivateMethodTester with ShouldMatchers {
+ extends FunSuite with BeforeAndAfterAll with PrivateMethodTester {
var oldArch: String = _
var oldOops: String = _
@@ -45,15 +53,13 @@ class SizeEstimatorSuite
expect(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3)))
}
- // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length.
- // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6.
- // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html
- // Work around to check for either.
+ // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors
+ // (Sun vs IBM). Use a DummyString class to make tests deterministic.
test("strings") {
- SizeEstimator.estimate("") should (equal (48) or equal (40))
- SizeEstimator.estimate("a") should (equal (56) or equal (48))
- SizeEstimator.estimate("ab") should (equal (56) or equal (48))
- SizeEstimator.estimate("abcdefgh") should (equal(64) or equal(56))
+ expect(40)(SizeEstimator.estimate(DummyString("")))
+ expect(48)(SizeEstimator.estimate(DummyString("a")))
+ expect(48)(SizeEstimator.estimate(DummyString("ab")))
+ expect(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
}
test("primitive arrays") {
@@ -105,18 +111,16 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
- expect(40)(SizeEstimator.estimate(""))
- expect(48)(SizeEstimator.estimate("a"))
- expect(48)(SizeEstimator.estimate("ab"))
- expect(56)(SizeEstimator.estimate("abcdefgh"))
+ expect(40)(SizeEstimator.estimate(DummyString("")))
+ expect(48)(SizeEstimator.estimate(DummyString("a")))
+ expect(48)(SizeEstimator.estimate(DummyString("ab")))
+ expect(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
resetOrClear("os.arch", arch)
}
- // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length.
- // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6.
- // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html
- // Work around to check for either.
+ // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors
+ // (Sun vs IBM). Use a DummyString class to make tests deterministic.
test("64-bit arch with no compressed oops") {
val arch = System.setProperty("os.arch", "amd64")
val oops = System.setProperty("spark.test.useCompressedOops", "false")
@@ -124,10 +128,10 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
- SizeEstimator.estimate("") should (equal (64) or equal (56))
- SizeEstimator.estimate("a") should (equal (72) or equal (64))
- SizeEstimator.estimate("ab") should (equal (72) or equal (64))
- SizeEstimator.estimate("abcdefgh") should (equal (80) or equal (72))
+ expect(56)(SizeEstimator.estimate(DummyString("")))
+ expect(64)(SizeEstimator.estimate(DummyString("a")))
+ expect(64)(SizeEstimator.estimate(DummyString("ab")))
+ expect(72)(SizeEstimator.estimate(DummyString("abcdefgh")))
resetOrClear("os.arch", arch)
resetOrClear("spark.test.useCompressedOops", oops)
diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
new file mode 100644
index 0000000000..ba6f8b588f
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
@@ -0,0 +1,42 @@
+package spark.scheduler
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import spark.TaskContext
+import spark.RDD
+import spark.SparkContext
+import spark.Split
+
+class TaskContextSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if (sc != null) {
+ sc.stop()
+ sc = null
+ }
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.master.port")
+ }
+
+ test("Calls executeOnCompleteCallbacks after failure") {
+ var completed = false
+ sc = new SparkContext("local", "test")
+ val rdd = new RDD[String](sc, List()) {
+ override def getSplits = Array[Split](StubSplit(0))
+ override def compute(split: Split, context: TaskContext) = {
+ context.addOnCompleteCallback(() => completed = true)
+ sys.error("failed")
+ }
+ }
+ val func = (c: TaskContext, i: Iterator[String]) => i.next
+ val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0)
+ intercept[RuntimeException] {
+ task.run(0)
+ }
+ assert(completed === true)
+ }
+
+ case class StubSplit(val index: Int) extends Split
+} \ No newline at end of file
diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
index ad2253596d..a1aeb12f25 100644
--- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
@@ -7,6 +7,10 @@ import akka.actor._
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.scalatest.PrivateMethodTester
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.matchers.ShouldMatchers._
+import org.scalatest.time.SpanSugar._
import spark.KryoSerializer
import spark.SizeEstimator
@@ -20,15 +24,16 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
var oldArch: String = null
var oldOops: String = null
var oldHeartBeat: String = null
-
- // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
+
+ // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
+ System.setProperty("spark.kryoserializer.buffer.mb", "1")
val serializer = new KryoSerializer
before {
actorSystem = ActorSystem("test")
- master = new BlockManagerMaster(actorSystem, true, true)
+ master = new BlockManagerMaster(actorSystem, true, true, "localhost", 7077)
- // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
+ // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
oldArch = System.setProperty("os.arch", "amd64")
oldOops = System.setProperty("spark.test.useCompressedOops", "true")
oldHeartBeat = System.setProperty("spark.storage.disableBlockManagerHeartBeat", "true")
@@ -63,7 +68,41 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
}
- test("manager-master interaction") {
+ test("StorageLevel object caching") {
+ val level1 = StorageLevel(false, false, false, 3)
+ val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1
+ val level3 = StorageLevel(false, false, false, 2) // this should return a different object
+ assert(level2 === level1, "level2 is not same as level1")
+ assert(level2.eq(level1), "level2 is not the same object as level1")
+ assert(level3 != level1, "level3 is same as level1")
+ val bytes1 = spark.Utils.serialize(level1)
+ val level1_ = spark.Utils.deserialize[StorageLevel](bytes1)
+ val bytes2 = spark.Utils.serialize(level2)
+ val level2_ = spark.Utils.deserialize[StorageLevel](bytes2)
+ assert(level1_ === level1, "Deserialized level1 not same as original level1")
+ assert(level1_.eq(level1), "Deserialized level1 not the same object as original level2")
+ assert(level2_ === level2, "Deserialized level2 not same as original level2")
+ assert(level2_.eq(level1), "Deserialized level2 not the same object as original level1")
+ }
+
+ test("BlockManagerId object caching") {
+ val id1 = BlockManagerId("XXX", 1)
+ val id2 = BlockManagerId("XXX", 1) // this should return the same object as id1
+ val id3 = BlockManagerId("XXX", 2) // this should return a different object
+ assert(id2 === id1, "id2 is not same as id1")
+ assert(id2.eq(id1), "id2 is not the same object as id1")
+ assert(id3 != id1, "id3 is same as id1")
+ val bytes1 = spark.Utils.serialize(id1)
+ val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1)
+ val bytes2 = spark.Utils.serialize(id2)
+ val id2_ = spark.Utils.deserialize[BlockManagerId](bytes2)
+ assert(id1_ === id1, "Deserialized id1 is not same as original id1")
+ assert(id1_.eq(id1), "Deserialized id1 is not the same object as original id1")
+ assert(id2_ === id2, "Deserialized id2 is not same as original id2")
+ assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1")
+ }
+
+ test("master + 1 manager interaction") {
store = new BlockManager(actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
@@ -74,83 +113,122 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY)
store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, false)
- // Checking whether blocks are in memory
+ // Checking whether blocks are in memory
assert(store.getSingle("a1") != None, "a1 was not in store")
assert(store.getSingle("a2") != None, "a2 was not in store")
assert(store.getSingle("a3") != None, "a3 was not in store")
// Checking whether master knows about the blocks or not
- assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1")
- assert(master.mustGetLocations(GetLocations("a2")).size > 0, "master was not told about a2")
- assert(master.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3")
-
+ assert(master.getLocations("a1").size > 0, "master was not told about a1")
+ assert(master.getLocations("a2").size > 0, "master was not told about a2")
+ assert(master.getLocations("a3").size === 0, "master was told about a3")
+
// Drop a1 and a2 from memory; this should be reported back to the master
store.dropFromMemory("a1", null)
store.dropFromMemory("a2", null)
assert(store.getSingle("a1") === None, "a1 not removed from store")
assert(store.getSingle("a2") === None, "a2 not removed from store")
- assert(master.mustGetLocations(GetLocations("a1")).size === 0, "master did not remove a1")
- assert(master.mustGetLocations(GetLocations("a2")).size === 0, "master did not remove a2")
+ assert(master.getLocations("a1").size === 0, "master did not remove a1")
+ assert(master.getLocations("a2").size === 0, "master did not remove a2")
}
- test("reregistration on heart beat") {
- val heartBeat = PrivateMethod[Unit]('heartBeat)
+ test("master + 2 managers interaction") {
store = new BlockManager(actorSystem, master, serializer, 2000)
+ store2 = new BlockManager(actorSystem, master, new KryoSerializer, 2000)
+
+ val peers = master.getPeers(store.blockManagerId, 1)
+ assert(peers.size === 1, "master did not return the other manager as a peer")
+ assert(peers.head === store2.blockManagerId, "peer returned by master is not the other manager")
+
val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_2)
+ store2.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_2)
+ assert(master.getLocations("a1").size === 2, "master did not report 2 locations for a1")
+ assert(master.getLocations("a2").size === 2, "master did not report 2 locations for a2")
+ }
- store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
+ test("removing block") {
+ store = new BlockManager(actorSystem, master, serializer, 2000)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
- assert(store.getSingle("a1") != None, "a1 was not in store")
- assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1")
+ // Putting a1, a2 and a3 in memory and telling master only about a1 and a2
+ store.putSingle("a1-to-remove", a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle("a2-to-remove", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, false)
- master.notifyADeadHost(store.blockManagerId.ip)
- assert(master.mustGetLocations(GetLocations("a1")).size == 0, "a1 was not removed from master")
+ // Checking whether blocks are in memory and memory size
+ val memStatus = master.getMemoryStatus.head._2
+ assert(memStatus._1 == 2000L, "total memory " + memStatus._1 + " should equal 2000")
+ assert(memStatus._2 <= 1200L, "remaining memory " + memStatus._2 + " should <= 1200")
+ assert(store.getSingle("a1-to-remove") != None, "a1 was not in store")
+ assert(store.getSingle("a2-to-remove") != None, "a2 was not in store")
+ assert(store.getSingle("a3-to-remove") != None, "a3 was not in store")
- store invokePrivate heartBeat()
- assert(master.mustGetLocations(GetLocations("a1")).size > 0,
- "a1 was not reregistered with master")
+ // Checking whether master knows about the blocks or not
+ assert(master.getLocations("a1-to-remove").size > 0, "master was not told about a1")
+ assert(master.getLocations("a2-to-remove").size > 0, "master was not told about a2")
+ assert(master.getLocations("a3-to-remove").size === 0, "master was told about a3")
+
+ // Remove a1 and a2 and a3. Should be no-op for a3.
+ master.removeBlock("a1-to-remove")
+ master.removeBlock("a2-to-remove")
+ master.removeBlock("a3-to-remove")
+
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("a1-to-remove") should be (None)
+ master.getLocations("a1-to-remove") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("a2-to-remove") should be (None)
+ master.getLocations("a2-to-remove") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("a3-to-remove") should not be (None)
+ master.getLocations("a3-to-remove") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ val memStatus = master.getMemoryStatus.head._2
+ memStatus._1 should equal (2000L)
+ memStatus._2 should equal (2000L)
+ }
}
- test("reregistration on block update") {
+ test("reregistration on heart beat") {
+ val heartBeat = PrivateMethod[Unit]('heartBeat)
store = new BlockManager(actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
- val a2 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
- assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1")
+ assert(store.getSingle("a1") != None, "a1 was not in store")
+ assert(master.getLocations("a1").size > 0, "master was not told about a1")
master.notifyADeadHost(store.blockManagerId.ip)
- assert(master.mustGetLocations(GetLocations("a1")).size == 0, "a1 was not removed from master")
-
- store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY)
+ assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
- assert(master.mustGetLocations(GetLocations("a1")).size > 0,
- "a1 was not reregistered with master")
- assert(master.mustGetLocations(GetLocations("a2")).size > 0,
- "master was not told about a2")
+ store invokePrivate heartBeat()
+ assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master")
}
- test("deregistration on duplicate") {
- val heartBeat = PrivateMethod[Unit]('heartBeat)
+ test("reregistration on block update") {
store = new BlockManager(actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
- assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1")
+ assert(master.getLocations("a1").size > 0, "master was not told about a1")
- store2 = new BlockManager(actorSystem, master, serializer, 2000)
-
- assert(master.mustGetLocations(GetLocations("a1")).size == 0, "a1 was not removed from master")
+ master.notifyADeadHost(store.blockManagerId.ip)
+ assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
- store invokePrivate heartBeat()
-
- assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1")
+ store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY)
- store2 invokePrivate heartBeat()
-
- assert(master.mustGetLocations(GetLocations("a1")).size == 0, "a2 was not removed from master")
+ assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master")
+ assert(master.getLocations("a2").size > 0, "master was not told about a2")
}
test("in-memory LRU storage") {
@@ -171,7 +249,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
assert(store.getSingle("a2") != None, "a2 was not in store")
assert(store.getSingle("a3") === None, "a3 was in store")
}
-
+
test("in-memory LRU storage with serialization") {
store = new BlockManager(actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
diff --git a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala
new file mode 100644
index 0000000000..794063fb6d
--- /dev/null
+++ b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala
@@ -0,0 +1,23 @@
+package spark.util
+
+import org.scalatest.FunSuite
+import java.io.ByteArrayOutputStream
+import java.util.concurrent.TimeUnit._
+
+class RateLimitedOutputStreamSuite extends FunSuite {
+
+ private def benchmark[U](f: => U): Long = {
+ val start = System.nanoTime
+ f
+ System.nanoTime - start
+ }
+
+ test("write") {
+ val underlying = new ByteArrayOutputStream
+ val data = "X" * 41000
+ val stream = new RateLimitedOutputStream(underlying, 10000)
+ val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) }
+ assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4)
+ assert(underlying.toString("UTF-8") == data)
+ }
+}