aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2012-06-07 00:25:47 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-06-07 12:45:38 -0700
commit63051dd2bcc4bf09d413ff7cf89a37967edc33ba (patch)
tree4389cb7212c6c643ed0565551b4ad654d2218767
parent7e1c97fc4b5a225e496ebd95c0ef6095dc4aeae9 (diff)
downloadspark-63051dd2bcc4bf09d413ff7cf89a37967edc33ba.tar.gz
spark-63051dd2bcc4bf09d413ff7cf89a37967edc33ba.tar.bz2
spark-63051dd2bcc4bf09d413ff7cf89a37967edc33ba.zip
Merge in engine improvements from the Spark Streaming project, developed
jointly with Tathagata Das and Haoyuan Li. This commit imports the changes and ports them to Mesos 0.9, but does not yet pass unit tests due to various classes not supporting a graceful stop() yet.
-rw-r--r--bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala11
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala70
-rw-r--r--core/src/main/scala/spark/BoundedMemoryCache.scala3
-rw-r--r--core/src/main/scala/spark/CacheTracker.scala334
-rw-r--r--core/src/main/scala/spark/CoGroupedRDD.scala6
-rw-r--r--core/src/main/scala/spark/DAGScheduler.scala374
-rw-r--r--core/src/main/scala/spark/Dependency.scala2
-rw-r--r--core/src/main/scala/spark/DiskSpillingCache.scala75
-rw-r--r--core/src/main/scala/spark/DoubleRDDFunctions.scala39
-rw-r--r--core/src/main/scala/spark/Executor.scala26
-rw-r--r--core/src/main/scala/spark/FetchFailedException.scala8
-rw-r--r--core/src/main/scala/spark/JavaSerializer.scala39
-rw-r--r--core/src/main/scala/spark/Job.scala16
-rw-r--r--core/src/main/scala/spark/KryoSerializer.scala98
-rw-r--r--core/src/main/scala/spark/Logging.scala7
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala108
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala56
-rw-r--r--core/src/main/scala/spark/ParallelShuffleFetcher.scala119
-rw-r--r--core/src/main/scala/spark/Partitioner.scala1
-rw-r--r--core/src/main/scala/spark/PipedRDD.scala1
-rw-r--r--core/src/main/scala/spark/RDD.scala104
-rw-r--r--core/src/main/scala/spark/Scheduler.scala27
-rw-r--r--core/src/main/scala/spark/SequenceFileRDDFunctions.scala2
-rw-r--r--core/src/main/scala/spark/Serializer.scala86
-rw-r--r--core/src/main/scala/spark/SerializingCache.scala26
-rw-r--r--core/src/main/scala/spark/ShuffleMapTask.scala56
-rw-r--r--core/src/main/scala/spark/ShuffledRDD.scala2
-rw-r--r--core/src/main/scala/spark/SimpleShuffleFetcher.scala46
-rw-r--r--core/src/main/scala/spark/SparkContext.scala77
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala79
-rw-r--r--core/src/main/scala/spark/Stage.scala41
-rw-r--r--core/src/main/scala/spark/Task.scala9
-rw-r--r--core/src/main/scala/spark/TaskContext.scala3
-rw-r--r--core/src/main/scala/spark/TaskEndReason.scala16
-rw-r--r--core/src/main/scala/spark/TaskResult.scala8
-rw-r--r--core/src/main/scala/spark/UnionRDD.scala3
-rw-r--r--core/src/main/scala/spark/Utils.scala35
-rw-r--r--core/src/main/scala/spark/network/Connection.scala364
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala467
-rw-r--r--core/src/main/scala/spark/network/ConnectionManagerTest.scala74
-rw-r--r--core/src/main/scala/spark/network/Message.scala219
-rw-r--r--core/src/main/scala/spark/network/ReceiverTest.scala20
-rw-r--r--core/src/main/scala/spark/network/SenderTest.scala53
-rw-r--r--core/src/main/scala/spark/partial/ApproximateActionListener.scala66
-rw-r--r--core/src/main/scala/spark/partial/ApproximateEvaluator.scala10
-rw-r--r--core/src/main/scala/spark/partial/BoundedDouble.scala8
-rw-r--r--core/src/main/scala/spark/partial/CountEvaluator.scala38
-rw-r--r--core/src/main/scala/spark/partial/GroupedCountEvaluator.scala62
-rw-r--r--core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala65
-rw-r--r--core/src/main/scala/spark/partial/GroupedSumEvaluator.scala72
-rw-r--r--core/src/main/scala/spark/partial/MeanEvaluator.scala41
-rw-r--r--core/src/main/scala/spark/partial/PartialResult.scala86
-rw-r--r--core/src/main/scala/spark/partial/StudentTCacher.scala26
-rw-r--r--core/src/main/scala/spark/partial/SumEvaluator.scala51
-rw-r--r--core/src/main/scala/spark/scheduler/ActiveJob.scala18
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala532
-rw-r--r--core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala30
-rw-r--r--core/src/main/scala/spark/scheduler/JobListener.scala11
-rw-r--r--core/src/main/scala/spark/scheduler/JobResult.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/JobWaiter.scala43
-rw-r--r--core/src/main/scala/spark/scheduler/ResultTask.scala (renamed from core/src/main/scala/spark/ResultTask.scala)15
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala135
-rw-r--r--core/src/main/scala/spark/scheduler/Stage.scala86
-rw-r--r--core/src/main/scala/spark/scheduler/Task.scala11
-rw-r--r--core/src/main/scala/spark/scheduler/TaskResult.scala34
-rw-r--r--core/src/main/scala/spark/scheduler/TaskScheduler.scala27
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala16
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSet.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala (renamed from core/src/main/scala/spark/LocalScheduler.scala)43
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala364
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala (renamed from core/src/main/scala/spark/MesosScheduler.scala)271
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala32
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala (renamed from core/src/main/scala/spark/SimpleJob.scala)259
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala507
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala516
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerWorker.scala142
-rw-r--r--core/src/main/scala/spark/storage/BlockMessage.scala219
-rw-r--r--core/src/main/scala/spark/storage/BlockMessageArray.scala140
-rw-r--r--core/src/main/scala/spark/storage/BlockStore.scala282
-rw-r--r--core/src/main/scala/spark/storage/StorageLevel.scala78
-rw-r--r--core/src/main/scala/spark/util/ByteBufferInputStream.scala30
-rw-r--r--core/src/main/scala/spark/util/StatCounter.scala89
-rw-r--r--core/src/test/scala/spark/CacheTrackerSuite.scala86
-rw-r--r--core/src/test/scala/spark/MesosSchedulerSuite.scala2
-rw-r--r--core/src/test/scala/spark/UtilsSuite.scala2
-rw-r--r--project/SparkBuild.scala10
86 files changed, 6374 insertions, 1409 deletions
diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
index 7084ff97d9..4c18cb9134 100644
--- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
+++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
@@ -11,6 +11,7 @@ import scala.xml.{XML,NodeSeq}
import scala.collection.mutable.ArrayBuffer
import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream}
+import java.nio.ByteBuffer
object WikipediaPageRankStandalone {
def main(args: Array[String]) {
@@ -118,23 +119,23 @@ class WPRSerializer extends spark.Serializer {
}
class WPRSerializerInstance extends SerializerInstance {
- def serialize[T](t: T): Array[Byte] = {
+ def serialize[T](t: T): ByteBuffer = {
throw new UnsupportedOperationException()
}
- def deserialize[T](bytes: Array[Byte]): T = {
+ def deserialize[T](bytes: ByteBuffer): T = {
throw new UnsupportedOperationException()
}
- def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
+ def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
throw new UnsupportedOperationException()
}
- def outputStream(s: OutputStream): SerializationStream = {
+ def serializeStream(s: OutputStream): SerializationStream = {
new WPRSerializationStream(s)
}
- def inputStream(s: InputStream): DeserializationStream = {
+ def deserializeStream(s: InputStream): DeserializationStream = {
new WPRDeserializationStream(s)
}
}
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
new file mode 100644
index 0000000000..e00a0d80fa
--- /dev/null
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -0,0 +1,70 @@
+package spark
+
+import java.io.EOFException
+import java.net.URL
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+
+import spark.storage.BlockException
+import spark.storage.BlockManagerId
+
+import it.unimi.dsi.fastutil.io.FastBufferedInputStream
+
+
+class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
+ def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
+ logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
+ val ser = SparkEnv.get.serializer.newInstance()
+ val blockManager = SparkEnv.get.blockManager
+
+ val startTime = System.currentTimeMillis
+ val addresses = SparkEnv.get.mapOutputTracker.getServerAddresses(shuffleId)
+ logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
+ shuffleId, reduceId, System.currentTimeMillis - startTime))
+
+ val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[Int]]
+ for ((address, index) <- addresses.zipWithIndex) {
+ splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += index
+ }
+
+ val blocksByAddress: Seq[(BlockManagerId, Seq[String])] = splitsByAddress.toSeq.map {
+ case (address, splits) =>
+ (address, splits.map(i => "shuffleid_%d_%d_%d".format(shuffleId, i, reduceId)))
+ }
+
+ try {
+ val blockOptions = blockManager.get(blocksByAddress)
+ logDebug("Fetching map output blocks for shuffle %d, reduce %d took %d ms".format(
+ shuffleId, reduceId, System.currentTimeMillis - startTime))
+ blockOptions.foreach(x => {
+ val (blockId, blockOption) = x
+ blockOption match {
+ case Some(block) => {
+ val values = block.asInstanceOf[Iterator[Any]]
+ for(value <- values) {
+ val v = value.asInstanceOf[(K, V)]
+ func(v._1, v._2)
+ }
+ }
+ case None => {
+ throw new BlockException(blockId, "Did not get block " + blockId)
+ }
+ }
+ })
+ } catch {
+ case be: BlockException => {
+ val regex = "shuffledid_([0-9]*)_([0-9]*)_([0-9]]*)".r
+ be.blockId match {
+ case regex(sId, mId, rId) => {
+ val address = addresses(mId.toInt)
+ throw new FetchFailedException(address, sId.toInt, mId.toInt, rId.toInt, be)
+ }
+ case _ => {
+ throw be
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala
index 1162e34ab0..fa5dcee7bb 100644
--- a/core/src/main/scala/spark/BoundedMemoryCache.scala
+++ b/core/src/main/scala/spark/BoundedMemoryCache.scala
@@ -90,7 +90,8 @@ class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
- SparkEnv.get.cacheTracker.dropEntry(datasetId, partition)
+ // TODO: remove BoundedMemoryCache
+ SparkEnv.get.cacheTracker.dropEntry(datasetId.asInstanceOf[(Int, Int)]._2, partition)
}
}
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
index 4867829c17..64b4af0ae2 100644
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -1,11 +1,17 @@
package spark
-import scala.actors._
-import scala.actors.Actor._
-import scala.actors.remote._
+import akka.actor._
+import akka.actor.Actor
+import akka.actor.Actor._
+import akka.util.duration._
+
+import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
+import spark.storage.BlockManager
+import spark.storage.StorageLevel
+
sealed trait CacheTrackerMessage
case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
@@ -18,8 +24,8 @@ case object GetCacheStatus extends CacheTrackerMessage
case object GetCacheLocations extends CacheTrackerMessage
case object StopCacheTracker extends CacheTrackerMessage
-
-class CacheTrackerActor extends DaemonActor with Logging {
+class CacheTrackerActor extends Actor with Logging {
+ // TODO: Should probably store (String, CacheType) tuples
private val locs = new HashMap[Int, Array[List[String]]]
/**
@@ -28,109 +34,93 @@ class CacheTrackerActor extends DaemonActor with Logging {
private val slaveCapacity = new HashMap[String, Long]
private val slaveUsage = new HashMap[String, Long]
- // TODO: Should probably store (String, CacheType) tuples
-
private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
- def act() {
- val port = System.getProperty("spark.master.port").toInt
- RemoteActor.alive(port)
- RemoteActor.register('CacheTracker, self)
- logInfo("Registered actor on port " + port)
-
- loop {
- react {
- case SlaveCacheStarted(host: String, size: Long) =>
- logInfo("Started slave cache (size %s) on %s".format(
- Utils.memoryBytesToString(size), host))
- slaveCapacity.put(host, size)
- slaveUsage.put(host, 0)
- reply('OK)
-
- case RegisterRDD(rddId: Int, numPartitions: Int) =>
- logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
- locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
- reply('OK)
-
- case AddedToCache(rddId, partition, host, size) =>
- if (size > 0) {
- slaveUsage.put(host, getCacheUsage(host) + size)
- logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
- rddId, partition, host, Utils.memoryBytesToString(size),
- Utils.memoryBytesToString(getCacheAvailable(host))))
- } else {
- logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host))
- }
- locs(rddId)(partition) = host :: locs(rddId)(partition)
- reply('OK)
-
- case DroppedFromCache(rddId, partition, host, size) =>
- if (size > 0) {
- logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
- rddId, partition, host, Utils.memoryBytesToString(size),
- Utils.memoryBytesToString(getCacheAvailable(host))))
- slaveUsage.put(host, getCacheUsage(host) - size)
-
- // Do a sanity check to make sure usage is greater than 0.
- val usage = getCacheUsage(host)
- if (usage < 0) {
- logError("Cache usage on %s is negative (%d)".format(host, usage))
- }
- } else {
- logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host))
- }
- locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
- reply('OK)
+ def receive = {
+ case SlaveCacheStarted(host: String, size: Long) =>
+ logInfo("Started slave cache (size %s) on %s".format(
+ Utils.memoryBytesToString(size), host))
+ slaveCapacity.put(host, size)
+ slaveUsage.put(host, 0)
+ self.reply(true)
- case MemoryCacheLost(host) =>
- logInfo("Memory cache lost on " + host)
- // TODO: Drop host from the memory locations list of all RDDs
-
- case GetCacheLocations =>
- logInfo("Asked for current cache locations")
- reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())})
-
- case GetCacheStatus =>
- val status = slaveCapacity.map { case (host,capacity) =>
- (host, capacity, getCacheUsage(host))
- }.toSeq
- reply(status)
-
- case StopCacheTracker =>
- reply('OK)
- exit()
+ case RegisterRDD(rddId: Int, numPartitions: Int) =>
+ logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
+ locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
+ self.reply(true)
+
+ case AddedToCache(rddId, partition, host, size) =>
+ slaveUsage.put(host, getCacheUsage(host) + size)
+ logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
+ rddId, partition, host, Utils.memoryBytesToString(size),
+ Utils.memoryBytesToString(getCacheAvailable(host))))
+ locs(rddId)(partition) = host :: locs(rddId)(partition)
+ self.reply(true)
+
+ case DroppedFromCache(rddId, partition, host, size) =>
+ logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
+ rddId, partition, host, Utils.memoryBytesToString(size),
+ Utils.memoryBytesToString(getCacheAvailable(host))))
+ slaveUsage.put(host, getCacheUsage(host) - size)
+ // Do a sanity check to make sure usage is greater than 0.
+ val usage = getCacheUsage(host)
+ if (usage < 0) {
+ logError("Cache usage on %s is negative (%d)".format(host, usage))
}
- }
- }
-}
+ locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
+ self.reply(true)
+ case MemoryCacheLost(host) =>
+ logInfo("Memory cache lost on " + host)
+ for ((id, locations) <- locs) {
+ for (i <- 0 until locations.length) {
+ locations(i) = locations(i).filterNot(_ == host)
+ }
+ }
+ self.reply(true)
-class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
- // Tracker actor on the master, or remote reference to it on workers
- var trackerActor: AbstractActor = null
+ case GetCacheLocations =>
+ logInfo("Asked for current cache locations")
+ self.reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())})
- val registeredRddIds = new HashSet[Int]
+ case GetCacheStatus =>
+ val status = slaveCapacity.map { case (host, capacity) =>
+ (host, capacity, getCacheUsage(host))
+ }.toSeq
+ self.reply(status)
- // Stores map results for various splits locally
- val cache = theCache.newKeySpace()
+ case StopCacheTracker =>
+ logInfo("CacheTrackerActor Server stopped!")
+ self.reply(true)
+ self.exit()
+ }
+}
+class CacheTracker(isMaster: Boolean, blockManager: BlockManager) extends Logging {
+ // Tracker actor on the master, or remote reference to it on workers
+ val ip: String = System.getProperty("spark.master.host", "localhost")
+ val port: Int = System.getProperty("spark.master.port", "7077").toInt
+ val aName: String = "CacheTracker"
+
if (isMaster) {
- val tracker = new CacheTrackerActor
- tracker.start()
- trackerActor = tracker
+ }
+
+ var trackerActor: ActorRef = if (isMaster) {
+ val actor = actorOf(new CacheTrackerActor)
+ remote.register(aName, actor)
+ actor.start()
+ logInfo("Registered CacheTrackerActor actor @ " + ip + ":" + port)
+ actor
} else {
- val host = System.getProperty("spark.master.host")
- val port = System.getProperty("spark.master.port").toInt
- trackerActor = RemoteActor.select(Node(host, port), 'CacheTracker)
+ remote.actorFor(aName, ip, port)
}
- // Report the cache being started.
- trackerActor !? SlaveCacheStarted(Utils.getHost, cache.getCapacity)
+ val registeredRddIds = new HashSet[Int]
// Remembers which splits are currently being loaded (on worker nodes)
- val loading = new HashSet[(Int, Int)]
+ val loading = new HashSet[String]
// Registers an RDD (on master only)
def registerRDD(rddId: Int, numPartitions: Int) {
@@ -138,24 +128,33 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
if (!registeredRddIds.contains(rddId)) {
logInfo("Registering RDD ID " + rddId + " with cache")
registeredRddIds += rddId
- trackerActor !? RegisterRDD(rddId, numPartitions)
+ (trackerActor ? RegisterRDD(rddId, numPartitions)).as[Any] match {
+ case Some(true) =>
+ logInfo("CacheTracker registerRDD " + RegisterRDD(rddId, numPartitions) + " successfully.")
+ case Some(oops) =>
+ logError("CacheTracker registerRDD" + RegisterRDD(rddId, numPartitions) + " failed: " + oops)
+ case None =>
+ logError("CacheTracker registerRDD None. " + RegisterRDD(rddId, numPartitions))
+ throw new SparkException("Internal error: CacheTracker registerRDD None.")
+ }
}
}
}
-
- // Get a snapshot of the currently known locations
- def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
- (trackerActor !? GetCacheLocations) match {
- case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]]
-
- case _ => throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap")
+
+ // For BlockManager.scala only
+ def cacheLost(host: String) {
+ (trackerActor ? MemoryCacheLost(host)).as[Any] match {
+ case Some(true) =>
+ logInfo("CacheTracker successfully removed entries on " + host)
+ case _ =>
+ logError("CacheTracker did not reply to MemoryCacheLost")
}
}
// Get the usage status of slave caches. Each tuple in the returned sequence
// is in the form of (host name, capacity, usage).
def getCacheStatus(): Seq[(String, Long, Long)] = {
- (trackerActor !? GetCacheStatus) match {
+ (trackerActor ? GetCacheStatus) match {
case h: Seq[(String, Long, Long)] => h.asInstanceOf[Seq[(String, Long, Long)]]
case _ =>
@@ -164,75 +163,94 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
}
}
+ // For BlockManager.scala only
+ def notifyTheCacheTrackerFromBlockManager(t: AddedToCache) {
+ (trackerActor ? t).as[Any] match {
+ case Some(true) =>
+ logInfo("CacheTracker notifyTheCacheTrackerFromBlockManager successfully.")
+ case Some(oops) =>
+ logError("CacheTracker notifyTheCacheTrackerFromBlockManager failed: " + oops)
+ case None =>
+ logError("CacheTracker notifyTheCacheTrackerFromBlockManager None.")
+ }
+ }
+
+ // Get a snapshot of the currently known locations
+ def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
+ (trackerActor ? GetCacheLocations).as[Any] match {
+ case Some(h: HashMap[_, _]) =>
+ h.asInstanceOf[HashMap[Int, Array[List[String]]]]
+
+ case _ =>
+ throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap")
+ }
+ }
+
// Gets or computes an RDD split
- def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T]): Iterator[T] = {
- logInfo("Looking for RDD partition %d:%d".format(rdd.id, split.index))
- val cachedVal = cache.get(rdd.id, split.index)
- if (cachedVal != null) {
- // Split is in cache, so just return its values
- logInfo("Found partition in cache!")
- return cachedVal.asInstanceOf[Array[T]].iterator
- } else {
- // Mark the split as loading (unless someone else marks it first)
- val key = (rdd.id, split.index)
- loading.synchronized {
- while (loading.contains(key)) {
- // Someone else is loading it; let's wait for them
- try { loading.wait() } catch { case _ => }
- }
- // See whether someone else has successfully loaded it. The main way this would fail
- // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
- // partition but we didn't want to make space for it. However, that case is unlikely
- // because it's unlikely that two threads would work on the same RDD partition. One
- // downside of the current code is that threads wait serially if this does happen.
- val cachedVal = cache.get(rdd.id, split.index)
- if (cachedVal != null) {
- return cachedVal.asInstanceOf[Array[T]].iterator
- }
- // Nobody's loading it and it's not in the cache; let's load it ourselves
- loading.add(key)
- }
- // If we got here, we have to load the split
- // Tell the master that we're doing so
-
- // TODO: fetch any remote copy of the split that may be available
- logInfo("Computing partition " + split)
- var array: Array[T] = null
- var putResponse: CachePutResponse = null
- try {
- array = rdd.compute(split).toArray(m)
- putResponse = cache.put(rdd.id, split.index, array)
- } finally {
- // Tell other threads that we've finished our attempt to load the key (whether or not
- // we've actually succeeded to put it in the map)
+ def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = {
+ val key = "rdd:%d:%d".format(rdd.id, split.index)
+ logInfo("Cache key is " + key)
+ blockManager.get(key) match {
+ case Some(cachedValues) =>
+ // Split is in cache, so just return its values
+ logInfo("Found partition in cache!")
+ return cachedValues.asInstanceOf[Iterator[T]]
+
+ case None =>
+ // Mark the split as loading (unless someone else marks it first)
loading.synchronized {
- loading.remove(key)
- loading.notifyAll()
+ if (loading.contains(key)) {
+ logInfo("Loading contains " + key + ", waiting...")
+ while (loading.contains(key)) {
+ try {loading.wait()} catch {case _ =>}
+ }
+ logInfo("Loading no longer contains " + key + ", so returning cached result")
+ // See whether someone else has successfully loaded it. The main way this would fail
+ // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
+ // partition but we didn't want to make space for it. However, that case is unlikely
+ // because it's unlikely that two threads would work on the same RDD partition. One
+ // downside of the current code is that threads wait serially if this does happen.
+ blockManager.get(key) match {
+ case Some(values) =>
+ return values.asInstanceOf[Iterator[T]]
+ case None =>
+ logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
+ loading.add(key)
+ }
+ } else {
+ loading.add(key)
+ }
}
- }
-
- putResponse match {
- case CachePutSuccess(size) => {
- // Tell the master that we added the entry. Don't return until it
- // replies so it can properly schedule future tasks that use this RDD.
- trackerActor !? AddedToCache(rdd.id, split.index, Utils.getHost, size)
+ // If we got here, we have to load the split
+ // Tell the master that we're doing so
+ //val host = System.getProperty("spark.hostname", Utils.localHostName)
+ //val future = trackerActor !! AddedToCache(rdd.id, split.index, host)
+ // TODO: fetch any remote copy of the split that may be available
+ // TODO: also register a listener for when it unloads
+ logInfo("Computing partition " + split)
+ try {
+ val values = new ArrayBuffer[Any]
+ values ++= rdd.compute(split)
+ blockManager.put(key, values.iterator, storageLevel, false)
+ //future.apply() // Wait for the reply from the cache tracker
+ return values.iterator.asInstanceOf[Iterator[T]]
+ } finally {
+ loading.synchronized {
+ loading.remove(key)
+ loading.notifyAll()
+ }
}
- case _ => null
- }
- return array.iterator
}
}
// Called by the Cache to report that an entry has been dropped from it
- def dropEntry(datasetId: Any, partition: Int) {
- datasetId match {
- //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here.
- case (cache.keySpaceId, rddId: Int) => trackerActor !! DroppedFromCache(rddId, partition, Utils.getHost)
- }
+ def dropEntry(rddId: Int, partition: Int) {
+ //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here.
+ trackerActor !! DroppedFromCache(rddId, partition, Utils.localHostName())
}
def stop() {
- trackerActor !? StopCacheTracker
+ trackerActor !! StopCacheTracker
registeredRddIds.clear()
trackerActor = null
}
diff --git a/core/src/main/scala/spark/CoGroupedRDD.scala b/core/src/main/scala/spark/CoGroupedRDD.scala
index 93f453bc5e..3543c8afa8 100644
--- a/core/src/main/scala/spark/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/CoGroupedRDD.scala
@@ -22,11 +22,12 @@ class CoGroupAggregator
{ (b1, b2) => b1 ++ b2 })
with Serializable
-class CoGroupedRDD[K](rdds: Seq[RDD[(_, _)]], part: Partitioner)
+class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging {
val aggr = new CoGroupAggregator
+ @transient
override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
for ((rdd, index) <- rdds.zipWithIndex) {
@@ -67,9 +68,10 @@ class CoGroupedRDD[K](rdds: Seq[RDD[(_, _)]], part: Partitioner)
override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit]
+ val numRdds = split.deps.size
val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
- map.getOrElseUpdate(k, Array.fill(rdds.size)(new ArrayBuffer[Any]))
+ map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any]))
}
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, itsSplit) => {
diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala
deleted file mode 100644
index 1b4af9d84c..0000000000
--- a/core/src/main/scala/spark/DAGScheduler.scala
+++ /dev/null
@@ -1,374 +0,0 @@
-package spark
-
-import java.util.concurrent.atomic.AtomicInteger
-
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map}
-
-/**
- * A task created by the DAG scheduler. Knows its stage ID and map ouput tracker generation.
- */
-abstract class DAGTask[T](val runId: Int, val stageId: Int) extends Task[T] {
- val gen = SparkEnv.get.mapOutputTracker.getGeneration
- override def generation: Option[Long] = Some(gen)
-}
-
-/**
- * A completion event passed by the underlying task scheduler to the DAG scheduler.
- */
-case class CompletionEvent(
- task: DAGTask[_],
- reason: TaskEndReason,
- result: Any,
- accumUpdates: Map[Long, Any])
-
-/**
- * Various possible reasons why a DAG task ended. The underlying scheduler is supposed to retry
- * tasks several times for "ephemeral" failures, and only report back failures that require some
- * old stages to be resubmitted, such as shuffle map fetch failures.
- */
-sealed trait TaskEndReason
-case object Success extends TaskEndReason
-case class FetchFailed(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
-case class ExceptionFailure(exception: Throwable) extends TaskEndReason
-case class OtherFailure(message: String) extends TaskEndReason
-
-/**
- * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
- * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
- * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
- * and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
- */
-private trait DAGScheduler extends Scheduler with Logging {
- // Must be implemented by subclasses to start running a set of tasks. The subclass should also
- // attempt to run different sets of tasks in the order given by runId (lower values first).
- def submitTasks(tasks: Seq[Task[_]], runId: Int): Unit
-
- // Must be called by subclasses to report task completions or failures.
- def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]) {
- lock.synchronized {
- val dagTask = task.asInstanceOf[DAGTask[_]]
- eventQueues.get(dagTask.runId) match {
- case Some(queue) =>
- queue += CompletionEvent(dagTask, reason, result, accumUpdates)
- lock.notifyAll()
- case None =>
- logInfo("Ignoring completion event for DAG job " + dagTask.runId + " because it's gone")
- }
- }
- }
-
- // The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
- // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
- // as more failure events come in
- val RESUBMIT_TIMEOUT = 2000L
-
- // The time, in millis, to wake up between polls of the completion queue in order to potentially
- // resubmit failed stages
- val POLL_TIMEOUT = 500L
-
- private val lock = new Object // Used for access to the entire DAGScheduler
-
- private val eventQueues = new HashMap[Int, Queue[CompletionEvent]] // Indexed by run ID
-
- val nextRunId = new AtomicInteger(0)
-
- val nextStageId = new AtomicInteger(0)
-
- val idToStage = new HashMap[Int, Stage]
-
- val shuffleToMapStage = new HashMap[Int, Stage]
-
- var cacheLocs = new HashMap[Int, Array[List[String]]]
-
- val env = SparkEnv.get
- val cacheTracker = env.cacheTracker
- val mapOutputTracker = env.mapOutputTracker
-
- def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
- cacheLocs(rdd.id)
- }
-
- def updateCacheLocs() {
- cacheLocs = cacheTracker.getLocationsSnapshot()
- }
-
- def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = {
- shuffleToMapStage.get(shuf.shuffleId) match {
- case Some(stage) => stage
- case None =>
- val stage = newStage(shuf.rdd, Some(shuf))
- shuffleToMapStage(shuf.shuffleId) = stage
- stage
- }
- }
-
- def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = {
- // Kind of ugly: need to register RDDs with the cache and map output tracker here
- // since we can't do it in the RDD constructor because # of splits is unknown
- cacheTracker.registerRDD(rdd.id, rdd.splits.size)
- if (shuffleDep != None) {
- mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
- }
- val id = nextStageId.getAndIncrement()
- val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd))
- idToStage(id) = stage
- stage
- }
-
- def getParentStages(rdd: RDD[_]): List[Stage] = {
- val parents = new HashSet[Stage]
- val visited = new HashSet[RDD[_]]
- def visit(r: RDD[_]) {
- if (!visited(r)) {
- visited += r
- // Kind of ugly: need to register RDDs with the cache here since
- // we can't do it in its constructor because # of splits is unknown
- cacheTracker.registerRDD(r.id, r.splits.size)
- for (dep <- r.dependencies) {
- dep match {
- case shufDep: ShuffleDependency[_,_,_] =>
- parents += getShuffleMapStage(shufDep)
- case _ =>
- visit(dep.rdd)
- }
- }
- }
- }
- visit(rdd)
- parents.toList
- }
-
- def getMissingParentStages(stage: Stage): List[Stage] = {
- val missing = new HashSet[Stage]
- val visited = new HashSet[RDD[_]]
- def visit(rdd: RDD[_]) {
- if (!visited(rdd)) {
- visited += rdd
- val locs = getCacheLocs(rdd)
- for (p <- 0 until rdd.splits.size) {
- if (locs(p) == Nil) {
- for (dep <- rdd.dependencies) {
- dep match {
- case shufDep: ShuffleDependency[_,_,_] =>
- val stage = getShuffleMapStage(shufDep)
- if (!stage.isAvailable) {
- missing += stage
- }
- case narrowDep: NarrowDependency[_] =>
- visit(narrowDep.rdd)
- }
- }
- }
- }
- }
- }
- visit(stage.rdd)
- missing.toList
- }
-
- override def runJob[T, U](
- finalRdd: RDD[T],
- func: (TaskContext, Iterator[T]) => U,
- partitions: Seq[Int],
- allowLocal: Boolean)
- (implicit m: ClassManifest[U]): Array[U] = {
- lock.synchronized {
- val runId = nextRunId.getAndIncrement()
-
- val outputParts = partitions.toArray
- val numOutputParts: Int = partitions.size
- val finalStage = newStage(finalRdd, None)
- val results = new Array[U](numOutputParts)
- val finished = new Array[Boolean](numOutputParts)
- var numFinished = 0
-
- val waiting = new HashSet[Stage] // stages we need to run whose parents aren't done
- val running = new HashSet[Stage] // stages we are running right now
- val failed = new HashSet[Stage] // stages that must be resubmitted due to fetch failures
- val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // missing tasks from each stage
- var lastFetchFailureTime: Long = 0 // used to wait a bit to avoid repeated resubmits
-
- SparkEnv.set(env)
-
- updateCacheLocs()
-
- logInfo("Final stage: " + finalStage)
- logInfo("Parents of final stage: " + finalStage.parents)
- logInfo("Missing parents: " + getMissingParentStages(finalStage))
-
- // Optimization for short actions like first() and take() that can be computed locally
- // without shipping tasks to the cluster.
- if (allowLocal && finalStage.parents.size == 0 && numOutputParts == 1) {
- logInfo("Computing the requested partition locally")
- val split = finalRdd.splits(outputParts(0))
- val taskContext = new TaskContext(finalStage.id, outputParts(0), 0)
- return Array(func(taskContext, finalRdd.iterator(split)))
- }
-
- // Register the job ID so that we can get completion events for it
- eventQueues(runId) = new Queue[CompletionEvent]
-
- def submitStage(stage: Stage) {
- if (!waiting(stage) && !running(stage)) {
- val missing = getMissingParentStages(stage)
- if (missing == Nil) {
- logInfo("Submitting " + stage + ", which has no missing parents")
- submitMissingTasks(stage)
- running += stage
- } else {
- for (parent <- missing) {
- submitStage(parent)
- }
- waiting += stage
- }
- }
- }
-
- def submitMissingTasks(stage: Stage) {
- // Get our pending tasks and remember them in our pendingTasks entry
- val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
- var tasks = ArrayBuffer[Task[_]]()
- if (stage == finalStage) {
- for (id <- 0 until numOutputParts if (!finished(id))) {
- val part = outputParts(id)
- val locs = getPreferredLocs(finalRdd, part)
- tasks += new ResultTask(runId, finalStage.id, finalRdd, func, part, locs, id)
- }
- } else {
- for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
- val locs = getPreferredLocs(stage.rdd, p)
- tasks += new ShuffleMapTask(runId, stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
- }
- }
- myPending ++= tasks
- submitTasks(tasks, runId)
- }
-
- submitStage(finalStage)
-
- while (numFinished != numOutputParts) {
- val eventOption = waitForEvent(runId, POLL_TIMEOUT)
- val time = System.currentTimeMillis // TODO: use a pluggable clock for testability
-
- // If we got an event off the queue, mark the task done or react to a fetch failure
- if (eventOption != None) {
- val evt = eventOption.get
- val stage = idToStage(evt.task.stageId)
- pendingTasks(stage) -= evt.task
- if (evt.reason == Success) {
- // A task ended
- logInfo("Completed " + evt.task)
- Accumulators.add(evt.accumUpdates)
- evt.task match {
- case rt: ResultTask[_, _] =>
- results(rt.outputId) = evt.result.asInstanceOf[U]
- finished(rt.outputId) = true
- numFinished += 1
- case smt: ShuffleMapTask =>
- val stage = idToStage(smt.stageId)
- stage.addOutputLoc(smt.partition, evt.result.asInstanceOf[String])
- if (running.contains(stage) && pendingTasks(stage).isEmpty) {
- logInfo(stage + " finished; looking for newly runnable stages")
- running -= stage
- if (stage.shuffleDep != None) {
- mapOutputTracker.registerMapOutputs(
- stage.shuffleDep.get.shuffleId,
- stage.outputLocs.map(_.head).toArray)
- }
- updateCacheLocs()
- val newlyRunnable = new ArrayBuffer[Stage]
- for (stage <- waiting if getMissingParentStages(stage) == Nil) {
- newlyRunnable += stage
- }
- waiting --= newlyRunnable
- running ++= newlyRunnable
- for (stage <- newlyRunnable) {
- submitMissingTasks(stage)
- }
- }
- }
- } else {
- evt.reason match {
- case FetchFailed(serverUri, shuffleId, mapId, reduceId) =>
- // Mark the stage that the reducer was in as unrunnable
- val failedStage = idToStage(evt.task.stageId)
- running -= failedStage
- failed += failedStage
- // TODO: Cancel running tasks in the stage
- logInfo("Marking " + failedStage + " for resubmision due to a fetch failure")
- // Mark the map whose fetch failed as broken in the map stage
- val mapStage = shuffleToMapStage(shuffleId)
- mapStage.removeOutputLoc(mapId, serverUri)
- mapOutputTracker.unregisterMapOutput(shuffleId, mapId, serverUri)
- logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission")
- failed += mapStage
- // Remember that a fetch failed now; this is used to resubmit the broken
- // stages later, after a small wait (to give other tasks the chance to fail)
- lastFetchFailureTime = time
- // TODO: If there are a lot of fetch failures on the same node, maybe mark all
- // outputs on the node as dead.
- case _ =>
- // Non-fetch failure -- probably a bug in the job, so bail out
- throw new SparkException("Task failed: " + evt.task + ", reason: " + evt.reason)
- // TODO: Cancel all tasks that are still running
- }
- }
- } // end if (evt != null)
-
- // If fetches have failed recently and we've waited for the right timeout,
- // resubmit all the failed stages
- if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
- logInfo("Resubmitting failed stages")
- updateCacheLocs()
- for (stage <- failed) {
- submitStage(stage)
- }
- failed.clear()
- }
- }
-
- eventQueues -= runId
- return results
- }
- }
-
- def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
- // If the partition is cached, return the cache locations
- val cached = getCacheLocs(rdd)(partition)
- if (cached != Nil) {
- return cached
- }
- // If the RDD has some placement preferences (as is the case for input RDDs), get those
- val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList
- if (rddPrefs != Nil) {
- return rddPrefs
- }
- // If the RDD has narrow dependencies, pick the first partition of the first narrow dep
- // that has any placement preferences. Ideally we would choose based on transfer sizes,
- // but this will do for now.
- rdd.dependencies.foreach(_ match {
- case n: NarrowDependency[_] =>
- for (inPart <- n.getParents(partition)) {
- val locs = getPreferredLocs(n.rdd, inPart)
- if (locs != Nil)
- return locs;
- }
- case _ =>
- })
- return Nil
- }
-
- // Assumes that lock is held on entrance, but will release it to wait for the next event.
- def waitForEvent(runId: Int, timeout: Long): Option[CompletionEvent] = {
- val endTime = System.currentTimeMillis() + timeout // TODO: Use pluggable clock for testing
- while (eventQueues(runId).isEmpty) {
- val time = System.currentTimeMillis()
- if (time >= endTime) {
- return None
- } else {
- lock.wait(endTime - time)
- }
- }
- return Some(eventQueues(runId).dequeue())
- }
-}
diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala
index d93c84924a..c0ff94acc6 100644
--- a/core/src/main/scala/spark/Dependency.scala
+++ b/core/src/main/scala/spark/Dependency.scala
@@ -8,7 +8,7 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd, false) {
class ShuffleDependency[K, V, C](
val shuffleId: Int,
- rdd: RDD[(K, V)],
+ @transient rdd: RDD[(K, V)],
val aggregator: Aggregator[K, V, C],
val partitioner: Partitioner)
extends Dependency(rdd, true)
diff --git a/core/src/main/scala/spark/DiskSpillingCache.scala b/core/src/main/scala/spark/DiskSpillingCache.scala
deleted file mode 100644
index e11466eb64..0000000000
--- a/core/src/main/scala/spark/DiskSpillingCache.scala
+++ /dev/null
@@ -1,75 +0,0 @@
-package spark
-
-import java.io.File
-import java.io.{FileOutputStream,FileInputStream}
-import java.io.IOException
-import java.util.LinkedHashMap
-import java.util.UUID
-
-// TODO: cache into a separate directory using Utils.createTempDir
-// TODO: clean up disk cache afterwards
-class DiskSpillingCache extends BoundedMemoryCache {
- private val diskMap = new LinkedHashMap[(Any, Int), File](32, 0.75f, true)
-
- override def get(datasetId: Any, partition: Int): Any = {
- synchronized {
- val ser = SparkEnv.get.serializer.newInstance()
- super.get(datasetId, partition) match {
- case bytes: Any => // found in memory
- ser.deserialize(bytes.asInstanceOf[Array[Byte]])
-
- case _ => diskMap.get((datasetId, partition)) match {
- case file: Any => // found on disk
- try {
- val startTime = System.currentTimeMillis
- val bytes = new Array[Byte](file.length.toInt)
- new FileInputStream(file).read(bytes)
- val timeTaken = System.currentTimeMillis - startTime
- logInfo("Reading key (%s, %d) of size %d bytes from disk took %d ms".format(
- datasetId, partition, file.length, timeTaken))
- super.put(datasetId, partition, bytes)
- ser.deserialize(bytes.asInstanceOf[Array[Byte]])
- } catch {
- case e: IOException =>
- logWarning("Failed to read key (%s, %d) from disk at %s: %s".format(
- datasetId, partition, file.getPath(), e.getMessage()))
- diskMap.remove((datasetId, partition)) // remove dead entry
- null
- }
-
- case _ => // not found
- null
- }
- }
- }
- }
-
- override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
- var ser = SparkEnv.get.serializer.newInstance()
- super.put(datasetId, partition, ser.serialize(value))
- }
-
- /**
- * Spill the given entry to disk. Assumes that a lock is held on the
- * DiskSpillingCache. Assumes that entry.value is a byte array.
- */
- override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
- logInfo("Spilling key (%s, %d) of size %d to make space".format(
- datasetId, partition, entry.size))
- val cacheDir = System.getProperty(
- "spark.diskSpillingCache.cacheDir",
- System.getProperty("java.io.tmpdir"))
- val file = new File(cacheDir, "spark-dsc-" + UUID.randomUUID.toString)
- try {
- val stream = new FileOutputStream(file)
- stream.write(entry.value.asInstanceOf[Array[Byte]])
- stream.close()
- diskMap.put((datasetId, partition), file)
- } catch {
- case e: IOException =>
- logWarning("Failed to spill key (%s, %d) to disk at %s: %s".format(
- datasetId, partition, file.getPath(), e.getMessage()))
- // Do nothing and let the entry be discarded
- }
- }
-}
diff --git a/core/src/main/scala/spark/DoubleRDDFunctions.scala b/core/src/main/scala/spark/DoubleRDDFunctions.scala
new file mode 100644
index 0000000000..1fbf66b7de
--- /dev/null
+++ b/core/src/main/scala/spark/DoubleRDDFunctions.scala
@@ -0,0 +1,39 @@
+package spark
+
+import spark.partial.BoundedDouble
+import spark.partial.MeanEvaluator
+import spark.partial.PartialResult
+import spark.partial.SumEvaluator
+
+import spark.util.StatCounter
+
+/**
+ * Extra functions available on RDDs of Doubles through an implicit conversion.
+ */
+class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
+ def sum(): Double = {
+ self.reduce(_ + _)
+ }
+
+ def stats(): StatCounter = {
+ self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b))
+ }
+
+ def mean(): Double = stats().mean
+
+ def variance(): Double = stats().variance
+
+ def stdev(): Double = stats().stdev
+
+ def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
+ val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
+ val evaluator = new MeanEvaluator(self.splits.size, confidence)
+ self.context.runApproximateJob(self, processPartition, evaluator, timeout)
+ }
+
+ def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
+ val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
+ val evaluator = new SumEvaluator(self.splits.size, confidence)
+ self.context.runApproximateJob(self, processPartition, evaluator, timeout)
+ }
+}
diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala
index c795b6c351..af9eb9c878 100644
--- a/core/src/main/scala/spark/Executor.scala
+++ b/core/src/main/scala/spark/Executor.scala
@@ -10,9 +10,10 @@ import scala.collection.mutable.ArrayBuffer
import com.google.protobuf.ByteString
import org.apache.mesos._
-import org.apache.mesos.Protos._
+import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
import spark.broadcast._
+import spark.scheduler._
/**
* The Mesos executor for Spark.
@@ -29,6 +30,9 @@ class Executor extends org.apache.mesos.Executor with Logging {
executorInfo: ExecutorInfo,
frameworkInfo: FrameworkInfo,
slaveInfo: SlaveInfo) {
+ // Make sure the local hostname we report matches Mesos's name for this host
+ Utils.setCustomHostname(slaveInfo.getHostname())
+
// Read spark.* system properties from executor arg
val props = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
for ((key, value) <- props) {
@@ -39,7 +43,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
RemoteActor.classLoader = getClass.getClassLoader
// Initialize Spark environment (using system properties read above)
- env = SparkEnv.createFromSystemProperties(false)
+ env = SparkEnv.createFromSystemProperties(false, false)
SparkEnv.set(env)
// Old stuff that isn't yet using env
Broadcast.initialize(false)
@@ -57,11 +61,11 @@ class Executor extends org.apache.mesos.Executor with Logging {
override def reregistered(d: ExecutorDriver, s: SlaveInfo) {}
- override def launchTask(d: ExecutorDriver, task: TaskInfo) {
+ override def launchTask(d: ExecutorDriver, task: MTaskInfo) {
threadPool.execute(new TaskRunner(task, d))
}
- class TaskRunner(info: TaskInfo, d: ExecutorDriver)
+ class TaskRunner(info: MTaskInfo, d: ExecutorDriver)
extends Runnable {
override def run() = {
val tid = info.getTaskId.getValue
@@ -74,11 +78,11 @@ class Executor extends org.apache.mesos.Executor with Logging {
.setState(TaskState.TASK_RUNNING)
.build())
try {
+ SparkEnv.set(env)
+ Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear
- val task = ser.deserialize[Task[Any]](info.getData.toByteArray, classLoader)
- for (gen <- task.generation) {// Update generation if any is set
- env.mapOutputTracker.updateGeneration(gen)
- }
+ val task = ser.deserialize[Task[Any]](info.getData.asReadOnlyByteBuffer, classLoader)
+ env.mapOutputTracker.updateGeneration(task.generation)
val value = task.run(tid.toInt)
val accumUpdates = Accumulators.values
val result = new TaskResult(value, accumUpdates)
@@ -105,9 +109,11 @@ class Executor extends org.apache.mesos.Executor with Logging {
.setData(ByteString.copyFrom(ser.serialize(reason)))
.build())
- // TODO: Handle errors in tasks less dramatically
+ // TODO: Should we exit the whole executor here? On the one hand, the failed task may
+ // have left some weird state around depending on when the exception was thrown, but on
+ // the other hand, maybe we could detect that when future tasks fail and exit then.
logError("Exception in task ID " + tid, t)
- System.exit(1)
+ //System.exit(1)
}
}
}
diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala
index a3c4e7873d..55512f4481 100644
--- a/core/src/main/scala/spark/FetchFailedException.scala
+++ b/core/src/main/scala/spark/FetchFailedException.scala
@@ -1,7 +1,9 @@
package spark
+import spark.storage.BlockManagerId
+
class FetchFailedException(
- val serverUri: String,
+ val bmAddress: BlockManagerId,
val shuffleId: Int,
val mapId: Int,
val reduceId: Int,
@@ -9,10 +11,10 @@ class FetchFailedException(
extends Exception {
override def getMessage(): String =
- "Fetch failed: %s %d %d %d".format(serverUri, shuffleId, mapId, reduceId)
+ "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
override def getCause(): Throwable = cause
def toTaskEndReason: TaskEndReason =
- FetchFailed(serverUri, shuffleId, mapId, reduceId)
+ FetchFailed(bmAddress, shuffleId, mapId, reduceId)
}
diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala
index 80f615eeb0..ec5c33d1df 100644
--- a/core/src/main/scala/spark/JavaSerializer.scala
+++ b/core/src/main/scala/spark/JavaSerializer.scala
@@ -1,6 +1,7 @@
package spark
import java.io._
+import java.nio.ByteBuffer
class JavaSerializationStream(out: OutputStream) extends SerializationStream {
val objOut = new ObjectOutputStream(out)
@@ -9,10 +10,11 @@ class JavaSerializationStream(out: OutputStream) extends SerializationStream {
def close() { objOut.close() }
}
-class JavaDeserializationStream(in: InputStream) extends DeserializationStream {
+class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
+extends DeserializationStream {
val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass) =
- Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
+ Class.forName(desc.getName, false, loader)
}
def readObject[T](): T = objIn.readObject().asInstanceOf[T]
@@ -20,35 +22,36 @@ class JavaDeserializationStream(in: InputStream) extends DeserializationStream {
}
class JavaSerializerInstance extends SerializerInstance {
- def serialize[T](t: T): Array[Byte] = {
+ def serialize[T](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
- val out = outputStream(bos)
+ val out = serializeStream(bos)
out.writeObject(t)
out.close()
- bos.toByteArray
+ ByteBuffer.wrap(bos.toByteArray)
}
- def deserialize[T](bytes: Array[Byte]): T = {
- val bis = new ByteArrayInputStream(bytes)
- val in = inputStream(bis)
+ def deserialize[T](bytes: ByteBuffer): T = {
+ val bis = new ByteArrayInputStream(bytes.array())
+ val in = deserializeStream(bis)
in.readObject().asInstanceOf[T]
}
- def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
- val bis = new ByteArrayInputStream(bytes)
- val ois = new ObjectInputStream(bis) {
- override def resolveClass(desc: ObjectStreamClass) =
- Class.forName(desc.getName, false, loader)
- }
- return ois.readObject.asInstanceOf[T]
+ def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
+ val bis = new ByteArrayInputStream(bytes.array())
+ val in = deserializeStream(bis, loader)
+ in.readObject().asInstanceOf[T]
}
- def outputStream(s: OutputStream): SerializationStream = {
+ def serializeStream(s: OutputStream): SerializationStream = {
new JavaSerializationStream(s)
}
- def inputStream(s: InputStream): DeserializationStream = {
- new JavaDeserializationStream(s)
+ def deserializeStream(s: InputStream): DeserializationStream = {
+ new JavaDeserializationStream(s, currentThread.getContextClassLoader)
+ }
+
+ def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
+ new JavaDeserializationStream(s, loader)
}
}
diff --git a/core/src/main/scala/spark/Job.scala b/core/src/main/scala/spark/Job.scala
deleted file mode 100644
index b7b0361c62..0000000000
--- a/core/src/main/scala/spark/Job.scala
+++ /dev/null
@@ -1,16 +0,0 @@
-package spark
-
-import org.apache.mesos._
-import org.apache.mesos.Protos._
-
-/**
- * Class representing a parallel job in MesosScheduler. Schedules the job by implementing various
- * callbacks.
- */
-abstract class Job(val runId: Int, val jobId: Int) {
- def slaveOffer(s: Offer, availableCpus: Double): Option[TaskInfo]
-
- def statusUpdate(t: TaskStatus): Unit
-
- def error(message: String): Unit
-}
diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala
index 5693613d6d..65d0532bd5 100644
--- a/core/src/main/scala/spark/KryoSerializer.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -12,6 +12,8 @@ import com.esotericsoftware.kryo.{Serializer => KSerializer}
import com.esotericsoftware.kryo.serialize.ClassSerializer
import de.javakaffee.kryoserializers.KryoReflectionFactorySupport
+import spark.storage._
+
/**
* Zig-zag encoder used to write object sizes to serialization streams.
* Based on Kryo's integer encoder.
@@ -64,57 +66,90 @@ object ZigZag {
}
}
-class KryoSerializationStream(kryo: Kryo, buf: ByteBuffer, out: OutputStream)
+class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream)
extends SerializationStream {
val channel = Channels.newChannel(out)
def writeObject[T](t: T) {
- kryo.writeClassAndObject(buf, t)
- ZigZag.writeInt(buf.position(), out)
- buf.flip()
- channel.write(buf)
- buf.clear()
+ kryo.writeClassAndObject(threadBuffer, t)
+ ZigZag.writeInt(threadBuffer.position(), out)
+ threadBuffer.flip()
+ channel.write(threadBuffer)
+ threadBuffer.clear()
}
def flush() { out.flush() }
def close() { out.close() }
}
-class KryoDeserializationStream(buf: ObjectBuffer, in: InputStream)
+class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream)
extends DeserializationStream {
def readObject[T](): T = {
val len = ZigZag.readInt(in)
- buf.readClassAndObject(in, len).asInstanceOf[T]
+ objectBuffer.readClassAndObject(in, len).asInstanceOf[T]
}
def close() { in.close() }
}
class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
- val buf = ks.threadBuf.get()
+ val kryo = ks.kryo
+ val threadBuffer = ks.threadBuffer.get()
+ val objectBuffer = ks.objectBuffer.get()
- def serialize[T](t: T): Array[Byte] = {
- buf.writeClassAndObject(t)
+ def serialize[T](t: T): ByteBuffer = {
+ // Write it to our thread-local scratch buffer first to figure out the size, then return a new
+ // ByteBuffer of the appropriate size
+ threadBuffer.clear()
+ kryo.writeClassAndObject(threadBuffer, t)
+ val newBuf = ByteBuffer.allocate(threadBuffer.position)
+ threadBuffer.flip()
+ newBuf.put(threadBuffer)
+ newBuf.flip()
+ newBuf
}
- def deserialize[T](bytes: Array[Byte]): T = {
- buf.readClassAndObject(bytes).asInstanceOf[T]
+ def deserialize[T](bytes: ByteBuffer): T = {
+ kryo.readClassAndObject(bytes).asInstanceOf[T]
}
- def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
- val oldClassLoader = ks.kryo.getClassLoader
- ks.kryo.setClassLoader(loader)
- val obj = buf.readClassAndObject(bytes).asInstanceOf[T]
- ks.kryo.setClassLoader(oldClassLoader)
+ def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
+ val oldClassLoader = kryo.getClassLoader
+ kryo.setClassLoader(loader)
+ val obj = kryo.readClassAndObject(bytes).asInstanceOf[T]
+ kryo.setClassLoader(oldClassLoader)
obj
}
- def outputStream(s: OutputStream): SerializationStream = {
- new KryoSerializationStream(ks.kryo, ks.threadByteBuf.get(), s)
+ def serializeStream(s: OutputStream): SerializationStream = {
+ threadBuffer.clear()
+ new KryoSerializationStream(kryo, threadBuffer, s)
+ }
+
+ def deserializeStream(s: InputStream): DeserializationStream = {
+ new KryoDeserializationStream(objectBuffer, s)
}
- def inputStream(s: InputStream): DeserializationStream = {
- new KryoDeserializationStream(buf, s)
+ override def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
+ threadBuffer.clear()
+ while (iterator.hasNext) {
+ val element = iterator.next()
+ // TODO: Do we also want to write the object's size? Doesn't seem necessary.
+ kryo.writeClassAndObject(threadBuffer, element)
+ }
+ val newBuf = ByteBuffer.allocate(threadBuffer.position)
+ threadBuffer.flip()
+ newBuf.put(threadBuffer)
+ newBuf.flip()
+ newBuf
+ }
+
+ override def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
+ buffer.rewind()
+ new Iterator[Any] {
+ override def hasNext: Boolean = buffer.remaining > 0
+ override def next(): Any = kryo.readClassAndObject(buffer)
+ }
}
}
@@ -126,20 +161,17 @@ trait KryoRegistrator {
class KryoSerializer extends Serializer with Logging {
val kryo = createKryo()
- val bufferSize =
- System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
+ val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024
- val threadBuf = new ThreadLocal[ObjectBuffer] {
+ val objectBuffer = new ThreadLocal[ObjectBuffer] {
override def initialValue = new ObjectBuffer(kryo, bufferSize)
}
- val threadByteBuf = new ThreadLocal[ByteBuffer] {
+ val threadBuffer = new ThreadLocal[ByteBuffer] {
override def initialValue = ByteBuffer.allocate(bufferSize)
}
def createKryo(): Kryo = {
- // This is used so we can serialize/deserialize objects without a zero-arg
- // constructor.
val kryo = new KryoReflectionFactorySupport()
// Register some commonly used classes
@@ -148,14 +180,20 @@ class KryoSerializer extends Serializer with Logging {
Array(1), Array(1.0), Array(1.0f), Array(1L), Array(""), Array(("", "")),
Array(new java.lang.Object), Array(1.toByte), Array(true), Array('c'),
// Specialized Tuple2s
- ("", ""), (1, 1), (1.0, 1.0), (1L, 1L),
+ ("", ""), ("", 1), (1, 1), (1.0, 1.0), (1L, 1L),
(1, 1.0), (1.0, 1), (1L, 1.0), (1.0, 1L), (1, 1L), (1L, 1),
// Scala collections
List(1), mutable.ArrayBuffer(1),
// Options and Either
Some(1), Left(1), Right(1),
// Higher-dimensional tuples
- (1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1)
+ (1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1),
+ None,
+ ByteBuffer.allocate(1),
+ StorageLevel.MEMORY_ONLY_DESER,
+ PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY_DESER),
+ GotBlock("1", ByteBuffer.allocate(1)),
+ GetBlock("1")
)
for (obj <- toRegister) {
kryo.register(obj.getClass)
diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala
index 0d11ab9cbd..54bd57f6d3 100644
--- a/core/src/main/scala/spark/Logging.scala
+++ b/core/src/main/scala/spark/Logging.scala
@@ -28,9 +28,11 @@ trait Logging {
}
// Log methods that take only a String
- def logInfo(msg: => String) = if (log.isInfoEnabled) log.info(msg)
+ def logInfo(msg: => String) = if (log.isInfoEnabled /*&& msg.contains("job finished in")*/) log.info(msg)
def logDebug(msg: => String) = if (log.isDebugEnabled) log.debug(msg)
+
+ def logTrace(msg: => String) = if (log.isTraceEnabled) log.trace(msg)
def logWarning(msg: => String) = if (log.isWarnEnabled) log.warn(msg)
@@ -43,6 +45,9 @@ trait Logging {
def logDebug(msg: => String, throwable: Throwable) =
if (log.isDebugEnabled) log.debug(msg)
+ def logTrace(msg: => String, throwable: Throwable) =
+ if (log.isTraceEnabled) log.trace(msg)
+
def logWarning(msg: => String, throwable: Throwable) =
if (log.isWarnEnabled) log.warn(msg, throwable)
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index a934c5a02f..d938a6eb62 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -2,80 +2,80 @@ package spark
import java.util.concurrent.ConcurrentHashMap
-import scala.actors._
-import scala.actors.Actor._
-import scala.actors.remote._
+import akka.actor._
+import akka.actor.Actor
+import akka.actor.Actor._
+import akka.util.duration._
+
import scala.collection.mutable.HashSet
+import spark.storage.BlockManagerId
+
sealed trait MapOutputTrackerMessage
case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage
case object StopMapOutputTracker extends MapOutputTrackerMessage
-class MapOutputTrackerActor(serverUris: ConcurrentHashMap[Int, Array[String]])
-extends DaemonActor with Logging {
- def act() {
- val port = System.getProperty("spark.master.port").toInt
- RemoteActor.alive(port)
- RemoteActor.register('MapOutputTracker, self)
- logInfo("Registered actor on port " + port)
-
- loop {
- react {
- case GetMapOutputLocations(shuffleId: Int) =>
- logInfo("Asked to get map output locations for shuffle " + shuffleId)
- reply(serverUris.get(shuffleId))
-
- case StopMapOutputTracker =>
- reply('OK)
- exit()
- }
- }
+class MapOutputTrackerActor(bmAddresses: ConcurrentHashMap[Int, Array[BlockManagerId]])
+extends Actor with Logging {
+ def receive = {
+ case GetMapOutputLocations(shuffleId: Int) =>
+ logInfo("Asked to get map output locations for shuffle " + shuffleId)
+ self.reply(bmAddresses.get(shuffleId))
+
+ case StopMapOutputTracker =>
+ logInfo("MapOutputTrackerActor stopped!")
+ self.reply(true)
+ self.exit()
}
}
class MapOutputTracker(isMaster: Boolean) extends Logging {
- var trackerActor: AbstractActor = null
+ val ip: String = System.getProperty("spark.master.host", "localhost")
+ val port: Int = System.getProperty("spark.master.port", "7077").toInt
+ val aName: String = "MapOutputTracker"
- private var serverUris = new ConcurrentHashMap[Int, Array[String]]
+ private var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
private var generation: Long = 0
private var generationLock = new java.lang.Object
-
- if (isMaster) {
- val tracker = new MapOutputTrackerActor(serverUris)
- tracker.start()
- trackerActor = tracker
+
+ var trackerActor: ActorRef = if (isMaster) {
+ val actor = actorOf(new MapOutputTrackerActor(bmAddresses))
+ remote.register(aName, actor)
+ logInfo("Registered MapOutputTrackerActor actor @ " + ip + ":" + port)
+ actor
} else {
- val host = System.getProperty("spark.master.host")
- val port = System.getProperty("spark.master.port").toInt
- trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker)
+ remote.actorFor(aName, ip, port)
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (serverUris.get(shuffleId) != null) {
+ if (bmAddresses.get(shuffleId) != null) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
- serverUris.put(shuffleId, new Array[String](numMaps))
+ bmAddresses.put(shuffleId, new Array[BlockManagerId](numMaps))
}
- def registerMapOutput(shuffleId: Int, mapId: Int, serverUri: String) {
- var array = serverUris.get(shuffleId)
+ def registerMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
+ var array = bmAddresses.get(shuffleId)
array.synchronized {
- array(mapId) = serverUri
+ array(mapId) = bmAddress
}
}
- def registerMapOutputs(shuffleId: Int, locs: Array[String]) {
- serverUris.put(shuffleId, Array[String]() ++ locs)
+ def registerMapOutputs(shuffleId: Int, locs: Array[BlockManagerId], changeGeneration: Boolean = false) {
+ bmAddresses.put(shuffleId, Array[BlockManagerId]() ++ locs)
+ if (changeGeneration) {
+ incrementGeneration()
+ }
}
- def unregisterMapOutput(shuffleId: Int, mapId: Int, serverUri: String) {
- var array = serverUris.get(shuffleId)
+ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
+ var array = bmAddresses.get(shuffleId)
if (array != null) {
array.synchronized {
- if (array(mapId) == serverUri) {
+ if (array(mapId) == bmAddress) {
array(mapId) = null
}
}
@@ -89,10 +89,10 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
val fetching = new HashSet[Int]
// Called on possibly remote nodes to get the server URIs for a given shuffle
- def getServerUris(shuffleId: Int): Array[String] = {
- val locs = serverUris.get(shuffleId)
+ def getServerAddresses(shuffleId: Int): Array[BlockManagerId] = {
+ val locs = bmAddresses.get(shuffleId)
if (locs == null) {
- logInfo("Don't have map outputs for " + shuffleId + ", fetching them")
+ logInfo("Don't have map outputs for shuffe " + shuffleId + ", fetching them")
fetching.synchronized {
if (fetching.contains(shuffleId)) {
// Someone else is fetching it; wait for them to be done
@@ -103,15 +103,17 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
case _ =>
}
}
- return serverUris.get(shuffleId)
+ return bmAddresses.get(shuffleId)
} else {
fetching += shuffleId
}
}
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
- val fetched = (trackerActor !? GetMapOutputLocations(shuffleId)).asInstanceOf[Array[String]]
- serverUris.put(shuffleId, fetched)
+ val fetched = (trackerActor ? GetMapOutputLocations(shuffleId)).as[Array[BlockManagerId]].get
+
+ logInfo("Got the output locations")
+ bmAddresses.put(shuffleId, fetched)
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
@@ -121,14 +123,10 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
return locs
}
}
-
- def getMapOutputUri(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int): String = {
- "%s/shuffle/%s/%s/%s".format(serverUri, shuffleId, mapId, reduceId)
- }
def stop() {
- trackerActor !? StopMapOutputTracker
- serverUris.clear()
+ trackerActor !! StopMapOutputTracker
+ bmAddresses.clear()
trackerActor = null
}
@@ -153,7 +151,7 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
- serverUris = new ConcurrentHashMap[Int, Array[String]]
+ bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
generation = newGen
}
}
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 8b63d1aba1..ff6764e0a2 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -4,14 +4,14 @@ import java.io.EOFException
import java.net.URL
import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong
-import java.util.HashSet
-import java.util.Random
+import java.util.{HashMap => JHashMap}
import java.util.Date
import java.text.SimpleDateFormat
+import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.Map
import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions._
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.BytesWritable
@@ -34,7 +34,9 @@ import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob}
import org.apache.hadoop.mapreduce.TaskAttemptID
import org.apache.hadoop.mapreduce.TaskAttemptContext
-import SparkContext._
+import spark.SparkContext._
+import spark.partial.BoundedDouble
+import spark.partial.PartialResult
/**
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
@@ -43,19 +45,6 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
self: RDD[(K, V)])
extends Logging
with Serializable {
-
- def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = {
- def mergeMaps(m1: HashMap[K, V], m2: HashMap[K, V]): HashMap[K, V] = {
- for ((k, v) <- m2) {
- m1.get(k) match {
- case None => m1(k) = v
- case Some(w) => m1(k) = func(w, v)
- }
- }
- return m1
- }
- self.map(pair => HashMap(pair)).reduce(mergeMaps)
- }
def combineByKey[C](createCombiner: V => C,
mergeValue: (C, V) => C,
@@ -77,6 +66,39 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = {
combineByKey[V]((v: V) => v, func, func, numSplits)
}
+
+ def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = {
+ def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = {
+ val map = new JHashMap[K, V]
+ for ((k, v) <- iter) {
+ val old = map.get(k)
+ map.put(k, if (old == null) v else func(old, v))
+ }
+ Iterator(map)
+ }
+
+ def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = {
+ for ((k, v) <- m2) {
+ val old = m1.get(k)
+ m1.put(k, if (old == null) v else func(old, v))
+ }
+ return m1
+ }
+
+ self.mapPartitions(reducePartition).reduce(mergeMaps)
+ }
+
+ // Alias for backwards compatibility
+ def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func)
+
+ // TODO: This should probably be a distributed version
+ def countByKey(): Map[K, Long] = self.map(_._1).countByValue()
+
+ // TODO: This should probably be a distributed version
+ def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
+ : PartialResult[Map[K, BoundedDouble]] = {
+ self.map(_._1).countByValueApprox(timeout, confidence)
+ }
def groupByKey(numSplits: Int): RDD[(K, Seq[V])] = {
def createCombiner(v: V) = ArrayBuffer(v)
diff --git a/core/src/main/scala/spark/ParallelShuffleFetcher.scala b/core/src/main/scala/spark/ParallelShuffleFetcher.scala
deleted file mode 100644
index 19eb288e84..0000000000
--- a/core/src/main/scala/spark/ParallelShuffleFetcher.scala
+++ /dev/null
@@ -1,119 +0,0 @@
-package spark
-
-import java.io.ByteArrayInputStream
-import java.io.EOFException
-import java.net.URL
-import java.util.concurrent.LinkedBlockingQueue
-import java.util.concurrent.TimeUnit
-import java.util.concurrent.atomic.AtomicBoolean
-import java.util.concurrent.atomic.AtomicInteger
-import java.util.concurrent.atomic.AtomicReference
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-
-import it.unimi.dsi.fastutil.io.FastBufferedInputStream
-
-
-class ParallelShuffleFetcher extends ShuffleFetcher with Logging {
- val parallelFetches = System.getProperty("spark.parallel.fetches", "3").toInt
-
- def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
- logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
-
- // Figure out a list of input IDs (mapper IDs) for each server
- val ser = SparkEnv.get.serializer.newInstance()
- val inputsByUri = new HashMap[String, ArrayBuffer[Int]]
- val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId)
- for ((serverUri, index) <- serverUris.zipWithIndex) {
- inputsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index
- }
-
- // Randomize them and put them in a LinkedBlockingQueue
- val serverQueue = new LinkedBlockingQueue[(String, ArrayBuffer[Int])]
- for (pair <- Utils.randomize(inputsByUri)) {
- serverQueue.put(pair)
- }
-
- // Create a queue to hold the fetched data
- val resultQueue = new LinkedBlockingQueue[Array[Byte]]
-
- // Atomic variables to communicate failures and # of fetches done
- var failure = new AtomicReference[FetchFailedException](null)
-
- // Start multiple threads to do the fetching (TODO: may be possible to do it asynchronously)
- for (i <- 0 until parallelFetches) {
- new Thread("Fetch thread " + i + " for reduce " + reduceId) {
- override def run() {
- while (true) {
- val pair = serverQueue.poll()
- if (pair == null)
- return
- val (serverUri, inputIds) = pair
- //logInfo("Pulled out server URI " + serverUri)
- for (i <- inputIds) {
- if (failure.get != null)
- return
- val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
- logInfo("Starting HTTP request for " + url)
- try {
- val conn = new URL(url).openConnection()
- conn.connect()
- val len = conn.getContentLength()
- if (len == -1) {
- throw new SparkException("Content length was not specified by server")
- }
- val buf = new Array[Byte](len)
- val in = new FastBufferedInputStream(conn.getInputStream())
- var pos = 0
- while (pos < len) {
- val n = in.read(buf, pos, len-pos)
- if (n == -1) {
- throw new SparkException("EOF before reading the expected " + len + " bytes")
- } else {
- pos += n
- }
- }
- // Done reading everything
- resultQueue.put(buf)
- in.close()
- } catch {
- case e: Exception =>
- logError("Fetch failed from " + url, e)
- failure.set(new FetchFailedException(serverUri, shuffleId, i, reduceId, e))
- return
- }
- }
- //logInfo("Done with server URI " + serverUri)
- }
- }
- }.start()
- }
-
- // Wait for results from the threads (either a failure or all servers done)
- var resultsDone = 0
- var totalResults = inputsByUri.map{case (uri, inputs) => inputs.size}.sum
- while (failure.get == null && resultsDone < totalResults) {
- try {
- val result = resultQueue.poll(100, TimeUnit.MILLISECONDS)
- if (result != null) {
- //logInfo("Pulled out a result")
- val in = ser.inputStream(new ByteArrayInputStream(result))
- try {
- while (true) {
- val pair = in.readObject().asInstanceOf[(K, V)]
- func(pair._1, pair._2)
- }
- } catch {
- case e: EOFException => {} // TODO: cleaner way to detect EOF, such as a sentinel
- }
- resultsDone += 1
- //logInfo("Results done = " + resultsDone)
- }
- } catch { case e: InterruptedException => {} }
- }
- if (failure.get != null) {
- throw failure.get
- }
- }
-}
diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala
index ac61fe3b54..8f3f0f5e15 100644
--- a/core/src/main/scala/spark/Partitioner.scala
+++ b/core/src/main/scala/spark/Partitioner.scala
@@ -70,4 +70,3 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
false
}
}
-
diff --git a/core/src/main/scala/spark/PipedRDD.scala b/core/src/main/scala/spark/PipedRDD.scala
index 8a5de3d7e9..9e0a01b5f9 100644
--- a/core/src/main/scala/spark/PipedRDD.scala
+++ b/core/src/main/scala/spark/PipedRDD.scala
@@ -3,6 +3,7 @@ package spark
import java.io.PrintWriter
import java.util.StringTokenizer
+import scala.collection.Map
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index fa53d9be2c..22dcc27bad 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -4,11 +4,14 @@ import java.io.EOFException
import java.net.URL
import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong
-import java.util.HashSet
import java.util.Random
import java.util.Date
+import java.util.{HashMap => JHashMap}
import scala.collection.mutable.ArrayBuffer
+import scala.collection.Map
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions.mapAsScalaMap
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
@@ -22,6 +25,14 @@ import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.TextOutputFormat
+import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+
+import spark.partial.BoundedDouble
+import spark.partial.CountEvaluator
+import spark.partial.GroupedCountEvaluator
+import spark.partial.PartialResult
+import spark.storage.StorageLevel
+
import SparkContext._
/**
@@ -61,19 +72,32 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
// Get a unique ID for this RDD
val id = sc.newRddId()
- // Variables relating to caching
- private var shouldCache = false
+ // Variables relating to persistence
+ private var storageLevel: StorageLevel = StorageLevel.NONE
- // Change this RDD's caching
- def cache(): RDD[T] = {
- shouldCache = true
+ // Change this RDD's storage level
+ def persist(newLevel: StorageLevel): RDD[T] = {
+ // TODO: Handle changes of StorageLevel
+ if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) {
+ throw new UnsupportedOperationException(
+ "Cannot change storage level of an RDD after it was already assigned a level")
+ }
+ storageLevel = newLevel
this
}
+
+ // Turn on the default caching level for this RDD
+ def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY_DESER)
+
+ // Turn on the default caching level for this RDD
+ def cache(): RDD[T] = persist()
+
+ def getStorageLevel = storageLevel
// Read this RDD; will read from cache if applicable, or otherwise compute
final def iterator(split: Split): Iterator[T] = {
- if (shouldCache) {
- SparkEnv.get.cacheTracker.getOrCompute[T](this, split)
+ if (storageLevel != StorageLevel.NONE) {
+ SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel)
} else {
compute(split)
}
@@ -162,6 +186,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
Array.concat(results: _*)
}
+ def toArray(): Array[T] = collect()
+
def reduce(f: (T, T) => T): T = {
val cleanF = sc.clean(f)
val reducePartition: Iterator[T] => Option[T] = iter => {
@@ -222,7 +248,67 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
}).sum
}
- def toArray(): Array[T] = collect()
+ /**
+ * Approximate version of count() that returns a potentially incomplete result after a timeout.
+ */
+ def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
+ val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) =>
+ var result = 0L
+ while (iter.hasNext) {
+ result += 1L
+ iter.next
+ }
+ result
+ }
+ val evaluator = new CountEvaluator(splits.size, confidence)
+ sc.runApproximateJob(this, countElements, evaluator, timeout)
+ }
+
+ /**
+ * Count elements equal to each value, returning a map of (value, count) pairs. The final combine
+ * step happens locally on the master, equivalent to running a single reduce task.
+ *
+ * TODO: This should perhaps be distributed by default.
+ */
+ def countByValue(): Map[T, Long] = {
+ def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = {
+ val map = new OLMap[T]
+ while (iter.hasNext) {
+ val v = iter.next()
+ map.put(v, map.getLong(v) + 1L)
+ }
+ Iterator(map)
+ }
+ def mergeMaps(m1: OLMap[T], m2: OLMap[T]): OLMap[T] = {
+ val iter = m2.object2LongEntrySet.fastIterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue)
+ }
+ return m1
+ }
+ val myResult = mapPartitions(countPartition).reduce(mergeMaps)
+ myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map
+ }
+
+ /**
+ * Approximate version of countByValue().
+ */
+ def countByValueApprox(
+ timeout: Long,
+ confidence: Double = 0.95
+ ): PartialResult[Map[T, BoundedDouble]] = {
+ val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) =>
+ val map = new OLMap[T]
+ while (iter.hasNext) {
+ val v = iter.next()
+ map.put(v, map.getLong(v) + 1L)
+ }
+ map
+ }
+ val evaluator = new GroupedCountEvaluator[T](splits.size, confidence)
+ sc.runApproximateJob(this, countPartition, evaluator, timeout)
+ }
/**
* Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
diff --git a/core/src/main/scala/spark/Scheduler.scala b/core/src/main/scala/spark/Scheduler.scala
deleted file mode 100644
index 6c7e569313..0000000000
--- a/core/src/main/scala/spark/Scheduler.scala
+++ /dev/null
@@ -1,27 +0,0 @@
-package spark
-
-/**
- * Scheduler trait, implemented by both MesosScheduler and LocalScheduler.
- */
-private trait Scheduler {
- def start()
-
- // Wait for registration with Mesos.
- def waitForRegister()
-
- /**
- * Run a function on some partitions of an RDD, returning an array of results. The allowLocal
- * flag specifies whether the scheduler is allowed to run the job on the master machine rather
- * than shipping it to the cluster, for actions that create short jobs such as first() and take().
- */
- def runJob[T, U: ClassManifest](
- rdd: RDD[T],
- func: (TaskContext, Iterator[T]) => U,
- partitions: Seq[Int],
- allowLocal: Boolean): Array[U]
-
- def stop()
-
- // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
- def defaultParallelism(): Int
-}
diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
index b213ca9dcb..9da73c4b02 100644
--- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
@@ -44,7 +44,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
}
// TODO: use something like WritableConverter to avoid reflection
}
- c.asInstanceOf[Class[ _ <: Writable]]
+ c.asInstanceOf[Class[_ <: Writable]]
}
def saveAsSequenceFile(path: String) {
diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/Serializer.scala
index 2429bbfeb9..61a70beaf1 100644
--- a/core/src/main/scala/spark/Serializer.scala
+++ b/core/src/main/scala/spark/Serializer.scala
@@ -1,6 +1,12 @@
package spark
-import java.io.{InputStream, OutputStream}
+import java.io.{EOFException, InputStream, OutputStream}
+import java.nio.ByteBuffer
+import java.nio.channels.Channels
+
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+
+import spark.util.ByteBufferInputStream
/**
* A serializer. Because some serialization libraries are not thread safe, this class is used to
@@ -14,11 +20,31 @@ trait Serializer {
* An instance of the serializer, for use by one thread at a time.
*/
trait SerializerInstance {
- def serialize[T](t: T): Array[Byte]
- def deserialize[T](bytes: Array[Byte]): T
- def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T
- def outputStream(s: OutputStream): SerializationStream
- def inputStream(s: InputStream): DeserializationStream
+ def serialize[T](t: T): ByteBuffer
+
+ def deserialize[T](bytes: ByteBuffer): T
+
+ def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T
+
+ def serializeStream(s: OutputStream): SerializationStream
+
+ def deserializeStream(s: InputStream): DeserializationStream
+
+ def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
+ // Default implementation uses serializeStream
+ val stream = new FastByteArrayOutputStream()
+ serializeStream(stream).writeAll(iterator)
+ val buffer = ByteBuffer.allocate(stream.position.toInt)
+ buffer.put(stream.array, 0, stream.position.toInt)
+ buffer.flip()
+ buffer
+ }
+
+ def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
+ // Default implementation uses deserializeStream
+ buffer.rewind()
+ deserializeStream(new ByteBufferInputStream(buffer)).toIterator
+ }
}
/**
@@ -28,6 +54,13 @@ trait SerializationStream {
def writeObject[T](t: T): Unit
def flush(): Unit
def close(): Unit
+
+ def writeAll[T](iter: Iterator[T]): SerializationStream = {
+ while (iter.hasNext) {
+ writeObject(iter.next())
+ }
+ this
+ }
}
/**
@@ -36,4 +69,45 @@ trait SerializationStream {
trait DeserializationStream {
def readObject[T](): T
def close(): Unit
+
+ /**
+ * Read the elements of this stream through an iterator. This can only be called once, as
+ * reading each element will consume data from the input source.
+ */
+ def toIterator: Iterator[Any] = new Iterator[Any] {
+ var gotNext = false
+ var finished = false
+ var nextValue: Any = null
+
+ private def getNext() {
+ try {
+ nextValue = readObject[Any]()
+ } catch {
+ case eof: EOFException =>
+ finished = true
+ }
+ gotNext = true
+ }
+
+ override def hasNext: Boolean = {
+ if (!gotNext) {
+ getNext()
+ }
+ if (finished) {
+ close()
+ }
+ !finished
+ }
+
+ override def next(): Any = {
+ if (!gotNext) {
+ getNext()
+ }
+ if (finished) {
+ throw new NoSuchElementException("End of stream")
+ }
+ gotNext = false
+ nextValue
+ }
+ }
}
diff --git a/core/src/main/scala/spark/SerializingCache.scala b/core/src/main/scala/spark/SerializingCache.scala
deleted file mode 100644
index 3d192f2403..0000000000
--- a/core/src/main/scala/spark/SerializingCache.scala
+++ /dev/null
@@ -1,26 +0,0 @@
-package spark
-
-import java.io._
-
-/**
- * Wrapper around a BoundedMemoryCache that stores serialized objects as byte arrays in order to
- * reduce storage cost and GC overhead
- */
-class SerializingCache extends Cache with Logging {
- val bmc = new BoundedMemoryCache
-
- override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
- val ser = SparkEnv.get.serializer.newInstance()
- bmc.put(datasetId, partition, ser.serialize(value))
- }
-
- override def get(datasetId: Any, partition: Int): Any = {
- val bytes = bmc.get(datasetId, partition)
- if (bytes != null) {
- val ser = SparkEnv.get.serializer.newInstance()
- return ser.deserialize(bytes.asInstanceOf[Array[Byte]])
- } else {
- return null
- }
- }
-}
diff --git a/core/src/main/scala/spark/ShuffleMapTask.scala b/core/src/main/scala/spark/ShuffleMapTask.scala
deleted file mode 100644
index 5fc59af06c..0000000000
--- a/core/src/main/scala/spark/ShuffleMapTask.scala
+++ /dev/null
@@ -1,56 +0,0 @@
-package spark
-
-import java.io.BufferedOutputStream
-import java.io.FileOutputStream
-import java.io.ObjectOutputStream
-import java.util.{HashMap => JHashMap}
-
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
-class ShuffleMapTask(
- runId: Int,
- stageId: Int,
- rdd: RDD[_],
- dep: ShuffleDependency[_,_,_],
- val partition: Int,
- locs: Seq[String])
- extends DAGTask[String](runId, stageId)
- with Logging {
-
- val split = rdd.splits(partition)
-
- override def run (attemptId: Int): String = {
- val numOutputSplits = dep.partitioner.numPartitions
- val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
- val partitioner = dep.partitioner.asInstanceOf[Partitioner]
- val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any])
- for (elem <- rdd.iterator(split)) {
- val (k, v) = elem.asInstanceOf[(Any, Any)]
- var bucketId = partitioner.getPartition(k)
- val bucket = buckets(bucketId)
- var existing = bucket.get(k)
- if (existing == null) {
- bucket.put(k, aggregator.createCombiner(v))
- } else {
- bucket.put(k, aggregator.mergeValue(existing, v))
- }
- }
- val ser = SparkEnv.get.serializer.newInstance()
- for (i <- 0 until numOutputSplits) {
- val file = SparkEnv.get.shuffleManager.getOutputFile(dep.shuffleId, partition, i)
- val out = ser.outputStream(new FastBufferedOutputStream(new FileOutputStream(file)))
- val iter = buckets(i).entrySet().iterator()
- while (iter.hasNext()) {
- val entry = iter.next()
- out.writeObject((entry.getKey, entry.getValue))
- }
- // TODO: have some kind of EOF marker
- out.close()
- }
- return SparkEnv.get.shuffleManager.getServerUri
- }
-
- override def preferredLocations: Seq[String] = locs
-
- override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
-}
diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala
index 5efc8cf50b..5434197eca 100644
--- a/core/src/main/scala/spark/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/ShuffledRDD.scala
@@ -8,7 +8,7 @@ class ShuffledRDDSplit(val idx: Int) extends Split {
}
class ShuffledRDD[K, V, C](
- parent: RDD[(K, V)],
+ @transient parent: RDD[(K, V)],
aggregator: Aggregator[K, V, C],
part : Partitioner)
extends RDD[(K, C)](parent.context) {
diff --git a/core/src/main/scala/spark/SimpleShuffleFetcher.scala b/core/src/main/scala/spark/SimpleShuffleFetcher.scala
deleted file mode 100644
index 196c64cf1f..0000000000
--- a/core/src/main/scala/spark/SimpleShuffleFetcher.scala
+++ /dev/null
@@ -1,46 +0,0 @@
-package spark
-
-import java.io.EOFException
-import java.net.URL
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-
-import it.unimi.dsi.fastutil.io.FastBufferedInputStream
-
-class SimpleShuffleFetcher extends ShuffleFetcher with Logging {
- def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
- logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
- val ser = SparkEnv.get.serializer.newInstance()
- val splitsByUri = new HashMap[String, ArrayBuffer[Int]]
- val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId)
- for ((serverUri, index) <- serverUris.zipWithIndex) {
- splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index
- }
- for ((serverUri, inputIds) <- Utils.randomize(splitsByUri)) {
- for (i <- inputIds) {
- try {
- val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
- // TODO: multithreaded fetch
- // TODO: would be nice to retry multiple times
- val inputStream = ser.inputStream(
- new FastBufferedInputStream(new URL(url).openStream()))
- try {
- while (true) {
- val pair = inputStream.readObject().asInstanceOf[(K, V)]
- func(pair._1, pair._2)
- }
- } finally {
- inputStream.close()
- }
- } catch {
- case e: EOFException => {} // We currently assume EOF means we read the whole thing
- case other: Exception => {
- logError("Fetch failed", other)
- throw new FetchFailedException(serverUri, shuffleId, i, reduceId, other)
- }
- }
- }
- }
- }
-}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 6e019d6e7f..7a9a70fee0 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -3,6 +3,9 @@ package spark
import java.io._
import java.util.concurrent.atomic.AtomicInteger
+import akka.actor.Actor
+import akka.actor.Actor._
+
import scala.actors.remote.RemoteActor
import scala.collection.mutable.ArrayBuffer
@@ -32,6 +35,15 @@ import org.apache.mesos.MesosNativeLibrary
import spark.broadcast._
+import spark.partial.ApproximateEvaluator
+import spark.partial.PartialResult
+
+import spark.scheduler.DAGScheduler
+import spark.scheduler.TaskScheduler
+import spark.scheduler.local.LocalScheduler
+import spark.scheduler.mesos.MesosScheduler
+import spark.scheduler.mesos.CoarseMesosScheduler
+
class SparkContext(
master: String,
frameworkName: String,
@@ -54,14 +66,19 @@ class SparkContext(
if (RemoteActor.classLoader == null) {
RemoteActor.classLoader = getClass.getClassLoader
}
+
+ remote.start(System.getProperty("spark.master.host"),
+ System.getProperty("spark.master.port").toInt)
+ private val isLocal = master.startsWith("local") // TODO: better check for local
+
// Create the Spark execution environment (cache, map output tracker, etc)
- val env = SparkEnv.createFromSystemProperties(true)
+ val env = SparkEnv.createFromSystemProperties(true, isLocal)
SparkEnv.set(env)
Broadcast.initialize(true)
// Create and start the scheduler
- private var scheduler: Scheduler = {
+ private var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
@@ -74,13 +91,17 @@ class SparkContext(
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
new LocalScheduler(threads.toInt, maxFailures.toInt)
case _ =>
- MesosNativeLibrary.load()
- new MesosScheduler(this, master, frameworkName)
+ System.loadLibrary("mesos")
+ if (System.getProperty("spark.mesos.coarse", "false") == "true") {
+ new CoarseMesosScheduler(this, master, frameworkName)
+ } else {
+ new MesosScheduler(this, master, frameworkName)
+ }
}
}
- scheduler.start()
+ taskScheduler.start()
- private val isLocal = scheduler.isInstanceOf[LocalScheduler]
+ private var dagScheduler = new DAGScheduler(taskScheduler)
// Methods for creating RDDs
@@ -237,19 +258,21 @@ class SparkContext(
// Stop the SparkContext
def stop() {
- scheduler.stop()
- scheduler = null
+ dagScheduler.stop()
+ dagScheduler = null
+ taskScheduler = null
// TODO: Broadcast.stop(), Cache.stop()?
env.mapOutputTracker.stop()
env.cacheTracker.stop()
env.shuffleFetcher.stop()
env.shuffleManager.stop()
+ env.connectionManager.stop()
SparkEnv.set(null)
}
- // Wait for the scheduler to be registered
+ // Wait for the scheduler to be registered with the cluster manager
def waitForRegister() {
- scheduler.waitForRegister()
+ taskScheduler.waitForRegister()
}
// Get Spark's home location from either a value set through the constructor,
@@ -281,7 +304,7 @@ class SparkContext(
): Array[U] = {
logInfo("Starting job...")
val start = System.nanoTime
- val result = scheduler.runJob(rdd, func, partitions, allowLocal)
+ val result = dagScheduler.runJob(rdd, func, partitions, allowLocal)
logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
result
}
@@ -306,6 +329,22 @@ class SparkContext(
runJob(rdd, func, 0 until rdd.splits.size, false)
}
+ /**
+ * Run a job that can return approximate results.
+ */
+ def runApproximateJob[T, U, R](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ evaluator: ApproximateEvaluator[U, R],
+ timeout: Long
+ ): PartialResult[R] = {
+ logInfo("Starting job...")
+ val start = System.nanoTime
+ val result = dagScheduler.runApproximateJob(rdd, func, evaluator, timeout)
+ logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
+ result
+ }
+
// Clean a closure to make it ready to serialized and send to tasks
// (removes unreferenced variables in $outer's, updates REPL variables)
private[spark] def clean[F <: AnyRef](f: F): F = {
@@ -314,7 +353,7 @@ class SparkContext(
}
// Default level of parallelism to use when not given by user (e.g. for reduce tasks)
- def defaultParallelism: Int = scheduler.defaultParallelism
+ def defaultParallelism: Int = taskScheduler.defaultParallelism
// Default min number of splits for Hadoop RDDs when not given by user
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
@@ -349,15 +388,23 @@ object SparkContext {
}
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
+
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
-
- implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](rdd: RDD[(K, V)]) =
+
+ implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
+ rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
- implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
+ implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
+ rdd: RDD[(K, V)]) =
new OrderedRDDFunctions(rdd)
+ implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd)
+
+ implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) =
+ new DoubleRDDFunctions(rdd.map(x => num.toDouble(x)))
+
// Implicit conversions to common Writable types, for saveAsSequenceFile
implicit def intToIntWritable(i: Int) = new IntWritable(i)
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index cd752f8b65..897a5ef82d 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -1,14 +1,26 @@
package spark
+import akka.actor.Actor
+
+import spark.storage.BlockManager
+import spark.storage.BlockManagerMaster
+import spark.network.ConnectionManager
+
class SparkEnv (
- val cache: Cache,
- val serializer: Serializer,
- val closureSerializer: Serializer,
- val cacheTracker: CacheTracker,
- val mapOutputTracker: MapOutputTracker,
- val shuffleFetcher: ShuffleFetcher,
- val shuffleManager: ShuffleManager
-)
+ val cache: Cache,
+ val serializer: Serializer,
+ val closureSerializer: Serializer,
+ val cacheTracker: CacheTracker,
+ val mapOutputTracker: MapOutputTracker,
+ val shuffleFetcher: ShuffleFetcher,
+ val shuffleManager: ShuffleManager,
+ val blockManager: BlockManager,
+ val connectionManager: ConnectionManager
+ ) {
+
+ /** No-parameter constructor for unit tests. */
+ def this() = this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null)
+}
object SparkEnv {
private val env = new ThreadLocal[SparkEnv]
@@ -21,36 +33,55 @@ object SparkEnv {
env.get()
}
- def createFromSystemProperties(isMaster: Boolean): SparkEnv = {
- val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache")
- val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
-
- val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer")
+ def createFromSystemProperties(isMaster: Boolean, isLocal: Boolean): SparkEnv = {
+ val serializerClass = System.getProperty("spark.serializer", "spark.KryoSerializer")
val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer]
+
+ BlockManagerMaster.startBlockManagerMaster(isMaster, isLocal)
+
+ var blockManager = new BlockManager(serializer)
+
+ val connectionManager = blockManager.connectionManager
+
+ val shuffleManager = new ShuffleManager()
val closureSerializerClass =
System.getProperty("spark.closure.serializer", "spark.JavaSerializer")
val closureSerializer =
Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer]
+ val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache")
+ val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
- val cacheTracker = new CacheTracker(isMaster, cache)
+ val cacheTracker = new CacheTracker(isMaster, blockManager)
+ blockManager.cacheTracker = cacheTracker
val mapOutputTracker = new MapOutputTracker(isMaster)
val shuffleFetcherClass =
- System.getProperty("spark.shuffle.fetcher", "spark.SimpleShuffleFetcher")
+ System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
val shuffleFetcher =
Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher]
- val shuffleMgr = new ShuffleManager()
+ /*
+ if (System.getProperty("spark.stream.distributed", "false") == "true") {
+ val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]]
+ if (isLocal || !isMaster) {
+ (new Thread() {
+ override def run() {
+ println("Wait started")
+ Thread.sleep(60000)
+ println("Wait ended")
+ val receiverClass = Class.forName("spark.stream.TestStreamReceiver4")
+ val constructor = receiverClass.getConstructor(blockManagerClass)
+ val receiver = constructor.newInstance(blockManager)
+ receiver.asInstanceOf[Thread].start()
+ }
+ }).start()
+ }
+ }
+ */
- new SparkEnv(
- cache,
- serializer,
- closureSerializer,
- cacheTracker,
- mapOutputTracker,
- shuffleFetcher,
- shuffleMgr)
+ new SparkEnv(cache, serializer, closureSerializer, cacheTracker, mapOutputTracker, shuffleFetcher,
+ shuffleManager, blockManager, connectionManager)
}
}
diff --git a/core/src/main/scala/spark/Stage.scala b/core/src/main/scala/spark/Stage.scala
deleted file mode 100644
index 9452ea3a8e..0000000000
--- a/core/src/main/scala/spark/Stage.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-package spark
-
-class Stage(
- val id: Int,
- val rdd: RDD[_],
- val shuffleDep: Option[ShuffleDependency[_,_,_]],
- val parents: List[Stage]) {
-
- val isShuffleMap = shuffleDep != None
- val numPartitions = rdd.splits.size
- val outputLocs = Array.fill[List[String]](numPartitions)(Nil)
- var numAvailableOutputs = 0
-
- def isAvailable: Boolean = {
- if (parents.size == 0 && !isShuffleMap) {
- true
- } else {
- numAvailableOutputs == numPartitions
- }
- }
-
- def addOutputLoc(partition: Int, host: String) {
- val prevList = outputLocs(partition)
- outputLocs(partition) = host :: prevList
- if (prevList == Nil)
- numAvailableOutputs += 1
- }
-
- def removeOutputLoc(partition: Int, host: String) {
- val prevList = outputLocs(partition)
- val newList = prevList.filterNot(_ == host)
- outputLocs(partition) = newList
- if (prevList != Nil && newList == Nil) {
- numAvailableOutputs -= 1
- }
- }
-
- override def toString = "Stage " + id
-
- override def hashCode(): Int = id
-}
diff --git a/core/src/main/scala/spark/Task.scala b/core/src/main/scala/spark/Task.scala
deleted file mode 100644
index bc3b374344..0000000000
--- a/core/src/main/scala/spark/Task.scala
+++ /dev/null
@@ -1,9 +0,0 @@
-package spark
-
-class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) extends Serializable
-
-abstract class Task[T] extends Serializable {
- def run(id: Int): T
- def preferredLocations: Seq[String] = Nil
- def generation: Option[Long] = None
-}
diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala
new file mode 100644
index 0000000000..7a6214aab6
--- /dev/null
+++ b/core/src/main/scala/spark/TaskContext.scala
@@ -0,0 +1,3 @@
+package spark
+
+class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) extends Serializable
diff --git a/core/src/main/scala/spark/TaskEndReason.scala b/core/src/main/scala/spark/TaskEndReason.scala
new file mode 100644
index 0000000000..6e4eb25ed4
--- /dev/null
+++ b/core/src/main/scala/spark/TaskEndReason.scala
@@ -0,0 +1,16 @@
+package spark
+
+import spark.storage.BlockManagerId
+
+/**
+ * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry
+ * tasks several times for "ephemeral" failures, and only report back failures that require some
+ * old stages to be resubmitted, such as shuffle map fetch failures.
+ */
+sealed trait TaskEndReason
+
+case object Success extends TaskEndReason
+case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it
+case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
+case class ExceptionFailure(exception: Throwable) extends TaskEndReason
+case class OtherFailure(message: String) extends TaskEndReason
diff --git a/core/src/main/scala/spark/TaskResult.scala b/core/src/main/scala/spark/TaskResult.scala
deleted file mode 100644
index 2b7fd1a4b2..0000000000
--- a/core/src/main/scala/spark/TaskResult.scala
+++ /dev/null
@@ -1,8 +0,0 @@
-package spark
-
-import scala.collection.mutable.Map
-
-// Task result. Also contains updates to accumulator variables.
-// TODO: Use of distributed cache to return result is a hack to get around
-// what seems to be a bug with messages over 60KB in libprocess; fix it
-private class TaskResult[T](val value: T, val accumUpdates: Map[Long, Any]) extends Serializable
diff --git a/core/src/main/scala/spark/UnionRDD.scala b/core/src/main/scala/spark/UnionRDD.scala
index 4c0f255e6b..17522e2bbb 100644
--- a/core/src/main/scala/spark/UnionRDD.scala
+++ b/core/src/main/scala/spark/UnionRDD.scala
@@ -33,7 +33,8 @@ class UnionRDD[T: ClassManifest](
override def splits = splits_
- @transient override val dependencies = {
+ @transient
+ override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for ((rdd, index) <- rdds.zipWithIndex) {
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index cfd6dc8b2a..742e60b176 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -118,6 +118,23 @@ object Utils {
* Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4).
*/
def localIpAddress(): String = InetAddress.getLocalHost.getHostAddress
+
+ private var customHostname: Option[String] = None
+
+ /**
+ * Allow setting a custom host name because when we run on Mesos we need to use the same
+ * hostname it reports to the master.
+ */
+ def setCustomHostname(hostname: String) {
+ customHostname = Some(hostname)
+ }
+
+ /**
+ * Get the local machine's hostname
+ */
+ def localHostName(): String = {
+ customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
+ }
/**
* Returns a standard ThreadFactory except all threads are daemons.
@@ -142,6 +159,14 @@ object Utils {
return threadPool
}
+
+ /**
+ * Return the string to tell how long has passed in seconds. The passing parameter should be in
+ * millisecond.
+ */
+ def getUsedTimeMs(startTimeMs: Long): String = {
+ return " " + (System.currentTimeMillis - startTimeMs) + " ms "
+ }
/**
* Wrapper over newFixedThreadPool.
@@ -155,16 +180,6 @@ object Utils {
}
/**
- * Get the local machine's hostname.
- */
- def localHostName(): String = InetAddress.getLocalHost.getHostName
-
- /**
- * Get current host
- */
- def getHost = System.getProperty("spark.hostname", localHostName())
-
- /**
* Delete a file or directory and its contents recursively.
*/
def deleteRecursively(file: File) {
diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
new file mode 100644
index 0000000000..4546dfa0fa
--- /dev/null
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -0,0 +1,364 @@
+package spark.network
+
+import spark._
+
+import scala.collection.mutable.{HashMap, Queue, ArrayBuffer}
+
+import java.io._
+import java.nio._
+import java.nio.channels._
+import java.nio.channels.spi._
+import java.net._
+
+
+abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging {
+
+ channel.configureBlocking(false)
+ channel.socket.setTcpNoDelay(true)
+ channel.socket.setReuseAddress(true)
+ channel.socket.setKeepAlive(true)
+ /*channel.socket.setReceiveBufferSize(32768) */
+
+ var onCloseCallback: Connection => Unit = null
+ var onExceptionCallback: (Connection, Exception) => Unit = null
+ var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
+
+ lazy val remoteAddress = getRemoteAddress()
+ lazy val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
+
+ def key() = channel.keyFor(selector)
+
+ def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
+
+ def read() {
+ throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString)
+ }
+
+ def write() {
+ throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString)
+ }
+
+ def close() {
+ key.cancel()
+ channel.close()
+ callOnCloseCallback()
+ }
+
+ def onClose(callback: Connection => Unit) {onCloseCallback = callback}
+
+ def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback}
+
+ def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback}
+
+ def callOnExceptionCallback(e: Exception) {
+ if (onExceptionCallback != null) {
+ onExceptionCallback(this, e)
+ } else {
+ logError("Error in connection to " + remoteConnectionManagerId +
+ " and OnExceptionCallback not registered", e)
+ }
+ }
+
+ def callOnCloseCallback() {
+ if (onCloseCallback != null) {
+ onCloseCallback(this)
+ } else {
+ logWarning("Connection to " + remoteConnectionManagerId +
+ " closed and OnExceptionCallback not registered")
+ }
+
+ }
+
+ def changeConnectionKeyInterest(ops: Int) {
+ if (onKeyInterestChangeCallback != null) {
+ onKeyInterestChangeCallback(this, ops)
+ } else {
+ throw new Exception("OnKeyInterestChangeCallback not registered")
+ }
+ }
+
+ def printRemainingBuffer(buffer: ByteBuffer) {
+ val bytes = new Array[Byte](buffer.remaining)
+ val curPosition = buffer.position
+ buffer.get(bytes)
+ bytes.foreach(x => print(x + " "))
+ buffer.position(curPosition)
+ print(" (" + bytes.size + ")")
+ }
+
+ def printBuffer(buffer: ByteBuffer, position: Int, length: Int) {
+ val bytes = new Array[Byte](length)
+ val curPosition = buffer.position
+ buffer.position(position)
+ buffer.get(bytes)
+ bytes.foreach(x => print(x + " "))
+ print(" (" + position + ", " + length + ")")
+ buffer.position(curPosition)
+ }
+
+}
+
+
+class SendingConnection(val address: InetSocketAddress, selector_ : Selector)
+extends Connection(SocketChannel.open, selector_) {
+
+ class Outbox(fair: Int = 0) {
+ val messages = new Queue[Message]()
+ val defaultChunkSize = 65536 //32768 //16384
+ var nextMessageToBeUsed = 0
+
+ def addMessage(message: Message): Unit = {
+ messages.synchronized{
+ /*messages += message*/
+ messages.enqueue(message)
+ logInfo("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]")
+ }
+ }
+
+ def getChunk(): Option[MessageChunk] = {
+ fair match {
+ case 0 => getChunkFIFO()
+ case 1 => getChunkRR()
+ case _ => throw new Exception("Unexpected fairness policy in outbox")
+ }
+ }
+
+ private def getChunkFIFO(): Option[MessageChunk] = {
+ /*logInfo("Using FIFO")*/
+ messages.synchronized {
+ while (!messages.isEmpty) {
+ val message = messages(0)
+ val chunk = message.getChunkForSending(defaultChunkSize)
+ if (chunk.isDefined) {
+ messages += message // this is probably incorrect, it wont work as fifo
+ if (!message.started) logDebug("Starting to send [" + message + "]")
+ message.started = true
+ return chunk
+ }
+ /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
+ logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
+ }
+ }
+ None
+ }
+
+ private def getChunkRR(): Option[MessageChunk] = {
+ messages.synchronized {
+ while (!messages.isEmpty) {
+ /*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
+ /*val message = messages(nextMessageToBeUsed)*/
+ val message = messages.dequeue
+ val chunk = message.getChunkForSending(defaultChunkSize)
+ if (chunk.isDefined) {
+ messages.enqueue(message)
+ nextMessageToBeUsed = nextMessageToBeUsed + 1
+ if (!message.started) {
+ logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]")
+ message.started = true
+ message.startTime = System.currentTimeMillis
+ }
+ logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]")
+ return chunk
+ }
+ /*messages -= message*/
+ message.finishTime = System.currentTimeMillis
+ logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
+ }
+ }
+ None
+ }
+ }
+
+ val outbox = new Outbox(1)
+ val currentBuffers = new ArrayBuffer[ByteBuffer]()
+
+ /*channel.socket.setSendBufferSize(256 * 1024)*/
+
+ override def getRemoteAddress() = address
+
+ def send(message: Message) {
+ outbox.synchronized {
+ outbox.addMessage(message)
+ if (channel.isConnected) {
+ changeConnectionKeyInterest(SelectionKey.OP_WRITE)
+ }
+ }
+ }
+
+ def connect() {
+ try{
+ channel.connect(address)
+ channel.register(selector, SelectionKey.OP_CONNECT)
+ logInfo("Initiating connection to [" + address + "]")
+ } catch {
+ case e: Exception => {
+ logError("Error connecting to " + address, e)
+ callOnExceptionCallback(e)
+ }
+ }
+ }
+
+ def finishConnect() {
+ try {
+ channel.finishConnect
+ changeConnectionKeyInterest(SelectionKey.OP_WRITE)
+ logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
+ } catch {
+ case e: Exception => {
+ logWarning("Error finishing connection to " + address, e)
+ callOnExceptionCallback(e)
+ }
+ }
+ }
+
+ override def write() {
+ try{
+ while(true) {
+ if (currentBuffers.size == 0) {
+ outbox.synchronized {
+ outbox.getChunk match {
+ case Some(chunk) => {
+ currentBuffers ++= chunk.buffers
+ }
+ case None => {
+ changeConnectionKeyInterest(0)
+ /*key.interestOps(0)*/
+ return
+ }
+ }
+ }
+ }
+
+ if (currentBuffers.size > 0) {
+ val buffer = currentBuffers(0)
+ val remainingBytes = buffer.remaining
+ val writtenBytes = channel.write(buffer)
+ if (buffer.remaining == 0) {
+ currentBuffers -= buffer
+ }
+ if (writtenBytes < remainingBytes) {
+ return
+ }
+ }
+ }
+ } catch {
+ case e: Exception => {
+ logWarning("Error writing in connection to " + remoteConnectionManagerId, e)
+ callOnExceptionCallback(e)
+ close()
+ }
+ }
+ }
+}
+
+
+class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
+extends Connection(channel_, selector_) {
+
+ class Inbox() {
+ val messages = new HashMap[Int, BufferMessage]()
+
+ def getChunk(header: MessageChunkHeader): Option[MessageChunk] = {
+
+ def createNewMessage: BufferMessage = {
+ val newMessage = Message.create(header).asInstanceOf[BufferMessage]
+ newMessage.started = true
+ newMessage.startTime = System.currentTimeMillis
+ logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]")
+ messages += ((newMessage.id, newMessage))
+ newMessage
+ }
+
+ val message = messages.getOrElseUpdate(header.id, createNewMessage)
+ logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]")
+ message.getChunkForReceiving(header.chunkSize)
+ }
+
+ def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = {
+ messages.get(chunk.header.id)
+ }
+
+ def removeMessage(message: Message) {
+ messages -= message.id
+ }
+ }
+
+ val inbox = new Inbox()
+ val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
+ var onReceiveCallback: (Connection , Message) => Unit = null
+ var currentChunk: MessageChunk = null
+
+ channel.register(selector, SelectionKey.OP_READ)
+
+ override def read() {
+ try {
+ while (true) {
+ if (currentChunk == null) {
+ val headerBytesRead = channel.read(headerBuffer)
+ if (headerBytesRead == -1) {
+ close()
+ return
+ }
+ if (headerBuffer.remaining > 0) {
+ return
+ }
+ headerBuffer.flip
+ if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
+ throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
+ }
+ val header = MessageChunkHeader.create(headerBuffer)
+ headerBuffer.clear()
+ header.typ match {
+ case Message.BUFFER_MESSAGE => {
+ if (header.totalSize == 0) {
+ if (onReceiveCallback != null) {
+ onReceiveCallback(this, Message.create(header))
+ }
+ currentChunk = null
+ return
+ } else {
+ currentChunk = inbox.getChunk(header).orNull
+ }
+ }
+ case _ => throw new Exception("Message of unknown type received")
+ }
+ }
+
+ if (currentChunk == null) throw new Exception("No message chunk to receive data")
+
+ val bytesRead = channel.read(currentChunk.buffer)
+ if (bytesRead == 0) {
+ return
+ } else if (bytesRead == -1) {
+ close()
+ return
+ }
+
+ /*logDebug("Read " + bytesRead + " bytes for the buffer")*/
+
+ if (currentChunk.buffer.remaining == 0) {
+ /*println("Filled buffer at " + System.currentTimeMillis)*/
+ val bufferMessage = inbox.getMessageForChunk(currentChunk).get
+ if (bufferMessage.isCompletelyReceived) {
+ bufferMessage.flip
+ bufferMessage.finishTime = System.currentTimeMillis
+ logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken)
+ if (onReceiveCallback != null) {
+ onReceiveCallback(this, bufferMessage)
+ }
+ inbox.removeMessage(bufferMessage)
+ }
+ currentChunk = null
+ }
+ }
+ } catch {
+ case e: Exception => {
+ logWarning("Error reading from connection to " + remoteConnectionManagerId, e)
+ callOnExceptionCallback(e)
+ close()
+ }
+ }
+ }
+
+ def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
+}
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
new file mode 100644
index 0000000000..e9f254d0f3
--- /dev/null
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -0,0 +1,467 @@
+package spark.network
+
+import spark._
+
+import scala.actors.Future
+import scala.actors.Futures.future
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.SynchronizedMap
+import scala.collection.mutable.SynchronizedQueue
+import scala.collection.mutable.Queue
+import scala.collection.mutable.ArrayBuffer
+
+import java.io._
+import java.nio._
+import java.nio.channels._
+import java.nio.channels.spi._
+import java.net._
+import java.util.concurrent.Executors
+
+case class ConnectionManagerId(val host: String, val port: Int) {
+ def toSocketAddress() = new InetSocketAddress(host, port)
+}
+
+object ConnectionManagerId {
+ def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
+ new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
+ }
+}
+
+class ConnectionManager(port: Int) extends Logging {
+
+ case class MessageStatus(message: Message, connectionManagerId: ConnectionManagerId) {
+ var ackMessage: Option[Message] = None
+ var attempted = false
+ var acked = false
+ }
+
+ val selector = SelectorProvider.provider.openSelector()
+ /*val handleMessageExecutor = new ThreadPoolExecutor(4, 4, 600, TimeUnit.SECONDS, new LinkedBlockingQueue()) */
+ val handleMessageExecutor = Executors.newFixedThreadPool(4)
+ val serverChannel = ServerSocketChannel.open()
+ val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
+ val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
+ val messageStatuses = new HashMap[Int, MessageStatus]
+ val connectionRequests = new SynchronizedQueue[SendingConnection]
+ val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
+ val sendMessageRequests = new Queue[(Message, SendingConnection)]
+
+ var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
+
+ serverChannel.configureBlocking(false)
+ serverChannel.socket.setReuseAddress(true)
+ serverChannel.socket.setReceiveBufferSize(256 * 1024)
+
+ serverChannel.socket.bind(new InetSocketAddress(port))
+ serverChannel.register(selector, SelectionKey.OP_ACCEPT)
+
+ val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
+ logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
+
+ val thisInstance = this
+ var selectorThread = new Thread("connection-manager-thread") {
+ override def run() {
+ thisInstance.run()
+ }
+ }
+ selectorThread.setDaemon(true)
+ selectorThread.start()
+
+ def run() {
+ try {
+ var interrupted = false
+ while(!interrupted) {
+ while(!connectionRequests.isEmpty) {
+ val sendingConnection = connectionRequests.dequeue
+ sendingConnection.connect()
+ addConnection(sendingConnection)
+ }
+ sendMessageRequests.synchronized {
+ while(!sendMessageRequests.isEmpty) {
+ val (message, connection) = sendMessageRequests.dequeue
+ connection.send(message)
+ }
+ }
+
+ while(!keyInterestChangeRequests.isEmpty) {
+ val (key, ops) = keyInterestChangeRequests.dequeue
+ val connection = connectionsByKey(key)
+ val lastOps = key.interestOps()
+ key.interestOps(ops)
+
+ def intToOpStr(op: Int): String = {
+ val opStrs = ArrayBuffer[String]()
+ if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
+ if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
+ if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
+ if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
+ if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
+ }
+
+ logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId +
+ "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
+
+ }
+
+ val selectedKeysCount = selector.select()
+ if (selectedKeysCount == 0) logInfo("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
+
+ interrupted = selectorThread.isInterrupted
+
+ val selectedKeys = selector.selectedKeys().iterator()
+ while (selectedKeys.hasNext()) {
+ val key = selectedKeys.next.asInstanceOf[SelectionKey]
+ selectedKeys.remove()
+ if (key.isValid) {
+ if (key.isAcceptable) {
+ acceptConnection(key)
+ } else
+ if (key.isConnectable) {
+ connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect()
+ } else
+ if (key.isReadable) {
+ connectionsByKey(key).read()
+ } else
+ if (key.isWritable) {
+ connectionsByKey(key).write()
+ }
+ }
+ }
+ }
+ } catch {
+ case e: Exception => logError("Error in select loop", e)
+ }
+ }
+
+ def acceptConnection(key: SelectionKey) {
+ val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
+ val newChannel = serverChannel.accept()
+ val newConnection = new ReceivingConnection(newChannel, selector)
+ newConnection.onReceive(receiveMessage)
+ newConnection.onClose(removeConnection)
+ addConnection(newConnection)
+ logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
+ }
+
+ def addConnection(connection: Connection) {
+ connectionsByKey += ((connection.key, connection))
+ if (connection.isInstanceOf[SendingConnection]) {
+ val sendingConnection = connection.asInstanceOf[SendingConnection]
+ connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection))
+ }
+ connection.onKeyInterestChange(changeConnectionKeyInterest)
+ connection.onException(handleConnectionError)
+ connection.onClose(removeConnection)
+ }
+
+ def removeConnection(connection: Connection) {
+ /*logInfo("Removing connection")*/
+ connectionsByKey -= connection.key
+ if (connection.isInstanceOf[SendingConnection]) {
+ val sendingConnection = connection.asInstanceOf[SendingConnection]
+ val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId
+ logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
+
+ connectionsById -= sendingConnectionManagerId
+
+ messageStatuses.synchronized {
+ messageStatuses
+ .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
+ logInfo("Notifying " + status)
+ status.synchronized {
+ status.attempted = true
+ status.acked = false
+ status.notifyAll
+ }
+ })
+
+ messageStatuses.retain((i, status) => {
+ status.connectionManagerId != sendingConnectionManagerId
+ })
+ }
+ } else if (connection.isInstanceOf[ReceivingConnection]) {
+ val receivingConnection = connection.asInstanceOf[ReceivingConnection]
+ val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId
+ logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
+
+ val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull
+ if (sendingConnectionManagerId == null) {
+ logError("Corresponding SendingConnectionManagerId not found")
+ return
+ }
+ logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId)
+
+ val sendingConnection = connectionsById(sendingConnectionManagerId)
+ sendingConnection.close()
+ connectionsById -= sendingConnectionManagerId
+
+ messageStatuses.synchronized {
+ messageStatuses
+ .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
+ logInfo("Notifying " + status)
+ status.synchronized {
+ status.attempted = true
+ status.acked = false
+ status.notifyAll
+ }
+ })
+
+ messageStatuses.retain((i, status) => {
+ status.connectionManagerId != sendingConnectionManagerId
+ })
+ }
+ }
+ }
+
+ def handleConnectionError(connection: Connection, e: Exception) {
+ logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId)
+ removeConnection(connection)
+ }
+
+ def changeConnectionKeyInterest(connection: Connection, ops: Int) {
+ keyInterestChangeRequests += ((connection.key, ops))
+ }
+
+ def receiveMessage(connection: Connection, message: Message) {
+ val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
+ logInfo("Received [" + message + "] from [" + connectionManagerId + "]")
+ val runnable = new Runnable() {
+ val creationTime = System.currentTimeMillis
+ def run() {
+ logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
+ handleMessage(connectionManagerId, message)
+ logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
+ }
+ }
+ handleMessageExecutor.execute(runnable)
+ /*handleMessage(connection, message)*/
+ }
+
+ private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) {
+ logInfo("Handling [" + message + "] from [" + connectionManagerId + "]")
+ message match {
+ case bufferMessage: BufferMessage => {
+ if (bufferMessage.hasAckId) {
+ val sentMessageStatus = messageStatuses.synchronized {
+ messageStatuses.get(bufferMessage.ackId) match {
+ case Some(status) => {
+ messageStatuses -= bufferMessage.ackId
+ status
+ }
+ case None => {
+ throw new Exception("Could not find reference for received ack message " + message.id)
+ null
+ }
+ }
+ }
+ sentMessageStatus.synchronized {
+ sentMessageStatus.ackMessage = Some(message)
+ sentMessageStatus.attempted = true
+ sentMessageStatus.acked = true
+ sentMessageStatus.notifyAll
+ }
+ } else {
+ val ackMessage = if (onReceiveCallback != null) {
+ logDebug("Calling back")
+ onReceiveCallback(bufferMessage, connectionManagerId)
+ } else {
+ logWarning("Not calling back as callback is null")
+ None
+ }
+
+ if (ackMessage.isDefined) {
+ if (!ackMessage.get.isInstanceOf[BufferMessage]) {
+ logWarning("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
+ } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
+ logWarning("Response to " + bufferMessage + " does not have ack id set")
+ ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
+ }
+ }
+
+ sendMessage(connectionManagerId, ackMessage.getOrElse {
+ Message.createBufferMessage(bufferMessage.id)
+ })
+ }
+ }
+ case _ => throw new Exception("Unknown type message received")
+ }
+ }
+
+ private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
+ def startNewConnection(): SendingConnection = {
+ val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
+ val newConnection = new SendingConnection(inetSocketAddress, selector)
+ connectionRequests += newConnection
+ newConnection
+ }
+ val connection = connectionsById.getOrElse(connectionManagerId, startNewConnection)
+ message.senderAddress = id.toSocketAddress()
+ logInfo("Sending [" + message + "] to [" + connectionManagerId + "]")
+ /*connection.send(message)*/
+ sendMessageRequests.synchronized {
+ sendMessageRequests += ((message, connection))
+ }
+ selector.wakeup()
+ }
+
+ def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message): Future[Option[Message]] = {
+ val messageStatus = new MessageStatus(message, connectionManagerId)
+ messageStatuses.synchronized {
+ messageStatuses += ((message.id, messageStatus))
+ }
+ sendMessage(connectionManagerId, message)
+ future {
+ messageStatus.synchronized {
+ if (!messageStatus.attempted) {
+ logTrace("Waiting, " + messageStatuses.size + " statuses" )
+ messageStatus.wait()
+ logTrace("Done waiting")
+ }
+ }
+ messageStatus.ackMessage
+ }
+ }
+
+ def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = {
+ sendMessageReliably(connectionManagerId, message)()
+ }
+
+ def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {
+ onReceiveCallback = callback
+ }
+
+ def stop() {
+ selectorThread.interrupt()
+ selectorThread.join()
+ selector.close()
+ val connections = connectionsByKey.values
+ connections.foreach(_.close())
+ if (connectionsByKey.size != 0) {
+ logWarning("All connections not cleaned up")
+ }
+ handleMessageExecutor.shutdown()
+ logInfo("ConnectionManager stopped")
+ }
+}
+
+
+object ConnectionManager {
+
+ def main(args: Array[String]) {
+
+ val manager = new ConnectionManager(9999)
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ println("Received [" + msg + "] from [" + id + "]")
+ None
+ })
+
+ /*testSequentialSending(manager)*/
+ /*System.gc()*/
+
+ /*testParallelSending(manager)*/
+ /*System.gc()*/
+
+ /*testParallelDecreasingSending(manager)*/
+ /*System.gc()*/
+
+ testContinuousSending(manager)
+ System.gc()
+ }
+
+ def testSequentialSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Sequential Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliablySync(manager.id, bufferMessage)
+ })
+ println("--------------------------")
+ println()
+ }
+
+ def testParallelSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Parallel Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val startTime = System.currentTimeMillis
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(manager.id, bufferMessage)
+ }).foreach(f => {if (!f().isDefined) println("Failed")})
+ val finishTime = System.currentTimeMillis
+
+ val mb = size * count / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ val tput = mb * 1000.0 / ms
+ println("--------------------------")
+ println("Started at " + startTime + ", finished at " + finishTime)
+ println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)")
+ println("--------------------------")
+ println()
+ }
+
+ def testParallelDecreasingSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Parallel Decreasing Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+ val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte)))
+ buffers.foreach(_.flip)
+ val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0
+
+ val startTime = System.currentTimeMillis
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
+ manager.sendMessageReliably(manager.id, bufferMessage)
+ }).foreach(f => {if (!f().isDefined) println("Failed")})
+ val finishTime = System.currentTimeMillis
+
+ val ms = finishTime - startTime
+ val tput = mb * 1000.0 / ms
+ println("--------------------------")
+ /*println("Started at " + startTime + ", finished at " + finishTime) */
+ println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
+ println("--------------------------")
+ println()
+ }
+
+ def testContinuousSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Continuous Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val startTime = System.currentTimeMillis
+ while(true) {
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(manager.id, bufferMessage)
+ }).foreach(f => {if (!f().isDefined) println("Failed")})
+ val finishTime = System.currentTimeMillis
+ Thread.sleep(1000)
+ val mb = size * count / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ val tput = mb * 1000.0 / ms
+ println("--------------------------")
+ println()
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala
new file mode 100644
index 0000000000..5d21bb793f
--- /dev/null
+++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala
@@ -0,0 +1,74 @@
+package spark.network
+
+import spark._
+import spark.SparkContext._
+
+import scala.io.Source
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+
+object ConnectionManagerTest extends Logging{
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>")
+ System.exit(1)
+ }
+
+ if (args(0).startsWith("local")) {
+ println("This runs only on a mesos cluster")
+ }
+
+ val sc = new SparkContext(args(0), "ConnectionManagerTest")
+ val slavesFile = Source.fromFile(args(1))
+ val slaves = slavesFile.mkString.split("\n")
+ slavesFile.close()
+
+ /*println("Slaves")*/
+ /*slaves.foreach(println)*/
+
+ val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map(
+ i => SparkEnv.get.connectionManager.id).collect()
+ println("\nSlave ConnectionManagerIds")
+ slaveConnManagerIds.foreach(println)
+ println
+
+ val count = 10
+ (0 until count).foreach(i => {
+ val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => {
+ val connManager = SparkEnv.get.connectionManager
+ val thisConnManagerId = connManager.id
+ connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ logInfo("Received [" + msg + "] from [" + id + "]")
+ None
+ })
+
+ val size = 100 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val startTime = System.currentTimeMillis
+ val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
+ connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
+ })
+ val results = futures.map(f => f())
+ val finishTime = System.currentTimeMillis
+ Thread.sleep(5000)
+
+ val mb = size * results.size / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
+ logInfo(resultStr)
+ resultStr
+ }).collect()
+
+ println("---------------------")
+ println("Run " + i)
+ resultStrs.foreach(println)
+ println("---------------------")
+ })
+ }
+}
+
diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala
new file mode 100644
index 0000000000..2e85803679
--- /dev/null
+++ b/core/src/main/scala/spark/network/Message.scala
@@ -0,0 +1,219 @@
+package spark.network
+
+import spark._
+
+import scala.collection.mutable.ArrayBuffer
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+import java.net.InetSocketAddress
+
+class MessageChunkHeader(
+ val typ: Long,
+ val id: Int,
+ val totalSize: Int,
+ val chunkSize: Int,
+ val other: Int,
+ val address: InetSocketAddress) {
+ lazy val buffer = {
+ val ip = address.getAddress.getAddress()
+ val port = address.getPort()
+ ByteBuffer.
+ allocate(MessageChunkHeader.HEADER_SIZE).
+ putLong(typ).
+ putInt(id).
+ putInt(totalSize).
+ putInt(chunkSize).
+ putInt(other).
+ putInt(ip.size).
+ put(ip).
+ putInt(port).
+ position(MessageChunkHeader.HEADER_SIZE).
+ flip.asInstanceOf[ByteBuffer]
+ }
+
+ override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
+ " and sizes " + totalSize + " / " + chunkSize + " bytes"
+}
+
+class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
+ val size = if (buffer == null) 0 else buffer.remaining
+ lazy val buffers = {
+ val ab = new ArrayBuffer[ByteBuffer]()
+ ab += header.buffer
+ if (buffer != null) {
+ ab += buffer
+ }
+ ab
+ }
+
+ override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
+}
+
+abstract class Message(val typ: Long, val id: Int) {
+ var senderAddress: InetSocketAddress = null
+ var started = false
+ var startTime = -1L
+ var finishTime = -1L
+
+ def size: Int
+
+ def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
+
+ def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
+
+ def timeTaken(): String = (finishTime - startTime).toString + " ms"
+
+ override def toString = "" + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
+}
+
+class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
+extends Message(Message.BUFFER_MESSAGE, id_) {
+
+ val initialSize = currentSize()
+ var gotChunkForSendingOnce = false
+
+ def size = initialSize
+
+ def currentSize() = {
+ if (buffers == null || buffers.isEmpty) {
+ 0
+ } else {
+ buffers.map(_.remaining).reduceLeft(_ + _)
+ }
+ }
+
+ def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
+ if (maxChunkSize <= 0) {
+ throw new Exception("Max chunk size is " + maxChunkSize)
+ }
+
+ if (size == 0 && gotChunkForSendingOnce == false) {
+ val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+
+ while(!buffers.isEmpty) {
+ val buffer = buffers(0)
+ if (buffer.remaining == 0) {
+ buffers -= buffer
+ } else {
+ val newBuffer = if (buffer.remaining <= maxChunkSize) {
+ buffer.duplicate
+ } else {
+ buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
+ }
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+ }
+ None
+ }
+
+ def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
+ // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
+ if (buffers.size > 1) {
+ throw new Exception("Attempting to get chunk from message with multiple data buffers")
+ }
+ val buffer = buffers(0)
+ if (buffer.remaining > 0) {
+ if (buffer.remaining < chunkSize) {
+ throw new Exception("Not enough space in data buffer for receiving chunk")
+ }
+ val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ return Some(newChunk)
+ }
+ None
+ }
+
+ def flip() {
+ buffers.foreach(_.flip)
+ }
+
+ def hasAckId() = (ackId != 0)
+
+ def isCompletelyReceived() = !buffers(0).hasRemaining
+
+ override def toString = {
+ if (hasAckId) {
+ "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
+ } else {
+ "BufferMessage(id = " + id + ", size = " + size + ")"
+ }
+
+ }
+}
+
+object MessageChunkHeader {
+ val HEADER_SIZE = 40
+
+ def create(buffer: ByteBuffer): MessageChunkHeader = {
+ if (buffer.remaining != HEADER_SIZE) {
+ throw new IllegalArgumentException("Cannot convert buffer data to Message")
+ }
+ val typ = buffer.getLong()
+ val id = buffer.getInt()
+ val totalSize = buffer.getInt()
+ val chunkSize = buffer.getInt()
+ val other = buffer.getInt()
+ val ipSize = buffer.getInt()
+ val ipBytes = new Array[Byte](ipSize)
+ buffer.get(ipBytes)
+ val ip = InetAddress.getByAddress(ipBytes)
+ val port = buffer.getInt()
+ new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
+ }
+}
+
+object Message {
+ val BUFFER_MESSAGE = 1111111111L
+
+ var lastId = 1
+
+ def getNewId() = synchronized {
+ lastId += 1
+ if (lastId == 0) lastId += 1
+ lastId
+ }
+
+ def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
+ if (dataBuffers == null) {
+ return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
+ }
+ if (dataBuffers.exists(_ == null)) {
+ throw new Exception("Attempting to create buffer message with null buffer")
+ }
+ return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId)
+ }
+
+ def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
+ createBufferMessage(dataBuffers, 0)
+
+ def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
+ if (dataBuffer == null) {
+ return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
+ } else {
+ return createBufferMessage(Array(dataBuffer), ackId)
+ }
+ }
+
+ def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
+ createBufferMessage(dataBuffer, 0)
+
+ def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId)
+
+ def create(header: MessageChunkHeader): Message = {
+ val newMessage: Message = header.typ match {
+ case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
+ }
+ newMessage.senderAddress = header.address
+ newMessage
+ }
+}
diff --git a/core/src/main/scala/spark/network/ReceiverTest.scala b/core/src/main/scala/spark/network/ReceiverTest.scala
new file mode 100644
index 0000000000..e1ba7c06c0
--- /dev/null
+++ b/core/src/main/scala/spark/network/ReceiverTest.scala
@@ -0,0 +1,20 @@
+package spark.network
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+
+object ReceiverTest {
+
+ def main(args: Array[String]) {
+ val manager = new ConnectionManager(9999)
+ println("Started connection manager with id = " + manager.id)
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ /*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/
+ val buffer = ByteBuffer.wrap("response".getBytes())
+ Some(Message.createBufferMessage(buffer, msg.id))
+ })
+ Thread.currentThread.join()
+ }
+}
+
diff --git a/core/src/main/scala/spark/network/SenderTest.scala b/core/src/main/scala/spark/network/SenderTest.scala
new file mode 100644
index 0000000000..4ab6dd3414
--- /dev/null
+++ b/core/src/main/scala/spark/network/SenderTest.scala
@@ -0,0 +1,53 @@
+package spark.network
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+
+object SenderTest {
+
+ def main(args: Array[String]) {
+
+ if (args.length < 2) {
+ println("Usage: SenderTest <target host> <target port>")
+ System.exit(1)
+ }
+
+ val targetHost = args(0)
+ val targetPort = args(1).toInt
+ val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort)
+
+ val manager = new ConnectionManager(0)
+ println("Started connection manager with id = " + manager.id)
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ println("Received [" + msg + "] from [" + id + "]")
+ None
+ })
+
+ val size = 100 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val targetServer = args(0)
+
+ val count = 100
+ (0 until count).foreach(i => {
+ val dataMessage = Message.createBufferMessage(buffer.duplicate)
+ val startTime = System.currentTimeMillis
+ /*println("Started timer at " + startTime)*/
+ val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) match {
+ case Some(response) =>
+ val buffer = response.asInstanceOf[BufferMessage].buffers(0)
+ new String(buffer.array)
+ case None => "none"
+ }
+ val finishTime = System.currentTimeMillis
+ val mb = size / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ /*val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"*/
+ val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" + (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr
+ println(resultStr)
+ })
+ }
+}
+
diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
new file mode 100644
index 0000000000..260547902b
--- /dev/null
+++ b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
@@ -0,0 +1,66 @@
+package spark.partial
+
+import spark._
+import spark.scheduler.JobListener
+
+/**
+ * A JobListener for an approximate single-result action, such as count() or non-parallel reduce().
+ * This listener waits up to timeout milliseconds and will return a partial answer even if the
+ * complete answer is not available by then.
+ *
+ * This class assumes that the action is performed on an entire RDD[T] via a function that computes
+ * a result of type U for each partition, and that the action returns a partial or complete result
+ * of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt).
+ */
+class ApproximateActionListener[T, U, R](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ evaluator: ApproximateEvaluator[U, R],
+ timeout: Long)
+ extends JobListener {
+
+ val startTime = System.currentTimeMillis()
+ val totalTasks = rdd.splits.size
+ var finishedTasks = 0
+ var failure: Option[Exception] = None // Set if the job has failed (permanently)
+ var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult
+
+ override def taskSucceeded(index: Int, result: Any): Unit = synchronized {
+ evaluator.merge(index, result.asInstanceOf[U])
+ finishedTasks += 1
+ if (finishedTasks == totalTasks) {
+ // If we had already returned a PartialResult, set its final value
+ resultObject.foreach(r => r.setFinalValue(evaluator.currentResult()))
+ // Notify any waiting thread that may have called getResult
+ this.notifyAll()
+ }
+ }
+
+ override def jobFailed(exception: Exception): Unit = synchronized {
+ failure = Some(exception)
+ this.notifyAll()
+ }
+
+ /**
+ * Waits for up to timeout milliseconds since the listener was created and then returns a
+ * PartialResult with the result so far. This may be complete if the whole job is done.
+ */
+ def getResult(): PartialResult[R] = synchronized {
+ val finishTime = startTime + timeout
+ while (true) {
+ val time = System.currentTimeMillis()
+ if (failure != None) {
+ throw failure.get
+ } else if (finishedTasks == totalTasks) {
+ return new PartialResult(evaluator.currentResult(), true)
+ } else if (time >= finishTime) {
+ resultObject = Some(new PartialResult(evaluator.currentResult(), false))
+ return resultObject.get
+ } else {
+ this.wait(finishTime - time)
+ }
+ }
+ // Should never be reached, but required to keep the compiler happy
+ return null
+ }
+}
diff --git a/core/src/main/scala/spark/partial/ApproximateEvaluator.scala b/core/src/main/scala/spark/partial/ApproximateEvaluator.scala
new file mode 100644
index 0000000000..4772e43ef0
--- /dev/null
+++ b/core/src/main/scala/spark/partial/ApproximateEvaluator.scala
@@ -0,0 +1,10 @@
+package spark.partial
+
+/**
+ * An object that computes a function incrementally by merging in results of type U from multiple
+ * tasks. Allows partial evaluation at any point by calling currentResult().
+ */
+trait ApproximateEvaluator[U, R] {
+ def merge(outputId: Int, taskResult: U): Unit
+ def currentResult(): R
+}
diff --git a/core/src/main/scala/spark/partial/BoundedDouble.scala b/core/src/main/scala/spark/partial/BoundedDouble.scala
new file mode 100644
index 0000000000..463c33d6e2
--- /dev/null
+++ b/core/src/main/scala/spark/partial/BoundedDouble.scala
@@ -0,0 +1,8 @@
+package spark.partial
+
+/**
+ * A Double with error bars on it.
+ */
+class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
+ override def toString(): String = "[%.3f, %.3f]".format(low, high)
+}
diff --git a/core/src/main/scala/spark/partial/CountEvaluator.scala b/core/src/main/scala/spark/partial/CountEvaluator.scala
new file mode 100644
index 0000000000..1bc90d6b39
--- /dev/null
+++ b/core/src/main/scala/spark/partial/CountEvaluator.scala
@@ -0,0 +1,38 @@
+package spark.partial
+
+import cern.jet.stat.Probability
+
+/**
+ * An ApproximateEvaluator for counts.
+ *
+ * TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might
+ * be best to make this a special case of GroupedCountEvaluator with one group.
+ */
+class CountEvaluator(totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[Long, BoundedDouble] {
+
+ var outputsMerged = 0
+ var sum: Long = 0
+
+ override def merge(outputId: Int, taskResult: Long) {
+ outputsMerged += 1
+ sum += taskResult
+ }
+
+ override def currentResult(): BoundedDouble = {
+ if (outputsMerged == totalOutputs) {
+ new BoundedDouble(sum, 1.0, sum, sum)
+ } else if (outputsMerged == 0) {
+ new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
+ } else {
+ val p = outputsMerged.toDouble / totalOutputs
+ val mean = (sum + 1 - p) / p
+ val variance = (sum + 1) * (1 - p) / (p * p)
+ val stdev = math.sqrt(variance)
+ val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
+ val low = mean - confFactor * stdev
+ val high = mean + confFactor * stdev
+ new BoundedDouble(mean, confidence, low, high)
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala
new file mode 100644
index 0000000000..3e631c0efc
--- /dev/null
+++ b/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala
@@ -0,0 +1,62 @@
+package spark.partial
+
+import java.util.{HashMap => JHashMap}
+import java.util.{Map => JMap}
+
+import scala.collection.Map
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions.mapAsScalaMap
+
+import cern.jet.stat.Probability
+
+import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+
+/**
+ * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval.
+ */
+class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] {
+
+ var outputsMerged = 0
+ var sums = new OLMap[T] // Sum of counts for each key
+
+ override def merge(outputId: Int, taskResult: OLMap[T]) {
+ outputsMerged += 1
+ val iter = taskResult.object2LongEntrySet.fastIterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ sums.put(entry.getKey, sums.getLong(entry.getKey) + entry.getLongValue)
+ }
+ }
+
+ override def currentResult(): Map[T, BoundedDouble] = {
+ if (outputsMerged == totalOutputs) {
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.object2LongEntrySet.fastIterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val sum = entry.getLongValue()
+ result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum)
+ }
+ result
+ } else if (outputsMerged == 0) {
+ new HashMap[T, BoundedDouble]
+ } else {
+ val p = outputsMerged.toDouble / totalOutputs
+ val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.object2LongEntrySet.fastIterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val sum = entry.getLongValue
+ val mean = (sum + 1 - p) / p
+ val variance = (sum + 1) * (1 - p) / (p * p)
+ val stdev = math.sqrt(variance)
+ val low = mean - confFactor * stdev
+ val high = mean + confFactor * stdev
+ result(entry.getKey) = new BoundedDouble(mean, confidence, low, high)
+ }
+ result
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala
new file mode 100644
index 0000000000..2a9ccba205
--- /dev/null
+++ b/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala
@@ -0,0 +1,65 @@
+package spark.partial
+
+import java.util.{HashMap => JHashMap}
+import java.util.{Map => JMap}
+
+import scala.collection.mutable.HashMap
+import scala.collection.Map
+import scala.collection.JavaConversions.mapAsScalaMap
+
+import spark.util.StatCounter
+
+/**
+ * An ApproximateEvaluator for means by key. Returns a map of key to confidence interval.
+ */
+class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
+
+ var outputsMerged = 0
+ var sums = new JHashMap[T, StatCounter] // Sum of counts for each key
+
+ override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) {
+ outputsMerged += 1
+ val iter = taskResult.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val old = sums.get(entry.getKey)
+ if (old != null) {
+ old.merge(entry.getValue)
+ } else {
+ sums.put(entry.getKey, entry.getValue)
+ }
+ }
+ }
+
+ override def currentResult(): Map[T, BoundedDouble] = {
+ if (outputsMerged == totalOutputs) {
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val mean = entry.getValue.mean
+ result(entry.getKey) = new BoundedDouble(mean, 1.0, mean, mean)
+ }
+ result
+ } else if (outputsMerged == 0) {
+ new HashMap[T, BoundedDouble]
+ } else {
+ val p = outputsMerged.toDouble / totalOutputs
+ val studentTCacher = new StudentTCacher(confidence)
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val counter = entry.getValue
+ val mean = counter.mean
+ val stdev = math.sqrt(counter.sampleVariance / counter.count)
+ val confFactor = studentTCacher.get(counter.count)
+ val low = mean - confFactor * stdev
+ val high = mean + confFactor * stdev
+ result(entry.getKey) = new BoundedDouble(mean, confidence, low, high)
+ }
+ result
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala
new file mode 100644
index 0000000000..6a2ec7a7bd
--- /dev/null
+++ b/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala
@@ -0,0 +1,72 @@
+package spark.partial
+
+import java.util.{HashMap => JHashMap}
+import java.util.{Map => JMap}
+
+import scala.collection.mutable.HashMap
+import scala.collection.Map
+import scala.collection.JavaConversions.mapAsScalaMap
+
+import spark.util.StatCounter
+
+/**
+ * An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval.
+ */
+class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
+
+ var outputsMerged = 0
+ var sums = new JHashMap[T, StatCounter] // Sum of counts for each key
+
+ override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) {
+ outputsMerged += 1
+ val iter = taskResult.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val old = sums.get(entry.getKey)
+ if (old != null) {
+ old.merge(entry.getValue)
+ } else {
+ sums.put(entry.getKey, entry.getValue)
+ }
+ }
+ }
+
+ override def currentResult(): Map[T, BoundedDouble] = {
+ if (outputsMerged == totalOutputs) {
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val sum = entry.getValue.sum
+ result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum)
+ }
+ result
+ } else if (outputsMerged == 0) {
+ new HashMap[T, BoundedDouble]
+ } else {
+ val p = outputsMerged.toDouble / totalOutputs
+ val studentTCacher = new StudentTCacher(confidence)
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val counter = entry.getValue
+ val meanEstimate = counter.mean
+ val meanVar = counter.sampleVariance / counter.count
+ val countEstimate = (counter.count + 1 - p) / p
+ val countVar = (counter.count + 1) * (1 - p) / (p * p)
+ val sumEstimate = meanEstimate * countEstimate
+ val sumVar = (meanEstimate * meanEstimate * countVar) +
+ (countEstimate * countEstimate * meanVar) +
+ (meanVar * countVar)
+ val sumStdev = math.sqrt(sumVar)
+ val confFactor = studentTCacher.get(counter.count)
+ val low = sumEstimate - confFactor * sumStdev
+ val high = sumEstimate + confFactor * sumStdev
+ result(entry.getKey) = new BoundedDouble(sumEstimate, confidence, low, high)
+ }
+ result
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/partial/MeanEvaluator.scala b/core/src/main/scala/spark/partial/MeanEvaluator.scala
new file mode 100644
index 0000000000..b8c7cb8863
--- /dev/null
+++ b/core/src/main/scala/spark/partial/MeanEvaluator.scala
@@ -0,0 +1,41 @@
+package spark.partial
+
+import cern.jet.stat.Probability
+
+import spark.util.StatCounter
+
+/**
+ * An ApproximateEvaluator for means.
+ */
+class MeanEvaluator(totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[StatCounter, BoundedDouble] {
+
+ var outputsMerged = 0
+ var counter = new StatCounter
+
+ override def merge(outputId: Int, taskResult: StatCounter) {
+ outputsMerged += 1
+ counter.merge(taskResult)
+ }
+
+ override def currentResult(): BoundedDouble = {
+ if (outputsMerged == totalOutputs) {
+ new BoundedDouble(counter.mean, 1.0, counter.mean, counter.mean)
+ } else if (outputsMerged == 0) {
+ new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
+ } else {
+ val mean = counter.mean
+ val stdev = math.sqrt(counter.sampleVariance / counter.count)
+ val confFactor = {
+ if (counter.count > 100) {
+ Probability.normalInverse(1 - (1 - confidence) / 2)
+ } else {
+ Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
+ }
+ }
+ val low = mean - confFactor * stdev
+ val high = mean + confFactor * stdev
+ new BoundedDouble(mean, confidence, low, high)
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/partial/PartialResult.scala b/core/src/main/scala/spark/partial/PartialResult.scala
new file mode 100644
index 0000000000..7095bc8ca1
--- /dev/null
+++ b/core/src/main/scala/spark/partial/PartialResult.scala
@@ -0,0 +1,86 @@
+package spark.partial
+
+class PartialResult[R](initialVal: R, isFinal: Boolean) {
+ private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None
+ private var failure: Option[Exception] = None
+ private var completionHandler: Option[R => Unit] = None
+ private var failureHandler: Option[Exception => Unit] = None
+
+ def initialValue: R = initialVal
+
+ def isInitialValueFinal: Boolean = isFinal
+
+ /**
+ * Blocking method to wait for and return the final value.
+ */
+ def getFinalValue(): R = synchronized {
+ while (finalValue == None && failure == None) {
+ this.wait()
+ }
+ if (finalValue != None) {
+ return finalValue.get
+ } else {
+ throw failure.get
+ }
+ }
+
+ /**
+ * Set a handler to be called when this PartialResult completes. Only one completion handler
+ * is supported per PartialResult.
+ */
+ def onComplete(handler: R => Unit): PartialResult[R] = synchronized {
+ if (completionHandler != None) {
+ throw new UnsupportedOperationException("onComplete cannot be called twice")
+ }
+ completionHandler = Some(handler)
+ if (finalValue != None) {
+ // We already have a final value, so let's call the handler
+ handler(finalValue.get)
+ }
+ return this
+ }
+
+ /**
+ * Set a handler to be called if this PartialResult's job fails. Only one failure handler
+ * is supported per PartialResult.
+ */
+ def onFail(handler: Exception => Unit): Unit = synchronized {
+ if (failureHandler != None) {
+ throw new UnsupportedOperationException("onFail cannot be called twice")
+ }
+ failureHandler = Some(handler)
+ if (failure != None) {
+ // We already have a failure, so let's call the handler
+ handler(failure.get)
+ }
+ }
+
+ private[spark] def setFinalValue(value: R): Unit = synchronized {
+ if (finalValue != None) {
+ throw new UnsupportedOperationException("setFinalValue called twice on a PartialResult")
+ }
+ finalValue = Some(value)
+ // Call the completion handler if it was set
+ completionHandler.foreach(h => h(value))
+ // Notify any threads that may be calling getFinalValue()
+ this.notifyAll()
+ }
+
+ private[spark] def setFailure(exception: Exception): Unit = synchronized {
+ if (failure != None) {
+ throw new UnsupportedOperationException("setFailure called twice on a PartialResult")
+ }
+ failure = Some(exception)
+ // Call the failure handler if it was set
+ failureHandler.foreach(h => h(exception))
+ // Notify any threads that may be calling getFinalValue()
+ this.notifyAll()
+ }
+
+ override def toString: String = synchronized {
+ finalValue match {
+ case Some(value) => "(final: " + value + ")"
+ case None => "(partial: " + initialValue + ")"
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/partial/StudentTCacher.scala b/core/src/main/scala/spark/partial/StudentTCacher.scala
new file mode 100644
index 0000000000..6263ee3518
--- /dev/null
+++ b/core/src/main/scala/spark/partial/StudentTCacher.scala
@@ -0,0 +1,26 @@
+package spark.partial
+
+import cern.jet.stat.Probability
+
+/**
+ * A utility class for caching Student's T distribution values for a given confidence level
+ * and various sample sizes. This is used by the MeanEvaluator to efficiently calculate
+ * confidence intervals for many keys.
+ */
+class StudentTCacher(confidence: Double) {
+ val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation
+ val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2)
+ val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0)
+
+ def get(sampleSize: Long): Double = {
+ if (sampleSize >= NORMAL_APPROX_SAMPLE_SIZE) {
+ normalApprox
+ } else {
+ val size = sampleSize.toInt
+ if (cache(size) < 0) {
+ cache(size) = Probability.studentTInverse(1 - confidence, size - 1)
+ }
+ cache(size)
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/partial/SumEvaluator.scala b/core/src/main/scala/spark/partial/SumEvaluator.scala
new file mode 100644
index 0000000000..0357a6bff8
--- /dev/null
+++ b/core/src/main/scala/spark/partial/SumEvaluator.scala
@@ -0,0 +1,51 @@
+package spark.partial
+
+import cern.jet.stat.Probability
+
+import spark.util.StatCounter
+
+/**
+ * An ApproximateEvaluator for sums. It estimates the mean and the cont and multiplies them
+ * together, then uses the formula for the variance of two independent random variables to get
+ * a variance for the result and compute a confidence interval.
+ */
+class SumEvaluator(totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[StatCounter, BoundedDouble] {
+
+ var outputsMerged = 0
+ var counter = new StatCounter
+
+ override def merge(outputId: Int, taskResult: StatCounter) {
+ outputsMerged += 1
+ counter.merge(taskResult)
+ }
+
+ override def currentResult(): BoundedDouble = {
+ if (outputsMerged == totalOutputs) {
+ new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum)
+ } else if (outputsMerged == 0) {
+ new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
+ } else {
+ val p = outputsMerged.toDouble / totalOutputs
+ val meanEstimate = counter.mean
+ val meanVar = counter.sampleVariance / counter.count
+ val countEstimate = (counter.count + 1 - p) / p
+ val countVar = (counter.count + 1) * (1 - p) / (p * p)
+ val sumEstimate = meanEstimate * countEstimate
+ val sumVar = (meanEstimate * meanEstimate * countVar) +
+ (countEstimate * countEstimate * meanVar) +
+ (meanVar * countVar)
+ val sumStdev = math.sqrt(sumVar)
+ val confFactor = {
+ if (counter.count > 100) {
+ Probability.normalInverse(1 - (1 - confidence) / 2)
+ } else {
+ Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
+ }
+ }
+ val low = sumEstimate - confFactor * sumStdev
+ val high = sumEstimate + confFactor * sumStdev
+ new BoundedDouble(sumEstimate, confidence, low, high)
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/ActiveJob.scala b/core/src/main/scala/spark/scheduler/ActiveJob.scala
new file mode 100644
index 0000000000..0ecff9ce77
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/ActiveJob.scala
@@ -0,0 +1,18 @@
+package spark.scheduler
+
+import spark.TaskContext
+
+/**
+ * Tracks information about an active job in the DAGScheduler.
+ */
+class ActiveJob(
+ val runId: Int,
+ val finalStage: Stage,
+ val func: (TaskContext, Iterator[_]) => _,
+ val partitions: Array[Int],
+ val listener: JobListener) {
+
+ val numPartitions = partitions.length
+ val finished = Array.fill[Boolean](numPartitions)(false)
+ var numFinished = 0
+}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
new file mode 100644
index 0000000000..f31e2c65a0
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -0,0 +1,532 @@
+package spark.scheduler
+
+import java.net.URI
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.Future
+import java.util.concurrent.LinkedBlockingQueue
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map}
+
+import spark._
+import spark.partial.ApproximateActionListener
+import spark.partial.ApproximateEvaluator
+import spark.partial.PartialResult
+import spark.storage.BlockManagerMaster
+import spark.storage.BlockManagerId
+
+/**
+ * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
+ * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
+ * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
+ * and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
+ */
+class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
+ taskSched.setListener(this)
+
+ // Called by TaskScheduler to report task completions or failures.
+ override def taskEnded(
+ task: Task[_],
+ reason: TaskEndReason,
+ result: Any,
+ accumUpdates: Map[Long, Any]) {
+ eventQueue.put(CompletionEvent(task, reason, result, accumUpdates))
+ }
+
+ // Called by TaskScheduler when a host fails.
+ override def hostLost(host: String) {
+ eventQueue.put(HostLost(host))
+ }
+
+ // The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
+ // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
+ // as more failure events come in
+ val RESUBMIT_TIMEOUT = 50L
+
+ // The time, in millis, to wake up between polls of the completion queue in order to potentially
+ // resubmit failed stages
+ val POLL_TIMEOUT = 10L
+
+ private val lock = new Object // Used for access to the entire DAGScheduler
+
+ private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
+
+ val nextRunId = new AtomicInteger(0)
+
+ val nextStageId = new AtomicInteger(0)
+
+ val idToStage = new HashMap[Int, Stage]
+
+ val shuffleToMapStage = new HashMap[Int, Stage]
+
+ var cacheLocs = new HashMap[Int, Array[List[String]]]
+
+ val env = SparkEnv.get
+ val cacheTracker = env.cacheTracker
+ val mapOutputTracker = env.mapOutputTracker
+
+ val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back;
+ // that's not going to be a realistic assumption in general
+
+ val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
+ val running = new HashSet[Stage] // Stages we are running right now
+ val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
+ val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage
+ var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits
+
+ val activeJobs = new HashSet[ActiveJob]
+ val resultStageToJob = new HashMap[Stage, ActiveJob]
+
+ // Start a thread to run the DAGScheduler event loop
+ new Thread("DAGScheduler") {
+ setDaemon(true)
+ override def run() {
+ DAGScheduler.this.run()
+ }
+ }.start()
+
+ def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
+ cacheLocs(rdd.id)
+ }
+
+ def updateCacheLocs() {
+ cacheLocs = cacheTracker.getLocationsSnapshot()
+ }
+
+ /**
+ * Get or create a shuffle map stage for the given shuffle dependency's map side.
+ * The priority value passed in will be used if the stage doesn't already exist with
+ * a lower priority (we assume that priorities always increase across jobs for now).
+ */
+ def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_,_], priority: Int): Stage = {
+ shuffleToMapStage.get(shuffleDep.shuffleId) match {
+ case Some(stage) => stage
+ case None =>
+ val stage = newStage(shuffleDep.rdd, Some(shuffleDep), priority)
+ shuffleToMapStage(shuffleDep.shuffleId) = stage
+ stage
+ }
+ }
+
+ /**
+ * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or
+ * as a result stage for the final RDD used directly in an action. The stage will also be given
+ * the provided priority.
+ */
+ def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]], priority: Int): Stage = {
+ // Kind of ugly: need to register RDDs with the cache and map output tracker here
+ // since we can't do it in the RDD constructor because # of splits is unknown
+ logInfo("Registering RDD " + rdd.id + ": " + rdd)
+ cacheTracker.registerRDD(rdd.id, rdd.splits.size)
+ if (shuffleDep != None) {
+ mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
+ }
+ val id = nextStageId.getAndIncrement()
+ val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority)
+ idToStage(id) = stage
+ stage
+ }
+
+ /**
+ * Get or create the list of parent stages for a given RDD. The stages will be assigned the
+ * provided priority if they haven't already been created with a lower priority.
+ */
+ def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
+ val parents = new HashSet[Stage]
+ val visited = new HashSet[RDD[_]]
+ def visit(r: RDD[_]) {
+ if (!visited(r)) {
+ visited += r
+ // Kind of ugly: need to register RDDs with the cache here since
+ // we can't do it in its constructor because # of splits is unknown
+ logInfo("Registering parent RDD " + r.id + ": " + r)
+ cacheTracker.registerRDD(r.id, r.splits.size)
+ for (dep <- r.dependencies) {
+ dep match {
+ case shufDep: ShuffleDependency[_,_,_] =>
+ parents += getShuffleMapStage(shufDep, priority)
+ case _ =>
+ visit(dep.rdd)
+ }
+ }
+ }
+ }
+ visit(rdd)
+ parents.toList
+ }
+
+ def getMissingParentStages(stage: Stage): List[Stage] = {
+ val missing = new HashSet[Stage]
+ val visited = new HashSet[RDD[_]]
+ def visit(rdd: RDD[_]) {
+ if (!visited(rdd)) {
+ visited += rdd
+ val locs = getCacheLocs(rdd)
+ for (p <- 0 until rdd.splits.size) {
+ if (locs(p) == Nil) {
+ for (dep <- rdd.dependencies) {
+ dep match {
+ case shufDep: ShuffleDependency[_,_,_] =>
+ val mapStage = getShuffleMapStage(shufDep, stage.priority)
+ if (!mapStage.isAvailable) {
+ missing += mapStage
+ }
+ case narrowDep: NarrowDependency[_] =>
+ visit(narrowDep.rdd)
+ }
+ }
+ }
+ }
+ }
+ }
+ visit(stage.rdd)
+ missing.toList
+ }
+
+ def runJob[T, U](
+ finalRdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ allowLocal: Boolean)
+ (implicit m: ClassManifest[U]): Array[U] =
+ {
+ val waiter = new JobWaiter(partitions.size)
+ val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
+ eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, waiter))
+ waiter.getResult() match {
+ case JobSucceeded(results: Seq[_]) =>
+ return results.asInstanceOf[Seq[U]].toArray
+ case JobFailed(exception: Exception) =>
+ throw exception
+ }
+ }
+
+ def runApproximateJob[T, U, R](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ evaluator: ApproximateEvaluator[U, R],
+ timeout: Long
+ ): PartialResult[R] =
+ {
+ val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
+ val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
+ val partitions = (0 until rdd.splits.size).toArray
+ eventQueue.put(JobSubmitted(rdd, func2, partitions, false, listener))
+ return listener.getResult() // Will throw an exception if the job fails
+ }
+
+ /**
+ * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
+ * events and responds by launching tasks. This runs in a dedicated thread and receives events
+ * via the eventQueue.
+ */
+ def run() = {
+ SparkEnv.set(env)
+
+ while (true) {
+ val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
+ val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
+ if (event != null) {
+ logDebug("Got event of type " + event.getClass.getName)
+ }
+
+ event match {
+ case JobSubmitted(finalRDD, func, partitions, allowLocal, listener) =>
+ val runId = nextRunId.getAndIncrement()
+ val finalStage = newStage(finalRDD, None, runId)
+ val job = new ActiveJob(runId, finalStage, func, partitions, listener)
+ updateCacheLocs()
+ logInfo("Got job " + job.runId + " with " + partitions.length + " output partitions")
+ logInfo("Final stage: " + finalStage)
+ logInfo("Parents of final stage: " + finalStage.parents)
+ logInfo("Missing parents: " + getMissingParentStages(finalStage))
+ if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
+ // Compute very short actions like first() or take() with no parent stages locally.
+ runLocally(job)
+ } else {
+ activeJobs += job
+ resultStageToJob(finalStage) = job
+ submitStage(finalStage)
+ }
+
+ case HostLost(host) =>
+ handleHostLost(host)
+
+ case completion: CompletionEvent =>
+ handleTaskCompletion(completion)
+
+ case null =>
+ // queue.poll() timed out, ignore it
+ }
+
+ // Periodically resubmit failed stages if some map output fetches have failed and we have
+ // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
+ // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
+ // the same time, so we want to make sure we've identified all the reduce tasks that depend
+ // on the failed node.
+ if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
+ logInfo("Resubmitting failed stages")
+ updateCacheLocs()
+ val failed2 = failed.toArray
+ failed.clear()
+ for (stage <- failed2.sortBy(_.priority)) {
+ submitStage(stage)
+ }
+ } else {
+ // TODO: We might want to run this less often, when we are sure that something has become
+ // runnable that wasn't before.
+ logDebug("Checking for newly runnable parent stages")
+ logDebug("running: " + running)
+ logDebug("waiting: " + waiting)
+ logDebug("failed: " + failed)
+ val waiting2 = waiting.toArray
+ waiting.clear()
+ for (stage <- waiting2.sortBy(_.priority)) {
+ submitStage(stage)
+ }
+ }
+ }
+ }
+
+ /**
+ * Run a job on an RDD locally, assuming it has only a single partition and no dependencies.
+ * We run the operation in a separate thread just in case it takes a bunch of time, so that we
+ * don't block the DAGScheduler event loop or other concurrent jobs.
+ */
+ def runLocally(job: ActiveJob) {
+ logInfo("Computing the requested partition locally")
+ new Thread("Local computation of job " + job.runId) {
+ override def run() {
+ try {
+ SparkEnv.set(env)
+ val rdd = job.finalStage.rdd
+ val split = rdd.splits(job.partitions(0))
+ val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
+ val result = job.func(taskContext, rdd.iterator(split))
+ job.listener.taskSucceeded(0, result)
+ } catch {
+ case e: Exception =>
+ job.listener.jobFailed(e)
+ }
+ }
+ }.start()
+ }
+
+ def submitStage(stage: Stage) {
+ logDebug("submitStage(" + stage + ")")
+ if (!waiting(stage) && !running(stage) && !failed(stage)) {
+ val missing = getMissingParentStages(stage).sortBy(_.id)
+ logDebug("missing: " + missing)
+ if (missing == Nil) {
+ logInfo("Submitting " + stage + ", which has no missing parents")
+ submitMissingTasks(stage)
+ running += stage
+ } else {
+ for (parent <- missing) {
+ submitStage(parent)
+ }
+ waiting += stage
+ }
+ }
+ }
+
+ def submitMissingTasks(stage: Stage) {
+ logDebug("submitMissingTasks(" + stage + ")")
+ // Get our pending tasks and remember them in our pendingTasks entry
+ val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
+ myPending.clear()
+ var tasks = ArrayBuffer[Task[_]]()
+ if (stage.isShuffleMap) {
+ for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
+ val locs = getPreferredLocs(stage.rdd, p)
+ tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
+ }
+ } else {
+ // This is a final stage; figure out its job's missing partitions
+ val job = resultStageToJob(stage)
+ for (id <- 0 until job.numPartitions if (!job.finished(id))) {
+ val partition = job.partitions(id)
+ val locs = getPreferredLocs(stage.rdd, partition)
+ tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id)
+ }
+ }
+ if (tasks.size > 0) {
+ logInfo("Submitting " + tasks.size + " missing tasks from " + stage)
+ myPending ++= tasks
+ logDebug("New pending tasks: " + myPending)
+ taskSched.submitTasks(
+ new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority))
+ } else {
+ logDebug("Stage " + stage + " is actually done; %b %d %d".format(
+ stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
+ running -= stage
+ }
+ }
+
+ /**
+ * Responds to a task finishing. This is called inside the event loop so it assumes that it can
+ * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
+ */
+ def handleTaskCompletion(event: CompletionEvent) {
+ val task = event.task
+ val stage = idToStage(task.stageId)
+ event.reason match {
+ case Success =>
+ logInfo("Completed " + task)
+ if (event.accumUpdates != null) {
+ Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
+ }
+ pendingTasks(stage) -= task
+ task match {
+ case rt: ResultTask[_, _] =>
+ resultStageToJob.get(stage) match {
+ case Some(job) =>
+ if (!job.finished(rt.outputId)) {
+ job.finished(rt.outputId) = true
+ job.numFinished += 1
+ job.listener.taskSucceeded(rt.outputId, event.result)
+ // If the whole job has finished, remove it
+ if (job.numFinished == job.numPartitions) {
+ activeJobs -= job
+ resultStageToJob -= stage
+ running -= stage
+ }
+ }
+ case None =>
+ logInfo("Ignoring result from " + rt + " because its job has finished")
+ }
+
+ case smt: ShuffleMapTask =>
+ val stage = idToStage(smt.stageId)
+ val bmAddress = event.result.asInstanceOf[BlockManagerId]
+ val host = bmAddress.ip
+ logInfo("ShuffleMapTask finished with host " + host)
+ if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos
+ stage.addOutputLoc(smt.partition, bmAddress)
+ }
+ if (running.contains(stage) && pendingTasks(stage).isEmpty) {
+ logInfo(stage + " finished; looking for newly runnable stages")
+ running -= stage
+ logInfo("running: " + running)
+ logInfo("waiting: " + waiting)
+ logInfo("failed: " + failed)
+ if (stage.shuffleDep != None) {
+ mapOutputTracker.registerMapOutputs(
+ stage.shuffleDep.get.shuffleId,
+ stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray)
+ }
+ updateCacheLocs()
+ if (stage.outputLocs.count(_ == Nil) != 0) {
+ // Some tasks had failed; let's resubmit this stage
+ // TODO: Lower-level scheduler should also deal with this
+ logInfo("Resubmitting " + stage + " because some of its tasks had failed: " +
+ stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", "))
+ submitStage(stage)
+ } else {
+ val newlyRunnable = new ArrayBuffer[Stage]
+ for (stage <- waiting) {
+ logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage))
+ }
+ for (stage <- waiting if getMissingParentStages(stage) == Nil) {
+ newlyRunnable += stage
+ }
+ waiting --= newlyRunnable
+ running ++= newlyRunnable
+ for (stage <- newlyRunnable.sortBy(_.id)) {
+ submitMissingTasks(stage)
+ }
+ }
+ }
+ }
+
+ case Resubmitted =>
+ logInfo("Resubmitted " + task + ", so marking it as still running")
+ pendingTasks(stage) += task
+
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ // Mark the stage that the reducer was in as unrunnable
+ val failedStage = idToStage(task.stageId)
+ running -= failedStage
+ failed += failedStage
+ // TODO: Cancel running tasks in the stage
+ logInfo("Marking " + failedStage + " for resubmision due to a fetch failure")
+ // Mark the map whose fetch failed as broken in the map stage
+ val mapStage = shuffleToMapStage(shuffleId)
+ mapStage.removeOutputLoc(mapId, bmAddress)
+ mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
+ logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission")
+ failed += mapStage
+ // Remember that a fetch failed now; this is used to resubmit the broken
+ // stages later, after a small wait (to give other tasks the chance to fail)
+ lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock
+ // TODO: mark the host as failed only if there were lots of fetch failures on it
+ if (bmAddress != null) {
+ handleHostLost(bmAddress.ip)
+ }
+
+ case _ =>
+ // Non-fetch failure -- probably a bug in the job, so bail out
+ // TODO: Cancel all tasks that are still running
+ resultStageToJob.get(stage) match {
+ case Some(job) =>
+ val error = new SparkException("Task failed: " + task + ", reason: " + event.reason)
+ job.listener.jobFailed(error)
+ activeJobs -= job
+ resultStageToJob -= stage
+ case None =>
+ logInfo("Ignoring result from " + task + " because its job has finished")
+ }
+ }
+ }
+
+ /**
+ * Responds to a host being lost. This is called inside the event loop so it assumes that it can
+ * modify the scheduler's internal state. Use hostLost() to post a host lost event from outside.
+ */
+ def handleHostLost(host: String) {
+ if (!deadHosts.contains(host)) {
+ logInfo("Host lost: " + host)
+ deadHosts += host
+ BlockManagerMaster.notifyADeadHost(host)
+ // TODO: This will be really slow if we keep accumulating shuffle map stages
+ for ((shuffleId, stage) <- shuffleToMapStage) {
+ stage.removeOutputsOnHost(host)
+ val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
+ mapOutputTracker.registerMapOutputs(shuffleId, locs, true)
+ }
+ cacheTracker.cacheLost(host)
+ updateCacheLocs()
+ }
+ }
+
+ def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
+ // If the partition is cached, return the cache locations
+ val cached = getCacheLocs(rdd)(partition)
+ if (cached != Nil) {
+ return cached
+ }
+ // If the RDD has some placement preferences (as is the case for input RDDs), get those
+ val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList
+ if (rddPrefs != Nil) {
+ return rddPrefs
+ }
+ // If the RDD has narrow dependencies, pick the first partition of the first narrow dep
+ // that has any placement preferences. Ideally we would choose based on transfer sizes,
+ // but this will do for now.
+ rdd.dependencies.foreach(_ match {
+ case n: NarrowDependency[_] =>
+ for (inPart <- n.getParents(partition)) {
+ val locs = getPreferredLocs(n.rdd, inPart)
+ if (locs != Nil)
+ return locs;
+ }
+ case _ =>
+ })
+ return Nil
+ }
+
+ def stop() {
+ // TODO: Put a stop event on our queue and break the event loop
+ taskSched.stop()
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
new file mode 100644
index 0000000000..c10abc9202
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
@@ -0,0 +1,30 @@
+package spark.scheduler
+
+import scala.collection.mutable.Map
+
+import spark._
+
+/**
+ * Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue
+ * architecture where any thread can post an event (e.g. a task finishing or a new job being
+ * submitted) but there is a single "logic" thread that reads these events and takes decisions.
+ * This greatly simplifies synchronization.
+ */
+sealed trait DAGSchedulerEvent
+
+case class JobSubmitted(
+ finalRDD: RDD[_],
+ func: (TaskContext, Iterator[_]) => _,
+ partitions: Array[Int],
+ allowLocal: Boolean,
+ listener: JobListener)
+ extends DAGSchedulerEvent
+
+case class CompletionEvent(
+ task: Task[_],
+ reason: TaskEndReason,
+ result: Any,
+ accumUpdates: Map[Long, Any])
+ extends DAGSchedulerEvent
+
+case class HostLost(host: String) extends DAGSchedulerEvent
diff --git a/core/src/main/scala/spark/scheduler/JobListener.scala b/core/src/main/scala/spark/scheduler/JobListener.scala
new file mode 100644
index 0000000000..d4dd536a7d
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/JobListener.scala
@@ -0,0 +1,11 @@
+package spark.scheduler
+
+/**
+ * Interface used to listen for job completion or failure events after submitting a job to the
+ * DAGScheduler. The listener is notified each time a task succeeds, as well as if the whole
+ * job fails (and no further taskSucceeded events will happen).
+ */
+trait JobListener {
+ def taskSucceeded(index: Int, result: Any)
+ def jobFailed(exception: Exception)
+}
diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala
new file mode 100644
index 0000000000..62b458eccb
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/JobResult.scala
@@ -0,0 +1,9 @@
+package spark.scheduler
+
+/**
+ * A result of a job in the DAGScheduler.
+ */
+sealed trait JobResult
+
+case class JobSucceeded(results: Seq[_]) extends JobResult
+case class JobFailed(exception: Exception) extends JobResult
diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala
new file mode 100644
index 0000000000..be8ec9bd7b
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/JobWaiter.scala
@@ -0,0 +1,43 @@
+package spark.scheduler
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * An object that waits for a DAGScheduler job to complete.
+ */
+class JobWaiter(totalTasks: Int) extends JobListener {
+ private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null)
+ private var finishedTasks = 0
+
+ private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
+ private var jobResult: JobResult = null // If the job is finished, this will be its result
+
+ override def taskSucceeded(index: Int, result: Any) = synchronized {
+ if (jobFinished) {
+ throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
+ }
+ taskResults(index) = result
+ finishedTasks += 1
+ if (finishedTasks == totalTasks) {
+ jobFinished = true
+ jobResult = JobSucceeded(taskResults)
+ this.notifyAll()
+ }
+ }
+
+ override def jobFailed(exception: Exception) = synchronized {
+ if (jobFinished) {
+ throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter")
+ }
+ jobFinished = true
+ jobResult = JobFailed(exception)
+ this.notifyAll()
+ }
+
+ def getResult(): JobResult = synchronized {
+ while (!jobFinished) {
+ this.wait()
+ }
+ return jobResult
+ }
+}
diff --git a/core/src/main/scala/spark/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index 3952bf85b2..d2fab55b5e 100644
--- a/core/src/main/scala/spark/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -1,14 +1,15 @@
-package spark
+package spark.scheduler
+
+import spark._
class ResultTask[T, U](
- runId: Int,
- stageId: Int,
- rdd: RDD[T],
+ stageId: Int,
+ rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
- val partition: Int,
- locs: Seq[String],
+ val partition: Int,
+ @transient locs: Seq[String],
val outputId: Int)
- extends DAGTask[U](runId, stageId) {
+ extends Task[U](stageId) {
val split = rdd.splits(partition)
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
new file mode 100644
index 0000000000..317faa0851
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -0,0 +1,135 @@
+package spark.scheduler
+
+import java.io._
+import java.util.HashMap
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
+import com.ning.compress.lzf.LZFInputStream
+import com.ning.compress.lzf.LZFOutputStream
+
+import spark._
+import spark.storage._
+
+object ShuffleMapTask {
+ val serializedInfoCache = new HashMap[Int, Array[Byte]]
+ val deserializedInfoCache = new HashMap[Int, (RDD[_], ShuffleDependency[_,_,_])]
+
+ def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = {
+ synchronized {
+ val old = serializedInfoCache.get(stageId)
+ if (old != null) {
+ return old
+ } else {
+ val out = new ByteArrayOutputStream
+ val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
+ objOut.writeObject(rdd)
+ objOut.writeObject(dep)
+ objOut.close()
+ val bytes = out.toByteArray
+ serializedInfoCache.put(stageId, bytes)
+ return bytes
+ }
+ }
+ }
+
+ def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = {
+ synchronized {
+ val old = deserializedInfoCache.get(stageId)
+ if (old != null) {
+ return old
+ } else {
+ val loader = currentThread.getContextClassLoader
+ val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+ val objIn = new ObjectInputStream(in) {
+ override def resolveClass(desc: ObjectStreamClass) =
+ Class.forName(desc.getName, false, loader)
+ }
+ val rdd = objIn.readObject().asInstanceOf[RDD[_]]
+ val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]]
+ val tuple = (rdd, dep)
+ deserializedInfoCache.put(stageId, tuple)
+ return tuple
+ }
+ }
+ }
+}
+
+class ShuffleMapTask(
+ stageId: Int,
+ var rdd: RDD[_],
+ var dep: ShuffleDependency[_,_,_],
+ var partition: Int,
+ @transient var locs: Seq[String])
+ extends Task[BlockManagerId](stageId)
+ with Externalizable
+ with Logging {
+
+ def this() = this(0, null, null, 0, null)
+
+ var split = if (rdd == null) {
+ null
+ } else {
+ rdd.splits(partition)
+ }
+
+ override def writeExternal(out: ObjectOutput) {
+ out.writeInt(stageId)
+ val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ out.writeInt(partition)
+ out.writeObject(split)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ val stageId = in.readInt()
+ val numBytes = in.readInt()
+ val bytes = new Array[Byte](numBytes)
+ in.readFully(bytes)
+ val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
+ rdd = rdd_
+ dep = dep_
+ partition = in.readInt()
+ split = in.readObject().asInstanceOf[Split]
+ }
+
+ override def run(attemptId: Int): BlockManagerId = {
+ val numOutputSplits = dep.partitioner.numPartitions
+ val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
+ val partitioner = dep.partitioner.asInstanceOf[Partitioner]
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
+ for (elem <- rdd.iterator(split)) {
+ val (k, v) = elem.asInstanceOf[(Any, Any)]
+ var bucketId = partitioner.getPartition(k)
+ val bucket = buckets(bucketId)
+ var existing = bucket.get(k)
+ if (existing == null) {
+ bucket.put(k, aggregator.createCombiner(v))
+ } else {
+ bucket.put(k, aggregator.mergeValue(existing, v))
+ }
+ }
+ val ser = SparkEnv.get.serializer.newInstance()
+ val blockManager = SparkEnv.get.blockManager
+ for (i <- 0 until numOutputSplits) {
+ val blockId = "shuffleid_" + dep.shuffleId + "_" + partition + "_" + i
+ val arr = new ArrayBuffer[Any]
+ val iter = buckets(i).entrySet().iterator()
+ while (iter.hasNext()) {
+ val entry = iter.next()
+ arr += ((entry.getKey(), entry.getValue()))
+ }
+ // TODO: This should probably be DISK_ONLY
+ blockManager.put(blockId, arr.iterator, StorageLevel.MEMORY_ONLY, false)
+ }
+ return SparkEnv.get.blockManager.blockManagerId
+ }
+
+ override def preferredLocations: Seq[String] = locs
+
+ override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
+}
diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala
new file mode 100644
index 0000000000..cd660c9085
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/Stage.scala
@@ -0,0 +1,86 @@
+package spark.scheduler
+
+import java.net.URI
+
+import spark._
+import spark.storage.BlockManagerId
+
+/**
+ * A stage is a set of independent tasks all computing the same function that need to run as part
+ * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run
+ * by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the
+ * DAGScheduler runs these stages in topological order.
+ *
+ * Each Stage can either be a shuffle map stage, in which case its tasks' results are input for
+ * another stage, or a result stage, in which case its tasks directly compute the action that
+ * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes
+ * that each output partition is on.
+ *
+ * Each Stage also has a priority, which is (by default) based on the job it was submitted in.
+ * This allows Stages from earlier jobs to be computed first or recovered faster on failure.
+ */
+class Stage(
+ val id: Int,
+ val rdd: RDD[_],
+ val shuffleDep: Option[ShuffleDependency[_,_,_]], // Output shuffle if stage is a map stage
+ val parents: List[Stage],
+ val priority: Int)
+ extends Logging {
+
+ val isShuffleMap = shuffleDep != None
+ val numPartitions = rdd.splits.size
+ val outputLocs = Array.fill[List[BlockManagerId]](numPartitions)(Nil)
+ var numAvailableOutputs = 0
+
+ private var nextAttemptId = 0
+
+ def isAvailable: Boolean = {
+ if (/*parents.size == 0 &&*/ !isShuffleMap) {
+ true
+ } else {
+ numAvailableOutputs == numPartitions
+ }
+ }
+
+ def addOutputLoc(partition: Int, bmAddress: BlockManagerId) {
+ val prevList = outputLocs(partition)
+ outputLocs(partition) = bmAddress :: prevList
+ if (prevList == Nil)
+ numAvailableOutputs += 1
+ }
+
+ def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) {
+ val prevList = outputLocs(partition)
+ val newList = prevList.filterNot(_ == bmAddress)
+ outputLocs(partition) = newList
+ if (prevList != Nil && newList == Nil) {
+ numAvailableOutputs -= 1
+ }
+ }
+
+ def removeOutputsOnHost(host: String) {
+ var becameUnavailable = false
+ for (partition <- 0 until numPartitions) {
+ val prevList = outputLocs(partition)
+ val newList = prevList.filterNot(_.ip == host)
+ outputLocs(partition) = newList
+ if (prevList != Nil && newList == Nil) {
+ becameUnavailable = true
+ numAvailableOutputs -= 1
+ }
+ }
+ if (becameUnavailable) {
+ logInfo("%s is now unavailable on %s (%d/%d, %s)".format(this, host, numAvailableOutputs, numPartitions, isAvailable))
+ }
+ }
+
+ def newAttemptId(): Int = {
+ val id = nextAttemptId
+ nextAttemptId += 1
+ return id
+ }
+
+ override def toString = "Stage " + id // + ": [RDD = " + rdd.id + ", isShuffle = " + isShuffleMap + "]"
+
+ override def hashCode(): Int = id
+}
diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala
new file mode 100644
index 0000000000..42325956ba
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/Task.scala
@@ -0,0 +1,11 @@
+package spark.scheduler
+
+/**
+ * A task to execute on a worker node.
+ */
+abstract class Task[T](val stageId: Int) extends Serializable {
+ def run(attemptId: Int): T
+ def preferredLocations: Seq[String] = Nil
+
+ var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler.
+}
diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala
new file mode 100644
index 0000000000..868ddb237c
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/TaskResult.scala
@@ -0,0 +1,34 @@
+package spark.scheduler
+
+import java.io._
+
+import scala.collection.mutable.Map
+
+// Task result. Also contains updates to accumulator variables.
+// TODO: Use of distributed cache to return result is a hack to get around
+// what seems to be a bug with messages over 60KB in libprocess; fix it
+class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Externalizable {
+ def this() = this(null.asInstanceOf[T], null)
+
+ override def writeExternal(out: ObjectOutput) {
+ out.writeObject(value)
+ out.writeInt(accumUpdates.size)
+ for ((key, value) <- accumUpdates) {
+ out.writeLong(key)
+ out.writeObject(value)
+ }
+ }
+
+ override def readExternal(in: ObjectInput) {
+ value = in.readObject().asInstanceOf[T]
+ val numUpdates = in.readInt
+ if (numUpdates == 0) {
+ accumUpdates = null
+ } else {
+ accumUpdates = Map()
+ for (i <- 0 until numUpdates) {
+ accumUpdates(in.readLong()) = in.readObject()
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
new file mode 100644
index 0000000000..cb7c375d97
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
@@ -0,0 +1,27 @@
+package spark.scheduler
+
+/**
+ * Low-level task scheduler interface, implemented by both MesosScheduler and LocalScheduler.
+ * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
+ * and are responsible for sending the tasks to the cluster, running them, retrying if there
+ * are failures, and mitigating stragglers. They return events to the DAGScheduler through
+ * the TaskSchedulerListener interface.
+ */
+trait TaskScheduler {
+ def start(): Unit
+
+ // Wait for registration with Mesos.
+ def waitForRegister(): Unit
+
+ // Disconnect from the cluster.
+ def stop(): Unit
+
+ // Submit a sequence of tasks to run.
+ def submitTasks(taskSet: TaskSet): Unit
+
+ // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
+ def setListener(listener: TaskSchedulerListener): Unit
+
+ // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
+ def defaultParallelism(): Int
+}
diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
new file mode 100644
index 0000000000..a647eec9e4
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
@@ -0,0 +1,16 @@
+package spark.scheduler
+
+import scala.collection.mutable.Map
+
+import spark.TaskEndReason
+
+/**
+ * Interface for getting events back from the TaskScheduler.
+ */
+trait TaskSchedulerListener {
+ // A task has finished or failed.
+ def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit
+
+ // A node was lost from the cluster.
+ def hostLost(host: String): Unit
+}
diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala
new file mode 100644
index 0000000000..6f29dd2e9d
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/TaskSet.scala
@@ -0,0 +1,9 @@
+package spark.scheduler
+
+/**
+ * A set of tasks submitted together to the low-level TaskScheduler, usually representing
+ * missing partitions of a particular stage.
+ */
+class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) {
+ val id: String = stageId + "." + attempt
+}
diff --git a/core/src/main/scala/spark/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 3910c7b09e..8339c0ae90 100644
--- a/core/src/main/scala/spark/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -1,16 +1,21 @@
-package spark
+package spark.scheduler.local
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger
+import spark._
+import spark.scheduler._
+
/**
- * A simple Scheduler implementation that runs tasks locally in a thread pool. Optionally the
- * scheduler also allows each task to fail up to maxFailures times, which is useful for testing
- * fault recovery.
+ * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
+ * the scheduler also allows each task to fail up to maxFailures times, which is useful for
+ * testing fault recovery.
*/
-private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGScheduler with Logging {
+class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with Logging {
var attemptId = new AtomicInteger(0)
var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
+ val env = SparkEnv.get
+ var listener: TaskSchedulerListener = null
// TODO: Need to take into account stage priority in scheduling
@@ -18,7 +23,12 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
override def waitForRegister() {}
- override def submitTasks(tasks: Seq[Task[_]], runId: Int) {
+ override def setListener(listener: TaskSchedulerListener) {
+ this.listener = listener
+ }
+
+ override def submitTasks(taskSet: TaskSet) {
+ val tasks = taskSet.tasks
val failCount = new Array[Int](tasks.size)
def submitTask(task: Task[_], idInJob: Int) {
@@ -38,23 +48,14 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
Accumulators.clear
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val startTime = System.currentTimeMillis
- val bytes = ser.serialize(task)
- val timeTaken = System.currentTimeMillis - startTime
- logInfo("Size of task %d is %d bytes and took %d ms to serialize".format(
- idInJob, bytes.size, timeTaken))
- val deserializedTask = ser.deserialize[Task[_]](bytes, currentThread.getContextClassLoader)
+ val bytes = Utils.serialize(task)
+ logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes")
+ val deserializedTask = Utils.deserialize[Task[_]](
+ bytes, Thread.currentThread.getContextClassLoader)
val result: Any = deserializedTask.run(attemptId)
-
- // Serialize and deserialize the result to emulate what the mesos
- // executor does. This is useful to catch serialization errors early
- // on in development (so when users move their local Spark programs
- // to the cluster, they don't get surprised by serialization errors).
- val resultToReturn = ser.deserialize[Any](ser.serialize(result))
val accumUpdates = Accumulators.values
logInfo("Finished task " + idInJob)
- taskEnded(task, Success, resultToReturn, accumUpdates)
+ listener.taskEnded(task, Success, result, accumUpdates)
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
@@ -64,7 +65,7 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
submitTask(task, idInJob)
} else {
// TODO: Do something nicer here to return all the way to the user
- taskEnded(task, new ExceptionFailure(t), null, null)
+ listener.taskEnded(task, new ExceptionFailure(t), null, null)
}
}
}
diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala
new file mode 100644
index 0000000000..8182901ce3
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala
@@ -0,0 +1,364 @@
+package spark.scheduler.mesos
+
+import java.io.{File, FileInputStream, FileOutputStream}
+import java.util.{ArrayList => JArrayList}
+import java.util.{List => JList}
+import java.util.{HashMap => JHashMap}
+import java.util.concurrent._
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.collection.mutable.Map
+import scala.collection.mutable.PriorityQueue
+import scala.collection.JavaConversions._
+import scala.math.Ordering
+
+import akka.actor._
+import akka.actor.Actor
+import akka.actor.Actor._
+import akka.actor.Channel
+import akka.serialization.RemoteActorSerialization._
+
+import com.google.protobuf.ByteString
+
+import org.apache.mesos.{Scheduler => MScheduler}
+import org.apache.mesos._
+import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
+
+import spark._
+import spark.scheduler._
+
+sealed trait CoarseMesosSchedulerMessage
+case class RegisterSlave(slaveId: String, host: String, port: Int) extends CoarseMesosSchedulerMessage
+case class StatusUpdate(slaveId: String, status: TaskStatus) extends CoarseMesosSchedulerMessage
+case class LaunchTask(slaveId: String, task: MTaskInfo) extends CoarseMesosSchedulerMessage
+case class ReviveOffers() extends CoarseMesosSchedulerMessage
+
+case class FakeOffer(slaveId: String, host: String, cores: Int)
+
+/**
+ * Mesos scheduler that uses coarse-grained tasks and does its own fine-grained scheduling inside
+ * them using Akka actors for messaging. Clients should first call start(), then submit task sets
+ * through the runTasks method.
+ *
+ * TODO: This is a pretty big hack for now.
+ */
+class CoarseMesosScheduler(
+ sc: SparkContext,
+ master: String,
+ frameworkName: String)
+ extends MesosScheduler(sc, master, frameworkName) {
+
+ val CORES_PER_SLAVE = System.getProperty("spark.coarseMesosScheduler.coresPerSlave", "4").toInt
+
+ class MasterActor extends Actor {
+ val slaveActor = new HashMap[String, ActorRef]
+ val slaveHost = new HashMap[String, String]
+ val freeCores = new HashMap[String, Int]
+
+ def receive = {
+ case RegisterSlave(slaveId, host, port) =>
+ slaveActor(slaveId) = remote.actorFor("WorkerActor", host, port)
+ logInfo("Slave actor: " + slaveActor(slaveId))
+ slaveHost(slaveId) = host
+ freeCores(slaveId) = CORES_PER_SLAVE
+ makeFakeOffers()
+
+ case StatusUpdate(slaveId, status) =>
+ fakeStatusUpdate(status)
+ if (isFinished(status.getState)) {
+ freeCores(slaveId) += 1
+ makeFakeOffers(slaveId)
+ }
+
+ case LaunchTask(slaveId, task) =>
+ freeCores(slaveId) -= 1
+ slaveActor(slaveId) ! LaunchTask(slaveId, task)
+
+ case ReviveOffers() =>
+ logInfo("Reviving offers")
+ makeFakeOffers()
+ }
+
+ // Make fake resource offers for all slaves
+ def makeFakeOffers() {
+ fakeResourceOffers(slaveHost.toSeq.map{case (id, host) => FakeOffer(id, host, freeCores(id))})
+ }
+
+ // Make fake resource offers for all slaves
+ def makeFakeOffers(slaveId: String) {
+ fakeResourceOffers(Seq(FakeOffer(slaveId, slaveHost(slaveId), freeCores(slaveId))))
+ }
+ }
+
+ val masterActor: ActorRef = actorOf(new MasterActor)
+ remote.register("MasterActor", masterActor)
+ masterActor.start()
+
+ val taskIdsOnSlave = new HashMap[String, HashSet[String]]
+
+ /**
+ * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets
+ * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that
+ * tasks are balanced across the cluster.
+ */
+ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
+ synchronized {
+ val tasks = offers.map(o => new JArrayList[MTaskInfo])
+ for (i <- 0 until offers.size) {
+ val o = offers.get(i)
+ val slaveId = o.getSlaveId.getValue
+ if (!slaveIdToHost.contains(slaveId)) {
+ slaveIdToHost(slaveId) = o.getHostname
+ hostsAlive += o.getHostname
+ taskIdsOnSlave(slaveId) = new HashSet[String]
+ // Launch an infinite task on the node that will talk to the MasterActor to get fake tasks
+ val cpuRes = Resource.newBuilder()
+ .setName("cpus")
+ .setType(Value.Type.SCALAR)
+ .setScalar(Value.Scalar.newBuilder().setValue(1).build())
+ .build()
+ val task = new WorkerTask(slaveId, o.getHostname)
+ val serializedTask = Utils.serialize(task)
+ tasks(i).add(MTaskInfo.newBuilder()
+ .setTaskId(newTaskId())
+ .setSlaveId(o.getSlaveId)
+ .setExecutor(executorInfo)
+ .setName("worker task")
+ .addResources(cpuRes)
+ .setData(ByteString.copyFrom(serializedTask))
+ .build())
+ }
+ }
+ val filters = Filters.newBuilder().setRefuseSeconds(10).build()
+ for (i <- 0 until offers.size) {
+ d.launchTasks(offers(i).getId(), tasks(i), filters)
+ }
+ }
+ }
+
+ override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
+ val tid = status.getTaskId.getValue
+ var taskSetToUpdate: Option[TaskSetManager] = None
+ var taskFailed = false
+ synchronized {
+ try {
+ taskIdToTaskSetId.get(tid) match {
+ case Some(taskSetId) =>
+ if (activeTaskSets.contains(taskSetId)) {
+ //activeTaskSets(taskSetId).statusUpdate(status)
+ taskSetToUpdate = Some(activeTaskSets(taskSetId))
+ }
+ if (isFinished(status.getState)) {
+ taskIdToTaskSetId.remove(tid)
+ if (taskSetTaskIds.contains(taskSetId)) {
+ taskSetTaskIds(taskSetId) -= tid
+ }
+ val slaveId = taskIdToSlaveId(tid)
+ taskIdToSlaveId -= tid
+ taskIdsOnSlave(slaveId) -= tid
+ }
+ if (status.getState == TaskState.TASK_FAILED) {
+ taskFailed = true
+ }
+ case None =>
+ logInfo("Ignoring update from TID " + tid + " because its task set is gone")
+ }
+ } catch {
+ case e: Exception => logError("Exception in statusUpdate", e)
+ }
+ }
+ // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock
+ if (taskSetToUpdate != None) {
+ taskSetToUpdate.get.statusUpdate(status)
+ }
+ if (taskFailed) {
+ // Revive offers if a task had failed for some reason other than host lost
+ reviveOffers()
+ }
+ }
+
+ override def slaveLost(d: SchedulerDriver, s: SlaveID) {
+ logInfo("Slave lost: " + s.getValue)
+ var failedHost: Option[String] = None
+ var lostTids: Option[HashSet[String]] = None
+ synchronized {
+ val slaveId = s.getValue
+ val host = slaveIdToHost(slaveId)
+ if (hostsAlive.contains(host)) {
+ slaveIdsWithExecutors -= slaveId
+ hostsAlive -= host
+ failedHost = Some(host)
+ lostTids = Some(taskIdsOnSlave(slaveId))
+ logInfo("failedHost: " + host)
+ logInfo("lostTids: " + lostTids)
+ taskIdsOnSlave -= slaveId
+ activeTaskSetsQueue.foreach(_.hostLost(host))
+ }
+ }
+ if (failedHost != None) {
+ // Report all the tasks on the failed host as lost, without holding a lock on this
+ for (tid <- lostTids.get; taskSetId <- taskIdToTaskSetId.get(tid)) {
+ // TODO: Maybe call our statusUpdate() instead to clean our internal data structures
+ activeTaskSets(taskSetId).statusUpdate(TaskStatus.newBuilder()
+ .setTaskId(TaskID.newBuilder().setValue(tid).build())
+ .setState(TaskState.TASK_LOST)
+ .build())
+ }
+ // Also report the loss to the DAGScheduler
+ listener.hostLost(failedHost.get)
+ reviveOffers();
+ }
+ }
+
+ override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
+
+ // Check for speculatable tasks in all our active jobs.
+ override def checkSpeculatableTasks() {
+ var shouldRevive = false
+ synchronized {
+ for (ts <- activeTaskSetsQueue) {
+ shouldRevive |= ts.checkSpeculatableTasks()
+ }
+ }
+ if (shouldRevive) {
+ reviveOffers()
+ }
+ }
+
+
+ val lock2 = new Object
+ var firstWait = true
+
+ override def waitForRegister() {
+ lock2.synchronized {
+ if (firstWait) {
+ super.waitForRegister()
+ Thread.sleep(5000)
+ firstWait = false
+ }
+ }
+ }
+
+ def fakeStatusUpdate(status: TaskStatus) {
+ statusUpdate(driver, status)
+ }
+
+ def fakeResourceOffers(offers: Seq[FakeOffer]) {
+ logDebug("fakeResourceOffers: " + offers)
+ val availableCpus = offers.map(_.cores.toDouble).toArray
+ var launchedTask = false
+ for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
+ do {
+ launchedTask = false
+ for (i <- 0 until offers.size if hostsAlive.contains(offers(i).host)) {
+ manager.slaveOffer(offers(i).slaveId, offers(i).host, availableCpus(i)) match {
+ case Some(task) =>
+ val tid = task.getTaskId.getValue
+ val sid = offers(i).slaveId
+ taskIdToTaskSetId(tid) = manager.taskSet.id
+ taskSetTaskIds(manager.taskSet.id) += tid
+ taskIdToSlaveId(tid) = sid
+ taskIdsOnSlave(sid) += tid
+ slaveIdsWithExecutors += sid
+ availableCpus(i) -= getResource(task.getResourcesList(), "cpus")
+ launchedTask = true
+ masterActor ! LaunchTask(sid, task)
+
+ case None => {}
+ }
+ }
+ } while (launchedTask)
+ }
+ }
+
+ override def reviveOffers() {
+ masterActor ! ReviveOffers()
+ }
+}
+
+class WorkerTask(slaveId: String, host: String) extends Task[Unit](-1) {
+ generation = 0
+
+ def run(id: Int): Unit = {
+ val actor = actorOf(new WorkerActor(slaveId, host))
+ if (!remote.isRunning) {
+ remote.start(Utils.localIpAddress, 7078)
+ }
+ remote.register("WorkerActor", actor)
+ actor.start()
+ while (true) {
+ Thread.sleep(10000)
+ }
+ }
+}
+
+class WorkerActor(slaveId: String, host: String) extends Actor with Logging {
+ val env = SparkEnv.get
+ val classLoader = currentThread.getContextClassLoader
+ val threadPool = new ThreadPoolExecutor(
+ 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
+
+ val masterIp: String = System.getProperty("spark.master.host", "localhost")
+ val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt
+ val masterActor = remote.actorFor("MasterActor", masterIp, masterPort)
+
+ class TaskRunner(desc: MTaskInfo)
+ extends Runnable {
+ override def run() = {
+ val tid = desc.getTaskId.getValue
+ logInfo("Running task ID " + tid)
+ try {
+ SparkEnv.set(env)
+ Thread.currentThread.setContextClassLoader(classLoader)
+ Accumulators.clear
+ val task = Utils.deserialize[Task[Any]](desc.getData.toByteArray, classLoader)
+ env.mapOutputTracker.updateGeneration(task.generation)
+ val value = task.run(tid.toInt)
+ val accumUpdates = Accumulators.values
+ val result = new TaskResult(value, accumUpdates)
+ masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
+ .setTaskId(desc.getTaskId)
+ .setState(TaskState.TASK_FINISHED)
+ .setData(ByteString.copyFrom(Utils.serialize(result)))
+ .build())
+ logInfo("Finished task ID " + tid)
+ } catch {
+ case ffe: FetchFailedException => {
+ val reason = ffe.toTaskEndReason
+ masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
+ .setTaskId(desc.getTaskId)
+ .setState(TaskState.TASK_FAILED)
+ .setData(ByteString.copyFrom(Utils.serialize(reason)))
+ .build())
+ }
+ case t: Throwable => {
+ val reason = ExceptionFailure(t)
+ masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
+ .setTaskId(desc.getTaskId)
+ .setState(TaskState.TASK_FAILED)
+ .setData(ByteString.copyFrom(Utils.serialize(reason)))
+ .build())
+
+ // TODO: Should we exit the whole executor here? On the one hand, the failed task may
+ // have left some weird state around depending on when the exception was thrown, but on
+ // the other hand, maybe we could detect that when future tasks fail and exit then.
+ logError("Exception in task ID " + tid, t)
+ //System.exit(1)
+ }
+ }
+ }
+ }
+
+ override def preStart {
+ val ref = toRemoteActorRefProtocol(self).toByteArray
+ logInfo("Registering with master")
+ masterActor ! RegisterSlave(slaveId, host, remote.address.getPort)
+ }
+
+ override def receive = {
+ case LaunchTask(slaveId, task) =>
+ threadPool.execute(new TaskRunner(task))
+ }
+}
diff --git a/core/src/main/scala/spark/MesosScheduler.scala b/core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala
index a7711e0d35..f72618c03f 100644
--- a/core/src/main/scala/spark/MesosScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala
@@ -1,4 +1,4 @@
-package spark
+package spark.scheduler.mesos
import java.io.{File, FileInputStream, FileOutputStream}
import java.util.{ArrayList => JArrayList}
@@ -17,20 +17,23 @@ import com.google.protobuf.ByteString
import org.apache.mesos.{Scheduler => MScheduler}
import org.apache.mesos._
-import org.apache.mesos.Protos._
+import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
+
+import spark._
+import spark.scheduler._
/**
- * The main Scheduler implementation, which runs jobs on Mesos. Clients should first call start(),
- * then submit tasks through the runTasks method.
+ * The main TaskScheduler implementation, which runs tasks on Mesos. Clients should first call
+ * start(), then submit task sets through the runTasks method.
*/
-private class MesosScheduler(
+class MesosScheduler(
sc: SparkContext,
master: String,
frameworkName: String)
- extends MScheduler
- with DAGScheduler
+ extends TaskScheduler
+ with MScheduler
with Logging {
-
+
// Environment variables to pass to our executors
val ENV_VARS_TO_SEND_TO_EXECUTORS = Array(
"SPARK_MEM",
@@ -49,55 +52,60 @@ private class MesosScheduler(
}
}
+ // How often to check for speculative tasks
+ val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
+
// Lock used to wait for scheduler to be registered
- private var isRegistered = false
- private val registeredLock = new Object()
+ var isRegistered = false
+ val registeredLock = new Object()
- private val activeJobs = new HashMap[Int, Job]
- private var activeJobsQueue = new ArrayBuffer[Job]
+ val activeTaskSets = new HashMap[String, TaskSetManager]
+ var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
- private val taskIdToJobId = new HashMap[String, Int]
- private val taskIdToSlaveId = new HashMap[String, String]
- private val jobTasks = new HashMap[Int, HashSet[String]]
+ val taskIdToTaskSetId = new HashMap[String, String]
+ val taskIdToSlaveId = new HashMap[String, String]
+ val taskSetTaskIds = new HashMap[String, HashSet[String]]
- // Incrementing job and task IDs
- private var nextJobId = 0
- private var nextTaskId = 0
+ // Incrementing Mesos task IDs
+ var nextTaskId = 0
// Driver for talking to Mesos
var driver: SchedulerDriver = null
- // Which nodes we have executors on
- private val slavesWithExecutors = new HashSet[String]
+ // Which hosts in the cluster are alive (contains hostnames)
+ val hostsAlive = new HashSet[String]
+
+ // Which slave IDs we have executors on
+ val slaveIdsWithExecutors = new HashSet[String]
+
+ val slaveIdToHost = new HashMap[String, String]
// JAR server, if any JARs were added by the user to the SparkContext
var jarServer: HttpServer = null
// URIs of JARs to pass to executor
var jarUris: String = ""
-
+
// Create an ExecutorInfo for our tasks
val executorInfo = createExecutorInfo()
- // Sorts jobs in reverse order of run ID for use in our priority queue (so lower IDs run first)
- private val jobOrdering = new Ordering[Job] {
- override def compare(j1: Job, j2: Job): Int = j2.runId - j1.runId
- }
-
- def newJobId(): Int = this.synchronized {
- val id = nextJobId
- nextJobId += 1
- return id
+ // Listener object to pass upcalls into
+ var listener: TaskSchedulerListener = null
+
+ val mapOutputTracker = SparkEnv.get.mapOutputTracker
+
+ override def setListener(listener: TaskSchedulerListener) {
+ this.listener = listener
}
def newTaskId(): TaskID = {
- val id = "" + nextTaskId;
- nextTaskId += 1;
- return TaskID.newBuilder().setValue(id).build()
+ val id = TaskID.newBuilder().setValue("" + nextTaskId).build()
+ nextTaskId += 1
+ return id
}
override def start() {
- new Thread("Spark scheduler") {
+ new Thread("MesosScheduler driver") {
setDaemon(true)
override def run {
val sched = MesosScheduler.this
@@ -110,12 +118,27 @@ private class MesosScheduler(
case e: Exception => logError("driver.run() failed", e)
}
}
- }.start
+ }.start()
+ if (System.getProperty("spark.speculation", "false") == "true") {
+ new Thread("MesosScheduler speculation check") {
+ setDaemon(true)
+ override def run {
+ waitForRegister()
+ while (true) {
+ try {
+ Thread.sleep(SPECULATION_INTERVAL)
+ } catch { case e: InterruptedException => {} }
+ checkSpeculatableTasks()
+ }
+ }
+ }.start()
+ }
}
def createExecutorInfo(): ExecutorInfo = {
val sparkHome = sc.getSparkHome match {
- case Some(path) => path
+ case Some(path) =>
+ path
case None =>
throw new SparkException("Spark home is not set; set it through the spark.home system " +
"property, the SPARK_HOME environment variable or the SparkContext constructor")
@@ -151,27 +174,26 @@ private class MesosScheduler(
.build()
}
- def submitTasks(tasks: Seq[Task[_]], runId: Int) {
- logInfo("Got a job with " + tasks.size + " tasks")
+ def submitTasks(taskSet: TaskSet) {
+ val tasks = taskSet.tasks
+ logInfo("Adding task set " + taskSet.id + " with " + tasks.size + " tasks")
waitForRegister()
this.synchronized {
- val jobId = newJobId()
- val myJob = new SimpleJob(this, tasks, runId, jobId)
- activeJobs(jobId) = myJob
- activeJobsQueue += myJob
- logInfo("Adding job with ID " + jobId)
- jobTasks(jobId) = HashSet.empty[String]
+ val manager = new TaskSetManager(this, taskSet)
+ activeTaskSets(taskSet.id) = manager
+ activeTaskSetsQueue += manager
+ taskSetTaskIds(taskSet.id) = new HashSet()
}
- driver.reviveOffers();
+ reviveOffers();
}
- def jobFinished(job: Job) {
+ def taskSetFinished(manager: TaskSetManager) {
this.synchronized {
- activeJobs -= job.jobId
- activeJobsQueue -= job
- taskIdToJobId --= jobTasks(job.jobId)
- taskIdToSlaveId --= jobTasks(job.jobId)
- jobTasks.remove(job.jobId)
+ activeTaskSets -= manager.taskSet.id
+ activeTaskSetsQueue -= manager
+ taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
+ taskIdToSlaveId --= taskSetTaskIds(manager.taskSet.id)
+ taskSetTaskIds.remove(manager.taskSet.id)
}
}
@@ -196,33 +218,40 @@ private class MesosScheduler(
override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {}
/**
- * Method called by Mesos to offer resources on slaves. We resond by asking our active jobs for
- * tasks in FIFO order. We fill each node with tasks in a round-robin manner so that tasks are
- * balanced across the cluster.
+ * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets
+ * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that
+ * tasks are balanced across the cluster.
*/
override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
synchronized {
- val tasks = offers.map(o => new JArrayList[TaskInfo])
+ // Mark each slave as alive and remember its hostname
+ for (o <- offers) {
+ slaveIdToHost(o.getSlaveId.getValue) = o.getHostname
+ hostsAlive += o.getHostname
+ }
+ // Build a list of tasks to assign to each slave
+ val tasks = offers.map(o => new JArrayList[MTaskInfo])
val availableCpus = offers.map(o => getResource(o.getResourcesList(), "cpus"))
val enoughMem = offers.map(o => {
val mem = getResource(o.getResourcesList(), "mem")
val slaveId = o.getSlaveId.getValue
- mem >= EXECUTOR_MEMORY || slavesWithExecutors.contains(slaveId)
+ mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId)
})
var launchedTask = false
- for (job <- activeJobsQueue.sorted(jobOrdering)) {
+ for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
do {
launchedTask = false
for (i <- 0 until offers.size if enoughMem(i)) {
- job.slaveOffer(offers(i), availableCpus(i)) match {
+ val sid = offers(i).getSlaveId.getValue
+ val host = offers(i).getHostname
+ manager.slaveOffer(sid, host, availableCpus(i)) match {
case Some(task) =>
tasks(i).add(task)
val tid = task.getTaskId.getValue
- val sid = offers(i).getSlaveId.getValue
- taskIdToJobId(tid) = job.jobId
- jobTasks(job.jobId) += tid
+ taskIdToTaskSetId(tid) = manager.taskSet.id
+ taskSetTaskIds(manager.taskSet.id) += tid
taskIdToSlaveId(tid) = sid
- slavesWithExecutors += sid
+ slaveIdsWithExecutors += sid
availableCpus(i) -= getResource(task.getResourcesList(), "cpus")
launchedTask = true
@@ -256,53 +285,74 @@ private class MesosScheduler(
}
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
- var jobToUpdate: Option[Job] = None
+ val tid = status.getTaskId.getValue
+ var taskSetToUpdate: Option[TaskSetManager] = None
+ var failedHost: Option[String] = None
+ var taskFailed = false
synchronized {
try {
- val tid = status.getTaskId.getValue
- if (status.getState == TaskState.TASK_LOST
- && taskIdToSlaveId.contains(tid)) {
+ if (status.getState == TaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) {
// We lost the executor on this slave, so remember that it's gone
- slavesWithExecutors -= taskIdToSlaveId(tid)
+ val slaveId = taskIdToSlaveId(tid)
+ val host = slaveIdToHost(slaveId)
+ if (hostsAlive.contains(host)) {
+ slaveIdsWithExecutors -= slaveId
+ hostsAlive -= host
+ activeTaskSetsQueue.foreach(_.hostLost(host))
+ failedHost = Some(host)
+ }
}
- taskIdToJobId.get(tid) match {
- case Some(jobId) =>
- if (activeJobs.contains(jobId)) {
- jobToUpdate = Some(activeJobs(jobId))
+ taskIdToTaskSetId.get(tid) match {
+ case Some(taskSetId) =>
+ if (activeTaskSets.contains(taskSetId)) {
+ //activeTaskSets(taskSetId).statusUpdate(status)
+ taskSetToUpdate = Some(activeTaskSets(taskSetId))
}
if (isFinished(status.getState)) {
- taskIdToJobId.remove(tid)
- if (jobTasks.contains(jobId)) {
- jobTasks(jobId) -= tid
+ taskIdToTaskSetId.remove(tid)
+ if (taskSetTaskIds.contains(taskSetId)) {
+ taskSetTaskIds(taskSetId) -= tid
}
taskIdToSlaveId.remove(tid)
}
+ if (status.getState == TaskState.TASK_FAILED) {
+ taskFailed = true
+ }
case None =>
- logInfo("Ignoring update from TID " + tid + " because its job is gone")
+ logInfo("Ignoring update from TID " + tid + " because its task set is gone")
}
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
}
}
- for (j <- jobToUpdate) {
- j.statusUpdate(status)
+ // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock
+ if (taskSetToUpdate != None) {
+ taskSetToUpdate.get.statusUpdate(status)
+ }
+ if (failedHost != None) {
+ listener.hostLost(failedHost.get)
+ reviveOffers();
+ }
+ if (taskFailed) {
+ // Also revive offers if a task had failed for some reason other than host lost
+ reviveOffers()
}
}
override def error(d: SchedulerDriver, message: String) {
logError("Mesos error: " + message)
synchronized {
- if (activeJobs.size > 0) {
- // Have each job throw a SparkException with the error
- for ((jobId, activeJob) <- activeJobs) {
+ if (activeTaskSets.size > 0) {
+ // Have each task set throw a SparkException with the error
+ for ((taskSetId, manager) <- activeTaskSets) {
try {
- activeJob.error(message)
+ manager.error(message)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
}
} else {
- // No jobs are active but we still got an error. Just exit since this
+ // No task sets are active but we still got an error. Just exit since this
// must mean the error is during registration.
// It might be good to do something smarter here in the future.
System.exit(1)
@@ -373,41 +423,68 @@ private class MesosScheduler(
return Utils.serialize(props.toArray)
}
- override def frameworkMessage(
- d: SchedulerDriver,
- e: ExecutorID,
- s: SlaveID,
- b: Array[Byte]) {}
+ override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
override def slaveLost(d: SchedulerDriver, s: SlaveID) {
- slavesWithExecutors.remove(s.getValue)
+ var failedHost: Option[String] = None
+ synchronized {
+ val slaveId = s.getValue
+ val host = slaveIdToHost(slaveId)
+ if (hostsAlive.contains(host)) {
+ slaveIdsWithExecutors -= slaveId
+ hostsAlive -= host
+ activeTaskSetsQueue.foreach(_.hostLost(host))
+ failedHost = Some(host)
+ }
+ }
+ if (failedHost != None) {
+ listener.hostLost(failedHost.get)
+ reviveOffers();
+ }
}
override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) {
- slavesWithExecutors.remove(s.getValue)
+ logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue))
+ slaveLost(d, s)
}
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
+
+ // Check for speculatable tasks in all our active jobs.
+ def checkSpeculatableTasks() {
+ var shouldRevive = false
+ synchronized {
+ for (ts <- activeTaskSetsQueue) {
+ shouldRevive |= ts.checkSpeculatableTasks()
+ }
+ }
+ if (shouldRevive) {
+ reviveOffers()
+ }
+ }
+
+ def reviveOffers() {
+ driver.reviveOffers()
+ }
}
object MesosScheduler {
/**
- * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
- * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM
+ * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
+ * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM
* environment variable.
*/
def memoryStringToMb(str: String): Int = {
val lower = str.toLowerCase
if (lower.endsWith("k")) {
- (lower.substring(0, lower.length - 1).toLong / 1024).toInt
+ (lower.substring(0, lower.length-1).toLong / 1024).toInt
} else if (lower.endsWith("m")) {
- lower.substring(0, lower.length - 1).toInt
+ lower.substring(0, lower.length-1).toInt
} else if (lower.endsWith("g")) {
- lower.substring(0, lower.length - 1).toInt * 1024
+ lower.substring(0, lower.length-1).toInt * 1024
} else if (lower.endsWith("t")) {
- lower.substring(0, lower.length - 1).toInt * 1024 * 1024
- } else {
- // no suffix, so it's just a number in bytes
+ lower.substring(0, lower.length-1).toInt * 1024 * 1024
+ } else {// no suffix, so it's just a number in bytes
(lower.toLong / 1024 / 1024).toInt
}
}
diff --git a/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala b/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala
new file mode 100644
index 0000000000..af2f80ea66
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala
@@ -0,0 +1,32 @@
+package spark.scheduler.mesos
+
+/**
+ * Information about a running task attempt.
+ */
+class TaskInfo(val taskId: String, val index: Int, val launchTime: Long, val host: String) {
+ var finishTime: Long = 0
+ var failed = false
+
+ def markSuccessful(time: Long = System.currentTimeMillis) {
+ finishTime = time
+ }
+
+ def markFailed(time: Long = System.currentTimeMillis) {
+ finishTime = time
+ failed = true
+ }
+
+ def finished: Boolean = finishTime != 0
+
+ def successful: Boolean = finished && !failed
+
+ def duration: Long = {
+ if (!finished) {
+ throw new UnsupportedOperationException("duration() called on unfinished tasks")
+ } else {
+ finishTime - launchTime
+ }
+ }
+
+ def timeRunning(currentTime: Long): Long = currentTime - launchTime
+}
diff --git a/core/src/main/scala/spark/SimpleJob.scala b/core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala
index 01c7efff1e..535c17d9d4 100644
--- a/core/src/main/scala/spark/SimpleJob.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala
@@ -1,28 +1,32 @@
-package spark
+package spark.scheduler.mesos
+import java.util.Arrays
import java.util.{HashMap => JHashMap}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.math.max
+import scala.math.min
import com.google.protobuf.ByteString
import org.apache.mesos._
-import org.apache.mesos.Protos._
+import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
+
+import spark._
+import spark.scheduler._
/**
- * A Job that runs a set of tasks with no interdependencies.
+ * Schedules the tasks within a single TaskSet in the MesosScheduler.
*/
-class SimpleJob(
+class TaskSetManager(
sched: MesosScheduler,
- tasksSeq: Seq[Task[_]],
- runId: Int,
- jobId: Int)
- extends Job(runId, jobId)
- with Logging {
+ val taskSet: TaskSet)
+ extends Logging {
// Maximum time to wait to run a task in a preferred location (in ms)
- val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "5000").toLong
+ val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
// CPUs to request per task
val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
@@ -30,18 +34,20 @@ class SimpleJob(
// Maximum times a task is allowed to fail before failing the job
val MAX_TASK_FAILURES = 4
+ // Quantile of tasks at which to start speculation
+ val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
+ val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
+
// Serializer for closures and tasks.
val ser = SparkEnv.get.closureSerializer.newInstance()
- val callingThread = Thread.currentThread
- val tasks = tasksSeq.toArray
+ val priority = taskSet.priority
+ val tasks = taskSet.tasks
val numTasks = tasks.length
- val launched = new Array[Boolean](numTasks)
+ val copiesRunning = new Array[Int](numTasks)
val finished = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
- val tidToIndex = HashMap[String, Int]()
-
- var tasksLaunched = 0
+ val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
var tasksFinished = 0
// Last time when we launched a preferred task (for delay scheduling)
@@ -62,6 +68,13 @@ class SimpleJob(
// List containing all pending tasks (also used as a stack, as above)
val allPendingTasks = new ArrayBuffer[Int]
+ // Tasks that can be specualted. Since these will be a small fraction of total
+ // tasks, we'll just hold them in a HaskSet.
+ val speculatableTasks = new HashSet[Int]
+
+ // Task index, start and finish time for each task attempt (indexed by task ID)
+ val taskInfos = new HashMap[String, TaskInfo]
+
// Did the job fail?
var failed = false
var causeOfFailure = ""
@@ -76,6 +89,12 @@ class SimpleJob(
// exceptions automatically.
val recentExceptions = HashMap[String, (Int, Long)]()
+ // Figure out the current map output tracker generation and set it on all tasks
+ val generation = sched.mapOutputTracker.getGeneration
+ for (t <- tasks) {
+ t.generation = generation
+ }
+
// Add all our tasks to the pending lists. We do this in reverse order
// of task index so that tasks with low indices get launched first.
for (i <- (0 until numTasks).reverse) {
@@ -84,7 +103,7 @@ class SimpleJob(
// Add a task to all the pending-task lists that it should be on.
def addPendingTask(index: Int) {
- val locations = tasks(index).preferredLocations
+ val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
if (locations.size == 0) {
pendingTasksWithNoPrefs += index
} else {
@@ -110,13 +129,37 @@ class SimpleJob(
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
- if (!launched(index) && !finished(index)) {
+ if (copiesRunning(index) == 0 && !finished(index)) {
return Some(index)
}
}
return None
}
+ // Return a speculative task for a given host if any are available. The task should not have an
+ // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
+ // task must have a preference for this host (or no preferred locations at all).
+ def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
+ speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
+ val localTask = speculatableTasks.find { index =>
+ val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
+ val attemptLocs = taskAttempts(index).map(_.host)
+ (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host)
+ }
+ if (localTask != None) {
+ speculatableTasks -= localTask.get
+ return localTask
+ }
+ if (!localOnly && speculatableTasks.size > 0) {
+ val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.host).contains(host))
+ if (nonLocalTask != None) {
+ speculatableTasks -= nonLocalTask.get
+ return nonLocalTask
+ }
+ }
+ return None
+ }
+
// Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well.
def findTask(host: String, localOnly: Boolean): Option[Int] = {
@@ -129,10 +172,13 @@ class SimpleJob(
return noPrefTask
}
if (!localOnly) {
- return findTaskFromList(allPendingTasks) // Look for non-local task
- } else {
- return None
+ val nonLocalTask = findTaskFromList(allPendingTasks)
+ if (nonLocalTask != None) {
+ return nonLocalTask
+ }
}
+ // Finally, if all else has failed, find a speculative task
+ return findSpeculativeTask(host, localOnly)
}
// Does a host count as a preferred location for a task? This is true if
@@ -144,11 +190,11 @@ class SimpleJob(
}
// Respond to an offer of a single slave from the scheduler by finding a task
- def slaveOffer(offer: Offer, availableCpus: Double): Option[TaskInfo] = {
- if (tasksLaunched < numTasks && availableCpus >= CPUS_PER_TASK) {
+ def slaveOffer(slaveId: String, host: String, availableCpus: Double): Option[MTaskInfo] = {
+ if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
val time = System.currentTimeMillis
- val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
- val host = offer.getHostname
+ var localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
+
findTask(host, localOnly) match {
case Some(index) => {
// Found a task; do some bookkeeping and return a Mesos task for it
@@ -156,17 +202,17 @@ class SimpleJob(
val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch
val preferred = isPreferredLocation(task, host)
- val prefStr = if(preferred) "preferred" else "non-preferred"
- val message =
- "Starting task %d:%d as TID %s on slave %s: %s (%s)".format(
- jobId, index, taskId.getValue, offer.getSlaveId.getValue, host, prefStr)
- logInfo(message)
+ val prefStr = if (preferred) "preferred" else "non-preferred"
+ logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
+ taskSet.id, index, taskId.getValue, slaveId, host, prefStr))
// Do various bookkeeping
- tidToIndex(taskId.getValue) = index
- launched(index) = true
- tasksLaunched += 1
- if (preferred)
+ copiesRunning(index) += 1
+ val info = new TaskInfo(taskId.getValue, index, time, host)
+ taskInfos(taskId.getValue) = info
+ taskAttempts(index) = info :: taskAttempts(index)
+ if (preferred) {
lastPreferredLaunchTime = time
+ }
// Create and return the Mesos task object
val cpuRes = Resource.newBuilder()
.setName("cpus")
@@ -178,13 +224,13 @@ class SimpleJob(
val serializedTask = ser.serialize(task)
val timeTaken = System.currentTimeMillis - startTime
- logInfo("Size of task %d:%d is %d bytes and took %d ms to serialize by %s"
- .format(jobId, index, serializedTask.size, timeTaken, ser.getClass.getName))
+ logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
+ taskSet.id, index, serializedTask.limit, timeTaken))
- val taskName = "task %d:%d".format(jobId, index)
- return Some(TaskInfo.newBuilder()
+ val taskName = "task %s:%d".format(taskSet.id, index)
+ return Some(MTaskInfo.newBuilder()
.setTaskId(taskId)
- .setSlaveId(offer.getSlaveId)
+ .setSlaveId(SlaveID.newBuilder().setValue(slaveId))
.setExecutor(sched.executorInfo)
.setName(taskName)
.addResources(cpuRes)
@@ -213,18 +259,21 @@ class SimpleJob(
def taskFinished(status: TaskStatus) {
val tid = status.getTaskId.getValue
- val index = tidToIndex(tid)
+ val info = taskInfos(tid)
+ val index = info.index
+ info.markSuccessful()
if (!finished(index)) {
tasksFinished += 1
- logInfo("Finished TID %s (progress: %d/%d)".format(tid, tasksFinished, numTasks))
- // Deserialize task result
- val result = ser.deserialize[TaskResult[_]](
- status.getData.toByteArray, getClass.getClassLoader)
- sched.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
+ logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
+ tid, info.duration, tasksFinished, numTasks))
+ // Deserialize task result and pass it to the scheduler
+ val result = ser.deserialize[TaskResult[_]](status.getData.asReadOnlyByteBuffer)
+ sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
// Mark finished and stop if we've finished all the tasks
finished(index) = true
- if (tasksFinished == numTasks)
- sched.jobFinished(this)
+ if (tasksFinished == numTasks) {
+ sched.taskSetFinished(this)
+ }
} else {
logInfo("Ignoring task-finished event for TID " + tid +
" because task " + index + " is already finished")
@@ -233,30 +282,29 @@ class SimpleJob(
def taskLost(status: TaskStatus) {
val tid = status.getTaskId.getValue
- val index = tidToIndex(tid)
+ val info = taskInfos(tid)
+ val index = info.index
+ info.markFailed()
if (!finished(index)) {
- logInfo("Lost TID %s (task %d:%d)".format(tid, jobId, index))
- launched(index) = false
- tasksLaunched -= 1
+ logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
+ copiesRunning(index) -= 1
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
if (status.getData != null && status.getData.size > 0) {
- val reason = ser.deserialize[TaskEndReason](
- status.getData.toByteArray, getClass.getClassLoader)
+ val reason = ser.deserialize[TaskEndReason](status.getData.asReadOnlyByteBuffer)
reason match {
case fetchFailed: FetchFailed =>
- logInfo("Loss was due to fetch failure from " + fetchFailed.serverUri)
- sched.taskEnded(tasks(index), fetchFailed, null, null)
+ logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+ sched.listener.taskEnded(tasks(index), fetchFailed, null, null)
finished(index) = true
tasksFinished += 1
- if (tasksFinished == numTasks) {
- sched.jobFinished(this)
- }
+ sched.taskSetFinished(this)
return
+
case ef: ExceptionFailure =>
val key = ef.exception.toString
val now = System.currentTimeMillis
- val (printFull, dupCount) =
+ val (printFull, dupCount) = {
if (recentExceptions.contains(key)) {
val (dupCount, printTime) = recentExceptions(key)
if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
@@ -267,32 +315,28 @@ class SimpleJob(
(false, dupCount + 1)
}
} else {
- recentExceptions += Tuple(key, (0, now))
+ recentExceptions(key) = (0, now)
(true, 0)
}
-
+ }
if (printFull) {
- val stackTrace =
- for (elem <- ef.exception.getStackTrace)
- yield "\tat %s".format(elem.toString)
- logInfo("Loss was due to %s\n%s".format(
- ef.exception.toString, stackTrace.mkString("\n")))
+ val locs = ef.exception.getStackTrace.map(loc => "\tat %s".format(loc.toString))
+ logInfo("Loss was due to %s\n%s".format(ef.exception.toString, locs.mkString("\n")))
} else {
- logInfo("Loss was due to %s [duplicate %d]".format(
- ef.exception.toString, dupCount))
+ logInfo("Loss was due to %s [duplicate %d]".format(ef.exception.toString, dupCount))
}
+
case _ => {}
}
}
- // On other failures, re-enqueue the task as pending for a max number of retries
+ // On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
- // Count attempts only on FAILED and LOST state (not on KILLED)
- if (status.getState == TaskState.TASK_FAILED ||
- status.getState == TaskState.TASK_LOST) {
+ // Count failed attempts only on FAILED and LOST state (not on KILLED)
+ if (status.getState == TaskState.TASK_FAILED || status.getState == TaskState.TASK_LOST) {
numFailures(index) += 1
if (numFailures(index) > MAX_TASK_FAILURES) {
- logError("Task %d:%d failed more than %d times; aborting job".format(
- jobId, index, MAX_TASK_FAILURES))
+ logError("Task %s:%d failed more than %d times; aborting job".format(
+ taskSet.id, index, MAX_TASK_FAILURES))
abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES))
}
}
@@ -311,6 +355,71 @@ class SimpleJob(
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.jobFinished(this)
+ sched.taskSetFinished(this)
+ }
+
+ def hostLost(hostname: String) {
+ logInfo("Re-queueing tasks for " + hostname)
+ // If some task has preferred locations only on hostname, put it in the no-prefs list
+ // to avoid the wait from delay scheduling
+ for (index <- getPendingTasksForHost(hostname)) {
+ val newLocs = tasks(index).preferredLocations.toSet & sched.hostsAlive
+ if (newLocs.isEmpty) {
+ pendingTasksWithNoPrefs += index
+ }
+ }
+ // Also re-enqueue any tasks that ran on the failed host if this is a shuffle map stage
+ if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+ for ((tid, info) <- taskInfos if info.host == hostname) {
+ val index = taskInfos(tid).index
+ if (finished(index)) {
+ finished(index) = false
+ copiesRunning(index) -= 1
+ tasksFinished -= 1
+ addPendingTask(index)
+ // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
+ // stage finishes when a total of tasks.size tasks finish.
+ sched.listener.taskEnded(tasks(index), Resubmitted, null, null)
+ }
+ }
+ }
+ }
+
+ /**
+ * Check for tasks to be speculated and return true if there are any. This is called periodically
+ * by the MesosScheduler.
+ *
+ * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
+ * we don't scan the whole task set. It might also help to make this sorted by launch time.
+ */
+ def checkSpeculatableTasks(): Boolean = {
+ // Can't speculate if we only have one task, or if all tasks have finished.
+ if (numTasks == 1 || tasksFinished == numTasks) {
+ return false
+ }
+ var foundTasks = false
+ val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
+ logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
+ if (tasksFinished >= minFinishedForSpeculation) {
+ val time = System.currentTimeMillis()
+ val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
+ Arrays.sort(durations)
+ val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
+ val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
+ // TODO: Threshold should also look at standard deviation of task durations and have a lower
+ // bound based on that.
+ logDebug("Task length threshold for speculation: " + threshold)
+ for ((tid, info) <- taskInfos) {
+ val index = info.index
+ if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
+ !speculatableTasks.contains(index)) {
+ logInfo("Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
+ taskSet.id, index, info.host, threshold))
+ speculatableTasks += index
+ foundTasks = true
+ }
+ }
+ }
+ return foundTasks
}
}
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
new file mode 100644
index 0000000000..367c79dd76
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -0,0 +1,507 @@
+package spark.storage
+
+import java.io._
+import java.nio._
+import java.nio.channels.FileChannel.MapMode
+import java.util.{HashMap => JHashMap}
+import java.util.LinkedHashMap
+import java.util.UUID
+import java.util.Collections
+
+import scala.actors._
+import scala.actors.Actor._
+import scala.actors.Future
+import scala.actors.Futures.future
+import scala.actors.remote._
+import scala.actors.remote.RemoteActor._
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions._
+
+import it.unimi.dsi.fastutil.io._
+
+import spark.CacheTracker
+import spark.Logging
+import spark.Serializer
+import spark.SizeEstimator
+import spark.SparkEnv
+import spark.SparkException
+import spark.Utils
+import spark.network._
+
+class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
+ def this() = this(null, 0)
+
+ override def writeExternal(out: ObjectOutput) {
+ out.writeUTF(ip)
+ out.writeInt(port)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ ip = in.readUTF()
+ port = in.readInt()
+ }
+
+ override def toString = "BlockManagerId(" + ip + ", " + port + ")"
+
+ override def hashCode = ip.hashCode * 41 + port
+
+ override def equals(that: Any) = that match {
+ case id: BlockManagerId => port == id.port && ip == id.ip
+ case _ => false
+ }
+}
+
+
+case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message)
+
+
+class BlockLocker(numLockers: Int) {
+ private val hashLocker = Array.fill(numLockers)(new Object())
+
+ def getLock(blockId: String): Object = {
+ return hashLocker(Math.abs(blockId.hashCode % numLockers))
+ }
+}
+
+
+/**
+ * A start towards a block manager class. This will eventually be used for both RDD persistence
+ * and shuffle outputs.
+ *
+ * TODO: Should make the communication with Master or Peers code more robust and log friendly.
+ */
+class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging {
+
+ private val NUM_LOCKS = 337
+ private val locker = new BlockLocker(NUM_LOCKS)
+
+ private val storageLevels = Collections.synchronizedMap(new JHashMap[String, StorageLevel])
+
+ private val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
+ private val diskStore: BlockStore = new DiskStore(this,
+ System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
+
+ val connectionManager = new ConnectionManager(0)
+
+ val connectionManagerId = connectionManager.id
+ val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port)
+
+ // TODO(Haoyuan): This will be removed after cacheTracker is removed from the code base.
+ var cacheTracker: CacheTracker = null
+
+ initLogging()
+
+ initialize()
+
+ /**
+ * Construct a BlockManager with a memory limit set based on system properties.
+ */
+ def this(serializer: Serializer) =
+ this(BlockManager.getMaxMemoryFromSystemProperties(), serializer)
+
+ /**
+ * Initialize the BlockManager. Register to the BlockManagerMaster, and start the
+ * BlockManagerWorker actor.
+ */
+ def initialize() {
+ BlockManagerMaster.mustRegisterBlockManager(
+ RegisterBlockManager(blockManagerId, maxMemory, maxMemory))
+ BlockManagerWorker.startBlockManagerWorker(this)
+ }
+
+ /**
+ * Get locations of the block.
+ */
+ def getLocations(blockId: String): Seq[String] = {
+ val startTimeMs = System.currentTimeMillis
+ var managers: Array[BlockManagerId] = BlockManagerMaster.mustGetLocations(GetLocations(blockId))
+ val locations = managers.map((manager: BlockManagerId) => { manager.ip }).toSeq
+ logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs))
+ return locations
+ }
+
+ /**
+ * Get locations of an array of blocks
+ */
+ def getLocationsMultipleBlockIds(blockIds: Array[String]): Array[Seq[String]] = {
+ val startTimeMs = System.currentTimeMillis
+ val locations = BlockManagerMaster.mustGetLocationsMultipleBlockIds(
+ GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray
+ logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
+ return locations
+ }
+
+ def getLocal(blockId: String): Option[Iterator[Any]] = {
+ logDebug("Getting block " + blockId)
+ locker.getLock(blockId).synchronized {
+
+ // Check storage level of block
+ val level = storageLevels.get(blockId)
+ if (level != null) {
+ logDebug("Level for block " + blockId + " is " + level + " on local machine")
+
+ // Look for the block in memory
+ if (level.useMemory) {
+ logDebug("Getting block " + blockId + " from memory")
+ memoryStore.getValues(blockId) match {
+ case Some(iterator) => {
+ logDebug("Block " + blockId + " found in memory")
+ return Some(iterator)
+ }
+ case None => {
+ logDebug("Block " + blockId + " not found in memory")
+ }
+ }
+ } else {
+ logDebug("Not getting block " + blockId + " from memory")
+ }
+
+ // Look for block in disk
+ if (level.useDisk) {
+ logDebug("Getting block " + blockId + " from disk")
+ diskStore.getValues(blockId) match {
+ case Some(iterator) => {
+ logDebug("Block " + blockId + " found in disk")
+ return Some(iterator)
+ }
+ case None => {
+ throw new Exception("Block " + blockId + " not found in disk")
+ return None
+ }
+ }
+ } else {
+ logDebug("Not getting block " + blockId + " from disk")
+ }
+
+ } else {
+ logDebug("Level for block " + blockId + " not found")
+ }
+ }
+ return None
+ }
+
+ def getRemote(blockId: String): Option[Iterator[Any]] = {
+ // Get locations of block
+ val locations = BlockManagerMaster.mustGetLocations(GetLocations(blockId))
+
+ // Get block from remote locations
+ for (loc <- locations) {
+ val data = BlockManagerWorker.syncGetBlock(
+ GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port))
+ if (data != null) {
+ logDebug("Data is not null: " + data)
+ return Some(dataDeserialize(data))
+ }
+ logDebug("Data is null")
+ }
+ logDebug("Data not found")
+ return None
+ }
+
+ /**
+ * Read a block from the block manager.
+ */
+ def get(blockId: String): Option[Iterator[Any]] = {
+ getLocal(blockId).orElse(getRemote(blockId))
+ }
+
+ /**
+ * Read many blocks from block manager using their BlockManagerIds.
+ */
+ def get(blocksByAddress: Seq[(BlockManagerId, Seq[String])]): HashMap[String, Option[Iterator[Any]]] = {
+ logDebug("Getting " + blocksByAddress.map(_._2.size).sum + " blocks")
+ var startTime = System.currentTimeMillis
+ val blocks = new HashMap[String,Option[Iterator[Any]]]()
+ val localBlockIds = new ArrayBuffer[String]()
+ val remoteBlockIds = new ArrayBuffer[String]()
+ val remoteBlockIdsPerLocation = new HashMap[BlockManagerId, Seq[String]]()
+
+ // Split local and remote blocks
+ for ((address, blockIds) <- blocksByAddress) {
+ if (address == blockManagerId) {
+ localBlockIds ++= blockIds
+ } else {
+ remoteBlockIds ++= blockIds
+ remoteBlockIdsPerLocation(address) = blockIds
+ }
+ }
+
+ // Start getting remote blocks
+ val remoteBlockFutures = remoteBlockIdsPerLocation.toSeq.map { case (bmId, bIds) =>
+ val cmId = ConnectionManagerId(bmId.ip, bmId.port)
+ val blockMessages = bIds.map(bId => BlockMessage.fromGetBlock(GetBlock(bId)))
+ val blockMessageArray = new BlockMessageArray(blockMessages)
+ val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
+ (cmId, future)
+ }
+ logDebug("Started remote gets for " + remoteBlockIds.size + " blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+
+ // Get the local blocks while remote blocks are being fetched
+ startTime = System.currentTimeMillis
+ localBlockIds.foreach(id => {
+ get(id) match {
+ case Some(block) => {
+ blocks.update(id, Some(block))
+ logDebug("Got local block " + id)
+ }
+ case None => {
+ throw new BlockException(id, "Could not get block " + id + " from local machine")
+ }
+ }
+ })
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+
+ // wait for and gather all the remote blocks
+ for ((cmId, future) <- remoteBlockFutures) {
+ var count = 0
+ val oneBlockId = remoteBlockIdsPerLocation(new BlockManagerId(cmId.host, cmId.port)).first
+ future() match {
+ case Some(message) => {
+ val bufferMessage = message.asInstanceOf[BufferMessage]
+ val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
+ blockMessageArray.foreach(blockMessage => {
+ if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
+ throw new BlockException(oneBlockId, "Unexpected message received from " + cmId)
+ }
+ val buffer = blockMessage.getData()
+ val blockId = blockMessage.getId()
+ val block = dataDeserialize(buffer)
+ blocks.update(blockId, Some(block))
+ logDebug("Got remote block " + blockId + " in " + Utils.getUsedTimeMs(startTime))
+ count += 1
+ })
+ }
+ case None => {
+ throw new BlockException(oneBlockId, "Could not get blocks from " + cmId)
+ }
+ }
+ logDebug("Got remote " + count + " blocks from " + cmId.host + " in " + Utils.getUsedTimeMs(startTime) + " ms")
+ }
+
+ logDebug("Got all blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+ return blocks
+ }
+
+ /**
+ * Write a new block to the block manager.
+ */
+ def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean = true) {
+ if (!level.useDisk && !level.useMemory) {
+ throw new IllegalArgumentException("Storage level has neither useMemory nor useDisk set")
+ }
+
+ val startTimeMs = System.currentTimeMillis
+ var bytes: ByteBuffer = null
+
+ locker.getLock(blockId).synchronized {
+ logDebug("Put for block " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ + " to get into synchronized block")
+
+ // Check and warn if block with same id already exists
+ if (storageLevels.get(blockId) != null) {
+ logWarning("Block " + blockId + " already exists in local machine")
+ return
+ }
+
+ // Store the storage level
+ storageLevels.put(blockId, level)
+
+ if (level.useMemory && level.useDisk) {
+ // If saving to both memory and disk, then serialize only once
+ memoryStore.putValues(blockId, values, level) match {
+ case Left(newValues) =>
+ diskStore.putValues(blockId, newValues, level) match {
+ case Right(newBytes) => bytes = newBytes
+ case _ => throw new Exception("Unexpected return value")
+ }
+ case Right(newBytes) =>
+ bytes = newBytes
+ diskStore.putBytes(blockId, newBytes, level)
+ }
+ } else if (level.useMemory) {
+ // If only save to memory
+ memoryStore.putValues(blockId, values, level) match {
+ case Right(newBytes) => bytes = newBytes
+ case _ =>
+ }
+ } else {
+ // If only save to disk
+ diskStore.putValues(blockId, values, level) match {
+ case Right(newBytes) => bytes = newBytes
+ case _ => throw new Exception("Unexpected return value")
+ }
+ }
+
+ if (tellMaster) {
+ notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0))
+ logDebug("Put block " + blockId + " after notifying the master " + Utils.getUsedTimeMs(startTimeMs))
+ }
+ }
+
+ // Replicate block if required
+ if (level.replication > 1) {
+ if (bytes == null) {
+ bytes = dataSerialize(values) // serialize the block if not already done
+ }
+ replicate(blockId, bytes, level)
+ }
+
+ // TODO(Haoyuan): This code will be removed when CacheTracker is gone.
+ if (blockId.startsWith("rdd")) {
+ notifyTheCacheTracker(blockId)
+ }
+ logDebug("Put block " + blockId + " after notifying the CacheTracker " + Utils.getUsedTimeMs(startTimeMs))
+ }
+
+
+ def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
+ val startTime = System.currentTimeMillis
+ if (!level.useDisk && !level.useMemory) {
+ throw new IllegalArgumentException("Storage level has neither useMemory nor useDisk set")
+ } else if (level.deserialized) {
+ throw new IllegalArgumentException("Storage level cannot have deserialized when putBytes is used")
+ }
+ val replicationFuture = if (level.replication > 1) {
+ future {
+ replicate(blockId, bytes, level)
+ }
+ } else {
+ null
+ }
+
+ locker.getLock(blockId).synchronized {
+ logDebug("PutBytes for block " + blockId + " used " + Utils.getUsedTimeMs(startTime)
+ + " to get into synchronized block")
+ if (storageLevels.get(blockId) != null) {
+ logWarning("Block " + blockId + " already exists")
+ return
+ }
+ storageLevels.put(blockId, level)
+
+ if (level.useMemory) {
+ memoryStore.putBytes(blockId, bytes, level)
+ }
+ if (level.useDisk) {
+ diskStore.putBytes(blockId, bytes, level)
+ }
+ if (tellMaster) {
+ notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0))
+ }
+ }
+
+ if (blockId.startsWith("rdd")) {
+ notifyTheCacheTracker(blockId)
+ }
+
+ if (level.replication > 1) {
+ if (replicationFuture == null) {
+ throw new Exception("Unexpected")
+ }
+ replicationFuture()
+ }
+
+ val finishTime = System.currentTimeMillis
+ if (level.replication > 1) {
+ logDebug("PutBytes with replication took " + (finishTime - startTime) + " ms")
+ } else {
+ logDebug("PutBytes without replication took " + (finishTime - startTime) + " ms")
+ }
+
+ }
+
+ private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
+ val tLevel: StorageLevel =
+ new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
+ var peers: Array[BlockManagerId] = BlockManagerMaster.mustGetPeers(
+ GetPeers(blockManagerId, level.replication - 1))
+ for (peer: BlockManagerId <- peers) {
+ val start = System.nanoTime
+ logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is "
+ + data.array().length + " Bytes. To node: " + peer)
+ if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel),
+ new ConnectionManagerId(peer.ip, peer.port))) {
+ logError("Failed to call syncPutBlock to " + peer)
+ }
+ logDebug("Replicated BlockId " + blockId + " once used " +
+ (System.nanoTime - start) / 1e6 + " s; The size of the data is " +
+ data.array().length + " bytes.")
+ }
+ }
+
+ // TODO(Haoyuan): This code will be removed when CacheTracker is gone.
+ def notifyTheCacheTracker(key: String) {
+ val rddInfo = key.split(":")
+ val rddId: Int = rddInfo(1).toInt
+ val splitIndex: Int = rddInfo(2).toInt
+ val host = System.getProperty("spark.hostname", Utils.localHostName)
+ cacheTracker.notifyTheCacheTrackerFromBlockManager(spark.AddedToCache(rddId, splitIndex, host))
+ }
+
+ /**
+ * Read a block consisting of a single object.
+ */
+ def getSingle(blockId: String): Option[Any] = {
+ get(blockId).map(_.next)
+ }
+
+ /**
+ * Write a block consisting of a single object.
+ */
+ def putSingle(blockId: String, value: Any, level: StorageLevel) {
+ put(blockId, Iterator(value), level)
+ }
+
+ /**
+ * Drop block from memory (called when memory store has reached it limit)
+ */
+ def dropFromMemory(blockId: String) {
+ locker.getLock(blockId).synchronized {
+ val level = storageLevels.get(blockId)
+ if (level == null) {
+ logWarning("Block " + blockId + " cannot be removed from memory as it does not exist")
+ return
+ }
+ if (!level.useMemory) {
+ logWarning("Block " + blockId + " cannot be removed from memory as it is not in memory")
+ return
+ }
+ memoryStore.remove(blockId)
+ if (!level.useDisk) {
+ storageLevels.remove(blockId)
+ } else {
+ val newLevel = level.clone
+ newLevel.useMemory = false
+ storageLevels.remove(blockId)
+ storageLevels.put(blockId, newLevel)
+ }
+ }
+ }
+
+ def dataSerialize(values: Iterator[Any]): ByteBuffer = {
+ /*serializer.newInstance().serializeMany(values)*/
+ val byteStream = new FastByteArrayOutputStream(4096)
+ serializer.newInstance().serializeStream(byteStream).writeAll(values).close()
+ byteStream.trim()
+ ByteBuffer.wrap(byteStream.array)
+ }
+
+ def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = {
+ /*serializer.newInstance().deserializeMany(bytes)*/
+ val ser = serializer.newInstance()
+ return ser.deserializeStream(new FastByteArrayInputStream(bytes.array())).toIterator
+ }
+
+ private def notifyMaster(heartBeat: HeartBeat) {
+ BlockManagerMaster.mustHeartBeat(heartBeat)
+ }
+}
+
+object BlockManager extends Logging {
+ def getMaxMemoryFromSystemProperties(): Long = {
+ val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble
+ val bytes = (Runtime.getRuntime.totalMemory * memoryFraction).toLong
+ logInfo("Maximum memory to use: " + bytes)
+ bytes
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
new file mode 100644
index 0000000000..bd94c185e9
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -0,0 +1,516 @@
+package spark.storage
+
+import java.io._
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.util.Random
+
+import akka.actor._
+import akka.actor.Actor
+import akka.actor.Actor._
+import akka.util.duration._
+
+import spark.Logging
+import spark.Utils
+
+sealed trait ToBlockManagerMaster
+
+case class RegisterBlockManager(
+ blockManagerId: BlockManagerId,
+ maxMemSize: Long,
+ maxDiskSize: Long)
+ extends ToBlockManagerMaster
+
+class HeartBeat(
+ var blockManagerId: BlockManagerId,
+ var blockId: String,
+ var storageLevel: StorageLevel,
+ var deserializedSize: Long,
+ var size: Long)
+ extends ToBlockManagerMaster
+ with Externalizable {
+
+ def this() = this(null, null, null, 0, 0) // For deserialization only
+
+ override def writeExternal(out: ObjectOutput) {
+ blockManagerId.writeExternal(out)
+ out.writeUTF(blockId)
+ storageLevel.writeExternal(out)
+ out.writeInt(deserializedSize.toInt)
+ out.writeInt(size.toInt)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ blockManagerId = new BlockManagerId()
+ blockManagerId.readExternal(in)
+ blockId = in.readUTF()
+ storageLevel = new StorageLevel()
+ storageLevel.readExternal(in)
+ deserializedSize = in.readInt()
+ size = in.readInt()
+ }
+}
+
+object HeartBeat {
+ def apply(blockManagerId: BlockManagerId,
+ blockId: String,
+ storageLevel: StorageLevel,
+ deserializedSize: Long,
+ size: Long): HeartBeat = {
+ new HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size)
+ }
+
+
+ // For pattern-matching
+ def unapply(h: HeartBeat): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = {
+ Some((h.blockManagerId, h.blockId, h.storageLevel, h.deserializedSize, h.size))
+ }
+}
+
+case class GetLocations(
+ blockId: String)
+ extends ToBlockManagerMaster
+
+case class GetLocationsMultipleBlockIds(
+ blockIds: Array[String])
+ extends ToBlockManagerMaster
+
+case class GetPeers(
+ blockManagerId: BlockManagerId,
+ size: Int)
+ extends ToBlockManagerMaster
+
+case class RemoveHost(
+ host: String)
+ extends ToBlockManagerMaster
+
+class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging {
+ class BlockManagerInfo(
+ timeMs: Long,
+ maxMem: Long,
+ maxDisk: Long) {
+ private var lastSeenMs = timeMs
+ private var remainedMem = maxMem
+ private var remainedDisk = maxDisk
+ private val blocks = new HashMap[String, StorageLevel]
+
+ def updateLastSeenMs() {
+ lastSeenMs = System.currentTimeMillis() / 1000
+ }
+
+ def addBlock(blockId: String, storageLevel: StorageLevel, deserializedSize: Long, size: Long) =
+ synchronized {
+ updateLastSeenMs()
+
+ if (blocks.contains(blockId)) {
+ val oriLevel: StorageLevel = blocks(blockId)
+
+ if (oriLevel.deserialized) {
+ remainedMem += deserializedSize
+ }
+ if (oriLevel.useMemory) {
+ remainedMem += size
+ }
+ if (oriLevel.useDisk) {
+ remainedDisk += size
+ }
+ }
+
+ blocks += (blockId -> storageLevel)
+
+ if (storageLevel.deserialized) {
+ remainedMem -= deserializedSize
+ }
+ if (storageLevel.useMemory) {
+ remainedMem -= size
+ }
+ if (storageLevel.useDisk) {
+ remainedDisk -= size
+ }
+
+ if (!(storageLevel.deserialized || storageLevel.useMemory || storageLevel.useDisk)) {
+ blocks.remove(blockId)
+ }
+ }
+
+ def getLastSeenMs(): Long = {
+ return lastSeenMs
+ }
+
+ def getRemainedMem(): Long = {
+ return remainedMem
+ }
+
+ def getRemainedDisk(): Long = {
+ return remainedDisk
+ }
+
+ override def toString(): String = {
+ return "BlockManagerInfo " + timeMs + " " + remainedMem + " " + remainedDisk
+ }
+ }
+
+ private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerInfo]
+ private val blockIdMap = new HashMap[String, Pair[Int, HashSet[BlockManagerId]]]
+
+ initLogging()
+
+ def removeHost(host: String) {
+ logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.")
+ logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq)
+ val ip = host.split(":")(0)
+ val port = host.split(":")(1)
+ blockManagerInfo.remove(new BlockManagerId(ip, port.toInt))
+ logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq)
+ self.reply(true)
+ }
+
+ def receive = {
+ case RegisterBlockManager(blockManagerId, maxMemSize, maxDiskSize) =>
+ register(blockManagerId, maxMemSize, maxDiskSize)
+
+ case HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
+ heartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size)
+
+ case GetLocations(blockId) =>
+ getLocations(blockId)
+
+ case GetLocationsMultipleBlockIds(blockIds) =>
+ getLocationsMultipleBlockIds(blockIds)
+
+ case GetPeers(blockManagerId, size) =>
+ getPeers_Deterministic(blockManagerId, size)
+ /*getPeers(blockManagerId, size)*/
+
+ case RemoveHost(host) =>
+ removeHost(host)
+
+ case msg =>
+ logInfo("Got unknown msg: " + msg)
+ }
+
+ private def register(blockManagerId: BlockManagerId, maxMemSize: Long, maxDiskSize: Long) {
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " " + blockManagerId + " "
+ logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
+ logInfo("Got Register Msg from " + blockManagerId)
+ if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
+ logInfo("Got Register Msg from master node, don't register it")
+ } else {
+ blockManagerInfo += (blockManagerId -> new BlockManagerInfo(
+ System.currentTimeMillis() / 1000, maxMemSize, maxDiskSize))
+ }
+ logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs))
+ self.reply(true)
+ }
+
+ private def heartBeat(
+ blockManagerId: BlockManagerId,
+ blockId: String,
+ storageLevel: StorageLevel,
+ deserializedSize: Long,
+ size: Long) {
+
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " " + blockManagerId + " " + blockId + " "
+ logDebug("Got in heartBeat 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
+
+ if (blockId == null) {
+ blockManagerInfo(blockManagerId).updateLastSeenMs()
+ logDebug("Got in heartBeat 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
+ self.reply(true)
+ }
+
+ blockManagerInfo(blockManagerId).addBlock(blockId, storageLevel, deserializedSize, size)
+ logDebug("Got in heartBeat 2" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
+
+ var locations: HashSet[BlockManagerId] = null
+ if (blockIdMap.contains(blockId)) {
+ locations = blockIdMap(blockId)._2
+ } else {
+ locations = new HashSet[BlockManagerId]
+ blockIdMap += (blockId -> (storageLevel.replication, locations))
+ }
+ logDebug("Got in heartBeat 3" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
+
+ if (storageLevel.deserialized || storageLevel.useDisk || storageLevel.useMemory) {
+ locations += blockManagerId
+ } else {
+ locations.remove(blockManagerId)
+ }
+ logDebug("Got in heartBeat 4" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
+
+ if (locations.size == 0) {
+ blockIdMap.remove(blockId)
+ }
+
+ logDebug("Got in heartBeat 5" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
+ self.reply(true)
+ }
+
+ private def getLocations(blockId: String) {
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " " + blockId + " "
+ logDebug("Got in getLocations 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
+ if (blockIdMap.contains(blockId)) {
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ res.appendAll(blockIdMap(blockId)._2)
+ logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at "
+ + Utils.getUsedTimeMs(startTimeMs))
+ self.reply(res.toSeq)
+ } else {
+ logDebug("Got in getLocations 2" + tmp + Utils.getUsedTimeMs(startTimeMs))
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ self.reply(res)
+ }
+ }
+
+ private def getLocationsMultipleBlockIds(blockIds: Array[String]) {
+ def getLocations(blockId: String): Seq[BlockManagerId] = {
+ val tmp = blockId
+ logDebug("Got in getLocationsMultipleBlockIds Sub 0 " + tmp)
+ if (blockIdMap.contains(blockId)) {
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ res.appendAll(blockIdMap(blockId)._2)
+ logDebug("Got in getLocationsMultipleBlockIds Sub 1 " + tmp + " " + res.toSeq)
+ return res.toSeq
+ } else {
+ logDebug("Got in getLocationsMultipleBlockIds Sub 2 " + tmp)
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ return res.toSeq
+ }
+ }
+
+ logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq)
+ var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]]
+ for (blockId <- blockIds) {
+ res.append(getLocations(blockId))
+ }
+ logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq + " : " + res.toSeq)
+ self.reply(res.toSeq)
+ }
+
+ private def getPeers(blockManagerId: BlockManagerId, size: Int) {
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " " + blockManagerId + " "
+ logDebug("Got in getPeers 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
+ var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ res.appendAll(peers)
+ res -= blockManagerId
+ val rand = new Random(System.currentTimeMillis())
+ logDebug("Got in getPeers 1" + tmp + Utils.getUsedTimeMs(startTimeMs))
+ while (res.length > size) {
+ res.remove(rand.nextInt(res.length))
+ }
+ logDebug("Got in getPeers 2" + tmp + Utils.getUsedTimeMs(startTimeMs))
+ self.reply(res.toSeq)
+ }
+
+ private def getPeers_Deterministic(blockManagerId: BlockManagerId, size: Int) {
+ val startTimeMs = System.currentTimeMillis()
+ var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+
+ val peersWithIndices = peers.zipWithIndex
+ val selfIndex = peersWithIndices.find(_._1 == blockManagerId).map(_._2).getOrElse(-1)
+ if (selfIndex == -1) {
+ throw new Exception("Self index for " + blockManagerId + " not found")
+ }
+
+ var index = selfIndex
+ while (res.size < size) {
+ index += 1
+ if (index == selfIndex) {
+ throw new Exception("More peer expected than available")
+ }
+ res += peers(index % peers.size)
+ }
+ val resStr = res.map(_.toString).reduceLeft(_ + ", " + _)
+ logDebug("Got peers for " + blockManagerId + " as [" + resStr + "]")
+ self.reply(res.toSeq)
+ }
+}
+
+object BlockManagerMaster extends Logging {
+ initLogging()
+
+ val AKKA_ACTOR_NAME: String = "BlockMasterManager"
+ val REQUEST_RETRY_INTERVAL_MS = 100
+ val DEFAULT_MASTER_IP: String = System.getProperty("spark.master.host", "localhost")
+ val DEFAULT_MASTER_PORT: Int = System.getProperty("spark.master.port", "7077").toInt
+ val DEFAULT_MANAGER_IP: String = Utils.localHostName()
+ val DEFAULT_MANAGER_PORT: String = "10902"
+
+ implicit val TIME_OUT_SEC = Actor.Timeout(3000 millis)
+ var masterActor: ActorRef = null
+
+ def startBlockManagerMaster(isMaster: Boolean, isLocal: Boolean) {
+ if (isMaster) {
+ masterActor = actorOf(new BlockManagerMaster(isLocal))
+ remote.register(AKKA_ACTOR_NAME, masterActor)
+ logInfo("Registered BlockManagerMaster Actor: " + DEFAULT_MASTER_IP + ":" + DEFAULT_MASTER_PORT)
+ masterActor.start()
+ } else {
+ masterActor = remote.actorFor(AKKA_ACTOR_NAME, DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT)
+ }
+ }
+
+ def notifyADeadHost(host: String) {
+ (masterActor ? RemoveHost(host + ":" + DEFAULT_MANAGER_PORT)).as[Any] match {
+ case Some(true) =>
+ logInfo("Removed " + host + " successfully. @ notifyADeadHost")
+ case Some(oops) =>
+ logError("Failed @ notifyADeadHost: " + oops)
+ case None =>
+ logError("None @ notifyADeadHost.")
+ }
+ }
+
+ def mustRegisterBlockManager(msg: RegisterBlockManager) {
+ while (! syncRegisterBlockManager(msg)) {
+ logWarning("Failed to register " + msg)
+ Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+ }
+ }
+
+ def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = {
+ //val masterActor = RemoteActor.select(node, name)
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " msg " + msg + " "
+ logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+
+ (masterActor ? msg).as[Any] match {
+ case Some(true) =>
+ logInfo("BlockManager registered successfully @ syncRegisterBlockManager.")
+ logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+ return true
+ case Some(oops) =>
+ logError("Failed @ syncRegisterBlockManager: " + oops)
+ return false
+ case None =>
+ logError("None @ syncRegisterBlockManager.")
+ return false
+ }
+ }
+
+ def mustHeartBeat(msg: HeartBeat) {
+ while (! syncHeartBeat(msg)) {
+ logWarning("Failed to send heartbeat" + msg)
+ Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+ }
+ }
+
+ def syncHeartBeat(msg: HeartBeat): Boolean = {
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " msg " + msg + " "
+ logDebug("Got in syncHeartBeat " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs))
+
+ (masterActor ? msg).as[Any] match {
+ case Some(true) =>
+ logInfo("Heartbeat sent successfully.")
+ logDebug("Got in syncHeartBeat " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs))
+ return true
+ case Some(oops) =>
+ logError("Failed: " + oops)
+ return false
+ case None =>
+ logError("None.")
+ return false
+ }
+ }
+
+ def mustGetLocations(msg: GetLocations): Array[BlockManagerId] = {
+ var res: Array[BlockManagerId] = syncGetLocations(msg)
+ while (res == null) {
+ logInfo("Failed to get locations " + msg)
+ Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+ res = syncGetLocations(msg)
+ }
+ return res
+ }
+
+ def syncGetLocations(msg: GetLocations): Array[BlockManagerId] = {
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " msg " + msg + " "
+ logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+
+ (masterActor ? msg).as[Seq[BlockManagerId]] match {
+ case Some(arr) =>
+ logDebug("GetLocations successfully.")
+ logDebug("Got in syncGetLocations 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+ val res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ for (ele <- arr) {
+ res += ele
+ }
+ logDebug("Got in syncGetLocations 2 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+ return res.toArray
+ case None =>
+ logError("GetLocations call returned None.")
+ return null
+ }
+ }
+
+ def mustGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds):
+ Seq[Seq[BlockManagerId]] = {
+ var res: Seq[Seq[BlockManagerId]] = syncGetLocationsMultipleBlockIds(msg)
+ while (res == null) {
+ logWarning("Failed to GetLocationsMultipleBlockIds " + msg)
+ Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+ res = syncGetLocationsMultipleBlockIds(msg)
+ }
+ return res
+ }
+
+ def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds):
+ Seq[Seq[BlockManagerId]] = {
+ val startTimeMs = System.currentTimeMillis
+ val tmp = " msg " + msg + " "
+ logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+
+ (masterActor ? msg).as[Any] match {
+ case Some(arr: Seq[Seq[BlockManagerId]]) =>
+ logDebug("GetLocationsMultipleBlockIds successfully: " + arr)
+ logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+ return arr
+ case Some(oops) =>
+ logError("Failed: " + oops)
+ return null
+ case None =>
+ logInfo("None.")
+ return null
+ }
+ }
+
+ def mustGetPeers(msg: GetPeers): Array[BlockManagerId] = {
+ var res: Array[BlockManagerId] = syncGetPeers(msg)
+ while ((res == null) || (res.length != msg.size)) {
+ logInfo("Failed to get peers " + msg)
+ Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+ res = syncGetPeers(msg)
+ }
+
+ return res
+ }
+
+ def syncGetPeers(msg: GetPeers): Array[BlockManagerId] = {
+ val startTimeMs = System.currentTimeMillis
+ val tmp = " msg " + msg + " "
+ logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+
+ (masterActor ? msg).as[Seq[BlockManagerId]] match {
+ case Some(arr) =>
+ logDebug("Got in syncGetPeers 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+ val res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ logInfo("GetPeers successfully: " + arr.length)
+ res.appendAll(arr)
+ logDebug("Got in syncGetPeers 2 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+ return res.toArray
+ case None =>
+ logError("GetPeers call returned None.")
+ return null
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
new file mode 100644
index 0000000000..a4cdbd8ddd
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
@@ -0,0 +1,142 @@
+package spark.storage
+
+import java.nio._
+
+import scala.actors._
+import scala.actors.Actor._
+import scala.actors.remote._
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.util.Random
+
+import spark.Logging
+import spark.Utils
+import spark.SparkEnv
+import spark.network._
+
+/**
+ * This should be changed to use event model late.
+ */
+class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
+ initLogging()
+
+ blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive)
+
+ def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = {
+ logDebug("Handling message " + msg)
+ msg match {
+ case bufferMessage: BufferMessage => {
+ try {
+ logDebug("Handling as a buffer message " + bufferMessage)
+ val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage)
+ logDebug("Parsed as a block message array")
+ val responseMessages = blockMessages.map(processBlockMessage _).filter(_ != None).map(_.get)
+ /*logDebug("Processed block messages")*/
+ return Some(new BlockMessageArray(responseMessages).toBufferMessage)
+ } catch {
+ case e: Exception => logError("Exception handling buffer message: " + e.getMessage)
+ return None
+ }
+ }
+ case otherMessage: Any => {
+ logError("Unknown type message received: " + otherMessage)
+ return None
+ }
+ }
+ }
+
+ def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = {
+ blockMessage.getType() match {
+ case BlockMessage.TYPE_PUT_BLOCK => {
+ val pB = PutBlock(blockMessage.getId(), blockMessage.getData(), blockMessage.getLevel())
+ logInfo("Received [" + pB + "]")
+ putBlock(pB.id, pB.data, pB.level)
+ return None
+ }
+ case BlockMessage.TYPE_GET_BLOCK => {
+ val gB = new GetBlock(blockMessage.getId())
+ logInfo("Received [" + gB + "]")
+ val buffer = getBlock(gB.id)
+ if (buffer == null) {
+ return None
+ }
+ return Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer)))
+ }
+ case _ => return None
+ }
+ }
+
+ private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) {
+ val startTimeMs = System.currentTimeMillis()
+ logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes)
+ blockManager.putBytes(id, bytes, level)
+ logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
+ + " with data size: " + bytes.array().length)
+ }
+
+ private def getBlock(id: String): ByteBuffer = {
+ val startTimeMs = System.currentTimeMillis()
+ logDebug("Getblock " + id + " started from " + startTimeMs)
+ val block = blockManager.get(id)
+ val buffer = block match {
+ case Some(tValues) => {
+ val values = tValues.asInstanceOf[Iterator[Any]]
+ val buffer = blockManager.dataSerialize(values)
+ buffer
+ }
+ case None => {
+ null
+ }
+ }
+ logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
+ + " and got buffer " + buffer)
+ return buffer
+ }
+}
+
+object BlockManagerWorker extends Logging {
+ private var blockManagerWorker: BlockManagerWorker = null
+ private val DATA_TRANSFER_TIME_OUT_MS: Long = 500
+ private val REQUEST_RETRY_INTERVAL_MS: Long = 1000
+
+ initLogging()
+
+ def startBlockManagerWorker(manager: BlockManager) {
+ blockManagerWorker = new BlockManagerWorker(manager)
+ }
+
+ def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = {
+ val blockManager = blockManagerWorker.blockManager
+ val connectionManager = blockManager.connectionManager
+ val serializer = blockManager.serializer
+ val blockMessage = BlockMessage.fromPutBlock(msg)
+ val blockMessageArray = new BlockMessageArray(blockMessage)
+ val resultMessage = connectionManager.sendMessageReliablySync(
+ toConnManagerId, blockMessageArray.toBufferMessage())
+ return (resultMessage != None)
+ }
+
+ def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
+ val blockManager = blockManagerWorker.blockManager
+ val connectionManager = blockManager.connectionManager
+ val serializer = blockManager.serializer
+ val blockMessage = BlockMessage.fromGetBlock(msg)
+ val blockMessageArray = new BlockMessageArray(blockMessage)
+ val responseMessage = connectionManager.sendMessageReliablySync(
+ toConnManagerId, blockMessageArray.toBufferMessage())
+ responseMessage match {
+ case Some(message) => {
+ val bufferMessage = message.asInstanceOf[BufferMessage]
+ logDebug("Response message received " + bufferMessage)
+ BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => {
+ logDebug("Found " + blockMessage)
+ return blockMessage.getData
+ })
+ }
+ case None => logDebug("No response message received"); return null
+ }
+ return null
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala
new file mode 100644
index 0000000000..bb128dce7a
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockMessage.scala
@@ -0,0 +1,219 @@
+package spark.storage
+
+import java.nio._
+
+import scala.collection.mutable.StringBuilder
+import scala.collection.mutable.ArrayBuffer
+
+import spark._
+import spark.network._
+
+case class GetBlock(id: String)
+case class GotBlock(id: String, data: ByteBuffer)
+case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel)
+
+class BlockMessage() extends Logging{
+ // Un-initialized: typ = 0
+ // GetBlock: typ = 1
+ // GotBlock: typ = 2
+ // PutBlock: typ = 3
+ private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED
+ private var id: String = null
+ private var data: ByteBuffer = null
+ private var level: StorageLevel = null
+
+ initLogging()
+
+ def set(getBlock: GetBlock) {
+ typ = BlockMessage.TYPE_GET_BLOCK
+ id = getBlock.id
+ }
+
+ def set(gotBlock: GotBlock) {
+ typ = BlockMessage.TYPE_GOT_BLOCK
+ id = gotBlock.id
+ data = gotBlock.data
+ }
+
+ def set(putBlock: PutBlock) {
+ typ = BlockMessage.TYPE_PUT_BLOCK
+ id = putBlock.id
+ data = putBlock.data
+ level = putBlock.level
+ }
+
+ def set(buffer: ByteBuffer) {
+ val startTime = System.currentTimeMillis
+ /*
+ println()
+ println("BlockMessage: ")
+ while(buffer.remaining > 0) {
+ print(buffer.get())
+ }
+ buffer.rewind()
+ println()
+ println()
+ */
+ typ = buffer.getInt()
+ val idLength = buffer.getInt()
+ val idBuilder = new StringBuilder(idLength)
+ for (i <- 1 to idLength) {
+ idBuilder += buffer.getChar()
+ }
+ id = idBuilder.toString()
+
+ logDebug("Set from buffer Result: " + typ + " " + id)
+ logDebug("Buffer position is " + buffer.position)
+ if (typ == BlockMessage.TYPE_PUT_BLOCK) {
+
+ val booleanInt = buffer.getInt()
+ val replication = buffer.getInt()
+ level = new StorageLevel(booleanInt, replication)
+
+ val dataLength = buffer.getInt()
+ data = ByteBuffer.allocate(dataLength)
+ if (dataLength != buffer.remaining) {
+ throw new Exception("Error parsing buffer")
+ }
+ data.put(buffer)
+ data.flip()
+ logDebug("Set from buffer Result 2: " + level + " " + data)
+ } else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
+
+ val dataLength = buffer.getInt()
+ logDebug("Data length is "+ dataLength)
+ logDebug("Buffer position is " + buffer.position)
+ data = ByteBuffer.allocate(dataLength)
+ if (dataLength != buffer.remaining) {
+ throw new Exception("Error parsing buffer")
+ }
+ data.put(buffer)
+ data.flip()
+ logDebug("Set from buffer Result 3: " + data)
+ }
+
+ val finishTime = System.currentTimeMillis
+ logDebug("Converted " + id + " from bytebuffer in " + (finishTime - startTime) / 1000.0 + " s")
+ }
+
+ def set(bufferMsg: BufferMessage) {
+ val buffer = bufferMsg.buffers.apply(0)
+ buffer.clear()
+ set(buffer)
+ }
+
+ def getType(): Int = {
+ return typ
+ }
+
+ def getId(): String = {
+ return id
+ }
+
+ def getData(): ByteBuffer = {
+ return data
+ }
+
+ def getLevel(): StorageLevel = {
+ return level
+ }
+
+ def toBufferMessage(): BufferMessage = {
+ val startTime = System.currentTimeMillis
+ val buffers = new ArrayBuffer[ByteBuffer]()
+ var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2)
+ buffer.putInt(typ).putInt(id.length())
+ id.foreach((x: Char) => buffer.putChar(x))
+ buffer.flip()
+ buffers += buffer
+
+ if (typ == BlockMessage.TYPE_PUT_BLOCK) {
+ buffer = ByteBuffer.allocate(8).putInt(level.toInt()).putInt(level.replication)
+ buffer.flip()
+ buffers += buffer
+
+ buffer = ByteBuffer.allocate(4).putInt(data.remaining)
+ buffer.flip()
+ buffers += buffer
+
+ buffers += data
+ } else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
+ buffer = ByteBuffer.allocate(4).putInt(data.remaining)
+ buffer.flip()
+ buffers += buffer
+
+ buffers += data
+ }
+
+ logDebug("Start to log buffers.")
+ buffers.foreach((x: ByteBuffer) => logDebug("" + x))
+ /*
+ println()
+ println("BlockMessage: ")
+ buffers.foreach(b => {
+ while(b.remaining > 0) {
+ print(b.get())
+ }
+ b.rewind()
+ })
+ println()
+ println()
+ */
+ val finishTime = System.currentTimeMillis
+ logDebug("Converted " + id + " to buffer message in " + (finishTime - startTime) / 1000.0 + " s")
+ return Message.createBufferMessage(buffers)
+ }
+
+ override def toString(): String = {
+ "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level +
+ ", data = " + (if (data != null) data.remaining.toString else "null") + "]"
+ }
+}
+
+object BlockMessage {
+ val TYPE_NON_INITIALIZED: Int = 0
+ val TYPE_GET_BLOCK: Int = 1
+ val TYPE_GOT_BLOCK: Int = 2
+ val TYPE_PUT_BLOCK: Int = 3
+
+ def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(bufferMessage)
+ newBlockMessage
+ }
+
+ def fromByteBuffer(buffer: ByteBuffer): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(buffer)
+ newBlockMessage
+ }
+
+ def fromGetBlock(getBlock: GetBlock): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(getBlock)
+ newBlockMessage
+ }
+
+ def fromGotBlock(gotBlock: GotBlock): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(gotBlock)
+ newBlockMessage
+ }
+
+ def fromPutBlock(putBlock: PutBlock): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(putBlock)
+ newBlockMessage
+ }
+
+ def main(args: Array[String]) {
+ val B = new BlockMessage()
+ B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.DISK_AND_MEMORY_2))
+ val bMsg = B.toBufferMessage()
+ val C = new BlockMessage()
+ C.set(bMsg)
+
+ println(B.getId() + " " + B.getLevel())
+ println(C.getId() + " " + C.getLevel())
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala
new file mode 100644
index 0000000000..5f411d3488
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockMessageArray.scala
@@ -0,0 +1,140 @@
+package spark.storage
+import java.nio._
+
+import scala.collection.mutable.StringBuilder
+import scala.collection.mutable.ArrayBuffer
+
+import spark._
+import spark.network._
+
+class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging {
+
+ def this(bm: BlockMessage) = this(Array(bm))
+
+ def this() = this(null.asInstanceOf[Seq[BlockMessage]])
+
+ def apply(i: Int) = blockMessages(i)
+
+ def iterator = blockMessages.iterator
+
+ def length = blockMessages.length
+
+ initLogging()
+
+ def set(bufferMessage: BufferMessage) {
+ val startTime = System.currentTimeMillis
+ val newBlockMessages = new ArrayBuffer[BlockMessage]()
+ val buffer = bufferMessage.buffers(0)
+ buffer.clear()
+ /*
+ println()
+ println("BlockMessageArray: ")
+ while(buffer.remaining > 0) {
+ print(buffer.get())
+ }
+ buffer.rewind()
+ println()
+ println()
+ */
+ while(buffer.remaining() > 0) {
+ val size = buffer.getInt()
+ logDebug("Creating block message of size " + size + " bytes")
+ val newBuffer = buffer.slice()
+ newBuffer.clear()
+ newBuffer.limit(size)
+ logDebug("Trying to convert buffer " + newBuffer + " to block message")
+ val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer)
+ logDebug("Created " + newBlockMessage)
+ newBlockMessages += newBlockMessage
+ buffer.position(buffer.position() + size)
+ }
+ val finishTime = System.currentTimeMillis
+ logDebug("Converted block message array from buffer message in " + (finishTime - startTime) / 1000.0 + " s")
+ this.blockMessages = newBlockMessages
+ }
+
+ def toBufferMessage(): BufferMessage = {
+ val buffers = new ArrayBuffer[ByteBuffer]()
+
+ blockMessages.foreach(blockMessage => {
+ val bufferMessage = blockMessage.toBufferMessage
+ logDebug("Adding " + blockMessage)
+ val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size)
+ sizeBuffer.flip
+ buffers += sizeBuffer
+ buffers ++= bufferMessage.buffers
+ logDebug("Added " + bufferMessage)
+ })
+
+ logDebug("Buffer list:")
+ buffers.foreach((x: ByteBuffer) => logDebug("" + x))
+ /*
+ println()
+ println("BlockMessageArray: ")
+ buffers.foreach(b => {
+ while(b.remaining > 0) {
+ print(b.get())
+ }
+ b.rewind()
+ })
+ println()
+ println()
+ */
+ return Message.createBufferMessage(buffers)
+ }
+}
+
+object BlockMessageArray {
+
+ def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = {
+ val newBlockMessageArray = new BlockMessageArray()
+ newBlockMessageArray.set(bufferMessage)
+ newBlockMessageArray
+ }
+
+ def main(args: Array[String]) {
+ val blockMessages =
+ (0 until 10).map(i => {
+ if (i % 2 == 0) {
+ val buffer = ByteBuffer.allocate(100)
+ buffer.clear
+ BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY))
+ } else {
+ BlockMessage.fromGetBlock(GetBlock(i.toString))
+ }
+ })
+ val blockMessageArray = new BlockMessageArray(blockMessages)
+ println("Block message array created")
+
+ val bufferMessage = blockMessageArray.toBufferMessage
+ println("Converted to buffer message")
+
+ val totalSize = bufferMessage.size
+ val newBuffer = ByteBuffer.allocate(totalSize)
+ newBuffer.clear()
+ bufferMessage.buffers.foreach(buffer => {
+ newBuffer.put(buffer)
+ buffer.rewind()
+ })
+ newBuffer.flip
+ val newBufferMessage = Message.createBufferMessage(newBuffer)
+ println("Copied to new buffer message, size = " + newBufferMessage.size)
+
+ val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage)
+ println("Converted back to block message array")
+ newBlockMessageArray.foreach(blockMessage => {
+ blockMessage.getType() match {
+ case BlockMessage.TYPE_PUT_BLOCK => {
+ val pB = PutBlock(blockMessage.getId(), blockMessage.getData(), blockMessage.getLevel())
+ println(pB)
+ }
+ case BlockMessage.TYPE_GET_BLOCK => {
+ val gB = new GetBlock(blockMessage.getId())
+ println(gB)
+ }
+ }
+ })
+ }
+}
+
+
diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala
new file mode 100644
index 0000000000..0584cc2d4f
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockStore.scala
@@ -0,0 +1,282 @@
+package spark.storage
+
+import spark.{Utils, Logging, Serializer, SizeEstimator}
+
+import scala.collection.mutable.ArrayBuffer
+
+import java.io.{File, RandomAccessFile}
+import java.nio.ByteBuffer
+import java.nio.channels.FileChannel.MapMode
+import java.util.{UUID, LinkedHashMap}
+import java.util.concurrent.Executors
+
+import it.unimi.dsi.fastutil.io._
+
+/**
+ * Abstract class to store blocks
+ */
+abstract class BlockStore(blockManager: BlockManager) extends Logging {
+ initLogging()
+
+ def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel)
+
+ def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer]
+
+ def getBytes(blockId: String): Option[ByteBuffer]
+
+ def getValues(blockId: String): Option[Iterator[Any]]
+
+ def remove(blockId: String)
+
+ def dataSerialize(values: Iterator[Any]): ByteBuffer = blockManager.dataSerialize(values)
+
+ def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = blockManager.dataDeserialize(bytes)
+}
+
+/**
+ * Class to store blocks in memory
+ */
+class MemoryStore(blockManager: BlockManager, maxMemory: Long)
+ extends BlockStore(blockManager) {
+
+ class Entry(var value: Any, val size: Long, val deserialized: Boolean)
+
+ private val memoryStore = new LinkedHashMap[String, Entry](32, 0.75f, true)
+ private var currentMemory = 0L
+
+ private val blockDropper = Executors.newSingleThreadExecutor()
+
+ def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ if (level.deserialized) {
+ bytes.rewind()
+ val values = dataDeserialize(bytes)
+ val elements = new ArrayBuffer[Any]
+ elements ++= values
+ val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
+ ensureFreeSpace(sizeEstimate)
+ val entry = new Entry(elements, sizeEstimate, true)
+ memoryStore.synchronized { memoryStore.put(blockId, entry) }
+ currentMemory += sizeEstimate
+ logDebug("Block " + blockId + " stored as values to memory")
+ } else {
+ val entry = new Entry(bytes, bytes.array().length, false)
+ ensureFreeSpace(bytes.array.length)
+ memoryStore.synchronized { memoryStore.put(blockId, entry) }
+ currentMemory += bytes.array().length
+ logDebug("Block " + blockId + " stored as " + bytes.array().length + " bytes to memory")
+ }
+ }
+
+ def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = {
+ if (level.deserialized) {
+ val elements = new ArrayBuffer[Any]
+ elements ++= values
+ val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
+ ensureFreeSpace(sizeEstimate)
+ val entry = new Entry(elements, sizeEstimate, true)
+ memoryStore.synchronized { memoryStore.put(blockId, entry) }
+ currentMemory += sizeEstimate
+ logDebug("Block " + blockId + " stored as values to memory")
+ return Left(elements.iterator)
+ } else {
+ val bytes = dataSerialize(values)
+ ensureFreeSpace(bytes.array().length)
+ val entry = new Entry(bytes, bytes.array().length, false)
+ memoryStore.synchronized { memoryStore.put(blockId, entry) }
+ currentMemory += bytes.array().length
+ logDebug("Block " + blockId + " stored as " + bytes.array.length + " bytes to memory")
+ return Right(bytes)
+ }
+ }
+
+ def getBytes(blockId: String): Option[ByteBuffer] = {
+ throw new UnsupportedOperationException("Not implemented")
+ }
+
+ def getValues(blockId: String): Option[Iterator[Any]] = {
+ val entry = memoryStore.synchronized { memoryStore.get(blockId) }
+ if (entry == null) {
+ return None
+ }
+ if (entry.deserialized) {
+ return Some(entry.value.asInstanceOf[ArrayBuffer[Any]].toIterator)
+ } else {
+ return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer]))
+ }
+ }
+
+ def remove(blockId: String) {
+ memoryStore.synchronized {
+ val entry = memoryStore.get(blockId)
+ if (entry != null) {
+ memoryStore.remove(blockId)
+ currentMemory -= entry.size
+ logDebug("Block " + blockId + " of size " + entry.size + " dropped from memory")
+ } else {
+ logWarning("Block " + blockId + " could not be removed as it doesnt exist")
+ }
+ }
+ }
+
+ private def drop(blockId: String) {
+ blockDropper.submit(new Runnable() {
+ def run() {
+ blockManager.dropFromMemory(blockId)
+ }
+ })
+ }
+
+ private def ensureFreeSpace(space: Long) {
+ logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
+ space, currentMemory, maxMemory))
+
+ val droppedBlockIds = new ArrayBuffer[String]()
+ var droppedMemory = 0L
+
+ memoryStore.synchronized {
+ val iter = memoryStore.entrySet().iterator()
+ while (maxMemory - (currentMemory - droppedMemory) < space && iter.hasNext) {
+ val pair = iter.next()
+ val blockId = pair.getKey
+ droppedBlockIds += blockId
+ droppedMemory += pair.getValue.size
+ logDebug("Decided to drop " + blockId)
+ }
+ }
+
+ for (blockId <- droppedBlockIds) {
+ drop(blockId)
+ }
+
+ droppedBlockIds.clear
+ }
+}
+
+
+/**
+ * Class to store blocks in disk
+ */
+class DiskStore(blockManager: BlockManager, rootDirs: String)
+ extends BlockStore(blockManager) {
+
+ val MAX_DIR_CREATION_ATTEMPTS: Int = 10
+ val localDirs = createLocalDirs()
+ var lastLocalDirUsed = 0
+
+ addShutdownHook()
+
+ def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ logDebug("Attempting to put block " + blockId)
+ val startTime = System.currentTimeMillis
+ val file = createFile(blockId)
+ if (file != null) {
+ val channel = new RandomAccessFile(file, "rw").getChannel()
+ val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.array.length)
+ buffer.put(bytes.array)
+ channel.close()
+ val finishTime = System.currentTimeMillis
+ logDebug("Block " + blockId + " stored to file of " + bytes.array.length + " bytes to disk in " + (finishTime - startTime) + " ms")
+ } else {
+ logError("File not created for block " + blockId)
+ }
+ }
+
+ def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = {
+ val bytes = dataSerialize(values)
+ logDebug("Converted block " + blockId + " to " + bytes.array.length + " bytes")
+ putBytes(blockId, bytes, level)
+ return Right(bytes)
+ }
+
+ def getBytes(blockId: String): Option[ByteBuffer] = {
+ val file = getFile(blockId)
+ val length = file.length().toInt
+ val channel = new RandomAccessFile(file, "r").getChannel()
+ val bytes = ByteBuffer.allocate(length)
+ bytes.put(channel.map(MapMode.READ_WRITE, 0, length))
+ return Some(bytes)
+ }
+
+ def getValues(blockId: String): Option[Iterator[Any]] = {
+ val file = getFile(blockId)
+ val length = file.length().toInt
+ val channel = new RandomAccessFile(file, "r").getChannel()
+ val bytes = channel.map(MapMode.READ_ONLY, 0, length)
+ val buffer = dataDeserialize(bytes)
+ channel.close()
+ return Some(buffer)
+ }
+
+ def remove(blockId: String) {
+ throw new UnsupportedOperationException("Not implemented")
+ }
+
+ private def createFile(blockId: String): File = {
+ val file = getFile(blockId)
+ if (file == null) {
+ lastLocalDirUsed = (lastLocalDirUsed + 1) % localDirs.size
+ val newFile = new File(localDirs(lastLocalDirUsed), blockId)
+ newFile.getParentFile.mkdirs()
+ return newFile
+ } else {
+ logError("File for block " + blockId + " already exists on disk, " + file)
+ return null
+ }
+ }
+
+ private def getFile(blockId: String): File = {
+ logDebug("Getting file for block " + blockId)
+ // Search for the file in all the local directories, only one of them should have the file
+ val files = localDirs.map(localDir => new File(localDir, blockId)).filter(_.exists)
+ if (files.size > 1) {
+ throw new Exception("Multiple files for same block " + blockId + " exists: " +
+ files.map(_.toString).reduceLeft(_ + ", " + _))
+ return null
+ } else if (files.size == 0) {
+ return null
+ } else {
+ logDebug("Got file " + files(0) + " of size " + files(0).length + " bytes")
+ return files(0)
+ }
+ }
+
+ private def createLocalDirs(): Seq[File] = {
+ logDebug("Creating local directories at root dirs '" + rootDirs + "'")
+ rootDirs.split("[;,:]").map(rootDir => {
+ var foundLocalDir: Boolean = false
+ var localDir: File = null
+ var localDirUuid: UUID = null
+ var tries = 0
+ while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
+ tries += 1
+ try {
+ localDirUuid = UUID.randomUUID()
+ localDir = new File(rootDir, "spark-local-" + localDirUuid)
+ if (!localDir.exists) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
+ " attempts to create local dir in " + rootDir)
+ System.exit(1)
+ }
+ logDebug("Created local directory at " + localDir)
+ localDir
+ })
+ }
+
+ private def addShutdownHook() {
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
+ override def run() {
+ logDebug("Shutdown hook called")
+ localDirs.foreach(localDir => Utils.deleteRecursively(localDir))
+ }
+ })
+ }
+}
diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala
new file mode 100644
index 0000000000..a2833a7090
--- /dev/null
+++ b/core/src/main/scala/spark/storage/StorageLevel.scala
@@ -0,0 +1,78 @@
+package spark.storage
+
+import java.io._
+
+class StorageLevel(
+ var useDisk: Boolean,
+ var useMemory: Boolean,
+ var deserialized: Boolean,
+ var replication: Int = 1)
+ extends Externalizable {
+
+ // TODO: Also add fields for caching priority, dataset ID, and flushing.
+
+ def this(booleanInt: Int, replication: Int) {
+ this(((booleanInt & 4) != 0),
+ ((booleanInt & 2) != 0),
+ ((booleanInt & 1) != 0),
+ replication)
+ }
+
+ def this() = this(false, true, false) // For deserialization
+
+ override def clone(): StorageLevel = new StorageLevel(
+ this.useDisk, this.useMemory, this.deserialized, this.replication)
+
+ override def equals(other: Any): Boolean = other match {
+ case s: StorageLevel =>
+ s.useDisk == useDisk &&
+ s.useMemory == useMemory &&
+ s.deserialized == deserialized &&
+ s.replication == replication
+ case _ =>
+ false
+ }
+
+ def toInt(): Int = {
+ var ret = 0
+ if (useDisk) {
+ ret += 4
+ }
+ if (useMemory) {
+ ret += 2
+ }
+ if (deserialized) {
+ ret += 1
+ }
+ return ret
+ }
+
+ override def writeExternal(out: ObjectOutput) {
+ out.writeByte(toInt().toByte)
+ out.writeByte(replication.toByte)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ val flags = in.readByte()
+ useDisk = (flags & 4) != 0
+ useMemory = (flags & 2) != 0
+ deserialized = (flags & 1) != 0
+ replication = in.readByte()
+ }
+
+ override def toString(): String =
+ "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
+}
+
+object StorageLevel {
+ val NONE = new StorageLevel(false, false, false)
+ val DISK_ONLY = new StorageLevel(true, false, false)
+ val MEMORY_ONLY = new StorageLevel(false, true, false)
+ val MEMORY_ONLY_2 = new StorageLevel(false, true, false, 2)
+ val MEMORY_ONLY_DESER = new StorageLevel(false, true, true)
+ val MEMORY_ONLY_DESER_2 = new StorageLevel(false, true, true, 2)
+ val DISK_AND_MEMORY = new StorageLevel(true, true, false)
+ val DISK_AND_MEMORY_2 = new StorageLevel(true, true, false, 2)
+ val DISK_AND_MEMORY_DESER = new StorageLevel(true, true, true)
+ val DISK_AND_MEMORY_DESER_2 = new StorageLevel(true, true, true, 2)
+}
diff --git a/core/src/main/scala/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/spark/util/ByteBufferInputStream.scala
new file mode 100644
index 0000000000..abe2d99dd8
--- /dev/null
+++ b/core/src/main/scala/spark/util/ByteBufferInputStream.scala
@@ -0,0 +1,30 @@
+package spark.util
+
+import java.io.InputStream
+import java.nio.ByteBuffer
+
+class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream {
+ override def read(): Int = {
+ if (buffer.remaining() == 0) {
+ -1
+ } else {
+ buffer.get()
+ }
+ }
+
+ override def read(dest: Array[Byte]): Int = {
+ read(dest, 0, dest.length)
+ }
+
+ override def read(dest: Array[Byte], offset: Int, length: Int): Int = {
+ val amountToGet = math.min(buffer.remaining(), length)
+ buffer.get(dest, offset, amountToGet)
+ return amountToGet
+ }
+
+ override def skip(bytes: Long): Long = {
+ val amountToSkip = math.min(bytes, buffer.remaining).toInt
+ buffer.position(buffer.position + amountToSkip)
+ return amountToSkip
+ }
+}
diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala
new file mode 100644
index 0000000000..efb1ae7529
--- /dev/null
+++ b/core/src/main/scala/spark/util/StatCounter.scala
@@ -0,0 +1,89 @@
+package spark.util
+
+/**
+ * A class for tracking the statistics of a set of numbers (count, mean and variance) in a
+ * numerically robust way. Includes support for merging two StatCounters. Based on Welford and
+ * Chan's algorithms described at http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
+ */
+class StatCounter(values: TraversableOnce[Double]) {
+ private var n: Long = 0 // Running count of our values
+ private var mu: Double = 0 // Running mean of our values
+ private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2)
+
+ merge(values)
+
+ def this() = this(Nil)
+
+ def merge(value: Double): StatCounter = {
+ val delta = value - mu
+ n += 1
+ mu += delta / n
+ m2 += delta * (value - mu)
+ this
+ }
+
+ def merge(values: TraversableOnce[Double]): StatCounter = {
+ values.foreach(v => merge(v))
+ this
+ }
+
+ def merge(other: StatCounter): StatCounter = {
+ if (other == this) {
+ merge(other.copy()) // Avoid overwriting fields in a weird order
+ } else {
+ val delta = other.mu - mu
+ if (other.n * 10 < n) {
+ mu = mu + (delta * other.n) / (n + other.n)
+ } else if (n * 10 < other.n) {
+ mu = other.mu - (delta * n) / (n + other.n)
+ } else {
+ mu = (mu * n + other.mu * other.n) / (n + other.n)
+ }
+ m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
+ n += other.n
+ this
+ }
+ }
+
+ def copy(): StatCounter = {
+ val other = new StatCounter
+ other.n = n
+ other.mu = mu
+ other.m2 = m2
+ other
+ }
+
+ def count: Long = n
+
+ def mean: Double = mu
+
+ def sum: Double = n * mu
+
+ def variance: Double = {
+ if (n == 0)
+ Double.NaN
+ else
+ m2 / n
+ }
+
+ def sampleVariance: Double = {
+ if (n <= 1)
+ Double.NaN
+ else
+ m2 / (n - 1)
+ }
+
+ def stdev: Double = math.sqrt(variance)
+
+ def sampleStdev: Double = math.sqrt(sampleVariance)
+
+ override def toString: String = {
+ "(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev)
+ }
+}
+
+object StatCounter {
+ def apply(values: TraversableOnce[Double]) = new StatCounter(values)
+
+ def apply(values: Double*) = new StatCounter(values)
+}
diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala
index 60290d14ca..3d170a6e22 100644
--- a/core/src/test/scala/spark/CacheTrackerSuite.scala
+++ b/core/src/test/scala/spark/CacheTrackerSuite.scala
@@ -1,95 +1,103 @@
package spark
import org.scalatest.FunSuite
-import collection.mutable.HashMap
+
+import scala.collection.mutable.HashMap
+
+import akka.actor._
+import akka.actor.Actor
+import akka.actor.Actor._
class CacheTrackerSuite extends FunSuite {
test("CacheTrackerActor slave initialization & cache status") {
- System.setProperty("spark.master.port", "1345")
+ //System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
- val tracker = new CacheTrackerActor
+ val tracker = actorOf(new CacheTrackerActor)
tracker.start()
- tracker !? SlaveCacheStarted("host001", initialSize)
+ tracker !! SlaveCacheStarted("host001", initialSize)
- assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 0L)))
+ assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 0L)))
- tracker !? StopCacheTracker
+ tracker !! StopCacheTracker
}
test("RegisterRDD") {
- System.setProperty("spark.master.port", "1345")
+ //System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
- val tracker = new CacheTrackerActor
+ val tracker = actorOf(new CacheTrackerActor)
tracker.start()
- tracker !? SlaveCacheStarted("host001", initialSize)
+ tracker !! SlaveCacheStarted("host001", initialSize)
- tracker !? RegisterRDD(1, 3)
- tracker !? RegisterRDD(2, 1)
+ tracker !! RegisterRDD(1, 3)
+ tracker !! RegisterRDD(2, 1)
- assert(getCacheLocations(tracker) == Map(1 -> List(List(), List(), List()), 2 -> List(List())))
+ assert(getCacheLocations(tracker) === Map(1 -> List(List(), List(), List()), 2 -> List(List())))
- tracker !? StopCacheTracker
+ tracker !! StopCacheTracker
}
test("AddedToCache") {
- System.setProperty("spark.master.port", "1345")
+ //System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
- val tracker = new CacheTrackerActor
+ val tracker = actorOf(new CacheTrackerActor)
tracker.start()
- tracker !? SlaveCacheStarted("host001", initialSize)
+ tracker !! SlaveCacheStarted("host001", initialSize)
- tracker !? RegisterRDD(1, 2)
- tracker !? RegisterRDD(2, 1)
+ tracker !! RegisterRDD(1, 2)
+ tracker !! RegisterRDD(2, 1)
- tracker !? AddedToCache(1, 0, "host001", 2L << 15)
- tracker !? AddedToCache(1, 1, "host001", 2L << 11)
- tracker !? AddedToCache(2, 0, "host001", 3L << 10)
+ tracker !! AddedToCache(1, 0, "host001", 2L << 15)
+ tracker !! AddedToCache(1, 1, "host001", 2L << 11)
+ tracker !! AddedToCache(2, 0, "host001", 3L << 10)
- assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L)))
+ assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 72704L)))
- assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
+ assert(getCacheLocations(tracker) ===
+ Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
- tracker !? StopCacheTracker
+ tracker !! StopCacheTracker
}
test("DroppedFromCache") {
- System.setProperty("spark.master.port", "1345")
+ //System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
- val tracker = new CacheTrackerActor
+ val tracker = actorOf(new CacheTrackerActor)
tracker.start()
- tracker !? SlaveCacheStarted("host001", initialSize)
+ tracker !! SlaveCacheStarted("host001", initialSize)
- tracker !? RegisterRDD(1, 2)
- tracker !? RegisterRDD(2, 1)
+ tracker !! RegisterRDD(1, 2)
+ tracker !! RegisterRDD(2, 1)
- tracker !? AddedToCache(1, 0, "host001", 2L << 15)
- tracker !? AddedToCache(1, 1, "host001", 2L << 11)
- tracker !? AddedToCache(2, 0, "host001", 3L << 10)
+ tracker !! AddedToCache(1, 0, "host001", 2L << 15)
+ tracker !! AddedToCache(1, 1, "host001", 2L << 11)
+ tracker !! AddedToCache(2, 0, "host001", 3L << 10)
- assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L)))
- assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
+ assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 72704L)))
+ assert(getCacheLocations(tracker) ===
+ Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
- tracker !? DroppedFromCache(1, 1, "host001", 2L << 11)
+ tracker !! DroppedFromCache(1, 1, "host001", 2L << 11)
- assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 68608L)))
- assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
+ assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 68608L)))
+ assert(getCacheLocations(tracker) ===
+ Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
- tracker !? StopCacheTracker
+ tracker !! StopCacheTracker
}
/**
* Helper function to get cacheLocations from CacheTracker
*/
- def getCacheLocations(tracker: CacheTrackerActor) = tracker !? GetCacheLocations match {
+ def getCacheLocations(tracker: ActorRef) = (tracker ? GetCacheLocations).get match {
case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]].map {
case (i, arr) => (i -> arr.toList)
}
diff --git a/core/src/test/scala/spark/MesosSchedulerSuite.scala b/core/src/test/scala/spark/MesosSchedulerSuite.scala
index 0e6820cbdc..54421225d8 100644
--- a/core/src/test/scala/spark/MesosSchedulerSuite.scala
+++ b/core/src/test/scala/spark/MesosSchedulerSuite.scala
@@ -2,6 +2,8 @@ package spark
import org.scalatest.FunSuite
+import spark.scheduler.mesos.MesosScheduler
+
class MesosSchedulerSuite extends FunSuite {
test("memoryStringToMb"){
diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala
index f31251e509..1ac4737f04 100644
--- a/core/src/test/scala/spark/UtilsSuite.scala
+++ b/core/src/test/scala/spark/UtilsSuite.scala
@@ -2,7 +2,7 @@ package spark
import org.scalatest.FunSuite
import java.io.{ByteArrayOutputStream, ByteArrayInputStream}
-import util.Random
+import scala.util.Random
class UtilsSuite extends FunSuite {
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 08c5a990b4..a2faf7399c 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -33,6 +33,7 @@ object SparkBuild extends Build {
"org.scalatest" %% "scalatest" % "1.6.1" % "test",
"org.scala-tools.testing" %% "scalacheck" % "1.9" % "test"
),
+ parallelExecution in Test := false,
/* Workaround for issue #206 (fixed after SBT 0.11.0) */
watchTransitiveSources <<= Defaults.inDependencies[Task[Seq[File]]](watchSources.task,
const(std.TaskExtra.constant(Nil)), aggregate = true, includeRoot = true) apply { _.join.map(_.flatten) }
@@ -57,8 +58,12 @@ object SparkBuild extends Build {
"asm" % "asm-all" % "3.3.1",
"com.google.protobuf" % "protobuf-java" % "2.4.1",
"de.javakaffee" % "kryo-serializers" % "0.9",
+ "se.scalablesolutions.akka" % "akka-actor" % "1.3.1",
+ "se.scalablesolutions.akka" % "akka-remote" % "1.3.1",
+ "se.scalablesolutions.akka" % "akka-slf4j" % "1.3.1",
"org.jboss.netty" % "netty" % "3.2.6.Final",
- "it.unimi.dsi" % "fastutil" % "6.4.2"
+ "it.unimi.dsi" % "fastutil" % "6.4.4",
+ "colt" % "colt" % "1.2.0"
)
) ++ assemblySettings ++ Seq(test in assembly := {})
@@ -68,8 +73,7 @@ object SparkBuild extends Build {
) ++ assemblySettings ++ Seq(test in assembly := {})
def examplesSettings = sharedSettings ++ Seq(
- name := "spark-examples",
- libraryDependencies += "colt" % "colt" % "1.2.0"
+ name := "spark-examples"
)
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")