aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala11
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala70
-rw-r--r--core/src/main/scala/spark/BoundedMemoryCache.scala3
-rw-r--r--core/src/main/scala/spark/CacheTracker.scala334
-rw-r--r--core/src/main/scala/spark/CoGroupedRDD.scala6
-rw-r--r--core/src/main/scala/spark/DAGScheduler.scala374
-rw-r--r--core/src/main/scala/spark/Dependency.scala2
-rw-r--r--core/src/main/scala/spark/DiskSpillingCache.scala75
-rw-r--r--core/src/main/scala/spark/DoubleRDDFunctions.scala39
-rw-r--r--core/src/main/scala/spark/Executor.scala26
-rw-r--r--core/src/main/scala/spark/FetchFailedException.scala8
-rw-r--r--core/src/main/scala/spark/JavaSerializer.scala39
-rw-r--r--core/src/main/scala/spark/Job.scala16
-rw-r--r--core/src/main/scala/spark/KryoSerializer.scala98
-rw-r--r--core/src/main/scala/spark/Logging.scala7
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala108
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala56
-rw-r--r--core/src/main/scala/spark/ParallelShuffleFetcher.scala119
-rw-r--r--core/src/main/scala/spark/Partitioner.scala1
-rw-r--r--core/src/main/scala/spark/PipedRDD.scala1
-rw-r--r--core/src/main/scala/spark/RDD.scala104
-rw-r--r--core/src/main/scala/spark/Scheduler.scala27
-rw-r--r--core/src/main/scala/spark/SequenceFileRDDFunctions.scala2
-rw-r--r--core/src/main/scala/spark/Serializer.scala86
-rw-r--r--core/src/main/scala/spark/SerializingCache.scala26
-rw-r--r--core/src/main/scala/spark/ShuffleMapTask.scala56
-rw-r--r--core/src/main/scala/spark/ShuffledRDD.scala2
-rw-r--r--core/src/main/scala/spark/SimpleShuffleFetcher.scala46
-rw-r--r--core/src/main/scala/spark/SparkContext.scala77
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala79
-rw-r--r--core/src/main/scala/spark/Stage.scala41
-rw-r--r--core/src/main/scala/spark/Task.scala9
-rw-r--r--core/src/main/scala/spark/TaskContext.scala3
-rw-r--r--core/src/main/scala/spark/TaskEndReason.scala16
-rw-r--r--core/src/main/scala/spark/TaskResult.scala8
-rw-r--r--core/src/main/scala/spark/UnionRDD.scala3
-rw-r--r--core/src/main/scala/spark/Utils.scala35
-rw-r--r--core/src/main/scala/spark/network/Connection.scala364
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala467
-rw-r--r--core/src/main/scala/spark/network/ConnectionManagerTest.scala74
-rw-r--r--core/src/main/scala/spark/network/Message.scala219
-rw-r--r--core/src/main/scala/spark/network/ReceiverTest.scala20
-rw-r--r--core/src/main/scala/spark/network/SenderTest.scala53
-rw-r--r--core/src/main/scala/spark/partial/ApproximateActionListener.scala66
-rw-r--r--core/src/main/scala/spark/partial/ApproximateEvaluator.scala10
-rw-r--r--core/src/main/scala/spark/partial/BoundedDouble.scala8
-rw-r--r--core/src/main/scala/spark/partial/CountEvaluator.scala38
-rw-r--r--core/src/main/scala/spark/partial/GroupedCountEvaluator.scala62
-rw-r--r--core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala65
-rw-r--r--core/src/main/scala/spark/partial/GroupedSumEvaluator.scala72
-rw-r--r--core/src/main/scala/spark/partial/MeanEvaluator.scala41
-rw-r--r--core/src/main/scala/spark/partial/PartialResult.scala86
-rw-r--r--core/src/main/scala/spark/partial/StudentTCacher.scala26
-rw-r--r--core/src/main/scala/spark/partial/SumEvaluator.scala51
-rw-r--r--core/src/main/scala/spark/scheduler/ActiveJob.scala18
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala532
-rw-r--r--core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala30
-rw-r--r--core/src/main/scala/spark/scheduler/JobListener.scala11
-rw-r--r--core/src/main/scala/spark/scheduler/JobResult.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/JobWaiter.scala43
-rw-r--r--core/src/main/scala/spark/scheduler/ResultTask.scala (renamed from core/src/main/scala/spark/ResultTask.scala)15
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala135
-rw-r--r--core/src/main/scala/spark/scheduler/Stage.scala86
-rw-r--r--core/src/main/scala/spark/scheduler/Task.scala11
-rw-r--r--core/src/main/scala/spark/scheduler/TaskResult.scala34
-rw-r--r--core/src/main/scala/spark/scheduler/TaskScheduler.scala27
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala16
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSet.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala (renamed from core/src/main/scala/spark/LocalScheduler.scala)43
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala364
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala (renamed from core/src/main/scala/spark/MesosScheduler.scala)271
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala32
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala (renamed from core/src/main/scala/spark/SimpleJob.scala)259
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala507
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala516
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerWorker.scala142
-rw-r--r--core/src/main/scala/spark/storage/BlockMessage.scala219
-rw-r--r--core/src/main/scala/spark/storage/BlockMessageArray.scala140
-rw-r--r--core/src/main/scala/spark/storage/BlockStore.scala282
-rw-r--r--core/src/main/scala/spark/storage/StorageLevel.scala78
-rw-r--r--core/src/main/scala/spark/util/ByteBufferInputStream.scala30
-rw-r--r--core/src/main/scala/spark/util/StatCounter.scala89
-rw-r--r--core/src/test/scala/spark/CacheTrackerSuite.scala86
-rw-r--r--core/src/test/scala/spark/MesosSchedulerSuite.scala2
-rw-r--r--core/src/test/scala/spark/UtilsSuite.scala2
-rw-r--r--project/SparkBuild.scala10
86 files changed, 6374 insertions, 1409 deletions
diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
index 7084ff97d9..4c18cb9134 100644
--- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
+++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
@@ -11,6 +11,7 @@ import scala.xml.{XML,NodeSeq}
import scala.collection.mutable.ArrayBuffer
import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream}
+import java.nio.ByteBuffer
object WikipediaPageRankStandalone {
def main(args: Array[String]) {
@@ -118,23 +119,23 @@ class WPRSerializer extends spark.Serializer {
}
class WPRSerializerInstance extends SerializerInstance {
- def serialize[T](t: T): Array[Byte] = {
+ def serialize[T](t: T): ByteBuffer = {
throw new UnsupportedOperationException()
}
- def deserialize[T](bytes: Array[Byte]): T = {
+ def deserialize[T](bytes: ByteBuffer): T = {
throw new UnsupportedOperationException()
}
- def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
+ def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
throw new UnsupportedOperationException()
}
- def outputStream(s: OutputStream): SerializationStream = {
+ def serializeStream(s: OutputStream): SerializationStream = {
new WPRSerializationStream(s)
}
- def inputStream(s: InputStream): DeserializationStream = {
+ def deserializeStream(s: InputStream): DeserializationStream = {
new WPRDeserializationStream(s)
}
}
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
new file mode 100644
index 0000000000..e00a0d80fa
--- /dev/null
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -0,0 +1,70 @@
+package spark
+
+import java.io.EOFException
+import java.net.URL
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+
+import spark.storage.BlockException
+import spark.storage.BlockManagerId
+
+import it.unimi.dsi.fastutil.io.FastBufferedInputStream
+
+
+class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
+ def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
+ logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
+ val ser = SparkEnv.get.serializer.newInstance()
+ val blockManager = SparkEnv.get.blockManager
+
+ val startTime = System.currentTimeMillis
+ val addresses = SparkEnv.get.mapOutputTracker.getServerAddresses(shuffleId)
+ logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
+ shuffleId, reduceId, System.currentTimeMillis - startTime))
+
+ val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[Int]]
+ for ((address, index) <- addresses.zipWithIndex) {
+ splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += index
+ }
+
+ val blocksByAddress: Seq[(BlockManagerId, Seq[String])] = splitsByAddress.toSeq.map {
+ case (address, splits) =>
+ (address, splits.map(i => "shuffleid_%d_%d_%d".format(shuffleId, i, reduceId)))
+ }
+
+ try {
+ val blockOptions = blockManager.get(blocksByAddress)
+ logDebug("Fetching map output blocks for shuffle %d, reduce %d took %d ms".format(
+ shuffleId, reduceId, System.currentTimeMillis - startTime))
+ blockOptions.foreach(x => {
+ val (blockId, blockOption) = x
+ blockOption match {
+ case Some(block) => {
+ val values = block.asInstanceOf[Iterator[Any]]
+ for(value <- values) {
+ val v = value.asInstanceOf[(K, V)]
+ func(v._1, v._2)
+ }
+ }
+ case None => {
+ throw new BlockException(blockId, "Did not get block " + blockId)
+ }
+ }
+ })
+ } catch {
+ case be: BlockException => {
+ val regex = "shuffledid_([0-9]*)_([0-9]*)_([0-9]]*)".r
+ be.blockId match {
+ case regex(sId, mId, rId) => {
+ val address = addresses(mId.toInt)
+ throw new FetchFailedException(address, sId.toInt, mId.toInt, rId.toInt, be)
+ }
+ case _ => {
+ throw be
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala
index 1162e34ab0..fa5dcee7bb 100644
--- a/core/src/main/scala/spark/BoundedMemoryCache.scala
+++ b/core/src/main/scala/spark/BoundedMemoryCache.scala
@@ -90,7 +90,8 @@ class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
- SparkEnv.get.cacheTracker.dropEntry(datasetId, partition)
+ // TODO: remove BoundedMemoryCache
+ SparkEnv.get.cacheTracker.dropEntry(datasetId.asInstanceOf[(Int, Int)]._2, partition)
}
}
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
index 4867829c17..64b4af0ae2 100644
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -1,11 +1,17 @@
package spark
-import scala.actors._
-import scala.actors.Actor._
-import scala.actors.remote._
+import akka.actor._
+import akka.actor.Actor
+import akka.actor.Actor._
+import akka.util.duration._
+
+import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
+import spark.storage.BlockManager
+import spark.storage.StorageLevel
+
sealed trait CacheTrackerMessage
case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
@@ -18,8 +24,8 @@ case object GetCacheStatus extends CacheTrackerMessage
case object GetCacheLocations extends CacheTrackerMessage
case object StopCacheTracker extends CacheTrackerMessage
-
-class CacheTrackerActor extends DaemonActor with Logging {
+class CacheTrackerActor extends Actor with Logging {
+ // TODO: Should probably store (String, CacheType) tuples
private val locs = new HashMap[Int, Array[List[String]]]
/**
@@ -28,109 +34,93 @@ class CacheTrackerActor extends DaemonActor with Logging {
private val slaveCapacity = new HashMap[String, Long]
private val slaveUsage = new HashMap[String, Long]
- // TODO: Should probably store (String, CacheType) tuples
-
private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
- def act() {
- val port = System.getProperty("spark.master.port").toInt
- RemoteActor.alive(port)
- RemoteActor.register('CacheTracker, self)
- logInfo("Registered actor on port " + port)
-
- loop {
- react {
- case SlaveCacheStarted(host: String, size: Long) =>
- logInfo("Started slave cache (size %s) on %s".format(
- Utils.memoryBytesToString(size), host))
- slaveCapacity.put(host, size)
- slaveUsage.put(host, 0)
- reply('OK)
-
- case RegisterRDD(rddId: Int, numPartitions: Int) =>
- logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
- locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
- reply('OK)
-
- case AddedToCache(rddId, partition, host, size) =>
- if (size > 0) {
- slaveUsage.put(host, getCacheUsage(host) + size)
- logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
- rddId, partition, host, Utils.memoryBytesToString(size),
- Utils.memoryBytesToString(getCacheAvailable(host))))
- } else {
- logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host))
- }
- locs(rddId)(partition) = host :: locs(rddId)(partition)
- reply('OK)
-
- case DroppedFromCache(rddId, partition, host, size) =>
- if (size > 0) {
- logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
- rddId, partition, host, Utils.memoryBytesToString(size),
- Utils.memoryBytesToString(getCacheAvailable(host))))
- slaveUsage.put(host, getCacheUsage(host) - size)
-
- // Do a sanity check to make sure usage is greater than 0.
- val usage = getCacheUsage(host)
- if (usage < 0) {
- logError("Cache usage on %s is negative (%d)".format(host, usage))
- }
- } else {
- logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host))
- }
- locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
- reply('OK)
+ def receive = {
+ case SlaveCacheStarted(host: String, size: Long) =>
+ logInfo("Started slave cache (size %s) on %s".format(
+ Utils.memoryBytesToString(size), host))
+ slaveCapacity.put(host, size)
+ slaveUsage.put(host, 0)
+ self.reply(true)
- case MemoryCacheLost(host) =>
- logInfo("Memory cache lost on " + host)
- // TODO: Drop host from the memory locations list of all RDDs
-
- case GetCacheLocations =>
- logInfo("Asked for current cache locations")
- reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())})
-
- case GetCacheStatus =>
- val status = slaveCapacity.map { case (host,capacity) =>
- (host, capacity, getCacheUsage(host))
- }.toSeq
- reply(status)
-
- case StopCacheTracker =>
- reply('OK)
- exit()
+ case RegisterRDD(rddId: Int, numPartitions: Int) =>
+ logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
+ locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
+ self.reply(true)
+
+ case AddedToCache(rddId, partition, host, size) =>
+ slaveUsage.put(host, getCacheUsage(host) + size)
+ logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
+ rddId, partition, host, Utils.memoryBytesToString(size),
+ Utils.memoryBytesToString(getCacheAvailable(host))))
+ locs(rddId)(partition) = host :: locs(rddId)(partition)
+ self.reply(true)
+
+ case DroppedFromCache(rddId, partition, host, size) =>
+ logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
+ rddId, partition, host, Utils.memoryBytesToString(size),
+ Utils.memoryBytesToString(getCacheAvailable(host))))
+ slaveUsage.put(host, getCacheUsage(host) - size)
+ // Do a sanity check to make sure usage is greater than 0.
+ val usage = getCacheUsage(host)
+ if (usage < 0) {
+ logError("Cache usage on %s is negative (%d)".format(host, usage))
}
- }
- }
-}
+ locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
+ self.reply(true)
+ case MemoryCacheLost(host) =>
+ logInfo("Memory cache lost on " + host)
+ for ((id, locations) <- locs) {
+ for (i <- 0 until locations.length) {
+ locations(i) = locations(i).filterNot(_ == host)
+ }
+ }
+ self.reply(true)
-class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
- // Tracker actor on the master, or remote reference to it on workers
- var trackerActor: AbstractActor = null
+ case GetCacheLocations =>
+ logInfo("Asked for current cache locations")
+ self.reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())})
- val registeredRddIds = new HashSet[Int]
+ case GetCacheStatus =>
+ val status = slaveCapacity.map { case (host, capacity) =>
+ (host, capacity, getCacheUsage(host))
+ }.toSeq
+ self.reply(status)
- // Stores map results for various splits locally
- val cache = theCache.newKeySpace()
+ case StopCacheTracker =>
+ logInfo("CacheTrackerActor Server stopped!")
+ self.reply(true)
+ self.exit()
+ }
+}
+class CacheTracker(isMaster: Boolean, blockManager: BlockManager) extends Logging {
+ // Tracker actor on the master, or remote reference to it on workers
+ val ip: String = System.getProperty("spark.master.host", "localhost")
+ val port: Int = System.getProperty("spark.master.port", "7077").toInt
+ val aName: String = "CacheTracker"
+
if (isMaster) {
- val tracker = new CacheTrackerActor
- tracker.start()
- trackerActor = tracker
+ }
+
+ var trackerActor: ActorRef = if (isMaster) {
+ val actor = actorOf(new CacheTrackerActor)
+ remote.register(aName, actor)
+ actor.start()
+ logInfo("Registered CacheTrackerActor actor @ " + ip + ":" + port)
+ actor
} else {
- val host = System.getProperty("spark.master.host")
- val port = System.getProperty("spark.master.port").toInt
- trackerActor = RemoteActor.select(Node(host, port), 'CacheTracker)
+ remote.actorFor(aName, ip, port)
}
- // Report the cache being started.
- trackerActor !? SlaveCacheStarted(Utils.getHost, cache.getCapacity)
+ val registeredRddIds = new HashSet[Int]
// Remembers which splits are currently being loaded (on worker nodes)
- val loading = new HashSet[(Int, Int)]
+ val loading = new HashSet[String]
// Registers an RDD (on master only)
def registerRDD(rddId: Int, numPartitions: Int) {
@@ -138,24 +128,33 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
if (!registeredRddIds.contains(rddId)) {
logInfo("Registering RDD ID " + rddId + " with cache")
registeredRddIds += rddId
- trackerActor !? RegisterRDD(rddId, numPartitions)
+ (trackerActor ? RegisterRDD(rddId, numPartitions)).as[Any] match {
+ case Some(true) =>
+ logInfo("CacheTracker registerRDD " + RegisterRDD(rddId, numPartitions) + " successfully.")
+ case Some(oops) =>
+ logError("CacheTracker registerRDD" + RegisterRDD(rddId, numPartitions) + " failed: " + oops)
+ case None =>
+ logError("CacheTracker registerRDD None. " + RegisterRDD(rddId, numPartitions))
+ throw new SparkException("Internal error: CacheTracker registerRDD None.")
+ }
}
}
}
-
- // Get a snapshot of the currently known locations
- def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
- (trackerActor !? GetCacheLocations) match {
- case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]]
-
- case _ => throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap")
+
+ // For BlockManager.scala only
+ def cacheLost(host: String) {
+ (trackerActor ? MemoryCacheLost(host)).as[Any] match {
+ case Some(true) =>
+ logInfo("CacheTracker successfully removed entries on " + host)
+ case _ =>
+ logError("CacheTracker did not reply to MemoryCacheLost")
}
}
// Get the usage status of slave caches. Each tuple in the returned sequence
// is in the form of (host name, capacity, usage).
def getCacheStatus(): Seq[(String, Long, Long)] = {
- (trackerActor !? GetCacheStatus) match {
+ (trackerActor ? GetCacheStatus) match {
case h: Seq[(String, Long, Long)] => h.asInstanceOf[Seq[(String, Long, Long)]]
case _ =>
@@ -164,75 +163,94 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
}
}
+ // For BlockManager.scala only
+ def notifyTheCacheTrackerFromBlockManager(t: AddedToCache) {
+ (trackerActor ? t).as[Any] match {
+ case Some(true) =>
+ logInfo("CacheTracker notifyTheCacheTrackerFromBlockManager successfully.")
+ case Some(oops) =>
+ logError("CacheTracker notifyTheCacheTrackerFromBlockManager failed: " + oops)
+ case None =>
+ logError("CacheTracker notifyTheCacheTrackerFromBlockManager None.")
+ }
+ }
+
+ // Get a snapshot of the currently known locations
+ def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
+ (trackerActor ? GetCacheLocations).as[Any] match {
+ case Some(h: HashMap[_, _]) =>
+ h.asInstanceOf[HashMap[Int, Array[List[String]]]]
+
+ case _ =>
+ throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap")
+ }
+ }
+
// Gets or computes an RDD split
- def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T]): Iterator[T] = {
- logInfo("Looking for RDD partition %d:%d".format(rdd.id, split.index))
- val cachedVal = cache.get(rdd.id, split.index)
- if (cachedVal != null) {
- // Split is in cache, so just return its values
- logInfo("Found partition in cache!")
- return cachedVal.asInstanceOf[Array[T]].iterator
- } else {
- // Mark the split as loading (unless someone else marks it first)
- val key = (rdd.id, split.index)
- loading.synchronized {
- while (loading.contains(key)) {
- // Someone else is loading it; let's wait for them
- try { loading.wait() } catch { case _ => }
- }
- // See whether someone else has successfully loaded it. The main way this would fail
- // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
- // partition but we didn't want to make space for it. However, that case is unlikely
- // because it's unlikely that two threads would work on the same RDD partition. One
- // downside of the current code is that threads wait serially if this does happen.
- val cachedVal = cache.get(rdd.id, split.index)
- if (cachedVal != null) {
- return cachedVal.asInstanceOf[Array[T]].iterator
- }
- // Nobody's loading it and it's not in the cache; let's load it ourselves
- loading.add(key)
- }
- // If we got here, we have to load the split
- // Tell the master that we're doing so
-
- // TODO: fetch any remote copy of the split that may be available
- logInfo("Computing partition " + split)
- var array: Array[T] = null
- var putResponse: CachePutResponse = null
- try {
- array = rdd.compute(split).toArray(m)
- putResponse = cache.put(rdd.id, split.index, array)
- } finally {
- // Tell other threads that we've finished our attempt to load the key (whether or not
- // we've actually succeeded to put it in the map)
+ def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = {
+ val key = "rdd:%d:%d".format(rdd.id, split.index)
+ logInfo("Cache key is " + key)
+ blockManager.get(key) match {
+ case Some(cachedValues) =>
+ // Split is in cache, so just return its values
+ logInfo("Found partition in cache!")
+ return cachedValues.asInstanceOf[Iterator[T]]
+
+ case None =>
+ // Mark the split as loading (unless someone else marks it first)
loading.synchronized {
- loading.remove(key)
- loading.notifyAll()
+ if (loading.contains(key)) {
+ logInfo("Loading contains " + key + ", waiting...")
+ while (loading.contains(key)) {
+ try {loading.wait()} catch {case _ =>}
+ }
+ logInfo("Loading no longer contains " + key + ", so returning cached result")
+ // See whether someone else has successfully loaded it. The main way this would fail
+ // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
+ // partition but we didn't want to make space for it. However, that case is unlikely
+ // because it's unlikely that two threads would work on the same RDD partition. One
+ // downside of the current code is that threads wait serially if this does happen.
+ blockManager.get(key) match {
+ case Some(values) =>
+ return values.asInstanceOf[Iterator[T]]
+ case None =>
+ logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
+ loading.add(key)
+ }
+ } else {
+ loading.add(key)
+ }
}
- }
-
- putResponse match {
- case CachePutSuccess(size) => {
- // Tell the master that we added the entry. Don't return until it
- // replies so it can properly schedule future tasks that use this RDD.
- trackerActor !? AddedToCache(rdd.id, split.index, Utils.getHost, size)
+ // If we got here, we have to load the split
+ // Tell the master that we're doing so
+ //val host = System.getProperty("spark.hostname", Utils.localHostName)
+ //val future = trackerActor !! AddedToCache(rdd.id, split.index, host)
+ // TODO: fetch any remote copy of the split that may be available
+ // TODO: also register a listener for when it unloads
+ logInfo("Computing partition " + split)
+ try {
+ val values = new ArrayBuffer[Any]
+ values ++= rdd.compute(split)
+ blockManager.put(key, values.iterator, storageLevel, false)
+ //future.apply() // Wait for the reply from the cache tracker
+ return values.iterator.asInstanceOf[Iterator[T]]
+ } finally {
+ loading.synchronized {
+ loading.remove(key)
+ loading.notifyAll()
+ }
}
- case _ => null
- }
- return array.iterator
}
}
// Called by the Cache to report that an entry has been dropped from it
- def dropEntry(datasetId: Any, partition: Int) {
- datasetId match {
- //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here.
- case (cache.keySpaceId, rddId: Int) => trackerActor !! DroppedFromCache(rddId, partition, Utils.getHost)
- }
+ def dropEntry(rddId: Int, partition: Int) {
+ //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here.
+ trackerActor !! DroppedFromCache(rddId, partition, Utils.localHostName())
}
def stop() {
- trackerActor !? StopCacheTracker
+ trackerActor !! StopCacheTracker
registeredRddIds.clear()
trackerActor = null
}
diff --git a/core/src/main/scala/spark/CoGroupedRDD.scala b/core/src/main/scala/spark/CoGroupedRDD.scala
index 93f453bc5e..3543c8afa8 100644
--- a/core/src/main/scala/spark/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/CoGroupedRDD.scala
@@ -22,11 +22,12 @@ class CoGroupAggregator
{ (b1, b2) => b1 ++ b2 })
with Serializable
-class CoGroupedRDD[K](rdds: Seq[RDD[(_, _)]], part: Partitioner)
+class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging {
val aggr = new CoGroupAggregator
+ @transient
override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
for ((rdd, index) <- rdds.zipWithIndex) {
@@ -67,9 +68,10 @@ class CoGroupedRDD[K](rdds: Seq[RDD[(_, _)]], part: Partitioner)
override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit]
+ val numRdds = split.deps.size
val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
- map.getOrElseUpdate(k, Array.fill(rdds.size)(new ArrayBuffer[Any]))
+ map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any]))
}
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, itsSplit) => {
diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala
deleted file mode 100644
index 1b4af9d84c..0000000000
--- a/core/src/main/scala/spark/DAGScheduler.scala
+++ /dev/null
@@ -1,374 +0,0 @@
-package spark
-
-import java.util.concurrent.atomic.AtomicInteger
-
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map}
-
-/**
- * A task created by the DAG scheduler. Knows its stage ID and map ouput tracker generation.
- */
-abstract class DAGTask[T](val runId: Int, val stageId: Int) extends Task[T] {
- val gen = SparkEnv.get.mapOutputTracker.getGeneration
- override def generation: Option[Long] = Some(gen)
-}
-
-/**
- * A completion event passed by the underlying task scheduler to the DAG scheduler.
- */
-case class CompletionEvent(
- task: DAGTask[_],
- reason: TaskEndReason,
- result: Any,
- accumUpdates: Map[Long, Any])
-
-/**
- * Various possible reasons why a DAG task ended. The underlying scheduler is supposed to retry
- * tasks several times for "ephemeral" failures, and only report back failures that require some
- * old stages to be resubmitted, such as shuffle map fetch failures.
- */
-sealed trait TaskEndReason
-case object Success extends TaskEndReason
-case class FetchFailed(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
-case class ExceptionFailure(exception: Throwable) extends TaskEndReason
-case class OtherFailure(message: String) extends TaskEndReason
-
-/**
- * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
- * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
- * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
- * and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
- */
-private trait DAGScheduler extends Scheduler with Logging {
- // Must be implemented by subclasses to start running a set of tasks. The subclass should also
- // attempt to run different sets of tasks in the order given by runId (lower values first).
- def submitTasks(tasks: Seq[Task[_]], runId: Int): Unit
-
- // Must be called by subclasses to report task completions or failures.
- def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]) {
- lock.synchronized {
- val dagTask = task.asInstanceOf[DAGTask[_]]
- eventQueues.get(dagTask.runId) match {
- case Some(queue) =>
- queue += CompletionEvent(dagTask, reason, result, accumUpdates)
- lock.notifyAll()
- case None =>
- logInfo("Ignoring completion event for DAG job " + dagTask.runId + " because it's gone")
- }
- }
- }
-
- // The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
- // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
- // as more failure events come in
- val RESUBMIT_TIMEOUT = 2000L
-
- // The time, in millis, to wake up between polls of the completion queue in order to potentially
- // resubmit failed stages
- val POLL_TIMEOUT = 500L
-
- private val lock = new Object // Used for access to the entire DAGScheduler
-
- private val eventQueues = new HashMap[Int, Queue[CompletionEvent]] // Indexed by run ID
-
- val nextRunId = new AtomicInteger(0)
-
- val nextStageId = new AtomicInteger(0)
-
- val idToStage = new HashMap[Int, Stage]
-
- val shuffleToMapStage = new HashMap[Int, Stage]
-
- var cacheLocs = new HashMap[Int, Array[List[String]]]
-
- val env = SparkEnv.get
- val cacheTracker = env.cacheTracker
- val mapOutputTracker = env.mapOutputTracker
-
- def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
- cacheLocs(rdd.id)
- }
-
- def updateCacheLocs() {
- cacheLocs = cacheTracker.getLocationsSnapshot()
- }
-
- def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = {
- shuffleToMapStage.get(shuf.shuffleId) match {
- case Some(stage) => stage
- case None =>
- val stage = newStage(shuf.rdd, Some(shuf))
- shuffleToMapStage(shuf.shuffleId) = stage
- stage
- }
- }
-
- def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = {
- // Kind of ugly: need to register RDDs with the cache and map output tracker here
- // since we can't do it in the RDD constructor because # of splits is unknown
- cacheTracker.registerRDD(rdd.id, rdd.splits.size)
- if (shuffleDep != None) {
- mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
- }
- val id = nextStageId.getAndIncrement()
- val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd))
- idToStage(id) = stage
- stage
- }
-
- def getParentStages(rdd: RDD[_]): List[Stage] = {
- val parents = new HashSet[Stage]
- val visited = new HashSet[RDD[_]]
- def visit(r: RDD[_]) {
- if (!visited(r)) {
- visited += r
- // Kind of ugly: need to register RDDs with the cache here since
- // we can't do it in its constructor because # of splits is unknown
- cacheTracker.registerRDD(r.id, r.splits.size)
- for (dep <- r.dependencies) {
- dep match {
- case shufDep: ShuffleDependency[_,_,_] =>
- parents += getShuffleMapStage(shufDep)
- case _ =>
- visit(dep.rdd)
- }
- }
- }
- }
- visit(rdd)
- parents.toList
- }
-
- def getMissingParentStages(stage: Stage): List[Stage] = {
- val missing = new HashSet[Stage]
- val visited = new HashSet[RDD[_]]
- def visit(rdd: RDD[_]) {
- if (!visited(rdd)) {
- visited += rdd
- val locs = getCacheLocs(rdd)
- for (p <- 0 until rdd.splits.size) {
- if (locs(p) == Nil) {
- for (dep <- rdd.dependencies) {
- dep match {
- case shufDep: ShuffleDependency[_,_,_] =>
- val stage = getShuffleMapStage(shufDep)
- if (!stage.isAvailable) {
- missing += stage
- }
- case narrowDep: NarrowDependency[_] =>
- visit(narrowDep.rdd)
- }
- }
- }
- }
- }
- }
- visit(stage.rdd)
- missing.toList
- }
-
- override def runJob[T, U](
- finalRdd: RDD[T],
- func: (TaskContext, Iterator[T]) => U,
- partitions: Seq[Int],
- allowLocal: Boolean)
- (implicit m: ClassManifest[U]): Array[U] = {
- lock.synchronized {
- val runId = nextRunId.getAndIncrement()
-
- val outputParts = partitions.toArray
- val numOutputParts: Int = partitions.size
- val finalStage = newStage(finalRdd, None)
- val results = new Array[U](numOutputParts)
- val finished = new Array[Boolean](numOutputParts)
- var numFinished = 0
-
- val waiting = new HashSet[Stage] // stages we need to run whose parents aren't done
- val running = new HashSet[Stage] // stages we are running right now
- val failed = new HashSet[Stage] // stages that must be resubmitted due to fetch failures
- val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // missing tasks from each stage
- var lastFetchFailureTime: Long = 0 // used to wait a bit to avoid repeated resubmits
-
- SparkEnv.set(env)
-
- updateCacheLocs()
-
- logInfo("Final stage: " + finalStage)
- logInfo("Parents of final stage: " + finalStage.parents)
- logInfo("Missing parents: " + getMissingParentStages(finalStage))
-
- // Optimization for short actions like first() and take() that can be computed locally
- // without shipping tasks to the cluster.
- if (allowLocal && finalStage.parents.size == 0 && numOutputParts == 1) {
- logInfo("Computing the requested partition locally")
- val split = finalRdd.splits(outputParts(0))
- val taskContext = new TaskContext(finalStage.id, outputParts(0), 0)
- return Array(func(taskContext, finalRdd.iterator(split)))
- }
-
- // Register the job ID so that we can get completion events for it
- eventQueues(runId) = new Queue[CompletionEvent]
-
- def submitStage(stage: Stage) {
- if (!waiting(stage) && !running(stage)) {
- val missing = getMissingParentStages(stage)
- if (missing == Nil) {
- logInfo("Submitting " + stage + ", which has no missing parents")
- submitMissingTasks(stage)
- running += stage
- } else {
- for (parent <- missing) {
- submitStage(parent)
- }
- waiting += stage
- }
- }
- }
-
- def submitMissingTasks(stage: Stage) {
- // Get our pending tasks and remember them in our pendingTasks entry
- val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
- var tasks = ArrayBuffer[Task[_]]()
- if (stage == finalStage) {
- for (id <- 0 until numOutputParts if (!finished(id))) {
- val part = outputParts(id)
- val locs = getPreferredLocs(finalRdd, part)
- tasks += new ResultTask(runId, finalStage.id, finalRdd, func, part, locs, id)
- }
- } else {
- for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
- val locs = getPreferredLocs(stage.rdd, p)
- tasks += new ShuffleMapTask(runId, stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
- }
- }
- myPending ++= tasks
- submitTasks(tasks, runId)
- }
-
- submitStage(finalStage)
-
- while (numFinished != numOutputParts) {
- val eventOption = waitForEvent(runId, POLL_TIMEOUT)
- val time = System.currentTimeMillis // TODO: use a pluggable clock for testability
-
- // If we got an event off the queue, mark the task done or react to a fetch failure
- if (eventOption != None) {
- val evt = eventOption.get
- val stage = idToStage(evt.task.stageId)
- pendingTasks(stage) -= evt.task
- if (evt.reason == Success) {
- // A task ended
- logInfo("Completed " + evt.task)
- Accumulators.add(evt.accumUpdates)
- evt.task match {
- case rt: ResultTask[_, _] =>
- results(rt.outputId) = evt.result.asInstanceOf[U]
- finished(rt.outputId) = true
- numFinished += 1
- case smt: ShuffleMapTask =>
- val stage = idToStage(smt.stageId)
- stage.addOutputLoc(smt.partition, evt.result.asInstanceOf[String])
- if (running.contains(stage) && pendingTasks(stage).isEmpty) {
- logInfo(stage + " finished; looking for newly runnable stages")
- running -= stage
- if (stage.shuffleDep != None) {
- mapOutputTracker.registerMapOutputs(
- stage.shuffleDep.get.shuffleId,
- stage.outputLocs.map(_.head).toArray)
- }
- updateCacheLocs()
- val newlyRunnable = new ArrayBuffer[Stage]
- for (stage <- waiting if getMissingParentStages(stage) == Nil) {
- newlyRunnable += stage
- }
- waiting --= newlyRunnable
- running ++= newlyRunnable
- for (stage <- newlyRunnable) {
- submitMissingTasks(stage)
- }
- }
- }
- } else {
- evt.reason match {
- case FetchFailed(serverUri, shuffleId, mapId, reduceId) =>
- // Mark the stage that the reducer was in as unrunnable
- val failedStage = idToStage(evt.task.stageId)
- running -= failedStage
- failed += failedStage
- // TODO: Cancel running tasks in the stage
- logInfo("Marking " + failedStage + " for resubmision due to a fetch failure")
- // Mark the map whose fetch failed as broken in the map stage
- val mapStage = shuffleToMapStage(shuffleId)
- mapStage.removeOutputLoc(mapId, serverUri)
- mapOutputTracker.unregisterMapOutput(shuffleId, mapId, serverUri)
- logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission")
- failed += mapStage
- // Remember that a fetch failed now; this is used to resubmit the broken
- // stages later, after a small wait (to give other tasks the chance to fail)
- lastFetchFailureTime = time
- // TODO: If there are a lot of fetch failures on the same node, maybe mark all
- // outputs on the node as dead.
- case _ =>
- // Non-fetch failure -- probably a bug in the job, so bail out
- throw new SparkException("Task failed: " + evt.task + ", reason: " + evt.reason)
- // TODO: Cancel all tasks that are still running
- }
- }
- } // end if (evt != null)
-
- // If fetches have failed recently and we've waited for the right timeout,
- // resubmit all the failed stages
- if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
- logInfo("Resubmitting failed stages")
- updateCacheLocs()
- for (stage <- failed) {
- submitStage(stage)
- }
- failed.clear()
- }
- }
-
- eventQueues -= runId
- return results
- }
- }
-
- def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
- // If the partition is cached, return the cache locations
- val cached = getCacheLocs(rdd)(partition)
- if (cached != Nil) {
- return cached
- }
- // If the RDD has some placement preferences (as is the case for input RDDs), get those
- val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList
- if (rddPrefs != Nil) {
- return rddPrefs
- }
- // If the RDD has narrow dependencies, pick the first partition of the first narrow dep
- // that has any placement preferences. Ideally we would choose based on transfer sizes,
- // but this will do for now.
- rdd.dependencies.foreach(_ match {
- case n: NarrowDependency[_] =>
- for (inPart <- n.getParents(partition)) {
- val locs = getPreferredLocs(n.rdd, inPart)
- if (locs != Nil)
- return locs;
- }
- case _ =>
- })
- return Nil
- }
-
- // Assumes that lock is held on entrance, but will release it to wait for the next event.
- def waitForEvent(runId: Int, timeout: Long): Option[CompletionEvent] = {
- val endTime = System.currentTimeMillis() + timeout // TODO: Use pluggable clock for testing
- while (eventQueues(runId).isEmpty) {
- val time = System.currentTimeMillis()
- if (time >= endTime) {
- return None
- } else {
- lock.wait(endTime - time)
- }
- }
- return Some(eventQueues(runId).dequeue())
- }
-}
diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala
index d93c84924a..c0ff94acc6 100644
--- a/core/src/main/scala/spark/Dependency.scala
+++ b/core/src/main/scala/spark/Dependency.scala
@@ -8,7 +8,7 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd, false) {
class ShuffleDependency[K, V, C](
val shuffleId: Int,
- rdd: RDD[(K, V)],
+ @transient rdd: RDD[(K, V)],
val aggregator: Aggregator[K, V, C],
val partitioner: Partitioner)
extends Dependency(rdd, true)
diff --git a/core/src/main/scala/spark/DiskSpillingCache.scala b/core/src/main/scala/spark/DiskSpillingCache.scala
deleted file mode 100644
index e11466eb64..0000000000
--- a/core/src/main/scala/spark/DiskSpillingCache.scala
+++ /dev/null
@@ -1,75 +0,0 @@
-package spark
-
-import java.io.File
-import java.io.{FileOutputStream,FileInputStream}
-import java.io.IOException
-import java.util.LinkedHashMap
-import java.util.UUID
-
-// TODO: cache into a separate directory using Utils.createTempDir
-// TODO: clean up disk cache afterwards
-class DiskSpillingCache extends BoundedMemoryCache {
- private val diskMap = new LinkedHashMap[(Any, Int), File](32, 0.75f, true)
-
- override def get(datasetId: Any, partition: Int): Any = {
- synchronized {
- val ser = SparkEnv.get.serializer.newInstance()
- super.get(datasetId, partition) match {
- case bytes: Any => // found in memory
- ser.deserialize(bytes.asInstanceOf[Array[Byte]])
-
- case _ => diskMap.get((datasetId, partition)) match {
- case file: Any => // found on disk
- try {
- val startTime = System.currentTimeMillis
- val bytes = new Array[Byte](file.length.toInt)
- new FileInputStream(file).read(bytes)
- val timeTaken = System.currentTimeMillis - startTime
- logInfo("Reading key (%s, %d) of size %d bytes from disk took %d ms".format(
- datasetId, partition, file.length, timeTaken))
- super.put(datasetId, partition, bytes)
- ser.deserialize(bytes.asInstanceOf[Array[Byte]])
- } catch {
- case e: IOException =>
- logWarning("Failed to read key (%s, %d) from disk at %s: %s".format(
- datasetId, partition, file.getPath(), e.getMessage()))
- diskMap.remove((datasetId, partition)) // remove dead entry
- null
- }
-
- case _ => // not found
- null
- }
- }
- }
- }
-
- override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
- var ser = SparkEnv.get.serializer.newInstance()
- super.put(datasetId, partition, ser.serialize(value))
- }
-
- /**
- * Spill the given entry to disk. Assumes that a lock is held on the
- * DiskSpillingCache. Assumes that entry.value is a byte array.
- */
- override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
- logInfo("Spilling key (%s, %d) of size %d to make space".format(
- datasetId, partition, entry.size))
- val cacheDir = System.getProperty(
- "spark.diskSpillingCache.cacheDir",
- System.getProperty("java.io.tmpdir"))
- val file = new File(cacheDir, "spark-dsc-" + UUID.randomUUID.toString)
- try {
- val stream = new FileOutputStream(file)
- stream.write(entry.value.asInstanceOf[Array[Byte]])
- stream.close()
- diskMap.put((datasetId, partition), file)
- } catch {
- case e: IOException =>
- logWarning("Failed to spill key (%s, %d) to disk at %s: %s".format(
- datasetId, partition, file.getPath(), e.getMessage()))
- // Do nothing and let the entry be discarded
- }
- }
-}
diff --git a/core/src/main/scala/spark/DoubleRDDFunctions.scala b/core/src/main/scala/spark/DoubleRDDFunctions.scala
new file mode 100644
index 0000000000..1fbf66b7de
--- /dev/null
+++ b/core/src/main/scala/spark/DoubleRDDFunctions.scala
@@ -0,0 +1,39 @@
+package spark
+
+import spark.partial.BoundedDouble
+import spark.partial.MeanEvaluator
+import spark.partial.PartialResult
+import spark.partial.SumEvaluator
+
+import spark.util.StatCounter
+
+/**
+ * Extra functions available on RDDs of Doubles through an implicit conversion.
+ */
+class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
+ def sum(): Double = {
+ self.reduce(_ + _)
+ }
+
+ def stats(): StatCounter = {
+ self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b))
+ }
+
+ def mean(): Double = stats().mean
+
+ def variance(): Double = stats().variance
+
+ def stdev(): Double = stats().stdev
+
+ def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
+ val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
+ val evaluator = new MeanEvaluator(self.splits.size, confidence)
+ self.context.runApproximateJob(self, processPartition, evaluator, timeout)
+ }
+
+ def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
+ val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
+ val evaluator = new SumEvaluator(self.splits.size, confidence)
+ self.context.runApproximateJob(self, processPartition, evaluator, timeout)
+ }
+}
diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala
index c795b6c351..af9eb9c878 100644
--- a/core/src/main/scala/spark/Executor.scala
+++ b/core/src/main/scala/spark/Executor.scala
@@ -10,9 +10,10 @@ import scala.collection.mutable.ArrayBuffer
import com.google.protobuf.ByteString
import org.apache.mesos._
-import org.apache.mesos.Protos._
+import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
import spark.broadcast._
+import spark.scheduler._
/**
* The Mesos executor for Spark.
@@ -29,6 +30,9 @@ class Executor extends org.apache.mesos.Executor with Logging {
executorInfo: ExecutorInfo,
frameworkInfo: FrameworkInfo,
slaveInfo: SlaveInfo) {
+ // Make sure the local hostname we report matches Mesos's name for this host
+ Utils.setCustomHostname(slaveInfo.getHostname())
+
// Read spark.* system properties from executor arg
val props = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
for ((key, value) <- props) {
@@ -39,7 +43,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
RemoteActor.classLoader = getClass.getClassLoader
// Initialize Spark environment (using system properties read above)
- env = SparkEnv.createFromSystemProperties(false)
+ env = SparkEnv.createFromSystemProperties(false, false)
SparkEnv.set(env)
// Old stuff that isn't yet using env
Broadcast.initialize(false)
@@ -57,11 +61,11 @@ class Executor extends org.apache.mesos.Executor with Logging {
override def reregistered(d: ExecutorDriver, s: SlaveInfo) {}
- override def launchTask(d: ExecutorDriver, task: TaskInfo) {
+ override def launchTask(d: ExecutorDriver, task: MTaskInfo) {
threadPool.execute(new TaskRunner(task, d))
}
- class TaskRunner(info: TaskInfo, d: ExecutorDriver)
+ class TaskRunner(info: MTaskInfo, d: ExecutorDriver)
extends Runnable {
override def run() = {
val tid = info.getTaskId.getValue
@@ -74,11 +78,11 @@ class Executor extends org.apache.mesos.Executor with Logging {
.setState(TaskState.TASK_RUNNING)
.build())
try {
+ SparkEnv.set(env)
+ Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear
- val task = ser.deserialize[Task[Any]](info.getData.toByteArray, classLoader)
- for (gen <- task.generation) {// Update generation if any is set
- env.mapOutputTracker.updateGeneration(gen)
- }
+ val task = ser.deserialize[Task[Any]](info.getData.asReadOnlyByteBuffer, classLoader)
+ env.mapOutputTracker.updateGeneration(task.generation)
val value = task.run(tid.toInt)
val accumUpdates = Accumulators.values
val result = new TaskResult(value, accumUpdates)
@@ -105,9 +109,11 @@ class Executor extends org.apache.mesos.Executor with Logging {
.setData(ByteString.copyFrom(ser.serialize(reason)))
.build())
- // TODO: Handle errors in tasks less dramatically
+ // TODO: Should we exit the whole executor here? On the one hand, the failed task may
+ // have left some weird state around depending on when the exception was thrown, but on
+ // the other hand, maybe we could detect that when future tasks fail and exit then.
logError("Exception in task ID " + tid, t)
- System.exit(1)
+ //System.exit(1)
}
}
}
diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala
index a3c4e7873d..55512f4481 100644
--- a/core/src/main/scala/spark/FetchFailedException.scala
+++ b/core/src/main/scala/spark/FetchFailedException.scala
@@ -1,7 +1,9 @@
package spark
+import spark.storage.BlockManagerId
+
class FetchFailedException(
- val serverUri: String,
+ val bmAddress: BlockManagerId,
val shuffleId: Int,
val mapId: Int,
val reduceId: Int,
@@ -9,10 +11,10 @@ class FetchFailedException(
extends Exception {
override def getMessage(): String =
- "Fetch failed: %s %d %d %d".format(serverUri, shuffleId, mapId, reduceId)
+ "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
override def getCause(): Throwable = cause
def toTaskEndReason: TaskEndReason =
- FetchFailed(serverUri, shuffleId, mapId, reduceId)
+ FetchFailed(bmAddress, shuffleId, mapId, reduceId)
}
diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala
index 80f615eeb0..ec5c33d1df 100644
--- a/core/src/main/scala/spark/JavaSerializer.scala
+++ b/core/src/main/scala/spark/JavaSerializer.scala
@@ -1,6 +1,7 @@
package spark
import java.io._
+import java.nio.ByteBuffer
class JavaSerializationStream(out: OutputStream) extends SerializationStream {
val objOut = new ObjectOutputStream(out)
@@ -9,10 +10,11 @@ class JavaSerializationStream(out: OutputStream) extends SerializationStream {
def close() { objOut.close() }
}
-class JavaDeserializationStream(in: InputStream) extends DeserializationStream {
+class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
+extends DeserializationStream {
val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass) =
- Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
+ Class.forName(desc.getName, false, loader)
}
def readObject[T](): T = objIn.readObject().asInstanceOf[T]
@@ -20,35 +22,36 @@ class JavaDeserializationStream(in: InputStream) extends DeserializationStream {
}
class JavaSerializerInstance extends SerializerInstance {
- def serialize[T](t: T): Array[Byte] = {
+ def serialize[T](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
- val out = outputStream(bos)
+ val out = serializeStream(bos)
out.writeObject(t)
out.close()
- bos.toByteArray
+ ByteBuffer.wrap(bos.toByteArray)
}
- def deserialize[T](bytes: Array[Byte]): T = {
- val bis = new ByteArrayInputStream(bytes)
- val in = inputStream(bis)
+ def deserialize[T](bytes: ByteBuffer): T = {
+ val bis = new ByteArrayInputStream(bytes.array())
+ val in = deserializeStream(bis)
in.readObject().asInstanceOf[T]
}
- def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
- val bis = new ByteArrayInputStream(bytes)
- val ois = new ObjectInputStream(bis) {
- override def resolveClass(desc: ObjectStreamClass) =
- Class.forName(desc.getName, false, loader)
- }
- return ois.readObject.asInstanceOf[T]
+ def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
+ val bis = new ByteArrayInputStream(bytes.array())
+ val in = deserializeStream(bis, loader)
+ in.readObject().asInstanceOf[T]
}
- def outputStream(s: OutputStream): SerializationStream = {
+ def serializeStream(s: OutputStream): SerializationStream = {
new JavaSerializationStream(s)
}
- def inputStream(s: InputStream): DeserializationStream = {
- new JavaDeserializationStream(s)
+ def deserializeStream(s: InputStream): DeserializationStream = {
+ new JavaDeserializationStream(s, currentThread.getContextClassLoader)
+ }
+
+ def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
+ new JavaDeserializationStream(s, loader)
}
}
diff --git a/core/src/main/scala/spark/Job.scala b/core/src/main/scala/spark/Job.scala
deleted file mode 100644
index b7b0361c62..0000000000
--- a/core/src/main/scala/spark/Job.scala
+++ /dev/null
@@ -1,16 +0,0 @@
-package spark
-
-import org.apache.mesos._
-import org.apache.mesos.Protos._
-
-/**
- * Class representing a parallel job in MesosScheduler. Schedules the job by implementing various
- * callbacks.
- */
-abstract class Job(val runId: Int, val jobId: Int) {
- def slaveOffer(s: Offer, availableCpus: Double): Option[TaskInfo]
-
- def statusUpdate(t: TaskStatus): Unit
-
- def error(message: String): Unit
-}
diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala
index 5693613d6d..65d0532bd5 100644
--- a/core/src/main/scala/spark/KryoSerializer.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -12,6 +12,8 @@ import com.esotericsoftware.kryo.{Serializer => KSerializer}
import com.esotericsoftware.kryo.serialize.ClassSerializer
import de.javakaffee.kryoserializers.KryoReflectionFactorySupport
+import spark.storage._
+
/**
* Zig-zag encoder used to write object sizes to serialization streams.
* Based on Kryo's integer encoder.
@@ -64,57 +66,90 @@ object ZigZag {
}
}
-class KryoSerializationStream(kryo: Kryo, buf: ByteBuffer, out: OutputStream)
+class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream)
extends SerializationStream {
val channel = Channels.newChannel(out)
def writeObject[T](t: T) {
- kryo.writeClassAndObject(buf, t)
- ZigZag.writeInt(buf.position(), out)
- buf.flip()
- channel.write(buf)
- buf.clear()
+ kryo.writeClassAndObject(threadBuffer, t)
+ ZigZag.writeInt(threadBuffer.position(), out)
+ threadBuffer.flip()
+ channel.write(threadBuffer)
+ threadBuffer.clear()
}
def flush() { out.flush() }
def close() { out.close() }
}
-class KryoDeserializationStream(buf: ObjectBuffer, in: InputStream)
+class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream)
extends DeserializationStream {
def readObject[T](): T = {
val len = ZigZag.readInt(in)
- buf.readClassAndObject(in, len).asInstanceOf[T]
+ objectBuffer.readClassAndObject(in, len).asInstanceOf[T]
}
def close() { in.close() }
}
class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
- val buf = ks.threadBuf.get()
+ val kryo = ks.kryo
+ val threadBuffer = ks.threadBuffer.get()
+ val objectBuffer = ks.objectBuffer.get()
- def serialize[T](t: T): Array[Byte] = {
- buf.writeClassAndObject(t)
+ def serialize[T](t: T): ByteBuffer = {
+ // Write it to our thread-local scratch buffer first to figure out the size, then return a new
+ // ByteBuffer of the appropriate size
+ threadBuffer.clear()
+ kryo.writeClassAndObject(threadBuffer, t)
+ val newBuf = ByteBuffer.allocate(threadBuffer.position)
+ threadBuffer.flip()
+ newBuf.put(threadBuffer)
+ newBuf.flip()
+ newBuf
}
- def deserialize[T](bytes: Array[Byte]): T = {
- buf.readClassAndObject(bytes).asInstanceOf[T]
+ def deserialize[T](bytes: ByteBuffer): T = {
+ kryo.readClassAndObject(bytes).asInstanceOf[T]
}
- def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
- val oldClassLoader = ks.kryo.getClassLoader
- ks.kryo.setClassLoader(loader)
- val obj = buf.readClassAndObject(bytes).asInstanceOf[T]
- ks.kryo.setClassLoader(oldClassLoader)
+ def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
+ val oldClassLoader = kryo.getClassLoader
+ kryo.setClassLoader(loader)
+ val obj = kryo.readClassAndObject(bytes).asInstanceOf[T]
+ kryo.setClassLoader(oldClassLoader)
obj
}
- def outputStream(s: OutputStream): SerializationStream = {
- new KryoSerializationStream(ks.kryo, ks.threadByteBuf.get(), s)
+ def serializeStream(s: OutputStream): SerializationStream = {
+ threadBuffer.clear()
+ new KryoSerializationStream(kryo, threadBuffer, s)
+ }
+
+ def deserializeStream(s: InputStream): DeserializationStream = {
+ new KryoDeserializationStream(objectBuffer, s)
}
- def inputStream(s: InputStream): DeserializationStream = {
- new KryoDeserializationStream(buf, s)
+ override def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
+ threadBuffer.clear()
+ while (iterator.hasNext) {
+ val element = iterator.next()
+ // TODO: Do we also want to write the object's size? Doesn't seem necessary.
+ kryo.writeClassAndObject(threadBuffer, element)
+ }
+ val newBuf = ByteBuffer.allocate(threadBuffer.position)
+ threadBuffer.flip()
+ newBuf.put(threadBuffer)
+ newBuf.flip()
+ newBuf
+ }
+
+ override def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
+ buffer.rewind()
+ new Iterator[Any] {
+ override def hasNext: Boolean = buffer.remaining > 0
+ override def next(): Any = kryo.readClassAndObject(buffer)
+ }
}
}
@@ -126,20 +161,17 @@ trait KryoRegistrator {
class KryoSerializer extends Serializer with Logging {
val kryo = createKryo()
- val bufferSize =
- System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
+ val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024
- val threadBuf = new ThreadLocal[ObjectBuffer] {
+ val objectBuffer = new ThreadLocal[ObjectBuffer] {
override def initialValue = new ObjectBuffer(kryo, bufferSize)
}
- val threadByteBuf = new ThreadLocal[ByteBuffer] {
+ val threadBuffer = new ThreadLocal[ByteBuffer] {
override def initialValue = ByteBuffer.allocate(bufferSize)
}
def createKryo(): Kryo = {
- // This is used so we can serialize/deserialize objects without a zero-arg
- // constructor.
val kryo = new KryoReflectionFactorySupport()
// Register some commonly used classes
@@ -148,14 +180,20 @@ class KryoSerializer extends Serializer with Logging {
Array(1), Array(1.0), Array(1.0f), Array(1L), Array(""), Array(("", "")),
Array(new java.lang.Object), Array(1.toByte), Array(true), Array('c'),
// Specialized Tuple2s
- ("", ""), (1, 1), (1.0, 1.0), (1L, 1L),
+ ("", ""), ("", 1), (1, 1), (1.0, 1.0), (1L, 1L),
(1, 1.0), (1.0, 1), (1L, 1.0), (1.0, 1L), (1, 1L), (1L, 1),
// Scala collections
List(1), mutable.ArrayBuffer(1),
// Options and Either
Some(1), Left(1), Right(1),
// Higher-dimensional tuples
- (1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1)
+ (1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1),
+ None,
+ ByteBuffer.allocate(1),
+ StorageLevel.MEMORY_ONLY_DESER,
+ PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY_DESER),
+ GotBlock("1", ByteBuffer.allocate(1)),
+ GetBlock("1")
)
for (obj <- toRegister) {
kryo.register(obj.getClass)
diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala
index 0d11ab9cbd..54bd57f6d3 100644
--- a/core/src/main/scala/spark/Logging.scala
+++ b/core/src/main/scala/spark/Logging.scala
@@ -28,9 +28,11 @@ trait Logging {
}
// Log methods that take only a String
- def logInfo(msg: => String) = if (log.isInfoEnabled) log.info(msg)
+ def logInfo(msg: => String) = if (log.isInfoEnabled /*&& msg.contains("job finished in")*/) log.info(msg)
def logDebug(msg: => String) = if (log.isDebugEnabled) log.debug(msg)
+
+ def logTrace(msg: => String) = if (log.isTraceEnabled) log.trace(msg)
def logWarning(msg: => String) = if (log.isWarnEnabled) log.warn(msg)
@@ -43,6 +45,9 @@ trait Logging {
def logDebug(msg: => String, throwable: Throwable) =
if (log.isDebugEnabled) log.debug(msg)
+ def logTrace(msg: => String, throwable: Throwable) =
+ if (log.isTraceEnabled) log.trace(msg)
+
def logWarning(msg: => String, throwable: Throwable) =
if (log.isWarnEnabled) log.warn(msg, throwable)
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index a934c5a02f..d938a6eb62 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -2,80 +2,80 @@ package spark
import java.util.concurrent.ConcurrentHashMap
-import scala.actors._
-import scala.actors.Actor._
-import scala.actors.remote._
+import akka.actor._
+import akka.actor.Actor
+import akka.actor.Actor._
+import akka.util.duration._
+
import scala.collection.mutable.HashSet
+import spark.storage.BlockManagerId
+
sealed trait MapOutputTrackerMessage
case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage
case object StopMapOutputTracker extends MapOutputTrackerMessage
-class MapOutputTrackerActor(serverUris: ConcurrentHashMap[Int, Array[String]])
-extends DaemonActor with Logging {
- def act() {
- val port = System.getProperty("spark.master.port").toInt
- RemoteActor.alive(port)
- RemoteActor.register('MapOutputTracker, self)
- logInfo("Registered actor on port " + port)
-
- loop {
- react {
- case GetMapOutputLocations(shuffleId: Int) =>
- logInfo("Asked to get map output locations for shuffle " + shuffleId)
- reply(serverUris.get(shuffleId))
-
- case StopMapOutputTracker =>
- reply('OK)
- exit()
- }
- }
+class MapOutputTrackerActor(bmAddresses: ConcurrentHashMap[Int, Array[BlockManagerId]])
+extends Actor with Logging {
+ def receive = {
+ case GetMapOutputLocations(shuffleId: Int) =>
+ logInfo("Asked to get map output locations for shuffle " + shuffleId)
+ self.reply(bmAddresses.get(shuffleId))
+
+ case StopMapOutputTracker =>
+ logInfo("MapOutputTrackerActor stopped!")
+ self.reply(true)
+ self.exit()
}
}
class MapOutputTracker(isMaster: Boolean) extends Logging {
- var trackerActor: AbstractActor = null
+ val ip: String = System.getProperty("spark.master.host", "localhost")
+ val port: Int = System.getProperty("spark.master.port", "7077").toInt
+ val aName: String = "MapOutputTracker"
- private var serverUris = new ConcurrentHashMap[Int, Array[String]]
+ private var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
private var generation: Long = 0
private var generationLock = new java.lang.Object
-
- if (isMaster) {
- val tracker = new MapOutputTrackerActor(serverUris)
- tracker.start()
- trackerActor = tracker
+
+ var trackerActor: ActorRef = if (isMaster) {
+ val actor = actorOf(new MapOutputTrackerActor(bmAddresses))
+ remote.register(aName, actor)
+ logInfo("Registered MapOutputTrackerActor actor @ " + ip + ":" + port)
+ actor
} else {
- val host = System.getProperty("spark.master.host")
- val port = System.getProperty("spark.master.port").toInt
- trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker)
+ remote.actorFor(aName, ip, port)
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (serverUris.get(shuffleId) != null) {
+ if (bmAddresses.get(shuffleId) != null) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
- serverUris.put(shuffleId, new Array[String](numMaps))
+ bmAddresses.put(shuffleId, new Array[BlockManagerId](numMaps))
}
- def registerMapOutput(shuffleId: Int, mapId: Int, serverUri: String) {
- var array = serverUris.get(shuffleId)
+ def registerMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
+ var array = bmAddresses.get(shuffleId)
array.synchronized {
- array(mapId) = serverUri
+ array(mapId) = bmAddress
}
}
- def registerMapOutputs(shuffleId: Int, locs: Array[String]) {
- serverUris.put(shuffleId, Array[String]() ++ locs)
+ def registerMapOutputs(shuffleId: Int, locs: Array[BlockManagerId], changeGeneration: Boolean = false) {
+ bmAddresses.put(shuffleId, Array[BlockManagerId]() ++ locs)
+ if (changeGeneration) {
+ incrementGeneration()
+ }
}
- def unregisterMapOutput(shuffleId: Int, mapId: Int, serverUri: String) {
- var array = serverUris.get(shuffleId)
+ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
+ var array = bmAddresses.get(shuffleId)
if (array != null) {
array.synchronized {
- if (array(mapId) == serverUri) {
+ if (array(mapId) == bmAddress) {
array(mapId) = null
}
}
@@ -89,10 +89,10 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
val fetching = new HashSet[Int]
// Called on possibly remote nodes to get the server URIs for a given shuffle
- def getServerUris(shuffleId: Int): Array[String] = {
- val locs = serverUris.get(shuffleId)
+ def getServerAddresses(shuffleId: Int): Array[BlockManagerId] = {
+ val locs = bmAddresses.get(shuffleId)
if (locs == null) {
- logInfo("Don't have map outputs for " + shuffleId + ", fetching them")
+ logInfo("Don't have map outputs for shuffe " + shuffleId + ", fetching them")
fetching.synchronized {
if (fetching.contains(shuffleId)) {
// Someone else is fetching it; wait for them to be done
@@ -103,15 +103,17 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
case _ =>
}
}
- return serverUris.get(shuffleId)
+ return bmAddresses.get(shuffleId)
} else {
fetching += shuffleId
}
}
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
- val fetched = (trackerActor !? GetMapOutputLocations(shuffleId)).asInstanceOf[Array[String]]
- serverUris.put(shuffleId, fetched)
+ val fetched = (trackerActor ? GetMapOutputLocations(shuffleId)).as[Array[BlockManagerId]].get
+
+ logInfo("Got the output locations")
+ bmAddresses.put(shuffleId, fetched)
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
@@ -121,14 +123,10 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
return locs
}
}
-
- def getMapOutputUri(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int): String = {
- "%s/shuffle/%s/%s/%s".format(serverUri, shuffleId, mapId, reduceId)
- }
def stop() {
- trackerActor !? StopMapOutputTracker
- serverUris.clear()
+ trackerActor !! StopMapOutputTracker
+ bmAddresses.clear()
trackerActor = null
}
@@ -153,7 +151,7 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
- serverUris = new ConcurrentHashMap[Int, Array[String]]
+ bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
generation = newGen
}
}
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 8b63d1aba1..ff6764e0a2 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -4,14 +4,14 @@ import java.io.EOFException
import java.net.URL
import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong
-import java.util.HashSet
-import java.util.Random
+import java.util.{HashMap => JHashMap}
import java.util.Date
import java.text.SimpleDateFormat
+import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.Map
import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions._
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.BytesWritable
@@ -34,7 +34,9 @@ import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob}
import org.apache.hadoop.mapreduce.TaskAttemptID
import org.apache.hadoop.mapreduce.TaskAttemptContext
-import SparkContext._
+import spark.SparkContext._
+import spark.partial.BoundedDouble
+import spark.partial.PartialResult
/**
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
@@ -43,19 +45,6 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
self: RDD[(K, V)])
extends Logging
with Serializable {
-
- def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = {
- def mergeMaps(m1: HashMap[K, V], m2: HashMap[K, V]): HashMap[K, V] = {
- for ((k, v) <- m2) {
- m1.get(k) match {
- case None => m1(k) = v
- case Some(w) => m1(k) = func(w, v)
- }
- }
- return m1
- }
- self.map(pair => HashMap(pair)).reduce(mergeMaps)
- }
def combineByKey[C](createCombiner: V => C,
mergeValue: (C, V) => C,
@@ -77,6 +66,39 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = {
combineByKey[V]((v: V) => v, func, func, numSplits)
}
+
+ def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = {
+ def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = {
+ val map = new JHashMap[K, V]
+ for ((k, v) <- iter) {
+ val old = map.get(k)
+ map.put(k, if (old == null) v else func(old, v))
+ }
+ Iterator(map)
+ }
+
+ def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = {
+ for ((k, v) <- m2) {
+ val old = m1.get(k)
+ m1.put(k, if (old == null) v else func(old, v))
+ }
+ return m1
+ }
+
+ self.mapPartitions(reducePartition).reduce(mergeMaps)
+ }
+
+ // Alias for backwards compatibility
+ def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func)
+
+ // TODO: This should probably be a distributed version
+ def countByKey(): Map[K, Long] = self.map(_._1).countByValue()
+
+ // TODO: This should probably be a distributed version
+ def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
+ : PartialResult[Map[K, BoundedDouble]] = {
+ self.map(_._1).countByValueApprox(timeout, confidence)
+ }
def groupByKey(numSplits: Int): RDD[(K, Seq[V])] = {
def createCombiner(v: V) = ArrayBuffer(v)
diff --git a/core/src/main/scala/spark/ParallelShuffleFetcher.scala b/core/src/main/scala/spark/ParallelShuffleFetcher.scala
deleted file mode 100644
index 19eb288e84..0000000000
--- a/core/src/main/scala/spark/ParallelShuffleFetcher.scala
+++ /dev/null
@@ -1,119 +0,0 @@
-package spark
-
-import java.io.ByteArrayInputStream
-import java.io.EOFException
-import java.net.URL
-import java.util.concurrent.LinkedBlockingQueue
-import java.util.concurrent.TimeUnit
-import java.util.concurrent.atomic.AtomicBoolean
-import java.util.concurrent.atomic.AtomicInteger
-import java.util.concurrent.atomic.AtomicReference
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-
-import it.unimi.dsi.fastutil.io.FastBufferedInputStream
-
-
-class ParallelShuffleFetcher extends ShuffleFetcher with Logging {
- val parallelFetches = System.getProperty("spark.parallel.fetches", "3").toInt
-
- def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
- logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
-
- // Figure out a list of input IDs (mapper IDs) for each server
- val ser = SparkEnv.get.serializer.newInstance()
- val inputsByUri = new HashMap[String, ArrayBuffer[Int]]
- val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId)
- for ((serverUri, index) <- serverUris.zipWithIndex) {
- inputsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index
- }
-
- // Randomize them and put them in a LinkedBlockingQueue
- val serverQueue = new LinkedBlockingQueue[(String, ArrayBuffer[Int])]
- for (pair <- Utils.randomize(inputsByUri)) {
- serverQueue.put(pair)
- }
-
- // Create a queue to hold the fetched data
- val resultQueue = new LinkedBlockingQueue[Array[Byte]]
-
- // Atomic variables to communicate failures and # of fetches done
- var failure = new AtomicReference[FetchFailedException](null)
-
- // Start multiple threads to do the fetching (TODO: may be possible to do it asynchronously)
- for (i <- 0 until parallelFetches) {
- new Thread("Fetch thread " + i + " for reduce " + reduceId) {
- override def run() {
- while (true) {
- val pair = serverQueue.poll()
- if (pair == null)
- return
- val (serverUri, inputIds) = pair
- //logInfo("Pulled out server URI " + serverUri)
- for (i <- inputIds) {
- if (failure.get != null)
- return
- val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
- logInfo("Starting HTTP request for " + url)
- try {
- val conn = new URL(url).openConnection()
- conn.connect()
- val len = conn.getContentLength()
- if (len == -1) {
- throw new SparkException("Content length was not specified by server")
- }
- val buf = new Array[Byte](len)
- val in = new FastBufferedInputStream(conn.getInputStream())
- var pos = 0
- while (pos < len) {
- val n = in.read(buf, pos, len-pos)
- if (n == -1) {
- throw new SparkException("EOF before reading the expected " + len + " bytes")
- } else {
- pos += n
- }
- }
- // Done reading everything
- resultQueue.put(buf)
- in.close()
- } catch {
- case e: Exception =>
- logError("Fetch failed from " + url, e)
- failure.set(new FetchFailedException(serverUri, shuffleId, i, reduceId, e))
- return
- }
- }
- //logInfo("Done with server URI " + serverUri)
- }
- }
- }.start()
- }
-
- // Wait for results from the threads (either a failure or all servers done)
- var resultsDone = 0
- var totalResults = inputsByUri.map{case (uri, inputs) => inputs.size}.sum
- while (failure.get == null && resultsDone < totalResults) {
- try {
- val result = resultQueue.poll(100, TimeUnit.MILLISECONDS)
- if (result != null) {
- //logInfo("Pulled out a result")
- val in = ser.inputStream(new ByteArrayInputStream(result))
- try {
- while (true) {
- val pair = in.readObject().asInstanceOf[(K, V)]
- func(pair._1, pair._2)
- }
- } catch {
- case e: EOFException => {} // TODO: cleaner way to detect EOF, such as a sentinel
- }
- resultsDone += 1
- //logInfo("Results done = " + resultsDone)
- }
- } catch { case e: InterruptedException => {} }
- }
- if (failure.get != null) {
- throw failure.get
- }
- }
-}
diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala
index ac61fe3b54..8f3f0f5e15 100644
--- a/core/src/main/scala/spark/Partitioner.scala
+++ b/core/src/main/scala/spark/Partitioner.scala
@@ -70,4 +70,3 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
false
}
}
-
diff --git a/core/src/main/scala/spark/PipedRDD.scala b/core/src/main/scala/spark/PipedRDD.scala
index 8a5de3d7e9..9e0a01b5f9 100644
--- a/core/src/main/scala/spark/PipedRDD.scala
+++ b/core/src/main/scala/spark/PipedRDD.scala
@@ -3,6 +3,7 @@ package spark
import java.io.PrintWriter
import java.util.StringTokenizer
+import scala.collection.Map
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index fa53d9be2c..22dcc27bad 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -4,11 +4,14 @@ import java.io.EOFException
import java.net.URL
import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong
-import java.util.HashSet
import java.util.Random
import java.util.Date
+import java.util.{HashMap => JHashMap}
import scala.collection.mutable.ArrayBuffer
+import scala.collection.Map
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions.mapAsScalaMap
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
@@ -22,6 +25,14 @@ import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.TextOutputFormat
+import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+
+import spark.partial.BoundedDouble
+import spark.partial.CountEvaluator
+import spark.partial.GroupedCountEvaluator
+import spark.partial.PartialResult
+import spark.storage.StorageLevel
+
import SparkContext._
/**
@@ -61,19 +72,32 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
// Get a unique ID for this RDD
val id = sc.newRddId()
- // Variables relating to caching
- private var shouldCache = false
+ // Variables relating to persistence
+ private var storageLevel: StorageLevel = StorageLevel.NONE
- // Change this RDD's caching
- def cache(): RDD[T] = {
- shouldCache = true
+ // Change this RDD's storage level
+ def persist(newLevel: StorageLevel): RDD[T] = {
+ // TODO: Handle changes of StorageLevel
+ if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) {
+ throw new UnsupportedOperationException(
+ "Cannot change storage level of an RDD after it was already assigned a level")
+ }
+ storageLevel = newLevel
this
}
+
+ // Turn on the default caching level for this RDD
+ def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY_DESER)
+
+ // Turn on the default caching level for this RDD
+ def cache(): RDD[T] = persist()
+
+ def getStorageLevel = storageLevel
// Read this RDD; will read from cache if applicable, or otherwise compute
final def iterator(split: Split): Iterator[T] = {
- if (shouldCache) {
- SparkEnv.get.cacheTracker.getOrCompute[T](this, split)
+ if (storageLevel != StorageLevel.NONE) {
+ SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel)
} else {
compute(split)
}
@@ -162,6 +186,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
Array.concat(results: _*)
}
+ def toArray(): Array[T] = collect()
+
def reduce(f: (T, T) => T): T = {
val cleanF = sc.clean(f)
val reducePartition: Iterator[T] => Option[T] = iter => {
@@ -222,7 +248,67 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
}).sum
}
- def toArray(): Array[T] = collect()
+ /**
+ * Approximate version of count() that returns a potentially incomplete result after a timeout.
+ */
+ def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
+ val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) =>
+ var result = 0L
+ while (iter.hasNext) {
+ result += 1L
+ iter.next
+ }
+ result
+ }
+ val evaluator = new CountEvaluator(splits.size, confidence)
+ sc.runApproximateJob(this, countElements, evaluator, timeout)
+ }
+
+ /**
+ * Count elements equal to each value, returning a map of (value, count) pairs. The final combine
+ * step happens locally on the master, equivalent to running a single reduce task.
+ *
+ * TODO: This should perhaps be distributed by default.
+ */
+ def countByValue(): Map[T, Long] = {
+ def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = {
+ val map = new OLMap[T]
+ while (iter.hasNext) {
+ val v = iter.next()
+ map.put(v, map.getLong(v) + 1L)
+ }
+ Iterator(map)
+ }
+ def mergeMaps(m1: OLMap[T], m2: OLMap[T]): OLMap[T] = {
+ val iter = m2.object2LongEntrySet.fastIterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue)
+ }
+ return m1
+ }
+ val myResult = mapPartitions(countPartition).reduce(mergeMaps)
+ myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map
+ }
+
+ /**
+ * Approximate version of countByValue().
+ */
+ def countByValueApprox(
+ timeout: Long,
+ confidence: Double = 0.95
+ ): PartialResult[Map[T, BoundedDouble]] = {
+ val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) =>
+ val map = new OLMap[T]
+ while (iter.hasNext) {
+ val v = iter.next()
+ map.put(v, map.getLong(v) + 1L)
+ }
+ map
+ }
+ val evaluator = new GroupedCountEvaluator[T](splits.size, confidence)
+ sc.runApproximateJob(this, countPartition, evaluator, timeout)
+ }
/**
* Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
diff --git a/core/src/main/scala/spark/Scheduler.scala b/core/src/main/scala/spark/Scheduler.scala
deleted file mode 100644
index 6c7e569313..0000000000
--- a/core/src/main/scala/spark/Scheduler.scala
+++ /dev/null
@@ -1,27 +0,0 @@
-package spark
-
-/**
- * Scheduler trait, implemented by both MesosScheduler and LocalScheduler.
- */
-private trait Scheduler {
- def start()
-
- // Wait for registration with Mesos.
- def waitForRegister()
-
- /**
- * Run a function on some partitions of an RDD, returning an array of results. The allowLocal
- * flag specifies whether the scheduler is allowed to run the job on the master machine rather
- * than shipping it to the cluster, for actions that create short jobs such as first() and take().
- */
- def runJob[T, U: ClassManifest](
- rdd: RDD[T],
- func: (TaskContext, Iterator[T]) => U,
- partitions: Seq[Int],
- allowLocal: Boolean): Array[U]
-
- def stop()
-
- // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
- def defaultParallelism(): Int
-}
diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
index b213ca9dcb..9da73c4b02 100644
--- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
@@ -44,7 +44,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
}
// TODO: use something like WritableConverter to avoid reflection
}
- c.asInstanceOf[Class[ _ <: Writable]]
+ c.asInstanceOf[Class[_ <: Writable]]
}
def saveAsSequenceFile(path: String) {
diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/Serializer.scala
index 2429bbfeb9..61a70beaf1 100644
--- a/core/src/main/scala/spark/Serializer.scala
+++ b/core/src/main/scala/spark/Serializer.scala
@@ -1,6 +1,12 @@
package spark
-import java.io.{InputStream, OutputStream}
+import java.io.{EOFException, InputStream, OutputStream}
+import java.nio.ByteBuffer
+import java.nio.channels.Channels
+
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+
+import spark.util.ByteBufferInputStream
/**
* A serializer. Because some serialization libraries are not thread safe, this class is used to
@@ -14,11 +20,31 @@ trait Serializer {
* An instance of the serializer, for use by one thread at a time.
*/
trait SerializerInstance {
- def serialize[T](t: T): Array[Byte]
- def deserialize[T](bytes: Array[Byte]): T
- def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T
- def outputStream(s: OutputStream): SerializationStream
- def inputStream(s: InputStream): DeserializationStream
+ def serialize[T](t: T): ByteBuffer
+
+ def deserialize[T](bytes: ByteBuffer): T
+
+ def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T
+
+ def serializeStream(s: OutputStream): SerializationStream
+
+ def deserializeStream(s: InputStream): DeserializationStream
+
+ def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
+ // Default implementation uses serializeStream
+ val stream = new FastByteArrayOutputStream()
+ serializeStream(stream).writeAll(iterator)
+ val buffer = ByteBuffer.allocate(stream.position.toInt)
+ buffer.put(stream.array, 0, stream.position.toInt)
+ buffer.flip()
+ buffer
+ }
+
+ def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
+ // Default implementation uses deserializeStream
+ buffer.rewind()
+ deserializeStream(new ByteBufferInputStream(buffer)).toIterator
+ }
}
/**
@@ -28,6 +54,13 @@ trait SerializationStream {
def writeObject[T](t: T): Unit
def flush(): Unit
def close(): Unit
+
+ def writeAll[T](iter: Iterator[T]): SerializationStream = {
+ while (iter.hasNext) {
+ writeObject(iter.next())
+ }
+ this
+ }
}
/**
@@ -36,4 +69,45 @@ trait SerializationStream {
trait DeserializationStream {
def readObject[T](): T
def close(): Unit
+
+ /**
+ * Read the elements of this stream through an iterator. This can only be called once, as
+ * reading each element will consume data from the input source.
+ */
+ def toIterator: Iterator[Any] = new Iterator[Any] {
+ var gotNext = false
+ var finished = false
+ var nextValue: Any = null
+
+ private def getNext() {
+ try {
+ nextValue = readObject[Any]()
+ } catch {
+ case eof: EOFException =>
+ finished = true
+ }
+ gotNext = true
+ }
+
+ override def hasNext: Boolean = {
+ if (!gotNext) {
+ getNext()
+ }
+ if (finished) {
+ close()
+ }
+ !finished
+ }
+
+ override def next(): Any = {
+ if (!gotNext) {
+ getNext()
+ }
+ if (finished) {
+ throw new NoSuchElementException("End of stream")
+ }
+ gotNext = false
+ nextValue
+ }
+ }
}
diff --git a/core/src/main/scala/spark/SerializingCache.scala b/core/src/main/scala/spark/SerializingCache.scala
deleted file mode 100644
index 3d192f2403..0000000000
--- a/core/src/main/scala/spark/SerializingCache.scala
+++ /dev/null
@@ -1,26 +0,0 @@
-package spark
-
-import java.io._
-
-/**
- * Wrapper around a BoundedMemoryCache that stores serialized objects as byte arrays in order to
- * reduce storage cost and GC overhead
- */
-class SerializingCache extends Cache with Logging {
- val bmc = new BoundedMemoryCache
-
- override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
- val ser = SparkEnv.get.serializer.newInstance()
- bmc.put(datasetId, partition, ser.serialize(value))
- }
-
- override def get(datasetId: Any, partition: Int): Any = {
- val bytes = bmc.get(datasetId, partition)
- if (bytes != null) {
- val ser = SparkEnv.get.serializer.newInstance()
- return ser.deserialize(bytes.asInstanceOf[Array[Byte]])
- } else {
- return null
- }
- }
-}
diff --git a/core/src/main/scala/spark/ShuffleMapTask.scala b/core/src/main/scala/spark/ShuffleMapTask.scala
deleted file mode 100644
index 5fc59af06c..0000000000
--- a/core/src/main/scala/spark/ShuffleMapTask.scala
+++ /dev/null
@@ -1,56 +0,0 @@
-package spark
-
-import java.io.BufferedOutputStream
-import java.io.FileOutputStream
-import java.io.ObjectOutputStream
-import java.util.{HashMap => JHashMap}
-
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
-class ShuffleMapTask(
- runId: Int,
- stageId: Int,
- rdd: RDD[_],
- dep: ShuffleDependency[_,_,_],
- val partition: Int,
- locs: Seq[String])
- extends DAGTask[String](runId, stageId)
- with Logging {
-
- val split = rdd.splits(partition)
-
- override def run (attemptId: Int): String = {
- val numOutputSplits = dep.partitioner.numPartitions
- val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
- val partitioner = dep.partitioner.asInstanceOf[Partitioner]
- val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any])
- for (elem <- rdd.iterator(split)) {
- val (k, v) = elem.asInstanceOf[(Any, Any)]
- var bucketId = partitioner.getPartition(k)
- val bucket = buckets(bucketId)
- var existing = bucket.get(k)
- if (existing == null) {
- bucket.put(k, aggregator.createCombiner(v))
- } else {
- bucket.put(k, aggregator.mergeValue(existing, v))
- }
- }
- val ser = SparkEnv.get.serializer.newInstance()
- for (i <- 0 until numOutputSplits) {
- val file = SparkEnv.get.shuffleManager.getOutputFile(dep.shuffleId, partition, i)
- val out = ser.outputStream(new FastBufferedOutputStream(new FileOutputStream(file)))
- val iter = buckets(i).entrySet().iterator()
- while (iter.hasNext()) {
- val entry = iter.next()
- out.writeObject((entry.getKey, entry.getValue))
- }
- // TODO: have some kind of EOF marker
- out.close()
- }
- return SparkEnv.get.shuffleManager.getServerUri
- }
-
- override def preferredLocations: Seq[String] = locs
-
- override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
-}
diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala
index 5efc8cf50b..5434197eca 100644
--- a/core/src/main/scala/spark/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/ShuffledRDD.scala
@@ -8,7 +8,7 @@ class ShuffledRDDSplit(val idx: Int) extends Split {
}
class ShuffledRDD[K, V, C](
- parent: RDD[(K, V)],
+ @transient parent: RDD[(K, V)],
aggregator: Aggregator[K, V, C],
part : Partitioner)
extends RDD[(K, C)](parent.context) {
diff --git a/core/src/main/scala/spark/SimpleShuffleFetcher.scala b/core/src/main/scala/spark/SimpleShuffleFetcher.scala
deleted file mode 100644
index 196c64cf1f..0000000000
--- a/core/src/main/scala/spark/SimpleShuffleFetcher.scala
+++ /dev/null
@@ -1,46 +0,0 @@
-package spark
-
-import java.io.EOFException
-import java.net.URL
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-
-import it.unimi.dsi.fastutil.io.FastBufferedInputStream
-
-class SimpleShuffleFetcher extends ShuffleFetcher with Logging {
- def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
- logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
- val ser = SparkEnv.get.serializer.newInstance()
- val splitsByUri = new HashMap[String, ArrayBuffer[Int]]
- val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId)
- for ((serverUri, index) <- serverUris.zipWithIndex) {
- splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index
- }
- for ((serverUri, inputIds) <- Utils.randomize(splitsByUri)) {
- for (i <- inputIds) {
- try {
- val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
- // TODO: multithreaded fetch
- // TODO: would be nice to retry multiple times
- val inputStream = ser.inputStream(
- new FastBufferedInputStream(new URL(url).openStream()))
- try {
- while (true) {
- val pair = inputStream.readObject().asInstanceOf[(K, V)]
- func(pair._1, pair._2)
- }
- } finally {
- inputStream.close()
- }
- } catch {
- case e: EOFException => {} // We currently assume EOF means we read the whole thing
- case other: Exception => {
- logError("Fetch failed", other)
- throw new FetchFailedException(serverUri, shuffleId, i, reduceId, other)
- }
- }
- }
- }
- }
-}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 6e019d6e7f..7a9a70fee0 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -3,6 +3,9 @@ package spark
import java.io._
import java.util.concurrent.atomic.AtomicInteger
+import akka.actor.Actor
+import akka.actor.Actor._
+
import scala.actors.remote.RemoteActor
import scala.collection.mutable.ArrayBuffer
@@ -32,6 +35,15 @@ import org.apache.mesos.MesosNativeLibrary
import spark.broadcast._
+import spark.partial.ApproximateEvaluator
+import spark.partial.PartialResult
+
+import spark.scheduler.DAGScheduler
+import spark.scheduler.TaskScheduler
+import spark.scheduler.local.LocalScheduler
+import spark.scheduler.mesos.MesosScheduler
+import spark.scheduler.mesos.CoarseMesosScheduler
+
class SparkContext(
master: String,
frameworkName: String,
@@ -54,14 +66,19 @@ class SparkContext(
if (RemoteActor.classLoader == null) {
RemoteActor.classLoader = getClass.getClassLoader
}
+
+ remote.start(System.getProperty("spark.master.host"),
+ System.getProperty("spark.master.port").toInt)
+ private val isLocal = master.startsWith("local") // TODO: better check for local
+
// Create the Spark execution environment (cache, map output tracker, etc)
- val env = SparkEnv.createFromSystemProperties(true)
+ val env = SparkEnv.createFromSystemProperties(true, isLocal)
SparkEnv.set(env)
Broadcast.initialize(true)
// Create and start the scheduler
- private var scheduler: Scheduler = {
+ private var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
@@ -74,13 +91,17 @@ class SparkContext(
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
new LocalScheduler(threads.toInt, maxFailures.toInt)
case _ =>
- MesosNativeLibrary.load()
- new MesosScheduler(this, master, frameworkName)
+ System.loadLibrary("mesos")
+ if (System.getProperty("spark.mesos.coarse", "false") == "true") {
+ new CoarseMesosScheduler(this, master, frameworkName)
+ } else {
+ new MesosScheduler(this, master, frameworkName)
+ }
}
}
- scheduler.start()
+ taskScheduler.start()
- private val isLocal = scheduler.isInstanceOf[LocalScheduler]
+ private var dagScheduler = new DAGScheduler(taskScheduler)
// Methods for creating RDDs
@@ -237,19 +258,21 @@ class SparkContext(
// Stop the SparkContext
def stop() {
- scheduler.stop()
- scheduler = null
+ dagScheduler.stop()
+ dagScheduler = null
+ taskScheduler = null
// TODO: Broadcast.stop(), Cache.stop()?
env.mapOutputTracker.stop()
env.cacheTracker.stop()
env.shuffleFetcher.stop()
env.shuffleManager.stop()
+ env.connectionManager.stop()
SparkEnv.set(null)
}
- // Wait for the scheduler to be registered
+ // Wait for the scheduler to be registered with the cluster manager
def waitForRegister() {
- scheduler.waitForRegister()
+ taskScheduler.waitForRegister()
}
// Get Spark's home location from either a value set through the constructor,
@@ -281,7 +304,7 @@ class SparkContext(
): Array[U] = {
logInfo("Starting job...")
val start = System.nanoTime
- val result = scheduler.runJob(rdd, func, partitions, allowLocal)
+ val result = dagScheduler.runJob(rdd, func, partitions, allowLocal)
logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
result
}
@@ -306,6 +329,22 @@ class SparkContext(
runJob(rdd, func, 0 until rdd.splits.size, false)
}
+ /**
+ * Run a job that can return approximate results.
+ */
+ def runApproximateJob[T, U, R](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ evaluator: ApproximateEvaluator[U, R],
+ timeout: Long
+ ): PartialResult[R] = {
+ logInfo("Starting job...")
+ val start = System.nanoTime
+ val result = dagScheduler.runApproximateJob(rdd, func, evaluator, timeout)
+ logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
+ result
+ }
+
// Clean a closure to make it ready to serialized and send to tasks
// (removes unreferenced variables in $outer's, updates REPL variables)
private[spark] def clean[F <: AnyRef](f: F): F = {
@@ -314,7 +353,7 @@ class SparkContext(
}
// Default level of parallelism to use when not given by user (e.g. for reduce tasks)
- def defaultParallelism: Int = scheduler.defaultParallelism
+ def defaultParallelism: Int = taskScheduler.defaultParallelism
// Default min number of splits for Hadoop RDDs when not given by user
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
@@ -349,15 +388,23 @@ object SparkContext {
}
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
+
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
-
- implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](rdd: RDD[(K, V)]) =
+
+ implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
+ rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
- implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
+ implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
+ rdd: RDD[(K, V)]) =
new OrderedRDDFunctions(rdd)
+ implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd)
+
+ implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) =
+ new DoubleRDDFunctions(rdd.map(x => num.toDouble(x)))
+
// Implicit conversions to common Writable types, for saveAsSequenceFile
implicit def intToIntWritable(i: Int) = new IntWritable(i)
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index cd752f8b65..897a5ef82d 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -1,14 +1,26 @@
package spark
+import akka.actor.Actor
+
+import spark.storage.BlockManager
+import spark.storage.BlockManagerMaster
+import spark.network.ConnectionManager
+
class SparkEnv (
- val cache: Cache,
- val serializer: Serializer,
- val closureSerializer: Serializer,
- val cacheTracker: CacheTracker,
- val mapOutputTracker: MapOutputTracker,
- val shuffleFetcher: ShuffleFetcher,
- val shuffleManager: ShuffleManager
-)
+ val cache: Cache,
+ val serializer: Serializer,
+ val closureSerializer: Serializer,
+ val cacheTracker: CacheTracker,
+ val mapOutputTracker: MapOutputTracker,
+ val shuffleFetcher: ShuffleFetcher,
+ val shuffleManager: ShuffleManager,
+ val blockManager: BlockManager,
+ val connectionManager: ConnectionManager
+ ) {
+
+ /** No-parameter constructor for unit tests. */
+ def this() = this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null)
+}
object SparkEnv {
private val env = new ThreadLocal[SparkEnv]
@@ -21,36 +33,55 @@ object SparkEnv {
env.get()
}
- def createFromSystemProperties(isMaster: Boolean): SparkEnv = {
- val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache")
- val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
-
- val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer")
+ def createFromSystemProperties(isMaster: Boolean, isLocal: Boolean): SparkEnv = {
+ val serializerClass = System.getProperty("spark.serializer", "spark.KryoSerializer")
val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer]
+
+ BlockManagerMaster.startBlockManagerMaster(isMaster, isLocal)
+
+ var blockManager = new BlockManager(serializer)
+
+ val connectionManager = blockManager.connectionManager
+
+ val shuffleManager = new ShuffleManager()
val closureSerializerClass =
System.getProperty("spark.closure.serializer", "spark.JavaSerializer")
val closureSerializer =
Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer]
+ val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache")
+ val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
- val cacheTracker = new CacheTracker(isMaster, cache)
+ val cacheTracker = new CacheTracker(isMaster, blockManager)
+ blockManager.cacheTracker = cacheTracker
val mapOutputTracker = new MapOutputTracker(isMaster)
val shuffleFetcherClass =
- System.getProperty("spark.shuffle.fetcher", "spark.SimpleShuffleFetcher")
+ System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
val shuffleFetcher =
Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher]
- val shuffleMgr = new ShuffleManager()
+ /*
+ if (System.getProperty("spark.stream.distributed", "false") == "true") {
+ val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]]
+ if (isLocal || !isMaster) {
+ (new Thread() {
+ override def run() {
+ println("Wait started")
+ Thread.sleep(60000)
+ println("Wait ended")
+ val receiverClass = Class.forName("spark.stream.TestStreamReceiver4")
+ val constructor = receiverClass.getConstructor(blockManagerClass)
+ val receiver = constructor.newInstance(blockManager)
+ receiver.asInstanceOf[Thread].start()
+ }
+ }).start()
+ }
+ }
+ */
- new SparkEnv(
- cache,
- serializer,
- closureSerializer,
- cacheTracker,
- mapOutputTracker,
- shuffleFetcher,
- shuffleMgr)
+ new SparkEnv(cache, serializer, closureSerializer, cacheTracker, mapOutputTracker, shuffleFetcher,
+ shuffleManager, blockManager, connectionManager)
}
}
diff --git a/core/src/main/scala/spark/Stage.scala b/core/src/main/scala/spark/Stage.scala
deleted file mode 100644
index 9452ea3a8e..0000000000
--- a/core/src/main/scala/spark/Stage.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-package spark
-
-class Stage(
- val id: Int,
- val rdd: RDD[_],
- val shuffleDep: Option[ShuffleDependency[_,_,_]],
- val parents: List[Stage]) {
-
- val isShuffleMap = shuffleDep != None
- val numPartitions = rdd.splits.size
- val outputLocs = Array.fill[List[String]](numPartitions)(Nil)
- var numAvailableOutputs = 0
-
- def isAvailable: Boolean = {
- if (parents.size == 0 && !isShuffleMap) {
- true
- } else {
- numAvailableOutputs == numPartitions
- }
- }
-
- def addOutputLoc(partition: Int, host: String) {
- val prevList = outputLocs(partition)
- outputLocs(partition) = host :: prevList
- if (prevList == Nil)
- numAvailableOutputs += 1
- }
-
- def removeOutputLoc(partition: Int, host: String) {
- val prevList = outputLocs(partition)
- val newList = prevList.filterNot(_ == host)
- outputLocs(partition) = newList
- if (prevList != Nil && newList == Nil) {
- numAvailableOutputs -= 1
- }
- }
-
- override def toString = "Stage " + id
-
- override def hashCode(): Int = id
-}
diff --git a/core/src/main/scala/spark/Task.scala b/core/src/main/scala/spark/Task.scala
deleted file mode 100644
index bc3b374344..0000000000
--- a/core/src/main/scala/spark/Task.scala
+++ /dev/null
@@ -1,9 +0,0 @@
-package spark
-
-class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) extends Serializable
-
-abstract class Task[T] extends Serializable {
- def run(id: Int): T
- def preferredLocations: Seq[String] = Nil
- def generation: Option[Long] = None
-}
diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala
new file mode 100644
index 0000000000..7a6214aab6
--- /dev/null
+++ b/core/src/main/scala/spark/TaskContext.scala
@@ -0,0 +1,3 @@
+package spark
+
+class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) extends Serializable
diff --git a/core/src/main/scala/spark/TaskEndReason.scala b/core/src/main/scala/spark/TaskEndReason.scala
new file mode 100644
index 0000000000..6e4eb25ed4
--- /dev/null
+++ b/core/src/main/scala/spark/TaskEndReason.scala
@@ -0,0 +1,16 @@
+package spark
+
+import spark.storage.BlockManagerId
+
+/**
+ * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry
+ * tasks several times for "ephemeral" failures, and only report back failures that require some
+ * old stages to be resubmitted, such as shuffle map fetch failures.
+ */
+sealed trait TaskEndReason
+
+case object Success extends TaskEndReason
+case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it
+case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
+case class ExceptionFailure(exception: Throwable) extends TaskEndReason
+case class OtherFailure(message: String) extends TaskEndReason
diff --git a/core/src/main/scala/spark/TaskResult.scala b/core/src/main/scala/spark/TaskResult.scala
deleted file mode 100644
index 2b7fd1a4b2..0000000000
--- a/core/src/main/scala/spark/TaskResult.scala
+++ /dev/null
@@ -1,8 +0,0 @@
-package spark
-
-import scala.collection.mutable.Map
-
-// Task result. Also contains updates to accumulator variables.
-// TODO: Use of distributed cache to return result is a hack to get around
-// what seems to be a bug with messages over 60KB in libprocess; fix it
-private class TaskResult[T](val value: T, val accumUpdates: Map[Long, Any]) extends Serializable
diff --git a/core/src/main/scala/spark/UnionRDD.scala b/core/src/main/scala/spark/UnionRDD.scala
index 4c0f255e6b..17522e2bbb 100644
--- a/core/src/main/scala/spark/UnionRDD.scala
+++ b/core/src/main/scala/spark/UnionRDD.scala
@@ -33,7 +33,8 @@ class UnionRDD[T: ClassManifest](
override def splits = splits_
- @transient override val dependencies = {
+ @transient
+ override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for ((rdd, index) <- rdds.zipWithIndex) {
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index cfd6dc8b2a..742e60b176 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -118,6 +118,23 @@ object Utils {
* Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4).
*/
def localIpAddress(): String = InetAddress.getLocalHost.getHostAddress
+
+ private var customHostname: Option[String] = None
+
+ /**
+ * Allow setting a custom host name because when we run on Mesos we need to use the same
+ * hostname it reports to the master.
+ */
+ def setCustomHostname(hostname: String) {
+ customHostname = Some(hostname)
+ }
+
+ /**
+ * Get the local machine's hostname
+ */
+ def localHostName(): String = {
+ customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
+ }
/**
* Returns a standard ThreadFactory except all threads are daemons.
@@ -142,6 +159,14 @@ object Utils {
return threadPool
}
+
+ /**
+ * Return the string to tell how long has passed in seconds. The passing parameter should be in
+ * millisecond.
+ */
+ def getUsedTimeMs(startTimeMs: Long): String = {
+ return " " + (System.currentTimeMillis - startTimeMs) + " ms "
+ }
/**
* Wrapper over newFixedThreadPool.
@@ -155,16 +180,6 @@ object Utils {
}
/**
- * Get the local machine's hostname.
- */
- def localHostName(): String = InetAddress.getLocalHost.getHostName
-
- /**
- * Get current host
- */
- def getHost = System.getProperty("spark.hostname", localHostName())
-
- /**
* Delete a file or directory and its contents recursively.
*/
def deleteRecursively(file: File) {
diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
new file mode 100644
index 0000000000..4546dfa0fa
--- /dev/null
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -0,0 +1,364 @@
+package spark.network
+
+import spark._
+
+import scala.collection.mutable.{HashMap, Queue, ArrayBuffer}
+
+import java.io._
+import java.nio._
+import java.nio.channels._
+import java.nio.channels.spi._
+import java.net._
+
+
+abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging {
+
+ channel.configureBlocking(false)
+ channel.socket.setTcpNoDelay(true)
+ channel.socket.setReuseAddress(true)
+ channel.socket.setKeepAlive(true)
+ /*channel.socket.setReceiveBufferSize(32768) */
+
+ var onCloseCallback: Connection => Unit = null
+ var onExceptionCallback: (Connection, Exception) => Unit = null
+ var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
+
+ lazy val remoteAddress = getRemoteAddress()
+ lazy val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
+
+ def key() = channel.keyFor(selector)
+
+ def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
+
+ def read() {
+ throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString)
+ }
+
+ def write() {
+ throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString)
+ }
+
+ def close() {
+ key.cancel()
+ channel.close()
+ callOnCloseCallback()
+ }
+
+ def onClose(callback: Connection => Unit) {onCloseCallback = callback}
+
+ def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback}
+
+ def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback}
+
+ def callOnExceptionCallback(e: Exception) {
+ if (onExceptionCallback != null) {
+ onExceptionCallback(this, e)
+ } else {
+ logError("Error in connection to " + remoteConnectionManagerId +
+ " and OnExceptionCallback not registered", e)
+ }
+ }
+
+ def callOnCloseCallback() {
+ if (onCloseCallback != null) {
+ onCloseCallback(this)
+ } else {
+ logWarning("Connection to " + remoteConnectionManagerId +
+ " closed and OnExceptionCallback not registered")
+ }
+
+ }
+
+ def changeConnectionKeyInterest(ops: Int) {
+ if (onKeyInterestChangeCallback != null) {
+ onKeyInterestChangeCallback(this, ops)
+ } else {
+ throw new Exception("OnKeyInterestChangeCallback not registered")
+ }
+ }
+
+ def printRemainingBuffer(buffer: ByteBuffer) {
+ val bytes = new Array[Byte](buffer.remaining)
+ val curPosition = buffer.position
+ buffer.get(bytes)
+ bytes.foreach(x => print(x + " "))
+ buffer.position(curPosition)
+ print(" (" + bytes.size + ")")
+ }
+
+ def printBuffer(buffer: ByteBuffer, position: Int, length: Int) {
+ val bytes = new Array[Byte](length)
+ val curPosition = buffer.position
+ buffer.position(position)
+ buffer.get(bytes)
+ bytes.foreach(x => print(x + " "))
+ print(" (" + position + ", " + length + ")")
+ buffer.position(curPosition)
+ }
+
+}
+
+
+class SendingConnection(val address: InetSocketAddress, selector_ : Selector)
+extends Connection(SocketChannel.open, selector_) {
+
+ class Outbox(fair: Int = 0) {
+ val messages = new Queue[Message]()
+ val defaultChunkSize = 65536 //32768 //16384
+ var nextMessageToBeUsed = 0
+
+ def addMessage(message: Message): Unit = {
+ messages.synchronized{
+ /*messages += message*/
+ messages.enqueue(message)
+ logInfo("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]")
+ }
+ }
+
+ def getChunk(): Option[MessageChunk] = {
+ fair match {
+ case 0 => getChunkFIFO()
+ case 1 => getChunkRR()
+ case _ => throw new Exception("Unexpected fairness policy in outbox")
+ }
+ }
+
+ private def getChunkFIFO(): Option[MessageChunk] = {
+ /*logInfo("Using FIFO")*/
+ messages.synchronized {
+ while (!messages.isEmpty) {
+ val message = messages(0)
+ val chunk = message.getChunkForSending(defaultChunkSize)
+ if (chunk.isDefined) {
+ messages += message // this is probably incorrect, it wont work as fifo
+ if (!message.started) logDebug("Starting to send [" + message + "]")
+ message.started = true
+ return chunk
+ }
+ /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
+ logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
+ }
+ }
+ None
+ }
+
+ private def getChunkRR(): Option[MessageChunk] = {
+ messages.synchronized {
+ while (!messages.isEmpty) {
+ /*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
+ /*val message = messages(nextMessageToBeUsed)*/
+ val message = messages.dequeue
+ val chunk = message.getChunkForSending(defaultChunkSize)
+ if (chunk.isDefined) {
+ messages.enqueue(message)
+ nextMessageToBeUsed = nextMessageToBeUsed + 1
+ if (!message.started) {
+ logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]")
+ message.started = true
+ message.startTime = System.currentTimeMillis
+ }
+ logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]")
+ return chunk
+ }
+ /*messages -= message*/
+ message.finishTime = System.currentTimeMillis
+ logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
+ }
+ }
+ None
+ }
+ }
+
+ val outbox = new Outbox(1)
+ val currentBuffers = new ArrayBuffer[ByteBuffer]()
+
+ /*channel.socket.setSendBufferSize(256 * 1024)*/
+
+ override def getRemoteAddress() = address
+
+ def send(message: Message) {
+ outbox.synchronized {
+ outbox.addMessage(message)
+ if (channel.isConnected) {
+ changeConnectionKeyInterest(SelectionKey.OP_WRITE)
+ }
+ }
+ }
+
+ def connect() {
+ try{
+ channel.connect(address)
+ channel.register(selector, SelectionKey.OP_CONNECT)
+ logInfo("Initiating connection to [" + address + "]")
+ } catch {
+ case e: Exception => {
+ logError("Error connecting to " + address, e)
+ callOnExceptionCallback(e)
+ }
+ }
+ }
+
+ def finishConnect() {
+ try {
+ channel.finishConnect
+ changeConnectionKeyInterest(SelectionKey.OP_WRITE)
+ logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
+ } catch {
+ case e: Exception => {
+ logWarning("Error finishing connection to " + address, e)
+ callOnExceptionCallback(e)
+ }
+ }
+ }
+
+ override def write() {
+ try{
+ while(true) {
+ if (currentBuffers.size == 0) {
+ outbox.synchronized {
+ outbox.getChunk match {
+ case Some(chunk) => {
+ currentBuffers ++= chunk.buffers
+ }
+ case None => {
+ changeConnectionKeyInterest(0)
+ /*key.interestOps(0)*/
+ return
+ }
+ }
+ }
+ }
+
+ if (currentBuffers.size > 0) {
+ val buffer = currentBuffers(0)
+ val remainingBytes = buffer.remaining
+ val writtenBytes = channel.write(buffer)
+ if (buffer.remaining == 0) {
+ currentBuffers -= buffer
+ }
+ if (writtenBytes < remainingBytes) {
+ return
+ }
+ }
+ }
+ } catch {
+ case e: Exception => {
+ logWarning("Error writing in connection to " + remoteConnectionManagerId, e)
+ callOnExceptionCallback(e)
+ close()
+ }
+ }
+ }
+}
+
+
+class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
+extends Connection(channel_, selector_) {
+
+ class Inbox() {
+ val messages = new HashMap[Int, BufferMessage]()
+
+ def getChunk(header: MessageChunkHeader): Option[MessageChunk] = {
+
+ def createNewMessage: BufferMessage = {
+ val newMessage = Message.create(header).asInstanceOf[BufferMessage]
+ newMessage.started = true
+ newMessage.startTime = System.currentTimeMillis
+ logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]")
+ messages += ((newMessage.id, newMessage))
+ newMessage
+ }
+
+ val message = messages.getOrElseUpdate(header.id, createNewMessage)
+ logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]")
+ message.getChunkForReceiving(header.chunkSize)
+ }
+
+ def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = {
+ messages.get(chunk.header.id)
+ }
+
+ def removeMessage(message: Message) {
+ messages -= message.id
+ }
+ }
+
+ val inbox = new Inbox()
+ val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
+ var onReceiveCallback: (Connection , Message) => Unit = null
+ var currentChunk: MessageChunk = null
+
+ channel.register(selector, SelectionKey.OP_READ)
+
+ override def read() {
+ try {
+ while (true) {
+ if (currentChunk == null) {
+ val headerBytesRead = channel.read(headerBuffer)
+ if (headerBytesRead == -1) {
+ close()
+ return
+ }
+ if (headerBuffer.remaining > 0) {
+ return
+ }
+ headerBuffer.flip
+ if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
+ throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
+ }
+ val header = MessageChunkHeader.create(headerBuffer)
+ headerBuffer.clear()
+ header.typ match {
+ case Message.BUFFER_MESSAGE => {
+ if (header.totalSize == 0) {
+ if (onReceiveCallback != null) {
+ onReceiveCallback(this, Message.create(header))
+ }
+ currentChunk = null
+ return
+ } else {
+ currentChunk = inbox.getChunk(header).orNull
+ }
+ }
+ case _ => throw new Exception("Message of unknown type received")
+ }
+ }
+
+ if (currentChunk == null) throw new Exception("No message chunk to receive data")
+
+ val bytesRead = channel.read(currentChunk.buffer)
+ if (bytesRead == 0) {
+ return
+ } else if (bytesRead == -1) {
+ close()
+ return
+ }
+
+ /*logDebug("Read " + bytesRead + " bytes for the buffer")*/
+
+ if (currentChunk.buffer.remaining == 0) {
+ /*println("Filled buffer at " + System.currentTimeMillis)*/
+ val bufferMessage = inbox.getMessageForChunk(currentChunk).get
+ if (bufferMessage.isCompletelyReceived) {
+ bufferMessage.flip
+ bufferMessage.finishTime = System.currentTimeMillis
+ logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken)
+ if (onReceiveCallback != null) {
+ onReceiveCallback(this, bufferMessage)
+ }
+ inbox.removeMessage(bufferMessage)
+ }
+ currentChunk = null
+ }
+ }
+ } catch {
+ case e: Exception => {
+ logWarning("Error reading from connection to " + remoteConnectionManagerId, e)
+ callOnExceptionCallback(e)
+ close()
+ }
+ }
+ }
+
+ def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
+}
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
new file mode 100644
index 0000000000..e9f254d0f3
--- /dev/null
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -0,0 +1,467 @@
+package spark.network
+
+import spark._
+
+import scala.actors.Future
+import scala.actors.Futures.future
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.SynchronizedMap
+import scala.collection.mutable.SynchronizedQueue
+import scala.collection.mutable.Queue
+import scala.collection.mutable.ArrayBuffer
+
+import java.io._
+import java.nio._
+import java.nio.channels._
+import java.nio.channels.spi._
+import java.net._
+import java.util.concurrent.Executors
+
+case class ConnectionManagerId(val host: String, val port: Int) {
+ def toSocketAddress() = new InetSocketAddress(host, port)
+}
+
+object ConnectionManagerId {
+ def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
+ new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
+ }
+}
+
+class ConnectionManager(port: Int) extends Logging {
+
+ case class MessageStatus(message: Message, connectionManagerId: ConnectionManagerId) {
+ var ackMessage: Option[Message] = None
+ var attempted = false
+ var acked = false
+ }
+
+ val selector = SelectorProvider.provider.openSelector()
+ /*val handleMessageExecutor = new ThreadPoolExecutor(4, 4, 600, TimeUnit.SECONDS, new LinkedBlockingQueue()) */
+ val handleMessageExecutor = Executors.newFixedThreadPool(4)
+ val serverChannel = ServerSocketChannel.open()
+ val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
+ val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
+ val messageStatuses = new HashMap[Int, MessageStatus]
+ val connectionRequests = new SynchronizedQueue[SendingConnection]
+ val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
+ val sendMessageRequests = new Queue[(Message, SendingConnection)]
+
+ var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
+
+ serverChannel.configureBlocking(false)
+ serverChannel.socket.setReuseAddress(true)
+ serverChannel.socket.setReceiveBufferSize(256 * 1024)
+
+ serverChannel.socket.bind(new InetSocketAddress(port))
+ serverChannel.register(selector, SelectionKey.OP_ACCEPT)
+
+ val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
+ logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
+
+ val thisInstance = this
+ var selectorThread = new Thread("connection-manager-thread") {
+ override def run() {
+ thisInstance.run()
+ }
+ }
+ selectorThread.setDaemon(true)
+ selectorThread.start()
+
+ def run() {
+ try {
+ var interrupted = false
+ while(!interrupted) {
+ while(!connectionRequests.isEmpty) {
+ val sendingConnection = connectionRequests.dequeue
+ sendingConnection.connect()
+ addConnection(sendingConnection)
+ }
+ sendMessageRequests.synchronized {
+ while(!sendMessageRequests.isEmpty) {
+ val (message, connection) = sendMessageRequests.dequeue
+ connection.send(message)
+ }
+ }
+
+ while(!keyInterestChangeRequests.isEmpty) {
+ val (key, ops) = keyInterestChangeRequests.dequeue
+ val connection = connectionsByKey(key)
+ val lastOps = key.interestOps()
+ key.interestOps(ops)
+
+ def intToOpStr(op: Int): String = {
+ val opStrs = ArrayBuffer[String]()
+ if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
+ if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
+ if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
+ if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
+ if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
+ }
+
+ logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId +
+ "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
+
+ }
+
+ val selectedKeysCount = selector.select()
+ if (selectedKeysCount == 0) logInfo("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
+
+ interrupted = selectorThread.isInterrupted
+
+ val selectedKeys = selector.selectedKeys().iterator()
+ while (selectedKeys.hasNext()) {
+ val key = selectedKeys.next.asInstanceOf[SelectionKey]
+ selectedKeys.remove()
+ if (key.isValid) {
+ if (key.isAcceptable) {
+ acceptConnection(key)
+ } else
+ if (key.isConnectable) {
+ connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect()
+ } else
+ if (key.isReadable) {
+ connectionsByKey(key).read()
+ } else
+ if (key.isWritable) {
+ connectionsByKey(key).write()
+ }
+ }
+ }
+ }
+ } catch {
+ case e: Exception => logError("Error in select loop", e)
+ }
+ }
+
+ def acceptConnection(key: SelectionKey) {
+ val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
+ val newChannel = serverChannel.accept()
+ val newConnection = new ReceivingConnection(newChannel, selector)
+ newConnection.onReceive(receiveMessage)
+ newConnection.onClose(removeConnection)
+ addConnection(newConnection)
+ logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
+ }
+
+ def addConnection(connection: Connection) {
+ connectionsByKey += ((connection.key, connection))
+ if (connection.isInstanceOf[SendingConnection]) {
+ val sendingConnection = connection.asInstanceOf[SendingConnection]
+ connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection))
+ }
+ connection.onKeyInterestChange(changeConnectionKeyInterest)
+ connection.onException(handleConnectionError)
+ connection.onClose(removeConnection)
+ }
+
+ def removeConnection(connection: Connection) {
+ /*logInfo("Removing connection")*/
+ connectionsByKey -= connection.key
+ if (connection.isInstanceOf[SendingConnection]) {
+ val sendingConnection = connection.asInstanceOf[SendingConnection]
+ val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId
+ logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
+
+ connectionsById -= sendingConnectionManagerId
+
+ messageStatuses.synchronized {
+ messageStatuses
+ .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
+ logInfo("Notifying " + status)
+ status.synchronized {
+ status.attempted = true
+ status.acked = false
+ status.notifyAll
+ }
+ })
+
+ messageStatuses.retain((i, status) => {
+ status.connectionManagerId != sendingConnectionManagerId
+ })
+ }
+ } else if (connection.isInstanceOf[ReceivingConnection]) {
+ val receivingConnection = connection.asInstanceOf[ReceivingConnection]
+ val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId
+ logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
+
+ val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull
+ if (sendingConnectionManagerId == null) {
+ logError("Corresponding SendingConnectionManagerId not found")
+ return
+ }
+ logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId)
+
+ val sendingConnection = connectionsById(sendingConnectionManagerId)
+ sendingConnection.close()
+ connectionsById -= sendingConnectionManagerId
+
+ messageStatuses.synchronized {
+ messageStatuses
+ .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
+ logInfo("Notifying " + status)
+ status.synchronized {
+ status.attempted = true
+ status.acked = false
+ status.notifyAll
+ }
+ })
+
+ messageStatuses.retain((i, status) => {
+ status.connectionManagerId != sendingConnectionManagerId
+ })
+ }
+ }
+ }
+
+ def handleConnectionError(connection: Connection, e: Exception) {
+ logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId)
+ removeConnection(connection)
+ }
+
+ def changeConnectionKeyInterest(connection: Connection, ops: Int) {
+ keyInterestChangeRequests += ((connection.key, ops))
+ }
+
+ def receiveMessage(connection: Connection, message: Message) {
+ val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
+ logInfo("Received [" + message + "] from [" + connectionManagerId + "]")
+ val runnable = new Runnable() {
+ val creationTime = System.currentTimeMillis
+ def run() {
+ logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
+ handleMessage(connectionManagerId, message)
+ logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
+ }
+ }
+ handleMessageExecutor.execute(runnable)
+ /*handleMessage(connection, message)*/
+ }
+
+ private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) {
+ logInfo("Handling [" + message + "] from [" + connectionManagerId + "]")
+ message match {
+ case bufferMessage: BufferMessage => {
+ if (bufferMessage.hasAckId) {
+ val sentMessageStatus = messageStatuses.synchronized {
+ messageStatuses.get(bufferMessage.ackId) match {
+ case Some(status) => {
+ messageStatuses -= bufferMessage.ackId
+ status
+ }
+ case None => {
+ throw new Exception("Could not find reference for received ack message " + message.id)
+ null
+ }
+ }
+ }
+ sentMessageStatus.synchronized {
+ sentMessageStatus.ackMessage = Some(message)
+ sentMessageStatus.attempted = true
+ sentMessageStatus.acked = true
+ sentMessageStatus.notifyAll
+ }
+ } else {
+ val ackMessage = if (onReceiveCallback != null) {
+ logDebug("Calling back")
+ onReceiveCallback(bufferMessage, connectionManagerId)
+ } else {
+ logWarning("Not calling back as callback is null")
+ None
+ }
+
+ if (ackMessage.isDefined) {
+ if (!ackMessage.get.isInstanceOf[BufferMessage]) {
+ logWarning("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
+ } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
+ logWarning("Response to " + bufferMessage + " does not have ack id set")
+ ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
+ }
+ }
+
+ sendMessage(connectionManagerId, ackMessage.getOrElse {
+ Message.createBufferMessage(bufferMessage.id)
+ })
+ }
+ }
+ case _ => throw new Exception("Unknown type message received")
+ }
+ }
+
+ private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
+ def startNewConnection(): SendingConnection = {
+ val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
+ val newConnection = new SendingConnection(inetSocketAddress, selector)
+ connectionRequests += newConnection
+ newConnection
+ }
+ val connection = connectionsById.getOrElse(connectionManagerId, startNewConnection)
+ message.senderAddress = id.toSocketAddress()
+ logInfo("Sending [" + message + "] to [" + connectionManagerId + "]")
+ /*connection.send(message)*/
+ sendMessageRequests.synchronized {
+ sendMessageRequests += ((message, connection))
+ }
+ selector.wakeup()
+ }
+
+ def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message): Future[Option[Message]] = {
+ val messageStatus = new MessageStatus(message, connectionManagerId)
+ messageStatuses.synchronized {
+ messageStatuses += ((message.id, messageStatus))
+ }
+ sendMessage(connectionManagerId, message)
+ future {
+ messageStatus.synchronized {
+ if (!messageStatus.attempted) {
+ logTrace("Waiting, " + messageStatuses.size + " statuses" )
+ messageStatus.wait()
+ logTrace("Done waiting")
+ }
+ }
+ messageStatus.ackMessage
+ }
+ }
+
+ def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = {
+ sendMessageReliably(connectionManagerId, message)()
+ }
+
+ def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {
+ onReceiveCallback = callback
+ }
+
+ def stop() {
+ selectorThread.interrupt()
+ selectorThread.join()
+ selector.close()
+ val connections = connectionsByKey.values
+ connections.foreach(_.close())
+ if (connectionsByKey.size != 0) {
+ logWarning("All connections not cleaned up")
+ }
+ handleMessageExecutor.shutdown()
+ logInfo("ConnectionManager stopped")
+ }
+}
+
+
+object ConnectionManager {
+
+ def main(args: Array[String]) {
+
+ val manager = new ConnectionManager(9999)
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ println("Received [" + msg + "] from [" + id + "]")
+ None
+ })
+
+ /*testSequentialSending(manager)*/
+ /*System.gc()*/
+
+ /*testParallelSending(manager)*/
+ /*System.gc()*/
+
+ /*testParallelDecreasingSending(manager)*/
+ /*System.gc()*/
+
+ testContinuousSending(manager)
+ System.gc()
+ }
+
+ def testSequentialSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Sequential Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliablySync(manager.id, bufferMessage)
+ })
+ println("--------------------------")
+ println()
+ }
+
+ def testParallelSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Parallel Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val startTime = System.currentTimeMillis
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(manager.id, bufferMessage)
+ }).foreach(f => {if (!f().isDefined) println("Failed")})
+ val finishTime = System.currentTimeMillis
+
+ val mb = size * count / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ val tput = mb * 1000.0 / ms
+ println("--------------------------")
+ println("Started at " + startTime + ", finished at " + finishTime)
+ println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)")
+ println("--------------------------")
+ println()
+ }
+
+ def testParallelDecreasingSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Parallel Decreasing Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+ val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte)))
+ buffers.foreach(_.flip)
+ val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0
+
+ val startTime = System.currentTimeMillis
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
+ manager.sendMessageReliably(manager.id, bufferMessage)
+ }).foreach(f => {if (!f().isDefined) println("Failed")})
+ val finishTime = System.currentTimeMillis
+
+ val ms = finishTime - startTime
+ val tput = mb * 1000.0 / ms
+ println("--------------------------")
+ /*println("Started at " + startTime + ", finished at " + finishTime) */
+ println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
+ println("--------------------------")
+ println()
+ }
+
+ def testContinuousSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Continuous Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val startTime = System.currentTimeMillis
+ while(true) {
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(manager.id, bufferMessage)
+ }).foreach(f => {if (!f().isDefined) println("Failed")})
+ val finishTime = System.currentTimeMillis
+ Thread.sleep(1000)
+ val mb = size * count / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ val tput = mb * 1000.0 / ms
+ println("--------------------------")
+ println()
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala
new file mode 100644
index 0000000000..5d21bb793f
--- /dev/null
+++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala
@@ -0,0 +1,74 @@
+package spark.network
+
+import spark._
+import spark.SparkContext._
+
+import scala.io.Source
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+
+object ConnectionManagerTest extends Logging{
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>")
+ System.exit(1)
+ }
+
+ if (args(0).startsWith("local")) {
+ println("This runs only on a mesos cluster")
+ }
+
+ val sc = new SparkContext(args(0), "ConnectionManagerTest")
+ val slavesFile = Source.fromFile(args(1))
+ val slaves = slavesFile.mkString.split("\n")
+ slavesFile.close()
+
+ /*println("Slaves")*/
+ /*slaves.foreach(println)*/
+
+ val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map(
+ i => SparkEnv.get.connectionManager.id).collect()
+ println("\nSlave ConnectionManagerIds")
+ slaveConnManagerIds.foreach(println)
+ println
+
+ val count = 10
+ (0 until count).foreach(i => {
+ val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => {
+ val connManager = SparkEnv.get.connectionManager
+ val thisConnManagerId = connManager.id
+ connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ logInfo("Received [" + msg + "] from [" + id + "]")
+ None
+ })
+
+ val size = 100 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val startTime = System.currentTimeMillis
+ val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
+ connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
+ })
+ val results = futures.map(f => f())
+ val finishTime = System.currentTimeMillis
+ Thread.sleep(5000)
+
+ val mb = size * results.size / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
+ logInfo(resultStr)
+ resultStr
+ }).collect()
+
+ println("---------------------")
+ println("Run " + i)
+ resultStrs.foreach(println)
+ println("---------------------")
+ })
+ }
+}
+
diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala
new file mode 100644
index 0000000000..2e85803679
--- /dev/null
+++ b/core/src/main/scala/spark/network/Message.scala
@@ -0,0 +1,219 @@
+package spark.network
+
+import spark._
+
+import scala.collection.mutable.ArrayBuffer
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+import java.net.InetSocketAddress
+
+class MessageChunkHeader(
+ val typ: Long,
+ val id: Int,
+ val totalSize: Int,
+ val chunkSize: Int,
+ val other: Int,
+ val address: InetSocketAddress) {
+ lazy val buffer = {
+ val ip = address.getAddress.getAddress()
+ val port = address.getPort()
+ ByteBuffer.
+ allocate(MessageChunkHeader.HEADER_SIZE).
+ putLong(typ).
+ putInt(id).
+ putInt(totalSize).
+ putInt(chunkSize).
+ putInt(other).
+ putInt(ip.size).
+ put(ip).
+ putInt(port).
+ position(MessageChunkHeader.HEADER_SIZE).
+ flip.asInstanceOf[ByteBuffer]
+ }
+
+ override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
+ " and sizes " + totalSize + " / " + chunkSize + " bytes"
+}
+
+class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
+ val size = if (buffer == null) 0 else buffer.remaining
+ lazy val buffers = {
+ val ab = new ArrayBuffer[ByteBuffer]()
+ ab += header.buffer
+ if (buffer != null) {
+ ab += buffer
+ }
+ ab
+ }
+
+ override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
+}
+
+abstract class Message(val typ: Long, val id: Int) {
+ var senderAddress: InetSocketAddress = null
+ var started = false
+ var startTime = -1L
+ var finishTime = -1L
+
+ def size: Int
+
+ def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
+
+ def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
+
+ def timeTaken(): String = (finishTime - startTime).toString + " ms"
+
+ override def toString = "" + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
+}
+
+class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
+extends Message(Message.BUFFER_MESSAGE, id_) {
+
+ val initialSize = currentSize()
+ var gotChunkForSendingOnce = false
+
+ def size = initialSize
+
+ def currentSize() = {
+ if (buffers == null || buffers.isEmpty) {
+ 0
+ } else {
+ buffers.map(_.remaining).reduceLeft(_ + _)
+ }
+ }
+
+ def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
+ if (maxChunkSize <= 0) {
+ throw new Exception("Max chunk size is " + maxChunkSize)
+ }
+
+ if (size == 0 && gotChunkForSendingOnce == false) {
+ val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+
+ while(!buffers.isEmpty) {
+ val buffer = buffers(0)
+ if (buffer.remaining == 0) {
+ buffers -= buffer
+ } else {
+ val newBuffer = if (buffer.remaining <= maxChunkSize) {
+ buffer.duplicate
+ } else {
+ buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
+ }
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+ }
+ None
+ }
+
+ def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
+ // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
+ if (buffers.size > 1) {
+ throw new Exception("Attempting to get chunk from message with multiple data buffers")
+ }
+ val buffer = buffers(0)
+ if (buffer.remaining > 0) {
+ if (buffer.remaining < chunkSize) {
+ throw new Exception("Not enough space in data buffer for receiving chunk")
+ }
+ val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ return Some(newChunk)
+ }
+ None
+ }
+
+ def flip() {
+ buffers.foreach(_.flip)
+ }
+
+ def hasAckId() = (ackId != 0)
+
+ def isCompletelyReceived() = !buffers(0).hasRemaining
+
+ override def toString = {
+ if (hasAckId) {
+ "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
+ } else {
+ "BufferMessage(id = " + id + ", size = " + size + ")"
+ }
+
+ }
+}
+
+object MessageChunkHeader {
+ val HEADER_SIZE = 40
+
+ def create(buffer: ByteBuffer): MessageChunkHeader = {
+ if (buffer.remaining != HEADER_SIZE) {
+ throw new IllegalArgumentException("Cannot convert buffer data to Message")
+ }
+ val typ = buffer.getLong()
+ val id = buffer.getInt()
+ val totalSize = buffer.getInt()
+ val chunkSize = buffer.getInt()
+ val other = buffer.getInt()
+ val ipSize = buffer.getInt()
+ val ipBytes = new Array[Byte](ipSize)
+ buffer.get(ipBytes)
+ val ip = InetAddress.getByAddress(ipBytes)
+ val port = buffer.getInt()
+ new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
+ }
+}
+
+object Message {
+ val BUFFER_MESSAGE = 1111111111L
+
+ var lastId = 1
+
+ def getNewId() = synchronized {
+ lastId += 1
+ if (lastId == 0) lastId += 1
+ lastId
+ }
+
+ def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
+ if (dataBuffers == null) {
+ return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
+ }
+ if (dataBuffers.exists(_ == null)) {
+ throw new Exception("Attempting to create buffer message with null buffer")
+ }
+ return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId)
+ }
+
+ def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
+ createBufferMessage(dataBuffers, 0)
+
+ def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
+ if (dataBuffer == null) {
+ return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
+ } else {
+ return createBufferMessage(Array(dataBuffer), ackId)
+ }
+ }
+
+ def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
+ createBufferMessage(dataBuffer, 0)
+
+ def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId)
+
+ def create(header: MessageChunkHeader): Message = {
+ val newMessage: Message = header.typ match {
+ case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
+ }
+ newMessage.senderAddress = header.address
+ newMessage
+ }
+}
diff --git a/core/src/main/scala/spark/network/ReceiverTest.scala b/core/src/main/scala/spark/network/ReceiverTest.scala
new file mode 100644
index 0000000000..e1ba7c06c0
--- /dev/null
+++ b/core/src/main/scala/spark/network/ReceiverTest.scala
@@ -0,0 +1,20 @@
+package spark.network
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+
+object ReceiverTest {
+
+ def main(args: Array[String]) {
+ val manager = new ConnectionManager(9999)
+ println("Started connection manager with id = " + manager.id)
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ /*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/
+ val buffer = ByteBuffer.wrap("response".getBytes())
+ Some(Message.createBufferMessage(buffer, msg.id))
+ })
+ Thread.currentThread.join()
+ }
+}
+
diff --git a/core/src/main/scala/spark/network/SenderTest.scala b/core/src/main/scala/spark/network/SenderTest.scala
new file mode 100644
index 0000000000..4ab6dd3414
--- /dev/null
+++ b/core/src/main/scala/spark/network/SenderTest.scala
@@ -0,0 +1,53 @@
+package spark.network
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+
+object SenderTest {
+
+ def main(args: Array[String]) {
+
+ if (args.length < 2) {
+ println("Usage: SenderTest <target host> <target port>")
+ System.exit(1)
+ }
+
+ val targetHost = args(0)
+ val targetPort = args(1).toInt
+ val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort)
+
+ val manager = new ConnectionManager(0)
+ println("Started connection manager with id = " + manager.id)
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ println("Received [" + msg + "] from [" + id + "]")
+ None
+ })
+
+ val size = 100 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val targetServer = args(0)
+
+ val count = 100
+ (0 until count).foreach(i => {
+ val dataMessage = Message.createBufferMessage(buffer.duplicate)
+ val startTime = System.currentTimeMillis
+ /*println("Started timer at " + startTime)*/
+ val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) match {
+ case Some(response) =>
+ val buffer = response.asInstanceOf[BufferMessage].buffers(0)
+ new String(buffer.array)
+ case None => "none"
+ }
+ val finishTime = System.currentTimeMillis
+ val mb = size / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ /*val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"*/
+ val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" + (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr
+ println(resultStr)
+ })
+ }
+}
+
diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
new file mode 100644
index 0000000000..260547902b
--- /dev/null
+++ b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
@@ -0,0 +1,66 @@
+package spark.partial
+
+import spark._
+import spark.scheduler.JobListener
+
+/**
+ * A JobListener for an approximate single-result action, such as count() or non-parallel reduce().
+ * This listener waits up to timeout milliseconds and will return a partial answer even if the
+ * complete answer is not available by then.
+ *
+ * This class assumes that the action is performed on an entire RDD[T] via a function that computes
+ * a result of type U for each partition, and that the action returns a partial or complete result
+ * of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt).
+ */
+class ApproximateActionListener[T, U, R](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ evaluator: ApproximateEvaluator[U, R],
+ timeout: Long)
+ extends JobListener {
+
+ val startTime = System.currentTimeMillis()
+ val totalTasks = rdd.splits.size
+ var finishedTasks = 0
+ var failure: Option[Exception] = None // Set if the job has failed (permanently)
+ var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult
+
+ override def taskSucceeded(index: Int, result: Any): Unit = synchronized {
+ evaluator.merge(index, result.asInstanceOf[U])
+ finishedTasks += 1
+ if (finishedTasks == totalTasks) {
+ // If we had already returned a PartialResult, set its final value
+ resultObject.foreach(r => r.setFinalValue(evaluator.currentResult()))
+ // Notify any waiting thread that may have called getResult
+ this.notifyAll()
+ }
+ }
+
+ override def jobFailed(exception: Exception): Unit = synchronized {
+ failure = Some(exception)
+ this.notifyAll()
+ }
+
+ /**
+ * Waits for up to timeout milliseconds since the listener was created and then returns a
+ * PartialResult with the result so far. This may be complete if the whole job is done.
+ */
+ def getResult(): PartialResult[R] = synchronized {
+ val finishTime = startTime + timeout
+ while (true) {
+ val time = System.currentTimeMillis()
+ if (failure != None) {
+ throw failure.get
+ } else if (finishedTasks == totalTasks) {
+ return new PartialResult(evaluator.currentResult(), true)
+ } else if (time >= finishTime) {
+ resultObject = Some(new PartialResult(evaluator.currentResult(), false))
+ return resultObject.get
+ } else {
+ this.wait(finishTime - time)
+ }
+ }
+ // Should never be reached, but required to keep the compiler happy
+ return null
+ }
+}
diff --git a/core/src/main/scala/spark/partial/ApproximateEvaluator.scala b/core/src/main/scala/spark/partial/ApproximateEvaluator.scala
new file mode 100644
index 0000000000..4772e43ef0
--- /dev/null
+++ b/core/src/main/scala/spark/partial/ApproximateEvaluator.scala
@@ -0,0 +1,10 @@
+package spark.partial
+
+/**
+ * An object that computes a function incrementally by merging in results of type U from multiple
+ * tasks. Allows partial evaluation at any point by calling currentResult().
+ */
+trait ApproximateEvaluator[U, R] {
+ def merge(outputId: Int, taskResult: U): Unit
+ def currentResult(): R
+}
diff --git a/core/src/main/scala/spark/partial/BoundedDouble.scala b/core/src/main/scala/spark/partial/BoundedDouble.scala
new file mode 100644
index 0000000000..463c33d6e2
--- /dev/null
+++ b/core/src/main/scala/spark/partial/BoundedDouble.scala
@@ -0,0 +1,8 @@
+package spark.partial
+
+/**
+ * A Double with error bars on it.
+ */
+class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
+ override def toString(): String = "[%.3f, %.3f]".format(low, high)
+}
diff --git a/core/src/main/scala/spark/partial/CountEvaluator.scala b/core/src/main/scala/spark/partial/CountEvaluator.scala
new file mode 100644
index 0000000000..1bc90d6b39
--- /dev/null
+++ b/core/src/main/scala/spark/partial/CountEvaluator.scala
@@ -0,0 +1,38 @@
+package spark.partial
+
+import cern.jet.stat.Probability
+
+/**
+ * An ApproximateEvaluator for counts.
+ *
+ * TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might
+ * be best to make this a special case of GroupedCountEvaluator with one group.
+ */
+class CountEvaluator(totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[Long, BoundedDouble] {
+
+ var outputsMerged = 0
+ var sum: Long = 0
+
+ override def merge(outputId: Int, taskResult: Long) {
+ outputsMerged += 1
+ sum += taskResult
+ }
+
+ override def currentResult(): BoundedDouble = {
+ if (outputsMerged == totalOutputs) {
+ new BoundedDouble(sum, 1.0, sum, sum)
+ } else if (outputsMerged == 0) {
+ new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
+ } else {
+ val p = outputsMerged.toDouble / totalOutputs
+ val mean = (sum + 1 - p) / p
+ val variance = (sum + 1) * (1 - p) / (p * p)
+ val stdev = math.sqrt(variance)
+ val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
+ val low = mean - confFactor * stdev
+ val high = mean + confFactor * stdev
+ new BoundedDouble(mean, confidence, low, high)
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala
new file mode 100644
index 0000000000..3e631c0efc
--- /dev/null
+++ b/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala
@@ -0,0 +1,62 @@
+package spark.partial
+
+import java.util.{HashMap => JHashMap}
+import java.util.{Map => JMap}
+
+import scala.collection.Map
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions.mapAsScalaMap
+
+import cern.jet.stat.Probability
+
+import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+
+/**
+ * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval.
+ */
+class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] {
+
+ var outputsMerged = 0
+ var sums = new OLMap[T] // Sum of counts for each key
+
+ override def merge(outputId: Int, taskResult: OLMap[T]) {
+ outputsMerged += 1
+ val iter = taskResult.object2LongEntrySet.fastIterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ sums.put(entry.getKey, sums.getLong(entry.getKey) + entry.getLongValue)
+ }
+ }
+
+ override def currentResult(): Map[T, BoundedDouble] = {
+ if (outputsMerged == totalOutputs) {
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.object2LongEntrySet.fastIterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val sum = entry.getLongValue()
+ result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum)
+ }
+ result
+ } else if (outputsMerged == 0) {
+ new HashMap[T, BoundedDouble]
+ } else {
+ val p = outputsMerged.toDouble / totalOutputs
+ val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.object2LongEntrySet.fastIterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val sum = entry.getLongValue
+ val mean = (sum + 1 - p) / p
+ val variance = (sum + 1) * (1 - p) / (p * p)
+ val stdev = math.sqrt(variance)
+ val low = mean - confFactor * stdev
+ val high = mean + confFactor * stdev
+ result(entry.getKey) = new BoundedDouble(mean, confidence, low, high)
+ }
+ result
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala
new file mode 100644
index 0000000000..2a9ccba205
--- /dev/null
+++ b/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala
@@ -0,0 +1,65 @@
+package spark.partial
+
+import java.util.{HashMap => JHashMap}
+import java.util.{Map => JMap}
+
+import scala.collection.mutable.HashMap
+import scala.collection.Map
+import scala.collection.JavaConversions.mapAsScalaMap
+
+import spark.util.StatCounter
+
+/**
+ * An ApproximateEvaluator for means by key. Returns a map of key to confidence interval.
+ */
+class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
+
+ var outputsMerged = 0
+ var sums = new JHashMap[T, StatCounter] // Sum of counts for each key
+
+ override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) {
+ outputsMerged += 1
+ val iter = taskResult.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val old = sums.get(entry.getKey)
+ if (old != null) {
+ old.merge(entry.getValue)
+ } else {
+ sums.put(entry.getKey, entry.getValue)
+ }
+ }
+ }
+
+ override def currentResult(): Map[T, BoundedDouble] = {
+ if (outputsMerged == totalOutputs) {
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val mean = entry.getValue.mean
+ result(entry.getKey) = new BoundedDouble(mean, 1.0, mean, mean)
+ }
+ result
+ } else if (outputsMerged == 0) {
+ new HashMap[T, BoundedDouble]
+ } else {
+ val p = outputsMerged.toDouble / totalOutputs
+ val studentTCacher = new StudentTCacher(confidence)
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val counter = entry.getValue
+ val mean = counter.mean
+ val stdev = math.sqrt(counter.sampleVariance / counter.count)
+ val confFactor = studentTCacher.get(counter.count)
+ val low = mean - confFactor * stdev
+ val high = mean + confFactor * stdev
+ result(entry.getKey) = new BoundedDouble(mean, confidence, low, high)
+ }
+ result
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala
new file mode 100644
index 0000000000..6a2ec7a7bd
--- /dev/null
+++ b/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala
@@ -0,0 +1,72 @@
+package spark.partial
+
+import java.util.{HashMap => JHashMap}
+import java.util.{Map => JMap}
+
+import scala.collection.mutable.HashMap
+import scala.collection.Map
+import scala.collection.JavaConversions.mapAsScalaMap
+
+import spark.util.StatCounter
+
+/**
+ * An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval.
+ */
+class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
+
+ var outputsMerged = 0
+ var sums = new JHashMap[T, StatCounter] // Sum of counts for each key
+
+ override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) {
+ outputsMerged += 1
+ val iter = taskResult.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val old = sums.get(entry.getKey)
+ if (old != null) {
+ old.merge(entry.getValue)
+ } else {
+ sums.put(entry.getKey, entry.getValue)
+ }
+ }
+ }
+
+ override def currentResult(): Map[T, BoundedDouble] = {
+ if (outputsMerged == totalOutputs) {
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val sum = entry.getValue.sum
+ result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum)
+ }
+ result
+ } else if (outputsMerged == 0) {
+ new HashMap[T, BoundedDouble]
+ } else {
+ val p = outputsMerged.toDouble / totalOutputs
+ val studentTCacher = new StudentTCacher(confidence)
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val counter = entry.getValue
+ val meanEstimate = counter.mean
+ val meanVar = counter.sampleVariance / counter.count
+ val countEstimate = (counter.count + 1 - p) / p
+ val countVar = (counter.count + 1) * (1 - p) / (p * p)
+ val sumEstimate = meanEstimate * countEstimate
+ val sumVar = (meanEstimate * meanEstimate * countVar) +
+ (countEstimate * countEstimate * meanVar) +
+ (meanVar * countVar)
+ val sumStdev = math.sqrt(sumVar)
+ val confFactor = studentTCacher.get(counter.count)
+ val low = sumEstimate - confFactor * sumStdev
+ val high = sumEstimate + confFactor * sumStdev
+ result(entry.getKey) = new BoundedDouble(sumEstimate, confidence, low, high)
+ }
+ result
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/partial/MeanEvaluator.scala b/core/src/main/scala/spark/partial/MeanEvaluator.scala
new file mode 100644
index 0000000000..b8c7cb8863
--- /dev/null
+++ b/core/src/main/scala/spark/partial/MeanEvaluator.scala
@@ -0,0 +1,41 @@
+package spark.partial
+
+import cern.jet.stat.Probability
+
+import spark.util.StatCounter
+
+/**
+ * An ApproximateEvaluator for means.
+ */
+class MeanEvaluator(totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[StatCounter, BoundedDouble] {
+
+ var outputsMerged = 0
+ var counter = new StatCounter
+
+ override def merge(outputId: Int, taskResult: StatCounter) {
+ outputsMerged += 1
+ counter.merge(taskResult)
+ }
+
+ override def currentResult(): BoundedDouble = {
+ if (outputsMerged == totalOutputs) {
+ new BoundedDouble(counter.mean, 1.0, counter.mean, counter.mean)
+ } else if (outputsMerged == 0) {
+ new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
+ } else {
+ val mean = counter.mean
+ val stdev = math.sqrt(counter.sampleVariance / counter.count)
+ val confFactor = {
+ if (counter.count > 100) {
+ Probability.normalInverse(1 - (1 - confidence) / 2)
+ } else {
+ Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
+ }
+ }
+ val low = mean - confFactor * stdev
+ val high = mean + confFactor * stdev
+ new BoundedDouble(mean, confidence, low, high)
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/partial/PartialResult.scala b/core/src/main/scala/spark/partial/PartialResult.scala
new file mode 100644
index 0000000000..7095bc8ca1
--- /dev/null
+++ b/core/src/main/scala/spark/partial/PartialResult.scala
@@ -0,0 +1,86 @@
+package spark.partial
+
+class PartialResult[R](initialVal: R, isFinal: Boolean) {
+ private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None
+ private var failure: Option[Exception] = None
+ private var completionHandler: Option[R => Unit] = None
+ private var failureHandler: Option[Exception => Unit] = None
+
+ def initialValue: R = initialVal
+
+ def isInitialValueFinal: Boolean = isFinal
+
+ /**
+ * Blocking method to wait for and return the final value.
+ */
+ def getFinalValue(): R = synchronized {
+ while (finalValue == None && failure == None) {
+ this.wait()
+ }
+ if (finalValue != None) {
+ return finalValue.get
+ } else {
+ throw failure.get
+ }
+ }
+
+ /**
+ * Set a handler to be called when this PartialResult completes. Only one completion handler
+ * is supported per PartialResult.
+ */
+ def onComplete(handler: R => Unit): PartialResult[R] = synchronized {
+ if (completionHandler != None) {
+ throw new UnsupportedOperationException("onComplete cannot be called twice")
+ }
+ completionHandler = Some(handler)
+ if (finalValue != None) {
+ // We already have a final value, so let's call the handler
+ handler(finalValue.get)
+ }
+ return this
+ }
+
+ /**
+ * Set a handler to be called if this PartialResult's job fails. Only one failure handler
+ * is supported per PartialResult.
+ */
+ def onFail(handler: Exception => Unit): Unit = synchronized {
+ if (failureHandler != None) {
+ throw new UnsupportedOperationException("onFail cannot be called twice")
+ }
+ failureHandler = Some(handler)
+ if (failure != None) {
+ // We already have a failure, so let's call the handler
+ handler(failure.get)
+ }
+ }
+
+ private[spark] def setFinalValue(value: R): Unit = synchronized {
+ if (finalValue != None) {
+ throw new UnsupportedOperationException("setFinalValue called twice on a PartialResult")
+ }
+ finalValue = Some(value)
+ // Call the completion handler if it was set
+ completionHandler.foreach(h => h(value))
+ // Notify any threads that may be calling getFinalValue()
+ this.notifyAll()
+ }
+
+ private[spark] def setFailure(exception: Exception): Unit = synchronized {
+ if (failure != None) {
+ throw new UnsupportedOperationException("setFailure called twice on a PartialResult")
+ }
+ failure = Some(exception)
+ // Call the failure handler if it was set
+ failureHandler.foreach(h => h(exception))
+ // Notify any threads that may be calling getFinalValue()
+ this.notifyAll()
+ }
+
+ override def toString: String = synchronized {
+ finalValue match {
+ case Some(value) => "(final: " + value + ")"
+ case None => "(partial: " + initialValue + ")"
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/partial/StudentTCacher.scala b/core/src/main/scala/spark/partial/StudentTCacher.scala
new file mode 100644
index 0000000000..6263ee3518
--- /dev/null
+++ b/core/src/main/scala/spark/partial/StudentTCacher.scala
@@ -0,0 +1,26 @@
+package spark.partial
+
+import cern.jet.stat.Probability
+
+/**
+ * A utility class for caching Student's T distribution values for a given confidence level
+ * and various sample sizes. This is used by the MeanEvaluator to efficiently calculate
+ * confidence intervals for many keys.
+ */
+class StudentTCacher(confidence: Double) {
+ val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation
+ val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2)
+ val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0)
+
+ def get(sampleSize: Long): Double = {
+ if (sampleSize >= NORMAL_APPROX_SAMPLE_SIZE) {
+ normalApprox
+ } else {
+ val size = sampleSize.toInt
+ if (cache(size) < 0) {
+ cache(size) = Probability.studentTInverse(1 - confidence, size - 1)
+ }
+ cache(size)
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/partial/SumEvaluator.scala b/core/src/main/scala/spark/partial/SumEvaluator.scala
new file mode 100644
index 0000000000..0357a6bff8
--- /dev/null
+++ b/core/src/main/scala/spark/partial/SumEvaluator.scala
@@ -0,0 +1,51 @@
+package spark.partial
+
+import cern.jet.stat.Probability
+
+import spark.util.StatCounter
+
+/**
+ * An ApproximateEvaluator for sums. It estimates the mean and the cont and multiplies them
+ * together, then uses the formula for the variance of two independent random variables to get
+ * a variance for the result and compute a confidence interval.
+ */
+class SumEvaluator(totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[StatCounter, BoundedDouble] {
+
+ var outputsMerged = 0
+ var counter = new StatCounter
+
+ override def merge(outputId: Int, taskResult: StatCounter) {
+ outputsMerged += 1
+ counter.merge(taskResult)
+ }
+
+ override def currentResult(): BoundedDouble = {
+ if (outputsMerged == totalOutputs) {
+ new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum)
+ } else if (outputsMerged == 0) {
+ new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
+ } else {
+ val p = outputsMerged.toDouble / totalOutputs
+ val meanEstimate = counter.mean
+ val meanVar = counter.sampleVariance / counter.count
+ val countEstimate = (counter.count + 1 - p) / p
+ val countVar = (counter.count + 1) * (1 - p) / (p * p)
+ val sumEstimate = meanEstimate * countEstimate
+ val sumVar = (meanEstimate * meanEstimate * countVar) +
+ (countEstimate * countEstimate * meanVar) +
+ (meanVar * countVar)
+ val sumStdev = math.sqrt(sumVar)
+ val confFactor = {
+ if (counter.count > 100) {
+ Probability.normalInverse(1 - (1 - confidence) / 2)
+ } else {
+ Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
+ }
+ }
+ val low = sumEstimate - confFactor * sumStdev
+ val high = sumEstimate + confFactor * sumStdev
+ new BoundedDouble(sumEstimate, confidence, low, high)
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/ActiveJob.scala b/core/src/main/scala/spark/scheduler/ActiveJob.scala
new file mode 100644
index 0000000000..0ecff9ce77
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/ActiveJob.scala
@@ -0,0 +1,18 @@
+package spark.scheduler
+
+import spark.TaskContext
+
+/**
+ * Tracks information about an active job in the DAGScheduler.
+ */
+class ActiveJob(
+ val runId: Int,
+ val finalStage: Stage,
+ val func: (TaskContext, Iterator[_]) => _,
+ val partitions: Array[Int],
+ val listener: JobListener) {
+
+ val numPartitions = partitions.length
+ val finished = Array.fill[Boolean](numPartitions)(false)
+ var numFinished = 0
+}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
new file mode 100644
index 0000000000..f31e2c65a0
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -0,0 +1,532 @@
+package spark.scheduler
+
+import java.net.URI
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.Future
+import java.util.concurrent.LinkedBlockingQueue
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map}
+
+import spark._
+import spark.partial.ApproximateActionListener
+import spark.partial.ApproximateEvaluator
+import spark.partial.PartialResult
+import spark.storage.BlockManagerMaster
+import spark.storage.BlockManagerId
+
+/**
+ * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
+ * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
+ * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
+ * and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
+ */
+class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
+ taskSched.setListener(this)
+
+ // Called by TaskScheduler to report task completions or failures.
+ override def taskEnded(
+ task: Task[_],
+ reason: TaskEndReason,
+ result: Any,
+ accumUpdates: Map[Long, Any]) {
+ eventQueue.put(CompletionEvent(task, reason, result, accumUpdates))
+ }
+
+ // Called by TaskScheduler when a host fails.
+ override def hostLost(host: String) {
+ eventQueue.put(HostLost(host))
+ }
+
+ // The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
+ // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
+ // as more failure events come in
+ val RESUBMIT_TIMEOUT = 50L
+
+ // The time, in millis, to wake up between polls of the completion queue in order to potentially
+ // resubmit failed stages
+ val POLL_TIMEOUT = 10L
+
+ private val lock = new Object // Used for access to the entire DAGScheduler
+
+ private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
+
+ val nextRunId = new AtomicInteger(0)
+
+ val nextStageId = new AtomicInteger(0)
+
+ val idToStage = new HashMap[Int, Stage]
+
+ val shuffleToMapStage = new HashMap[Int, Stage]
+
+ var cacheLocs = new HashMap[Int, Array[List[String]]]
+
+ val env = SparkEnv.get
+ val cacheTracker = env.cacheTracker
+ val mapOutputTracker = env.mapOutputTracker
+
+ val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back;
+ // that's not going to be a realistic assumption in general
+
+ val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
+ val running = new HashSet[Stage] // Stages we are running right now
+ val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
+ val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage
+ var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits
+
+ val activeJobs = new HashSet[ActiveJob]
+ val resultStageToJob = new HashMap[Stage, ActiveJob]
+
+ // Start a thread to run the DAGScheduler event loop
+ new Thread("DAGScheduler") {
+ setDaemon(true)
+ override def run() {
+ DAGScheduler.this.run()
+ }
+ }.start()
+
+ def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
+ cacheLocs(rdd.id)
+ }
+
+ def updateCacheLocs() {
+ cacheLocs = cacheTracker.getLocationsSnapshot()
+ }
+
+ /**
+ * Get or create a shuffle map stage for the given shuffle dependency's map side.
+ * The priority value passed in will be used if the stage doesn't already exist with
+ * a lower priority (we assume that priorities always increase across jobs for now).
+ */
+ def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_,_], priority: Int): Stage = {
+ shuffleToMapStage.get(shuffleDep.shuffleId) match {
+ case Some(stage) => stage
+ case None =>
+ val stage = newStage(shuffleDep.rdd, Some(shuffleDep), priority)
+ shuffleToMapStage(shuffleDep.shuffleId) = stage
+ stage
+ }
+ }
+
+ /**
+ * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or
+ * as a result stage for the final RDD used directly in an action. The stage will also be given
+ * the provided priority.
+ */
+ def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]], priority: Int): Stage = {
+ // Kind of ugly: need to register RDDs with the cache and map output tracker here
+ // since we can't do it in the RDD constructor because # of splits is unknown
+ logInfo("Registering RDD " + rdd.id + ": " + rdd)
+ cacheTracker.registerRDD(rdd.id, rdd.splits.size)
+ if (shuffleDep != None) {
+ mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
+ }
+ val id = nextStageId.getAndIncrement()
+ val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority)
+ idToStage(id) = stage
+ stage
+ }
+
+ /**
+ * Get or create the list of parent stages for a given RDD. The stages will be assigned the
+ * provided priority if they haven't already been created with a lower priority.
+ */
+ def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
+ val parents = new HashSet[Stage]
+ val visited = new HashSet[RDD[_]]
+ def visit(r: RDD[_]) {
+ if (!visited(r)) {
+ visited += r
+ // Kind of ugly: need to register RDDs with the cache here since
+ // we can't do it in its constructor because # of splits is unknown
+ logInfo("Registering parent RDD " + r.id + ": " + r)
+ cacheTracker.registerRDD(r.id, r.splits.size)
+ for (dep <- r.dependencies) {
+ dep match {
+ case shufDep: ShuffleDependency[_,_,_] =>
+ parents += getShuffleMapStage(shufDep, priority)
+ case _ =>
+ visit(dep.rdd)
+ }
+ }
+ }
+ }
+ visit(rdd)
+ parents.toList
+ }
+
+ def getMissingParentStages(stage: Stage): List[Stage] = {
+ val missing = new HashSet[Stage]
+ val visited = new HashSet[RDD[_]]
+ def visit(rdd: RDD[_]) {
+ if (!visited(rdd)) {
+ visited += rdd
+ val locs = getCacheLocs(rdd)
+ for (p <- 0 until rdd.splits.size) {
+ if (locs(p) == Nil) {
+ for (dep <- rdd.dependencies) {
+ dep match {
+ case shufDep: ShuffleDependency[_,_,_] =>
+ val mapStage = getShuffleMapStage(shufDep, stage.priority)
+ if (!mapStage.isAvailable) {
+ missing += mapStage
+ }
+ case narrowDep: NarrowDependency[_] =>
+ visit(narrowDep.rdd)
+ }
+ }
+ }
+ }
+ }
+ }
+ visit(stage.rdd)
+ missing.toList
+ }
+
+ def runJob[T, U](
+ finalRdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ allowLocal: Boolean)
+ (implicit m: ClassManifest[U]): Array[U] =
+ {
+ val waiter = new JobWaiter(partitions.size)
+ val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
+ eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, waiter))
+ waiter.getResult() match {
+ case JobSucceeded(results: Seq[_]) =>
+ return results.asInstanceOf[Seq[U]].toArray
+ case JobFailed(exception: Exception) =>
+ throw exception
+ }
+ }
+
+ def runApproximateJob[T, U, R](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ evaluator: ApproximateEvaluator[U, R],
+ timeout: Long
+ ): PartialResult[R] =
+ {
+ val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
+ val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
+ val partitions = (0 until rdd.splits.size).toArray
+ eventQueue.put(JobSubmitted(rdd, func2, partitions, false, listener))
+ return listener.getResult() // Will throw an exception if the job fails
+ }
+
+ /**
+ * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
+ * events and responds by launching tasks. This runs in a dedicated thread and receives events
+ * via the eventQueue.
+ */
+ def run() = {
+ SparkEnv.set(env)
+
+ while (true) {
+ val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
+ val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
+ if (event != null) {
+ logDebug("Got event of type " + event.getClass.getName)
+ }
+
+ event match {
+ case JobSubmitted(finalRDD, func, partitions, allowLocal, listener) =>
+ val runId = nextRunId.getAndIncrement()
+ val finalStage = newStage(finalRDD, None, runId)
+ val job = new ActiveJob(runId, finalStage, func, partitions, listener)
+ updateCacheLocs()
+ logInfo("Got job " + job.runId + " with " + partitions.length + " output partitions")
+ logInfo("Final stage: " + finalStage)
+ logInfo("Parents of final stage: " + finalStage.parents)
+ logInfo("Missing parents: " + getMissingParentStages(finalStage))
+ if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
+ // Compute very short actions like first() or take() with no parent stages locally.
+ runLocally(job)
+ } else {
+ activeJobs += job
+ resultStageToJob(finalStage) = job
+ submitStage(finalStage)
+ }
+
+ case HostLost(host) =>
+ handleHostLost(host)
+
+ case completion: CompletionEvent =>
+ handleTaskCompletion(completion)
+
+ case null =>
+ // queue.poll() timed out, ignore it
+ }
+
+ // Periodically resubmit failed stages if some map output fetches have failed and we have
+ // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
+ // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
+ // the same time, so we want to make sure we've identified all the reduce tasks that depend
+ // on the failed node.
+ if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
+ logInfo("Resubmitting failed stages")
+ updateCacheLocs()
+ val failed2 = failed.toArray
+ failed.clear()
+ for (stage <- failed2.sortBy(_.priority)) {
+ submitStage(stage)
+ }
+ } else {
+ // TODO: We might want to run this less often, when we are sure that something has become
+ // runnable that wasn't before.
+ logDebug("Checking for newly runnable parent stages")
+ logDebug("running: " + running)
+ logDebug("waiting: " + waiting)
+ logDebug("failed: " + failed)
+ val waiting2 = waiting.toArray
+ waiting.clear()
+ for (stage <- waiting2.sortBy(_.priority)) {
+ submitStage(stage)
+ }
+ }
+ }
+ }
+
+ /**
+ * Run a job on an RDD locally, assuming it has only a single partition and no dependencies.
+ * We run the operation in a separate thread just in case it takes a bunch of time, so that we
+ * don't block the DAGScheduler event loop or other concurrent jobs.
+ */
+ def runLocally(job: ActiveJob) {
+ logInfo("Computing the requested partition locally")
+ new Thread("Local computation of job " + job.runId) {
+ override def run() {
+ try {
+ SparkEnv.set(env)
+ val rdd = job.finalStage.rdd
+ val split = rdd.splits(job.partitions(0))
+ val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
+ val result = job.func(taskContext, rdd.iterator(split))
+ job.listener.taskSucceeded(0, result)
+ } catch {
+ case e: Exception =>
+ job.listener.jobFailed(e)
+ }
+ }
+ }.start()
+ }
+
+ def submitStage(stage: Stage) {
+ logDebug("submitStage(" + stage + ")")
+ if (!waiting(stage) && !running(stage) && !failed(stage)) {
+ val missing = getMissingParentStages(stage).sortBy(_.id)
+ logDebug("missing: " + missing)
+ if (missing == Nil) {
+ logInfo("Submitting " + stage + ", which has no missing parents")
+ submitMissingTasks(stage)
+ running += stage
+ } else {
+ for (parent <- missing) {
+ submitStage(parent)
+ }
+ waiting += stage
+ }
+ }
+ }
+
+ def submitMissingTasks(stage: Stage) {
+ logDebug("submitMissingTasks(" + stage + ")")
+ // Get our pending tasks and remember them in our pendingTasks entry
+ val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
+ myPending.clear()
+ var tasks = ArrayBuffer[Task[_]]()
+ if (stage.isShuffleMap) {
+ for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
+ val locs = getPreferredLocs(stage.rdd, p)
+ tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
+ }
+ } else {
+ // This is a final stage; figure out its job's missing partitions
+ val job = resultStageToJob(stage)
+ for (id <- 0 until job.numPartitions if (!job.finished(id))) {
+ val partition = job.partitions(id)
+ val locs = getPreferredLocs(stage.rdd, partition)
+ tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id)
+ }
+ }
+ if (tasks.size > 0) {
+ logInfo("Submitting " + tasks.size + " missing tasks from " + stage)
+ myPending ++= tasks
+ logDebug("New pending tasks: " + myPending)
+ taskSched.submitTasks(
+ new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority))
+ } else {
+ logDebug("Stage " + stage + " is actually done; %b %d %d".format(
+ stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
+ running -= stage
+ }
+ }
+
+ /**
+ * Responds to a task finishing. This is called inside the event loop so it assumes that it can
+ * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
+ */
+ def handleTaskCompletion(event: CompletionEvent) {
+ val task = event.task
+ val stage = idToStage(task.stageId)
+ event.reason match {
+ case Success =>
+ logInfo("Completed " + task)
+ if (event.accumUpdates != null) {
+ Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
+ }
+ pendingTasks(stage) -= task
+ task match {
+ case rt: ResultTask[_, _] =>
+ resultStageToJob.get(stage) match {
+ case Some(job) =>
+ if (!job.finished(rt.outputId)) {
+ job.finished(rt.outputId) = true
+ job.numFinished += 1
+ job.listener.taskSucceeded(rt.outputId, event.result)
+ // If the whole job has finished, remove it
+ if (job.numFinished == job.numPartitions) {
+ activeJobs -= job
+ resultStageToJob -= stage
+ running -= stage
+ }
+ }
+ case None =>
+ logInfo("Ignoring result from " + rt + " because its job has finished")
+ }
+
+ case smt: ShuffleMapTask =>
+ val stage = idToStage(smt.stageId)
+ val bmAddress = event.result.asInstanceOf[BlockManagerId]
+ val host = bmAddress.ip
+ logInfo("ShuffleMapTask finished with host " + host)
+ if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos
+ stage.addOutputLoc(smt.partition, bmAddress)
+ }
+ if (running.contains(stage) && pendingTasks(stage).isEmpty) {
+ logInfo(stage + " finished; looking for newly runnable stages")
+ running -= stage
+ logInfo("running: " + running)
+ logInfo("waiting: " + waiting)
+ logInfo("failed: " + failed)
+ if (stage.shuffleDep != None) {
+ mapOutputTracker.registerMapOutputs(
+ stage.shuffleDep.get.shuffleId,
+ stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray)
+ }
+ updateCacheLocs()
+ if (stage.outputLocs.count(_ == Nil) != 0) {
+ // Some tasks had failed; let's resubmit this stage
+ // TODO: Lower-level scheduler should also deal with this
+ logInfo("Resubmitting " + stage + " because some of its tasks had failed: " +
+ stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", "))
+ submitStage(stage)
+ } else {
+ val newlyRunnable = new ArrayBuffer[Stage]
+ for (stage <- waiting) {
+ logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage))
+ }
+ for (stage <- waiting if getMissingParentStages(stage) == Nil) {
+ newlyRunnable += stage
+ }
+ waiting --= newlyRunnable
+ running ++= newlyRunnable
+ for (stage <- newlyRunnable.sortBy(_.id)) {
+ submitMissingTasks(stage)
+ }
+ }
+ }
+ }
+
+ case Resubmitted =>
+ logInfo("Resubmitted " + task + ", so marking it as still running")
+ pendingTasks(stage) += task
+
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ // Mark the stage that the reducer was in as unrunnable
+ val failedStage = idToStage(task.stageId)
+ running -= failedStage
+ failed += failedStage
+ // TODO: Cancel running tasks in the stage
+ logInfo("Marking " + failedStage + " for resubmision due to a fetch failure")
+ // Mark the map whose fetch failed as broken in the map stage
+ val mapStage = shuffleToMapStage(shuffleId)
+ mapStage.removeOutputLoc(mapId, bmAddress)
+ mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
+ logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission")
+ failed += mapStage
+ // Remember that a fetch failed now; this is used to resubmit the broken
+ // stages later, after a small wait (to give other tasks the chance to fail)
+ lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock
+ // TODO: mark the host as failed only if there were lots of fetch failures on it
+ if (bmAddress != null) {
+ handleHostLost(bmAddress.ip)
+ }
+
+ case _ =>
+ // Non-fetch failure -- probably a bug in the job, so bail out
+ // TODO: Cancel all tasks that are still running
+ resultStageToJob.get(stage) match {
+ case Some(job) =>
+ val error = new SparkException("Task failed: " + task + ", reason: " + event.reason)
+ job.listener.jobFailed(error)
+ activeJobs -= job
+ resultStageToJob -= stage
+ case None =>
+ logInfo("Ignoring result from " + task + " because its job has finished")
+ }
+ }
+ }
+
+ /**
+ * Responds to a host being lost. This is called inside the event loop so it assumes that it can
+ * modify the scheduler's internal state. Use hostLost() to post a host lost event from outside.
+ */
+ def handleHostLost(host: String) {
+ if (!deadHosts.contains(host)) {
+ logInfo("Host lost: " + host)
+ deadHosts += host
+ BlockManagerMaster.notifyADeadHost(host)
+ // TODO: This will be really slow if we keep accumulating shuffle map stages
+ for ((shuffleId, stage) <- shuffleToMapStage) {
+ stage.removeOutputsOnHost(host)
+ val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
+ mapOutputTracker.registerMapOutputs(shuffleId, locs, true)
+ }
+ cacheTracker.cacheLost(host)
+ updateCacheLocs()
+ }
+ }
+
+ def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
+ // If the partition is cached, return the cache locations
+ val cached = getCacheLocs(rdd)(partition)
+ if (cached != Nil) {
+ return cached
+ }
+ // If the RDD has some placement preferences (as is the case for input RDDs), get those
+ val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList
+ if (rddPrefs != Nil) {
+ return rddPrefs
+ }
+ // If the RDD has narrow dependencies, pick the first partition of the first narrow dep
+ // that has any placement preferences. Ideally we would choose based on transfer sizes,
+ // but this will do for now.
+ rdd.dependencies.foreach(_ match {
+ case n: NarrowDependency[_] =>
+ for (inPart <- n.getParents(partition)) {
+ val locs = getPreferredLocs(n.rdd, inPart)
+ if (locs != Nil)
+ return locs;
+ }
+ case _ =>
+ })
+ return Nil
+ }
+
+ def stop() {
+ // TODO: Put a stop event on our queue and break the event loop
+ taskSched.stop()
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
new file mode 100644
index 0000000000..c10abc9202
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
@@ -0,0 +1,30 @@
+package spark.scheduler
+
+import scala.collection.mutable.Map
+
+import spark._
+
+/**
+ * Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue
+ * architecture where any thread can post an event (e.g. a task finishing or a new job being
+ * submitted) but there is a single "logic" thread that reads these events and takes decisions.
+ * This greatly simplifies synchronization.
+ */
+sealed trait DAGSchedulerEvent
+
+case class JobSubmitted(
+ finalRDD: RDD[_],
+ func: (TaskContext, Iterator[_]) => _,
+ partitions: Array[Int],
+ allowLocal: Boolean,
+ listener: JobListener)
+ extends DAGSchedulerEvent
+
+case class CompletionEvent(
+ task: Task[_],
+ reason: TaskEndReason,
+ result: Any,
+ accumUpdates: Map[Long, Any])
+ extends DAGSchedulerEvent
+
+case class HostLost(host: String) extends DAGSchedulerEvent
diff --git a/core/src/main/scala/spark/scheduler/JobListener.scala b/core/src/main/scala/spark/scheduler/JobListener.scala
new file mode 100644
index 0000000000..d4dd536a7d
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/JobListener.scala
@@ -0,0 +1,11 @@
+package spark.scheduler
+
+/**
+ * Interface used to listen for job completion or failure events after submitting a job to the
+ * DAGScheduler. The listener is notified each time a task succeeds, as well as if the whole
+ * job fails (and no further taskSucceeded events will happen).
+ */
+trait JobListener {
+ def taskSucceeded(index: Int, result: Any)
+ def jobFailed(exception: Exception)
+}
diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala
new file mode 100644
index 0000000000..62b458eccb
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/JobResult.scala
@@ -0,0 +1,9 @@
+package spark.scheduler
+
+/**
+ * A result of a job in the DAGScheduler.
+ */
+sealed trait JobResult
+
+case class JobSucceeded(results: Seq[_]) extends JobResult
+case class JobFailed(exception: Exception) extends JobResult
diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala
new file mode 100644
index 0000000000..be8ec9bd7b
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/JobWaiter.scala
@@ -0,0 +1,43 @@
+package spark.scheduler
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * An object that waits for a DAGScheduler job to complete.
+ */
+class JobWaiter(totalTasks: Int) extends JobListener {
+ private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null)
+ private var finishedTasks = 0
+
+ private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
+ private var jobResult: JobResult = null // If the job is finished, this will be its result
+
+ override def taskSucceeded(index: Int, result: Any) = synchronized {
+ if (jobFinished) {
+ throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
+ }
+ taskResults(index) = result
+ finishedTasks += 1
+ if (finishedTasks == totalTasks) {
+ jobFinished = true
+ jobResult = JobSucceeded(taskResults)
+ this.notifyAll()
+ }
+ }
+
+ override def jobFailed(exception: Exception) = synchronized {
+ if (jobFinished) {
+ throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter")
+ }
+ jobFinished = true
+ jobResult = JobFailed(exception)
+ this.notifyAll()
+ }
+
+ def getResult(): JobResult = synchronized {
+ while (!jobFinished) {
+ this.wait()
+ }
+ return jobResult
+ }
+}
diff --git a/core/src/main/scala/spark/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index 3952bf85b2..d2fab55b5e 100644
--- a/core/src/main/scala/spark/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -1,14 +1,15 @@
-package spark
+package spark.scheduler
+
+import spark._
class ResultTask[T, U](
- runId: Int,
- stageId: Int,
- rdd: RDD[T],
+ stageId: Int,
+ rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
- val partition: Int,
- locs: Seq[String],
+ val partition: Int,
+ @transient locs: Seq[String],
val outputId: Int)
- extends DAGTask[U](runId, stageId) {
+ extends Task[U](stageId) {
val split = rdd.splits(partition)
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
new file mode 100644
index 0000000000..317faa0851
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -0,0 +1,135 @@
+package spark.scheduler
+
+import java.io._
+import java.util.HashMap
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
+import com.ning.compress.lzf.LZFInputStream
+import com.ning.compress.lzf.LZFOutputStream
+
+import spark._
+import spark.storage._
+
+object ShuffleMapTask {
+ val serializedInfoCache = new HashMap[Int, Array[Byte]]
+ val deserializedInfoCache = new HashMap[Int, (RDD[_], ShuffleDependency[_,_,_])]
+
+ def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = {
+ synchronized {
+ val old = serializedInfoCache.get(stageId)
+ if (old != null) {
+ return old
+ } else {
+ val out = new ByteArrayOutputStream
+ val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
+ objOut.writeObject(rdd)
+ objOut.writeObject(dep)
+ objOut.close()
+ val bytes = out.toByteArray
+ serializedInfoCache.put(stageId, bytes)
+ return bytes
+ }
+ }
+ }
+
+ def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = {
+ synchronized {
+ val old = deserializedInfoCache.get(stageId)
+ if (old != null) {
+ return old
+ } else {
+ val loader = currentThread.getContextClassLoader
+ val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+ val objIn = new ObjectInputStream(in) {
+ override def resolveClass(desc: ObjectStreamClass) =
+ Class.forName(desc.getName, false, loader)
+ }
+ val rdd = objIn.readObject().asInstanceOf[RDD[_]]
+ val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]]
+ val tuple = (rdd, dep)
+ deserializedInfoCache.put(stageId, tuple)
+ return tuple
+ }
+ }
+ }
+}
+
+class ShuffleMapTask(
+ stageId: Int,
+ var rdd: RDD[_],
+ var dep: ShuffleDependency[_,_,_],
+ var partition: Int,
+ @transient var locs: Seq[String])
+ extends Task[BlockManagerId](stageId)
+ with Externalizable
+ with Logging {
+
+ def this() = this(0, null, null, 0, null)
+
+ var split = if (rdd == null) {
+ null
+ } else {
+ rdd.splits(partition)
+ }
+
+ override def writeExternal(out: ObjectOutput) {
+ out.writeInt(stageId)
+ val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ out.writeInt(partition)
+ out.writeObject(split)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ val stageId = in.readInt()
+ val numBytes = in.readInt()
+ val bytes = new Array[Byte](numBytes)
+ in.readFully(bytes)
+ val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
+ rdd = rdd_
+ dep = dep_
+ partition = in.readInt()
+ split = in.readObject().asInstanceOf[Split]
+ }
+
+ override def run(attemptId: Int): BlockManagerId = {
+ val numOutputSplits = dep.partitioner.numPartitions
+ val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
+ val partitioner = dep.partitioner.asInstanceOf[Partitioner]
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
+ for (elem <- rdd.iterator(split)) {
+ val (k, v) = elem.asInstanceOf[(Any, Any)]
+ var bucketId = partitioner.getPartition(k)
+ val bucket = buckets(bucketId)
+ var existing = bucket.get(k)
+ if (existing == null) {
+ bucket.put(k, aggregator.createCombiner(v))
+ } else {
+ bucket.put(k, aggregator.mergeValue(existing, v))
+ }
+ }
+ val ser = SparkEnv.get.serializer.newInstance()
+ val blockManager = SparkEnv.get.blockManager
+ for (i <- 0 until numOutputSplits) {
+ val blockId = "shuffleid_" + dep.shuffleId + "_" + partition + "_" + i
+ val arr = new ArrayBuffer[Any]
+ val iter = buckets(i).entrySet().iterator()
+ while (iter.hasNext()) {
+ val entry = iter.next()
+ arr += ((entry.getKey(), entry.getValue()))
+ }
+ // TODO: This should probably be DISK_ONLY
+ blockManager.put(blockId, arr.iterator, StorageLevel.MEMORY_ONLY, false)
+ }
+ return SparkEnv.get.blockManager.blockManagerId
+ }
+
+ override def preferredLocations: Seq[String] = locs
+
+ override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
+}
diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala
new file mode 100644
index 0000000000..cd660c9085
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/Stage.scala
@@ -0,0 +1,86 @@
+package spark.scheduler
+
+import java.net.URI
+
+import spark._
+import spark.storage.BlockManagerId
+
+/**
+ * A stage is a set of independent tasks all computing the same function that need to run as part
+ * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run
+ * by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the
+ * DAGScheduler runs these stages in topological order.
+ *
+ * Each Stage can either be a shuffle map stage, in which case its tasks' results are input for
+ * another stage, or a result stage, in which case its tasks directly compute the action that
+ * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes
+ * that each output partition is on.
+ *
+ * Each Stage also has a priority, which is (by default) based on the job it was submitted in.
+ * This allows Stages from earlier jobs to be computed first or recovered faster on failure.
+ */
+class Stage(
+ val id: Int,
+ val rdd: RDD[_],
+ val shuffleDep: Option[ShuffleDependency[_,_,_]], // Output shuffle if stage is a map stage
+ val parents: List[Stage],
+ val priority: Int)
+ extends Logging {
+
+ val isShuffleMap = shuffleDep != None
+ val numPartitions = rdd.splits.size
+ val outputLocs = Array.fill[List[BlockManagerId]](numPartitions)(Nil)
+ var numAvailableOutputs = 0
+
+ private var nextAttemptId = 0
+
+ def isAvailable: Boolean = {
+ if (/*parents.size == 0 &&*/ !isShuffleMap) {
+ true
+ } else {
+ numAvailableOutputs == numPartitions
+ }
+ }
+
+ def addOutputLoc(partition: Int, bmAddress: BlockManagerId) {
+ val prevList = outputLocs(partition)
+ outputLocs(partition) = bmAddress :: prevList
+ if (prevList == Nil)
+ numAvailableOutputs += 1
+ }
+
+ def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) {
+ val prevList = outputLocs(partition)
+ val newList = prevList.filterNot(_ == bmAddress)
+ outputLocs(partition) = newList
+ if (prevList != Nil && newList == Nil) {
+ numAvailableOutputs -= 1
+ }
+ }
+
+ def removeOutputsOnHost(host: String) {
+ var becameUnavailable = false
+ for (partition <- 0 until numPartitions) {
+ val prevList = outputLocs(partition)
+ val newList = prevList.filterNot(_.ip == host)
+ outputLocs(partition) = newList
+ if (prevList != Nil && newList == Nil) {
+ becameUnavailable = true
+ numAvailableOutputs -= 1
+ }
+ }
+ if (becameUnavailable) {
+ logInfo("%s is now unavailable on %s (%d/%d, %s)".format(this, host, numAvailableOutputs, numPartitions, isAvailable))
+ }
+ }
+
+ def newAttemptId(): Int = {
+ val id = nextAttemptId
+ nextAttemptId += 1
+ return id
+ }
+
+ override def toString = "Stage " + id // + ": [RDD = " + rdd.id + ", isShuffle = " + isShuffleMap + "]"
+
+ override def hashCode(): Int = id
+}
diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala
new file mode 100644
index 0000000000..42325956ba
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/Task.scala
@@ -0,0 +1,11 @@
+package spark.scheduler
+
+/**
+ * A task to execute on a worker node.
+ */
+abstract class Task[T](val stageId: Int) extends Serializable {
+ def run(attemptId: Int): T
+ def preferredLocations: Seq[String] = Nil
+
+ var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler.
+}
diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala
new file mode 100644
index 0000000000..868ddb237c
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/TaskResult.scala
@@ -0,0 +1,34 @@
+package spark.scheduler
+
+import java.io._
+
+import scala.collection.mutable.Map
+
+// Task result. Also contains updates to accumulator variables.
+// TODO: Use of distributed cache to return result is a hack to get around
+// what seems to be a bug with messages over 60KB in libprocess; fix it
+class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Externalizable {
+ def this() = this(null.asInstanceOf[T], null)
+
+ override def writeExternal(out: ObjectOutput) {
+ out.writeObject(value)
+ out.writeInt(accumUpdates.size)
+ for ((key, value) <- accumUpdates) {
+ out.writeLong(key)
+ out.writeObject(value)
+ }
+ }
+
+ override def readExternal(in: ObjectInput) {
+ value = in.readObject().asInstanceOf[T]
+ val numUpdates = in.readInt
+ if (numUpdates == 0) {
+ accumUpdates = null
+ } else {
+ accumUpdates = Map()
+ for (i <- 0 until numUpdates) {
+ accumUpdates(in.readLong()) = in.readObject()
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
new file mode 100644
index 0000000000..cb7c375d97
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
@@ -0,0 +1,27 @@
+package spark.scheduler
+
+/**
+ * Low-level task scheduler interface, implemented by both MesosScheduler and LocalScheduler.
+ * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
+ * and are responsible for sending the tasks to the cluster, running them, retrying if there
+ * are failures, and mitigating stragglers. They return events to the DAGScheduler through
+ * the TaskSchedulerListener interface.
+ */
+trait TaskScheduler {
+ def start(): Unit
+
+ // Wait for registration with Mesos.
+ def waitForRegister(): Unit
+
+ // Disconnect from the cluster.
+ def stop(): Unit
+
+ // Submit a sequence of tasks to run.
+ def submitTasks(taskSet: TaskSet): Unit
+
+ // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
+ def setListener(listener: TaskSchedulerListener): Unit
+
+ // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
+ def defaultParallelism(): Int
+}
diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
new file mode 100644
index 0000000000..a647eec9e4
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
@@ -0,0 +1,16 @@
+package spark.scheduler
+
+import scala.collection.mutable.Map
+
+import spark.TaskEndReason
+
+/**
+ * Interface for getting events back from the TaskScheduler.
+ */
+trait TaskSchedulerListener {
+ // A task has finished or failed.
+ def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit
+
+ // A node was lost from the cluster.
+ def hostLost(host: String): Unit
+}
diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala
new file mode 100644
index 0000000000..6f29dd2e9d
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/TaskSet.scala
@@ -0,0 +1,9 @@
+package spark.scheduler
+
+/**
+ * A set of tasks submitted together to the low-level TaskScheduler, usually representing
+ * missing partitions of a particular stage.
+ */
+class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) {
+ val id: String = stageId + "." + attempt
+}
diff --git a/core/src/main/scala/spark/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 3910c7b09e..8339c0ae90 100644
--- a/core/src/main/scala/spark/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -1,16 +1,21 @@
-package spark
+package spark.scheduler.local
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger
+import spark._
+import spark.scheduler._
+
/**
- * A simple Scheduler implementation that runs tasks locally in a thread pool. Optionally the
- * scheduler also allows each task to fail up to maxFailures times, which is useful for testing
- * fault recovery.
+ * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
+ * the scheduler also allows each task to fail up to maxFailures times, which is useful for
+ * testing fault recovery.
*/
-private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGScheduler with Logging {
+class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with Logging {
var attemptId = new AtomicInteger(0)
var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
+ val env = SparkEnv.get
+ var listener: TaskSchedulerListener = null
// TODO: Need to take into account stage priority in scheduling
@@ -18,7 +23,12 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
override def waitForRegister() {}
- override def submitTasks(tasks: Seq[Task[_]], runId: Int) {
+ override def setListener(listener: TaskSchedulerListener) {
+ this.listener = listener
+ }
+
+ override def submitTasks(taskSet: TaskSet) {
+ val tasks = taskSet.tasks
val failCount = new Array[Int](tasks.size)
def submitTask(task: Task[_], idInJob: Int) {
@@ -38,23 +48,14 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
Accumulators.clear
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val startTime = System.currentTimeMillis
- val bytes = ser.serialize(task)
- val timeTaken = System.currentTimeMillis - startTime
- logInfo("Size of task %d is %d bytes and took %d ms to serialize".format(
- idInJob, bytes.size, timeTaken))
- val deserializedTask = ser.deserialize[Task[_]](bytes, currentThread.getContextClassLoader)
+ val bytes = Utils.serialize(task)
+ logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes")
+ val deserializedTask = Utils.deserialize[Task[_]](
+ bytes, Thread.currentThread.getContextClassLoader)
val result: Any = deserializedTask.run(attemptId)
-
- // Serialize and deserialize the result to emulate what the mesos
- // executor does. This is useful to catch serialization errors early
- // on in development (so when users move their local Spark programs
- // to the cluster, they don't get surprised by serialization errors).
- val resultToReturn = ser.deserialize[Any](ser.serialize(result))
val accumUpdates = Accumulators.values
logInfo("Finished task " + idInJob)
- taskEnded(task, Success, resultToReturn, accumUpdates)
+ listener.taskEnded(task, Success, result, accumUpdates)
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
@@ -64,7 +65,7 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
submitTask(task, idInJob)
} else {
// TODO: Do something nicer here to return all the way to the user
- taskEnded(task, new ExceptionFailure(t), null, null)
+ listener.taskEnded(task, new ExceptionFailure(t), null, null)
}
}
}
diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala
new file mode 100644
index 0000000000..8182901ce3
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala
@@ -0,0 +1,364 @@
+package spark.scheduler.mesos
+
+import java.io.{File, FileInputStream, FileOutputStream}
+import java.util.{ArrayList => JArrayList}
+import java.util.{List => JList}
+import java.util.{HashMap => JHashMap}
+import java.util.concurrent._
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.collection.mutable.Map
+import scala.collection.mutable.PriorityQueue
+import scala.collection.JavaConversions._
+import scala.math.Ordering
+
+import akka.actor._
+import akka.actor.Actor
+import akka.actor.Actor._
+import akka.actor.Channel
+import akka.serialization.RemoteActorSerialization._
+
+import com.google.protobuf.ByteString
+
+import org.apache.mesos.{Scheduler => MScheduler}
+import org.apache.mesos._
+import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
+
+import spark._
+import spark.scheduler._
+
+sealed trait CoarseMesosSchedulerMessage
+case class RegisterSlave(slaveId: String, host: String, port: Int) extends CoarseMesosSchedulerMessage
+case class StatusUpdate(slaveId: String, status: TaskStatus) extends CoarseMesosSchedulerMessage
+case class LaunchTask(slaveId: String, task: MTaskInfo) extends CoarseMesosSchedulerMessage
+case class ReviveOffers() extends CoarseMesosSchedulerMessage
+
+case class FakeOffer(slaveId: String, host: String, cores: Int)
+
+/**
+ * Mesos scheduler that uses coarse-grained tasks and does its own fine-grained scheduling inside
+ * them using Akka actors for messaging. Clients should first call start(), then submit task sets
+ * through the runTasks method.
+ *
+ * TODO: This is a pretty big hack for now.
+ */
+class CoarseMesosScheduler(
+ sc: SparkContext,
+ master: String,
+ frameworkName: String)
+ extends MesosScheduler(sc, master, frameworkName) {
+
+ val CORES_PER_SLAVE = System.getProperty("spark.coarseMesosScheduler.coresPerSlave", "4").toInt
+
+ class MasterActor extends Actor {
+ val slaveActor = new HashMap[String, ActorRef]
+ val slaveHost = new HashMap[String, String]
+ val freeCores = new HashMap[String, Int]
+
+ def receive = {
+ case RegisterSlave(slaveId, host, port) =>
+ slaveActor(slaveId) = remote.actorFor("WorkerActor", host, port)
+ logInfo("Slave actor: " + slaveActor(slaveId))
+ slaveHost(slaveId) = host
+ freeCores(slaveId) = CORES_PER_SLAVE
+ makeFakeOffers()
+
+ case StatusUpdate(slaveId, status) =>
+ fakeStatusUpdate(status)
+ if (isFinished(status.getState)) {
+ freeCores(slaveId) += 1
+ makeFakeOffers(slaveId)
+ }
+
+ case LaunchTask(slaveId, task) =>
+ freeCores(slaveId) -= 1
+ slaveActor(slaveId) ! LaunchTask(slaveId, task)
+
+ case ReviveOffers() =>
+ logInfo("Reviving offers")
+ makeFakeOffers()
+ }
+
+ // Make fake resource offers for all slaves
+ def makeFakeOffers() {
+ fakeResourceOffers(slaveHost.toSeq.map{case (id, host) => FakeOffer(id, host, freeCores(id))})
+ }
+
+ // Make fake resource offers for all slaves
+ def makeFakeOffers(slaveId: String) {
+ fakeResourceOffers(Seq(FakeOffer(slaveId, slaveHost(slaveId), freeCores(slaveId))))
+ }
+ }
+
+ val masterActor: ActorRef = actorOf(new MasterActor)
+ remote.register("MasterActor", masterActor)
+ masterActor.start()
+
+ val taskIdsOnSlave = new HashMap[String, HashSet[String]]
+
+ /**
+ * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets
+ * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that
+ * tasks are balanced across the cluster.
+ */
+ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
+ synchronized {
+ val tasks = offers.map(o => new JArrayList[MTaskInfo])
+ for (i <- 0 until offers.size) {
+ val o = offers.get(i)
+ val slaveId = o.getSlaveId.getValue
+ if (!slaveIdToHost.contains(slaveId)) {
+ slaveIdToHost(slaveId) = o.getHostname
+ hostsAlive += o.getHostname
+ taskIdsOnSlave(slaveId) = new HashSet[String]
+ // Launch an infinite task on the node that will talk to the MasterActor to get fake tasks
+ val cpuRes = Resource.newBuilder()
+ .setName("cpus")
+ .setType(Value.Type.SCALAR)
+ .setScalar(Value.Scalar.newBuilder().setValue(1).build())
+ .build()
+ val task = new WorkerTask(slaveId, o.getHostname)
+ val serializedTask = Utils.serialize(task)
+ tasks(i).add(MTaskInfo.newBuilder()
+ .setTaskId(newTaskId())
+ .setSlaveId(o.getSlaveId)
+ .setExecutor(executorInfo)
+ .setName("worker task")
+ .addResources(cpuRes)
+ .setData(ByteString.copyFrom(serializedTask))
+ .build())
+ }
+ }
+ val filters = Filters.newBuilder().setRefuseSeconds(10).build()
+ for (i <- 0 until offers.size) {
+ d.launchTasks(offers(i).getId(), tasks(i), filters)
+ }
+ }
+ }
+
+ override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
+ val tid = status.getTaskId.getValue
+ var taskSetToUpdate: Option[TaskSetManager] = None
+ var taskFailed = false
+ synchronized {
+ try {
+ taskIdToTaskSetId.get(tid) match {
+ case Some(taskSetId) =>
+ if (activeTaskSets.contains(taskSetId)) {
+ //activeTaskSets(taskSetId).statusUpdate(status)
+ taskSetToUpdate = Some(activeTaskSets(taskSetId))
+ }
+ if (isFinished(status.getState)) {
+ taskIdToTaskSetId.remove(tid)
+ if (taskSetTaskIds.contains(taskSetId)) {
+ taskSetTaskIds(taskSetId) -= tid
+ }
+ val slaveId = taskIdToSlaveId(tid)
+ taskIdToSlaveId -= tid
+ taskIdsOnSlave(slaveId) -= tid
+ }
+ if (status.getState == TaskState.TASK_FAILED) {
+ taskFailed = true
+ }
+ case None =>
+ logInfo("Ignoring update from TID " + tid + " because its task set is gone")
+ }
+ } catch {
+ case e: Exception => logError("Exception in statusUpdate", e)
+ }
+ }
+ // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock
+ if (taskSetToUpdate != None) {
+ taskSetToUpdate.get.statusUpdate(status)
+ }
+ if (taskFailed) {
+ // Revive offers if a task had failed for some reason other than host lost
+ reviveOffers()
+ }
+ }
+
+ override def slaveLost(d: SchedulerDriver, s: SlaveID) {
+ logInfo("Slave lost: " + s.getValue)
+ var failedHost: Option[String] = None
+ var lostTids: Option[HashSet[String]] = None
+ synchronized {
+ val slaveId = s.getValue
+ val host = slaveIdToHost(slaveId)
+ if (hostsAlive.contains(host)) {
+ slaveIdsWithExecutors -= slaveId
+ hostsAlive -= host
+ failedHost = Some(host)
+ lostTids = Some(taskIdsOnSlave(slaveId))
+ logInfo("failedHost: " + host)
+ logInfo("lostTids: " + lostTids)
+ taskIdsOnSlave -= slaveId
+ activeTaskSetsQueue.foreach(_.hostLost(host))
+ }
+ }
+ if (failedHost != None) {
+ // Report all the tasks on the failed host as lost, without holding a lock on this
+ for (tid <- lostTids.get; taskSetId <- taskIdToTaskSetId.get(tid)) {
+ // TODO: Maybe call our statusUpdate() instead to clean our internal data structures
+ activeTaskSets(taskSetId).statusUpdate(TaskStatus.newBuilder()
+ .setTaskId(TaskID.newBuilder().setValue(tid).build())
+ .setState(TaskState.TASK_LOST)
+ .build())
+ }
+ // Also report the loss to the DAGScheduler
+ listener.hostLost(failedHost.get)
+ reviveOffers();
+ }
+ }
+
+ override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
+
+ // Check for speculatable tasks in all our active jobs.
+ override def checkSpeculatableTasks() {
+ var shouldRevive = false
+ synchronized {
+ for (ts <- activeTaskSetsQueue) {
+ shouldRevive |= ts.checkSpeculatableTasks()
+ }
+ }
+ if (shouldRevive) {
+ reviveOffers()
+ }
+ }
+
+
+ val lock2 = new Object
+ var firstWait = true
+
+ override def waitForRegister() {
+ lock2.synchronized {
+ if (firstWait) {
+ super.waitForRegister()
+ Thread.sleep(5000)
+ firstWait = false
+ }
+ }
+ }
+
+ def fakeStatusUpdate(status: TaskStatus) {
+ statusUpdate(driver, status)
+ }
+
+ def fakeResourceOffers(offers: Seq[FakeOffer]) {
+ logDebug("fakeResourceOffers: " + offers)
+ val availableCpus = offers.map(_.cores.toDouble).toArray
+ var launchedTask = false
+ for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
+ do {
+ launchedTask = false
+ for (i <- 0 until offers.size if hostsAlive.contains(offers(i).host)) {
+ manager.slaveOffer(offers(i).slaveId, offers(i).host, availableCpus(i)) match {
+ case Some(task) =>
+ val tid = task.getTaskId.getValue
+ val sid = offers(i).slaveId
+ taskIdToTaskSetId(tid) = manager.taskSet.id
+ taskSetTaskIds(manager.taskSet.id) += tid
+ taskIdToSlaveId(tid) = sid
+ taskIdsOnSlave(sid) += tid
+ slaveIdsWithExecutors += sid
+ availableCpus(i) -= getResource(task.getResourcesList(), "cpus")
+ launchedTask = true
+ masterActor ! LaunchTask(sid, task)
+
+ case None => {}
+ }
+ }
+ } while (launchedTask)
+ }
+ }
+
+ override def reviveOffers() {
+ masterActor ! ReviveOffers()
+ }
+}
+
+class WorkerTask(slaveId: String, host: String) extends Task[Unit](-1) {
+ generation = 0
+
+ def run(id: Int): Unit = {
+ val actor = actorOf(new WorkerActor(slaveId, host))
+ if (!remote.isRunning) {
+ remote.start(Utils.localIpAddress, 7078)
+ }
+ remote.register("WorkerActor", actor)
+ actor.start()
+ while (true) {
+ Thread.sleep(10000)
+ }
+ }
+}
+
+class WorkerActor(slaveId: String, host: String) extends Actor with Logging {
+ val env = SparkEnv.get
+ val classLoader = currentThread.getContextClassLoader
+ val threadPool = new ThreadPoolExecutor(
+ 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
+
+ val masterIp: String = System.getProperty("spark.master.host", "localhost")
+ val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt
+ val masterActor = remote.actorFor("MasterActor", masterIp, masterPort)
+
+ class TaskRunner(desc: MTaskInfo)
+ extends Runnable {
+ override def run() = {
+ val tid = desc.getTaskId.getValue
+ logInfo("Running task ID " + tid)
+ try {
+ SparkEnv.set(env)
+ Thread.currentThread.setContextClassLoader(classLoader)
+ Accumulators.clear
+ val task = Utils.deserialize[Task[Any]](desc.getData.toByteArray, classLoader)
+ env.mapOutputTracker.updateGeneration(task.generation)
+ val value = task.run(tid.toInt)
+ val accumUpdates = Accumulators.values
+ val result = new TaskResult(value, accumUpdates)
+ masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
+ .setTaskId(desc.getTaskId)
+ .setState(TaskState.TASK_FINISHED)
+ .setData(ByteString.copyFrom(Utils.serialize(result)))
+ .build())
+ logInfo("Finished task ID " + tid)
+ } catch {
+ case ffe: FetchFailedException => {
+ val reason = ffe.toTaskEndReason
+ masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
+ .setTaskId(desc.getTaskId)
+ .setState(TaskState.TASK_FAILED)
+ .setData(ByteString.copyFrom(Utils.serialize(reason)))
+ .build())
+ }
+ case t: Throwable => {
+ val reason = ExceptionFailure(t)
+ masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
+ .setTaskId(desc.getTaskId)
+ .setState(TaskState.TASK_FAILED)
+ .setData(ByteString.copyFrom(Utils.serialize(reason)))
+ .build())
+
+ // TODO: Should we exit the whole executor here? On the one hand, the failed task may
+ // have left some weird state around depending on when the exception was thrown, but on
+ // the other hand, maybe we could detect that when future tasks fail and exit then.
+ logError("Exception in task ID " + tid, t)
+ //System.exit(1)
+ }
+ }
+ }
+ }
+
+ override def preStart {
+ val ref = toRemoteActorRefProtocol(self).toByteArray
+ logInfo("Registering with master")
+ masterActor ! RegisterSlave(slaveId, host, remote.address.getPort)
+ }
+
+ override def receive = {
+ case LaunchTask(slaveId, task) =>
+ threadPool.execute(new TaskRunner(task))
+ }
+}
diff --git a/core/src/main/scala/spark/MesosScheduler.scala b/core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala
index a7711e0d35..f72618c03f 100644
--- a/core/src/main/scala/spark/MesosScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala
@@ -1,4 +1,4 @@
-package spark
+package spark.scheduler.mesos
import java.io.{File, FileInputStream, FileOutputStream}
import java.util.{ArrayList => JArrayList}
@@ -17,20 +17,23 @@ import com.google.protobuf.ByteString
import org.apache.mesos.{Scheduler => MScheduler}
import org.apache.mesos._
-import org.apache.mesos.Protos._
+import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
+
+import spark._
+import spark.scheduler._
/**
- * The main Scheduler implementation, which runs jobs on Mesos. Clients should first call start(),
- * then submit tasks through the runTasks method.
+ * The main TaskScheduler implementation, which runs tasks on Mesos. Clients should first call
+ * start(), then submit task sets through the runTasks method.
*/
-private class MesosScheduler(
+class MesosScheduler(
sc: SparkContext,
master: String,
frameworkName: String)
- extends MScheduler
- with DAGScheduler
+ extends TaskScheduler
+ with MScheduler
with Logging {
-
+
// Environment variables to pass to our executors
val ENV_VARS_TO_SEND_TO_EXECUTORS = Array(
"SPARK_MEM",
@@ -49,55 +52,60 @@ private class MesosScheduler(
}
}
+ // How often to check for speculative tasks
+ val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
+
// Lock used to wait for scheduler to be registered
- private var isRegistered = false
- private val registeredLock = new Object()
+ var isRegistered = false
+ val registeredLock = new Object()
- private val activeJobs = new HashMap[Int, Job]
- private var activeJobsQueue = new ArrayBuffer[Job]
+ val activeTaskSets = new HashMap[String, TaskSetManager]
+ var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
- private val taskIdToJobId = new HashMap[String, Int]
- private val taskIdToSlaveId = new HashMap[String, String]
- private val jobTasks = new HashMap[Int, HashSet[String]]
+ val taskIdToTaskSetId = new HashMap[String, String]
+ val taskIdToSlaveId = new HashMap[String, String]
+ val taskSetTaskIds = new HashMap[String, HashSet[String]]
- // Incrementing job and task IDs
- private var nextJobId = 0
- private var nextTaskId = 0
+ // Incrementing Mesos task IDs
+ var nextTaskId = 0
// Driver for talking to Mesos
var driver: SchedulerDriver = null
- // Which nodes we have executors on
- private val slavesWithExecutors = new HashSet[String]
+ // Which hosts in the cluster are alive (contains hostnames)
+ val hostsAlive = new HashSet[String]
+
+ // Which slave IDs we have executors on
+ val slaveIdsWithExecutors = new HashSet[String]
+
+ val slaveIdToHost = new HashMap[String, String]
// JAR server, if any JARs were added by the user to the SparkContext
var jarServer: HttpServer = null
// URIs of JARs to pass to executor
var jarUris: String = ""
-
+
// Create an ExecutorInfo for our tasks
val executorInfo = createExecutorInfo()
- // Sorts jobs in reverse order of run ID for use in our priority queue (so lower IDs run first)
- private val jobOrdering = new Ordering[Job] {
- override def compare(j1: Job, j2: Job): Int = j2.runId - j1.runId
- }
-
- def newJobId(): Int = this.synchronized {
- val id = nextJobId
- nextJobId += 1
- return id
+ // Listener object to pass upcalls into
+ var listener: TaskSchedulerListener = null
+
+ val mapOutputTracker = SparkEnv.get.mapOutputTracker
+
+ override def setListener(listener: TaskSchedulerListener) {
+ this.listener = listener
}
def newTaskId(): TaskID = {
- val id = "" + nextTaskId;
- nextTaskId += 1;
- return TaskID.newBuilder().setValue(id).build()
+ val id = TaskID.newBuilder().setValue("" + nextTaskId).build()
+ nextTaskId += 1
+ return id
}
override def start() {
- new Thread("Spark scheduler") {
+ new Thread("MesosScheduler driver") {
setDaemon(true)
override def run {
val sched = MesosScheduler.this
@@ -110,12 +118,27 @@ private class MesosScheduler(
case e: Exception => logError("driver.run() failed", e)
}
}
- }.start
+ }.start()
+ if (System.getProperty("spark.speculation", "false") == "true") {
+ new Thread("MesosScheduler speculation check") {
+ setDaemon(true)
+ override def run {
+ waitForRegister()
+ while (true) {
+ try {
+ Thread.sleep(SPECULATION_INTERVAL)
+ } catch { case e: InterruptedException => {} }
+ checkSpeculatableTasks()
+ }
+ }
+ }.start()
+ }
}
def createExecutorInfo(): ExecutorInfo = {
val sparkHome = sc.getSparkHome match {
- case Some(path) => path
+ case Some(path) =>
+ path
case None =>
throw new SparkException("Spark home is not set; set it through the spark.home system " +
"property, the SPARK_HOME environment variable or the SparkContext constructor")
@@ -151,27 +174,26 @@ private class MesosScheduler(
.build()
}
- def submitTasks(tasks: Seq[Task[_]], runId: Int) {
- logInfo("Got a job with " + tasks.size + " tasks")
+ def submitTasks(taskSet: TaskSet) {
+ val tasks = taskSet.tasks
+ logInfo("Adding task set " + taskSet.id + " with " + tasks.size + " tasks")
waitForRegister()
this.synchronized {
- val jobId = newJobId()
- val myJob = new SimpleJob(this, tasks, runId, jobId)
- activeJobs(jobId) = myJob
- activeJobsQueue += myJob
- logInfo("Adding job with ID " + jobId)
- jobTasks(jobId) = HashSet.empty[String]
+ val manager = new TaskSetManager(this, taskSet)
+ activeTaskSets(taskSet.id) = manager
+ activeTaskSetsQueue += manager
+ taskSetTaskIds(taskSet.id) = new HashSet()
}
- driver.reviveOffers();
+ reviveOffers();
}
- def jobFinished(job: Job) {
+ def taskSetFinished(manager: TaskSetManager) {
this.synchronized {
- activeJobs -= job.jobId
- activeJobsQueue -= job
- taskIdToJobId --= jobTasks(job.jobId)
- taskIdToSlaveId --= jobTasks(job.jobId)
- jobTasks.remove(job.jobId)
+ activeTaskSets -= manager.taskSet.id
+ activeTaskSetsQueue -= manager
+ taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
+ taskIdToSlaveId --= taskSetTaskIds(manager.taskSet.id)
+ taskSetTaskIds.remove(manager.taskSet.id)
}
}
@@ -196,33 +218,40 @@ private class MesosScheduler(
override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {}
/**
- * Method called by Mesos to offer resources on slaves. We resond by asking our active jobs for
- * tasks in FIFO order. We fill each node with tasks in a round-robin manner so that tasks are
- * balanced across the cluster.
+ * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets
+ * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that
+ * tasks are balanced across the cluster.
*/
override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
synchronized {
- val tasks = offers.map(o => new JArrayList[TaskInfo])
+ // Mark each slave as alive and remember its hostname
+ for (o <- offers) {
+ slaveIdToHost(o.getSlaveId.getValue) = o.getHostname
+ hostsAlive += o.getHostname
+ }
+ // Build a list of tasks to assign to each slave
+ val tasks = offers.map(o => new JArrayList[MTaskInfo])
val availableCpus = offers.map(o => getResource(o.getResourcesList(), "cpus"))
val enoughMem = offers.map(o => {
val mem = getResource(o.getResourcesList(), "mem")
val slaveId = o.getSlaveId.getValue
- mem >= EXECUTOR_MEMORY || slavesWithExecutors.contains(slaveId)
+ mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId)
})
var launchedTask = false
- for (job <- activeJobsQueue.sorted(jobOrdering)) {
+ for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
do {
launchedTask = false
for (i <- 0 until offers.size if enoughMem(i)) {
- job.slaveOffer(offers(i), availableCpus(i)) match {
+ val sid = offers(i).getSlaveId.getValue
+ val host = offers(i).getHostname
+ manager.slaveOffer(sid, host, availableCpus(i)) match {
case Some(task) =>
tasks(i).add(task)
val tid = task.getTaskId.getValue
- val sid = offers(i).getSlaveId.getValue
- taskIdToJobId(tid) = job.jobId
- jobTasks(job.jobId) += tid
+ taskIdToTaskSetId(tid) = manager.taskSet.id
+ taskSetTaskIds(manager.taskSet.id) += tid
taskIdToSlaveId(tid) = sid
- slavesWithExecutors += sid
+ slaveIdsWithExecutors += sid
availableCpus(i) -= getResource(task.getResourcesList(), "cpus")
launchedTask = true
@@ -256,53 +285,74 @@ private class MesosScheduler(
}
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
- var jobToUpdate: Option[Job] = None
+ val tid = status.getTaskId.getValue
+ var taskSetToUpdate: Option[TaskSetManager] = None
+ var failedHost: Option[String] = None
+ var taskFailed = false
synchronized {
try {
- val tid = status.getTaskId.getValue
- if (status.getState == TaskState.TASK_LOST
- && taskIdToSlaveId.contains(tid)) {
+ if (status.getState == TaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) {
// We lost the executor on this slave, so remember that it's gone
- slavesWithExecutors -= taskIdToSlaveId(tid)
+ val slaveId = taskIdToSlaveId(tid)
+ val host = slaveIdToHost(slaveId)
+ if (hostsAlive.contains(host)) {
+ slaveIdsWithExecutors -= slaveId
+ hostsAlive -= host
+ activeTaskSetsQueue.foreach(_.hostLost(host))
+ failedHost = Some(host)
+ }
}
- taskIdToJobId.get(tid) match {
- case Some(jobId) =>
- if (activeJobs.contains(jobId)) {
- jobToUpdate = Some(activeJobs(jobId))
+ taskIdToTaskSetId.get(tid) match {
+ case Some(taskSetId) =>
+ if (activeTaskSets.contains(taskSetId)) {
+ //activeTaskSets(taskSetId).statusUpdate(status)
+ taskSetToUpdate = Some(activeTaskSets(taskSetId))
}
if (isFinished(status.getState)) {
- taskIdToJobId.remove(tid)
- if (jobTasks.contains(jobId)) {
- jobTasks(jobId) -= tid
+ taskIdToTaskSetId.remove(tid)
+ if (taskSetTaskIds.contains(taskSetId)) {
+ taskSetTaskIds(taskSetId) -= tid
}
taskIdToSlaveId.remove(tid)
}
+ if (status.getState == TaskState.TASK_FAILED) {
+ taskFailed = true
+ }
case None =>
- logInfo("Ignoring update from TID " + tid + " because its job is gone")
+ logInfo("Ignoring update from TID " + tid + " because its task set is gone")
}
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
}
}
- for (j <- jobToUpdate) {
- j.statusUpdate(status)
+ // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock
+ if (taskSetToUpdate != None) {
+ taskSetToUpdate.get.statusUpdate(status)
+ }
+ if (failedHost != None) {
+ listener.hostLost(failedHost.get)
+ reviveOffers();
+ }
+ if (taskFailed) {
+ // Also revive offers if a task had failed for some reason other than host lost
+ reviveOffers()
}
}
override def error(d: SchedulerDriver, message: String) {
logError("Mesos error: " + message)
synchronized {
- if (activeJobs.size > 0) {
- // Have each job throw a SparkException with the error
- for ((jobId, activeJob) <- activeJobs) {
+ if (activeTaskSets.size > 0) {
+ // Have each task set throw a SparkException with the error
+ for ((taskSetId, manager) <- activeTaskSets) {
try {
- activeJob.error(message)
+ manager.error(message)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
}
} else {
- // No jobs are active but we still got an error. Just exit since this
+ // No task sets are active but we still got an error. Just exit since this
// must mean the error is during registration.
// It might be good to do something smarter here in the future.
System.exit(1)
@@ -373,41 +423,68 @@ private class MesosScheduler(
return Utils.serialize(props.toArray)
}
- override def frameworkMessage(
- d: SchedulerDriver,
- e: ExecutorID,
- s: SlaveID,
- b: Array[Byte]) {}
+ override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
override def slaveLost(d: SchedulerDriver, s: SlaveID) {
- slavesWithExecutors.remove(s.getValue)
+ var failedHost: Option[String] = None
+ synchronized {
+ val slaveId = s.getValue
+ val host = slaveIdToHost(slaveId)
+ if (hostsAlive.contains(host)) {
+ slaveIdsWithExecutors -= slaveId
+ hostsAlive -= host
+ activeTaskSetsQueue.foreach(_.hostLost(host))
+ failedHost = Some(host)
+ }
+ }
+ if (failedHost != None) {
+ listener.hostLost(failedHost.get)
+ reviveOffers();
+ }
}
override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) {
- slavesWithExecutors.remove(s.getValue)
+ logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue))
+ slaveLost(d, s)
}
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
+
+ // Check for speculatable tasks in all our active jobs.
+ def checkSpeculatableTasks() {
+ var shouldRevive = false
+ synchronized {
+ for (ts <- activeTaskSetsQueue) {
+ shouldRevive |= ts.checkSpeculatableTasks()
+ }
+ }
+ if (shouldRevive) {
+ reviveOffers()
+ }
+ }
+
+ def reviveOffers() {
+ driver.reviveOffers()
+ }
}
object MesosScheduler {
/**
- * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
- * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM
+ * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
+ * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM
* environment variable.
*/
def memoryStringToMb(str: String): Int = {
val lower = str.toLowerCase
if (lower.endsWith("k")) {
- (lower.substring(0, lower.length - 1).toLong / 1024).toInt
+ (lower.substring(0, lower.length-1).toLong / 1024).toInt
} else if (lower.endsWith("m")) {
- lower.substring(0, lower.length - 1).toInt
+ lower.substring(0, lower.length-1).toInt
} else if (lower.endsWith("g")) {
- lower.substring(0, lower.length - 1).toInt * 1024
+ lower.substring(0, lower.length-1).toInt * 1024
} else if (lower.endsWith("t")) {
- lower.substring(0, lower.length - 1).toInt * 1024 * 1024
- } else {
- // no suffix, so it's just a number in bytes
+ lower.substring(0, lower.length-1).toInt * 1024 * 1024
+ } else {// no suffix, so it's just a number in bytes
(lower.toLong / 1024 / 1024).toInt
}
}
diff --git a/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala b/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala
new file mode 100644
index 0000000000..af2f80ea66
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala
@@ -0,0 +1,32 @@
+package spark.scheduler.mesos
+
+/**
+ * Information about a running task attempt.
+ */
+class TaskInfo(val taskId: String, val index: Int, val launchTime: Long, val host: String) {
+ var finishTime: Long = 0
+ var failed = false
+
+ def markSuccessful(time: Long = System.currentTimeMillis) {
+ finishTime = time
+ }
+
+ def markFailed(time: Long = System.currentTimeMillis) {
+ finishTime = time
+ failed = true
+ }
+
+ def finished: Boolean = finishTime != 0
+
+ def successful: Boolean = finished && !failed
+
+ def duration: Long = {
+ if (!finished) {
+ throw new UnsupportedOperationException("duration() called on unfinished tasks")
+ } else {
+ finishTime - launchTime
+ }
+ }
+
+ def timeRunning(currentTime: Long): Long = currentTime - launchTime
+}
diff --git a/core/src/main/scala/spark/SimpleJob.scala b/core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala
index 01c7efff1e..535c17d9d4 100644
--- a/core/src/main/scala/spark/SimpleJob.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala
@@ -1,28 +1,32 @@
-package spark
+package spark.scheduler.mesos
+import java.util.Arrays
import java.util.{HashMap => JHashMap}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.math.max
+import scala.math.min
import com.google.protobuf.ByteString
import org.apache.mesos._
-import org.apache.mesos.Protos._
+import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
+
+import spark._
+import spark.scheduler._
/**
- * A Job that runs a set of tasks with no interdependencies.
+ * Schedules the tasks within a single TaskSet in the MesosScheduler.
*/
-class SimpleJob(
+class TaskSetManager(
sched: MesosScheduler,
- tasksSeq: Seq[Task[_]],
- runId: Int,
- jobId: Int)
- extends Job(runId, jobId)
- with Logging {
+ val taskSet: TaskSet)
+ extends Logging {
// Maximum time to wait to run a task in a preferred location (in ms)
- val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "5000").toLong
+ val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
// CPUs to request per task
val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
@@ -30,18 +34,20 @@ class SimpleJob(
// Maximum times a task is allowed to fail before failing the job
val MAX_TASK_FAILURES = 4
+ // Quantile of tasks at which to start speculation
+ val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
+ val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
+
// Serializer for closures and tasks.
val ser = SparkEnv.get.closureSerializer.newInstance()
- val callingThread = Thread.currentThread
- val tasks = tasksSeq.toArray
+ val priority = taskSet.priority
+ val tasks = taskSet.tasks
val numTasks = tasks.length
- val launched = new Array[Boolean](numTasks)
+ val copiesRunning = new Array[Int](numTasks)
val finished = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
- val tidToIndex = HashMap[String, Int]()
-
- var tasksLaunched = 0
+ val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
var tasksFinished = 0
// Last time when we launched a preferred task (for delay scheduling)
@@ -62,6 +68,13 @@ class SimpleJob(
// List containing all pending tasks (also used as a stack, as above)
val allPendingTasks = new ArrayBuffer[Int]
+ // Tasks that can be specualted. Since these will be a small fraction of total
+ // tasks, we'll just hold them in a HaskSet.
+ val speculatableTasks = new HashSet[Int]
+
+ // Task index, start and finish time for each task attempt (indexed by task ID)
+ val taskInfos = new HashMap[String, TaskInfo]
+
// Did the job fail?
var failed = false
var causeOfFailure = ""
@@ -76,6 +89,12 @@ class SimpleJob(
// exceptions automatically.
val recentExceptions = HashMap[String, (Int, Long)]()
+ // Figure out the current map output tracker generation and set it on all tasks
+ val generation = sched.mapOutputTracker.getGeneration
+ for (t <- tasks) {
+ t.generation = generation
+ }
+
// Add all our tasks to the pending lists. We do this in reverse order
// of task index so that tasks with low indices get launched first.
for (i <- (0 until numTasks).reverse) {
@@ -84,7 +103,7 @@ class SimpleJob(
// Add a task to all the pending-task lists that it should be on.
def addPendingTask(index: Int) {
- val locations = tasks(index).preferredLocations
+ val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
if (locations.size == 0) {
pendingTasksWithNoPrefs += index
} else {
@@ -110,13 +129,37 @@ class SimpleJob(
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
- if (!launched(index) && !finished(index)) {
+ if (copiesRunning(index) == 0 && !finished(index)) {
return Some(index)
}
}
return None
}
+ // Return a speculative task for a given host if any are available. The task should not have an
+ // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
+ // task must have a preference for this host (or no preferred locations at all).
+ def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
+ speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
+ val localTask = speculatableTasks.find { index =>
+ val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
+ val attemptLocs = taskAttempts(index).map(_.host)
+ (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host)
+ }
+ if (localTask != None) {
+ speculatableTasks -= localTask.get
+ return localTask
+ }
+ if (!localOnly && speculatableTasks.size > 0) {
+ val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.host).contains(host))
+ if (nonLocalTask != None) {
+ speculatableTasks -= nonLocalTask.get
+ return nonLocalTask
+ }
+ }
+ return None
+ }
+
// Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well.
def findTask(host: String, localOnly: Boolean): Option[Int] = {
@@ -129,10 +172,13 @@ class SimpleJob(
return noPrefTask
}
if (!localOnly) {
- return findTaskFromList(allPendingTasks) // Look for non-local task
- } else {
- return None
+ val nonLocalTask = findTaskFromList(allPendingTasks)
+ if (nonLocalTask != None) {
+ return nonLocalTask
+ }
}
+ // Finally, if all else has failed, find a speculative task
+ return findSpeculativeTask(host, localOnly)
}
// Does a host count as a preferred location for a task? This is true if
@@ -144,11 +190,11 @@ class SimpleJob(
}
// Respond to an offer of a single slave from the scheduler by finding a task
- def slaveOffer(offer: Offer, availableCpus: Double): Option[TaskInfo] = {
- if (tasksLaunched < numTasks && availableCpus >= CPUS_PER_TASK) {
+ def slaveOffer(slaveId: String, host: String, availableCpus: Double): Option[MTaskInfo] = {
+ if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
val time = System.currentTimeMillis
- val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
- val host = offer.getHostname
+ var localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
+
findTask(host, localOnly) match {
case Some(index) => {
// Found a task; do some bookkeeping and return a Mesos task for it
@@ -156,17 +202,17 @@ class SimpleJob(
val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch
val preferred = isPreferredLocation(task, host)
- val prefStr = if(preferred) "preferred" else "non-preferred"
- val message =
- "Starting task %d:%d as TID %s on slave %s: %s (%s)".format(
- jobId, index, taskId.getValue, offer.getSlaveId.getValue, host, prefStr)
- logInfo(message)
+ val prefStr = if (preferred) "preferred" else "non-preferred"
+ logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
+ taskSet.id, index, taskId.getValue, slaveId, host, prefStr))
// Do various bookkeeping
- tidToIndex(taskId.getValue) = index
- launched(index) = true
- tasksLaunched += 1
- if (preferred)
+ copiesRunning(index) += 1
+ val info = new TaskInfo(taskId.getValue, index, time, host)
+ taskInfos(taskId.getValue) = info
+ taskAttempts(index) = info :: taskAttempts(index)
+ if (preferred) {
lastPreferredLaunchTime = time
+ }
// Create and return the Mesos task object
val cpuRes = Resource.newBuilder()
.setName("cpus")
@@ -178,13 +224,13 @@ class SimpleJob(
val serializedTask = ser.serialize(task)
val timeTaken = System.currentTimeMillis - startTime
- logInfo("Size of task %d:%d is %d bytes and took %d ms to serialize by %s"
- .format(jobId, index, serializedTask.size, timeTaken, ser.getClass.getName))
+ logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
+ taskSet.id, index, serializedTask.limit, timeTaken))
- val taskName = "task %d:%d".format(jobId, index)
- return Some(TaskInfo.newBuilder()
+ val taskName = "task %s:%d".format(taskSet.id, index)
+ return Some(MTaskInfo.newBuilder()
.setTaskId(taskId)
- .setSlaveId(offer.getSlaveId)
+ .setSlaveId(SlaveID.newBuilder().setValue(slaveId))
.setExecutor(sched.executorInfo)
.setName(taskName)
.addResources(cpuRes)
@@ -213,18 +259,21 @@ class SimpleJob(
def taskFinished(status: TaskStatus) {
val tid = status.getTaskId.getValue
- val index = tidToIndex(tid)
+ val info = taskInfos(tid)
+ val index = info.index
+ info.markSuccessful()
if (!finished(index)) {
tasksFinished += 1
- logInfo("Finished TID %s (progress: %d/%d)".format(tid, tasksFinished, numTasks))
- // Deserialize task result
- val result = ser.deserialize[TaskResult[_]](
- status.getData.toByteArray, getClass.getClassLoader)
- sched.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
+ logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
+ tid, info.duration, tasksFinished, numTasks))
+ // Deserialize task result and pass it to the scheduler
+ val result = ser.deserialize[TaskResult[_]](status.getData.asReadOnlyByteBuffer)
+ sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
// Mark finished and stop if we've finished all the tasks
finished(index) = true
- if (tasksFinished == numTasks)
- sched.jobFinished(this)
+ if (tasksFinished == numTasks) {
+ sched.taskSetFinished(this)
+ }
} else {
logInfo("Ignoring task-finished event for TID " + tid +
" because task " + index + " is already finished")
@@ -233,30 +282,29 @@ class SimpleJob(
def taskLost(status: TaskStatus) {
val tid = status.getTaskId.getValue
- val index = tidToIndex(tid)
+ val info = taskInfos(tid)
+ val index = info.index
+ info.markFailed()
if (!finished(index)) {
- logInfo("Lost TID %s (task %d:%d)".format(tid, jobId, index))
- launched(index) = false
- tasksLaunched -= 1
+ logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
+ copiesRunning(index) -= 1
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
if (status.getData != null && status.getData.size > 0) {
- val reason = ser.deserialize[TaskEndReason](
- status.getData.toByteArray, getClass.getClassLoader)
+ val reason = ser.deserialize[TaskEndReason](status.getData.asReadOnlyByteBuffer)
reason match {
case fetchFailed: FetchFailed =>
- logInfo("Loss was due to fetch failure from " + fetchFailed.serverUri)
- sched.taskEnded(tasks(index), fetchFailed, null, null)
+ logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+ sched.listener.taskEnded(tasks(index), fetchFailed, null, null)
finished(index) = true
tasksFinished += 1
- if (tasksFinished == numTasks) {
- sched.jobFinished(this)
- }
+ sched.taskSetFinished(this)
return
+
case ef: ExceptionFailure =>
val key = ef.exception.toString
val now = System.currentTimeMillis
- val (printFull, dupCount) =
+ val (printFull, dupCount) = {
if (recentExceptions.contains(key)) {
val (dupCount, printTime) = recentExceptions(key)
if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
@@ -267,32 +315,28 @@ class SimpleJob(
(false, dupCount + 1)
}
} else {
- recentExceptions += Tuple(key, (0, now))
+ recentExceptions(key) = (0, now)
(true, 0)
}
-
+ }
if (printFull) {
- val stackTrace =
- for (elem <- ef.exception.getStackTrace)
- yield "\tat %s".format(elem.toString)
- logInfo("Loss was due to %s\n%s".format(
- ef.exception.toString, stackTrace.mkString("\n")))
+ val locs = ef.exception.getStackTrace.map(loc => "\tat %s".format(loc.toString))
+ logInfo("Loss was due to %s\n%s".format(ef.exception.toString, locs.mkString("\n")))
} else {
- logInfo("Loss was due to %s [duplicate %d]".format(
- ef.exception.toString, dupCount))
+ logInfo("Loss was due to %s [duplicate %d]".format(ef.exception.toString, dupCount))
}
+
case _ => {}
}
}
- // On other failures, re-enqueue the task as pending for a max number of retries
+ // On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
- // Count attempts only on FAILED and LOST state (not on KILLED)
- if (status.getState == TaskState.TASK_FAILED ||
- status.getState == TaskState.TASK_LOST) {
+ // Count failed attempts only on FAILED and LOST state (not on KILLED)
+ if (status.getState == TaskState.TASK_FAILED || status.getState == TaskState.TASK_LOST) {
numFailures(index) += 1
if (numFailures(index) > MAX_TASK_FAILURES) {
- logError("Task %d:%d failed more than %d times; aborting job".format(
- jobId, index, MAX_TASK_FAILURES))
+ logError("Task %s:%d failed more than %d times; aborting job".format(
+ taskSet.id, index, MAX_TASK_FAILURES))
abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES))
}
}
@@ -311,6 +355,71 @@ class SimpleJob(
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.jobFinished(this)
+ sched.taskSetFinished(this)
+ }
+
+ def hostLost(hostname: String) {
+ logInfo("Re-queueing tasks for " + hostname)
+ // If some task has preferred locations only on hostname, put it in the no-prefs list
+ // to avoid the wait from delay scheduling
+ for (index <- getPendingTasksForHost(hostname)) {
+ val newLocs = tasks(index).preferredLocations.toSet & sched.hostsAlive
+ if (newLocs.isEmpty) {
+ pendingTasksWithNoPrefs += index
+ }
+ }
+ // Also re-enqueue any tasks that ran on the failed host if this is a shuffle map stage
+ if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+ for ((tid, info) <- taskInfos if info.host == hostname) {
+ val index = taskInfos(tid).index
+ if (finished(index)) {
+ finished(index) = false
+ copiesRunning(index) -= 1
+ tasksFinished -= 1
+ addPendingTask(index)
+ // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
+ // stage finishes when a total of tasks.size tasks finish.
+ sched.listener.taskEnded(tasks(index), Resubmitted, null, null)
+ }
+ }
+ }
+ }
+
+ /**
+ * Check for tasks to be speculated and return true if there are any. This is called periodically
+ * by the MesosScheduler.
+ *
+ * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
+ * we don't scan the whole task set. It might also help to make this sorted by launch time.
+ */
+ def checkSpeculatableTasks(): Boolean = {
+ // Can't speculate if we only have one task, or if all tasks have finished.
+ if (numTasks == 1 || tasksFinished == numTasks) {
+ return false
+ }
+ var foundTasks = false
+ val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
+ logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
+ if (tasksFinished >= minFinishedForSpeculation) {
+ val time = System.currentTimeMillis()
+ val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
+ Arrays.sort(durations)
+ val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
+ val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
+ // TODO: Threshold should also look at standard deviation of task durations and have a lower
+ // bound based on that.
+ logDebug("Task length threshold for speculation: " + threshold)
+ for ((tid, info) <- taskInfos) {
+ val index = info.index
+ if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
+ !speculatableTasks.contains(index)) {
+ logInfo("Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
+ taskSet.id, index, info.host, threshold))
+ speculatableTasks += index
+ foundTasks = true
+ }
+ }
+ }
+ return foundTasks
}
}
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
new file mode 100644
index 0000000000..367c79dd76
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -0,0 +1,507 @@
+package spark.storage
+
+import java.io._
+import java.nio._
+import java.nio.channels.FileChannel.MapMode
+import java.util.{HashMap => JHashMap}
+import java.util.LinkedHashMap
+import java.util.UUID
+import java.util.Collections
+
+import scala.actors._
+import scala.actors.Actor._
+import scala.actors.Future
+import scala.actors.Futures.future
+import scala.actors.remote._
+import scala.actors.remote.RemoteActor._
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions._
+
+import it.unimi.dsi.fastutil.io._
+
+import spark.CacheTracker
+import spark.Logging
+import spark.Serializer
+import spark.SizeEstimator
+import spark.SparkEnv
+import spark.SparkException
+import spark.Utils
+import spark.network._
+
+class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
+ def this() = this(null, 0)
+
+ override def writeExternal(out: ObjectOutput) {
+ out.writeUTF(ip)
+ out.writeInt(port)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ ip = in.readUTF()
+ port = in.readInt()
+ }
+
+ override def toString = "BlockManagerId(" + ip + ", " + port + ")"
+
+ override def hashCode = ip.hashCode * 41 + port
+
+ override def equals(that: Any) = that match {
+ case id: BlockManagerId => port == id.port && ip == id.ip
+ case _ => false
+ }
+}
+
+
+case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message)
+
+
+class BlockLocker(numLockers: Int) {
+ private val hashLocker = Array.fill(numLockers)(new Object())
+
+ def getLock(blockId: String): Object = {
+ return hashLocker(Math.abs(blockId.hashCode % numLockers))
+ }
+}
+
+
+/**
+ * A start towards a block manager class. This will eventually be used for both RDD persistence
+ * and shuffle outputs.
+ *
+ * TODO: Should make the communication with Master or Peers code more robust and log friendly.
+ */
+class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging {
+
+ private val NUM_LOCKS = 337
+ private val locker = new BlockLocker(NUM_LOCKS)
+
+ private val storageLevels = Collections.synchronizedMap(new JHashMap[String, StorageLevel])
+
+ private val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
+ private val diskStore: BlockStore = new DiskStore(this,
+ System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
+
+ val connectionManager = new ConnectionManager(0)
+
+ val connectionManagerId = connectionManager.id
+ val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port)
+
+ // TODO(Haoyuan): This will be removed after cacheTracker is removed from the code base.
+ var cacheTracker: CacheTracker = null
+
+ initLogging()
+
+ initialize()
+
+ /**
+ * Construct a BlockManager with a memory limit set based on system properties.
+ */
+ def this(serializer: Serializer) =
+ this(BlockManager.getMaxMemoryFromSystemProperties(), serializer)
+
+ /**
+ * Initialize the BlockManager. Register to the BlockManagerMaster, and start the
+ * BlockManagerWorker actor.
+ */
+ def initialize() {
+ BlockManagerMaster.mustRegisterBlockManager(
+ RegisterBlockManager(blockManagerId, maxMemory, maxMemory))
+ BlockManagerWorker.startBlockManagerWorker(this)
+ }
+
+ /**
+ * Get locations of the block.
+ */
+ def getLocations(blockId: String): Seq[String] = {
+ val startTimeMs = System.currentTimeMillis
+ var managers: Array[BlockManagerId] = BlockManagerMaster.mustGetLocations(GetLocations(blockId))
+ val locations = managers.map((manager: BlockManagerId) => { manager.ip }).toSeq
+ logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs))
+ return locations
+ }
+
+ /**
+ * Get locations of an array of blocks
+ */
+ def getLocationsMultipleBlockIds(blockIds: Array[String]): Array[Seq[String]] = {
+ val startTimeMs = System.currentTimeMillis
+ val locations = BlockManagerMaster.mustGetLocationsMultipleBlockIds(
+ GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray
+ logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
+ return locations
+ }
+
+ def getLocal(blockId: String): Option[Iterator[Any]] = {
+ logDebug("Getting block " + blockId)
+ locker.getLock(blockId).synchronized {
+
+ // Check storage level of block
+ val level = storageLevels.get(blockId)
+ if (level != null) {
+ logDebug("Level for block " + blockId + " is " + level + " on local machine")
+
+ // Look for the block in memory
+ if (level.useMemory) {
+ logDebug("Getting block " + blockId + " from memory")
+ memoryStore.getValues(blockId) match {
+ case Some(iterator) => {
+ logDebug("Block " + blockId + " found in memory")
+ return Some(iterator)
+ }
+ case None => {
+ logDebug("Block " + blockId + " not found in memory")
+ }
+ }
+ } else {
+ logDebug("Not getting block " + blockId + " from memory")
+ }
+
+ // Look for block in disk
+ if (level.useDisk) {
+ logDebug("Getting block " + blockId + " from disk")
+ diskStore.getValues(blockId) match {
+ case Some(iterator) => {
+ logDebug("Block " + blockId + " found in disk")
+ return Some(iterator)
+ }
+ case None => {
+ throw new Exception("Block " + blockId + " not found in disk")
+ return None
+ }
+ }
+ } else {
+ logDebug("Not getting block " + blockId + " from disk")
+ }
+
+ } else {
+ logDebug("Level for block " + blockId + " not found")
+ }
+ }
+ return None
+ }
+
+ def getRemote(blockId: String): Option[Iterator[Any]] = {
+ // Get locations of block
+ val locations = BlockManagerMaster.mustGetLocations(GetLocations(blockId))
+
+ // Get block from remote locations
+ for (loc <- locations) {
+ val data = BlockManagerWorker.syncGetBlock(
+ GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port))
+ if (data != null) {
+ logDebug("Data is not null: " + data)
+ return Some(dataDeserialize(data))
+ }
+ logDebug("Data is null")
+ }
+ logDebug("Data not found")
+ return None
+ }
+
+ /**
+ * Read a block from the block manager.
+ */
+ def get(blockId: String): Option[Iterator[Any]] = {
+ getLocal(blockId).orElse(getRemote(blockId))
+ }
+
+ /**
+ * Read many blocks from block manager using their BlockManagerIds.
+ */
+ def get(blocksByAddress: Seq[(BlockManagerId, Seq[String])]): HashMap[String, Option[Iterator[Any]]] = {
+ logDebug("Getting " + blocksByAddress.map(_._2.size).sum + " blocks")
+ var startTime = System.currentTimeMillis
+ val blocks = new HashMap[String,Option[Iterator[Any]]]()
+ val localBlockIds = new ArrayBuffer[String]()
+ val remoteBlockIds = new ArrayBuffer[String]()
+ val remoteBlockIdsPerLocation = new HashMap[BlockManagerId, Seq[String]]()
+
+ // Split local and remote blocks
+ for ((address, blockIds) <- blocksByAddress) {
+ if (address == blockManagerId) {
+ localBlockIds ++= blockIds
+ } else {
+ remoteBlockIds ++= blockIds
+ remoteBlockIdsPerLocation(address) = blockIds
+ }
+ }
+
+ // Start getting remote blocks
+ val remoteBlockFutures = remoteBlockIdsPerLocation.toSeq.map { case (bmId, bIds) =>
+ val cmId = ConnectionManagerId(bmId.ip, bmId.port)
+ val blockMessages = bIds.map(bId => BlockMessage.fromGetBlock(GetBlock(bId)))
+ val blockMessageArray = new BlockMessageArray(blockMessages)
+ val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
+ (cmId, future)
+ }
+ logDebug("Started remote gets for " + remoteBlockIds.size + " blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+
+ // Get the local blocks while remote blocks are being fetched
+ startTime = System.currentTimeMillis
+ localBlockIds.foreach(id => {
+ get(id) match {
+ case Some(block) => {
+ blocks.update(id, Some(block))
+ logDebug("Got local block " + id)
+ }
+ case None => {
+ throw new BlockException(id, "Could not get block " + id + " from local machine")
+ }
+ }
+ })
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+
+ // wait for and gather all the remote blocks
+ for ((cmId, future) <- remoteBlockFutures) {
+ var count = 0
+ val oneBlockId = remoteBlockIdsPerLocation(new BlockManagerId(cmId.host, cmId.port)).first
+ future() match {
+ case Some(message) => {
+ val bufferMessage = message.asInstanceOf[BufferMessage]
+ val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
+ blockMessageArray.foreach(blockMessage => {
+ if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
+ throw new BlockException(oneBlockId, "Unexpected message received from " + cmId)
+ }
+ val buffer = blockMessage.getData()
+ val blockId = blockMessage.getId()
+ val block = dataDeserialize(buffer)
+ blocks.update(blockId, Some(block))
+ logDebug("Got remote block " + blockId + " in " + Utils.getUsedTimeMs(startTime))
+ count += 1
+ })
+ }
+ case None => {
+ throw new BlockException(oneBlockId, "Could not get blocks from " + cmId)
+ }
+ }
+ logDebug("Got remote " + count + " blocks from " + cmId.host + " in " + Utils.getUsedTimeMs(startTime) + " ms")
+ }
+
+ logDebug("Got all blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+ return blocks
+ }
+
+ /**
+ * Write a new block to the block manager.
+ */
+ def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean = true) {
+ if (!level.useDisk && !level.useMemory) {
+ throw new IllegalArgumentException("Storage level has neither useMemory nor useDisk set")
+ }
+
+ val startTimeMs = System.currentTimeMillis
+ var bytes: ByteBuffer = null
+
+ locker.getLock(blockId).synchronized {
+ logDebug("Put for block " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ + " to get into synchronized block")
+
+ // Check and warn if block with same id already exists
+ if (storageLevels.get(blockId) != null) {
+ logWarning("Block " + blockId + " already exists in local machine")
+ return
+ }
+
+ // Store the storage level
+ storageLevels.put(blockId, level)
+
+ if (level.useMemory && level.useDisk) {
+ // If saving to both memory and disk, then serialize only once
+ memoryStore.putValues(blockId, values, level) match {
+ case Left(newValues) =>
+ diskStore.putValues(blockId, newValues, level) match {
+ case Right(newBytes) => bytes = newBytes
+ case _ => throw new Exception("Unexpected return value")
+ }
+ case Right(newBytes) =>
+ bytes = newBytes
+ diskStore.putBytes(blockId, newBytes, level)
+ }
+ } else if (level.useMemory) {
+ // If only save to memory
+ memoryStore.putValues(blockId, values, level) match {
+ case Right(newBytes) => bytes = newBytes
+ case _ =>
+ }
+ } else {
+ // If only save to disk
+ diskStore.putValues(blockId, values, level) match {
+ case Right(newBytes) => bytes = newBytes
+ case _ => throw new Exception("Unexpected return value")
+ }
+ }
+
+ if (tellMaster) {
+ notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0))
+ logDebug("Put block " + blockId + " after notifying the master " + Utils.getUsedTimeMs(startTimeMs))
+ }
+ }
+
+ // Replicate block if required
+ if (level.replication > 1) {
+ if (bytes == null) {
+ bytes = dataSerialize(values) // serialize the block if not already done
+ }
+ replicate(blockId, bytes, level)
+ }
+
+ // TODO(Haoyuan): This code will be removed when CacheTracker is gone.
+ if (blockId.startsWith("rdd")) {
+ notifyTheCacheTracker(blockId)
+ }
+ logDebug("Put block " + blockId + " after notifying the CacheTracker " + Utils.getUsedTimeMs(startTimeMs))
+ }
+
+
+ def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
+ val startTime = System.currentTimeMillis
+ if (!level.useDisk && !level.useMemory) {
+ throw new IllegalArgumentException("Storage level has neither useMemory nor useDisk set")
+ } else if (level.deserialized) {
+ throw new IllegalArgumentException("Storage level cannot have deserialized when putBytes is used")
+ }
+ val replicationFuture = if (level.replication > 1) {
+ future {
+ replicate(blockId, bytes, level)
+ }
+ } else {
+ null
+ }
+
+ locker.getLock(blockId).synchronized {
+ logDebug("PutBytes for block " + blockId + " used " + Utils.getUsedTimeMs(startTime)
+ + " to get into synchronized block")
+ if (storageLevels.get(blockId) != null) {
+ logWarning("Block " + blockId + " already exists")
+ return
+ }
+ storageLevels.put(blockId, level)
+
+ if (level.useMemory) {
+ memoryStore.putBytes(blockId, bytes, level)
+ }
+ if (level.useDisk) {
+ diskStore.putBytes(blockId, bytes, level)
+ }
+ if (tellMaster) {
+ notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0))
+ }
+ }
+
+ if (blockId.startsWith("rdd")) {
+ notifyTheCacheTracker(blockId)
+ }
+
+ if (level.replication > 1) {
+ if (replicationFuture == null) {
+ throw new Exception("Unexpected")
+ }
+ replicationFuture()
+ }
+
+ val finishTime = System.currentTimeMillis
+ if (level.replication > 1) {
+ logDebug("PutBytes with replication took " + (finishTime - startTime) + " ms")
+ } else {
+ logDebug("PutBytes without replication took " + (finishTime - startTime) + " ms")
+ }
+
+ }
+
+ private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
+ val tLevel: StorageLevel =
+ new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
+ var peers: Array[BlockManagerId] = BlockManagerMaster.mustGetPeers(
+ GetPeers(blockManagerId, level.replication - 1))
+ for (peer: BlockManagerId <- peers) {
+ val start = System.nanoTime
+ logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is "
+ + data.array().length + " Bytes. To node: " + peer)
+ if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel),
+ new ConnectionManagerId(peer.ip, peer.port))) {
+ logError("Failed to call syncPutBlock to " + peer)
+ }
+ logDebug("Replicated BlockId " + blockId + " once used " +
+ (System.nanoTime - start) / 1e6 + " s; The size of the data is " +
+ data.array().length + " bytes.")
+ }
+ }
+
+ // TODO(Haoyuan): This code will be removed when CacheTracker is gone.
+ def notifyTheCacheTracker(key: String) {
+ val rddInfo = key.split(":")
+ val rddId: Int = rddInfo(1).toInt
+ val splitIndex: Int = rddInfo(2).toInt
+ val host = System.getProperty("spark.hostname", Utils.localHostName)
+ cacheTracker.notifyTheCacheTrackerFromBlockManager(spark.AddedToCache(rddId, splitIndex, host))
+ }
+
+ /**
+ * Read a block consisting of a single object.
+ */
+ def getSingle(blockId: String): Option[Any] = {
+ get(blockId).map(_.next)
+ }
+
+ /**
+ * Write a block consisting of a single object.
+ */
+ def putSingle(blockId: String, value: Any, level: StorageLevel) {
+ put(blockId, Iterator(value), level)
+ }
+
+ /**
+ * Drop block from memory (called when memory store has reached it limit)
+ */
+ def dropFromMemory(blockId: String) {
+ locker.getLock(blockId).synchronized {
+ val level = storageLevels.get(blockId)
+ if (level == null) {
+ logWarning("Block " + blockId + " cannot be removed from memory as it does not exist")
+ return
+ }
+ if (!level.useMemory) {
+ logWarning("Block " + blockId + " cannot be removed from memory as it is not in memory")
+ return
+ }
+ memoryStore.remove(blockId)
+ if (!level.useDisk) {
+ storageLevels.remove(blockId)
+ } else {
+ val newLevel = level.clone
+ newLevel.useMemory = false
+ storageLevels.remove(blockId)
+ storageLevels.put(blockId, newLevel)
+ }
+ }
+ }
+
+ def dataSerialize(values: Iterator[Any]): ByteBuffer = {
+ /*serializer.newInstance().serializeMany(values)*/
+ val byteStream = new FastByteArrayOutputStream(4096)
+ serializer.newInstance().serializeStream(byteStream).writeAll(values).close()
+ byteStream.trim()
+ ByteBuffer.wrap(byteStream.array)
+ }
+
+ def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = {
+ /*serializer.newInstance().deserializeMany(bytes)*/
+ val ser = serializer.newInstance()
+ return ser.deserializeStream(new FastByteArrayInputStream(bytes.array())).toIterator
+ }
+
+ private def notifyMaster(heartBeat: HeartBeat) {
+ BlockManagerMaster.mustHeartBeat(heartBeat)
+ }
+}
+
+object BlockManager extends Logging {
+ def getMaxMemoryFromSystemProperties(): Long = {
+ val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble
+ val bytes = (Runtime.getRuntime.totalMemory * memoryFraction).toLong
+ logInfo("Maximum memory to use: " + bytes)
+ bytes
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
new file mode 100644
index 0000000000..bd94c185e9
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -0,0 +1,516 @@
+package spark.storage
+
+import java.io._
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.util.Random
+
+import akka.actor._
+import akka.actor.Actor
+import akka.actor.Actor._
+import akka.util.duration._
+
+import spark.Logging
+import spark.Utils
+
+sealed trait ToBlockManagerMaster
+
+case class RegisterBlockManager(
+ blockManagerId: BlockManagerId,
+ maxMemSize: Long,
+ maxDiskSize: Long)
+ extends ToBlockManagerMaster
+
+class HeartBeat(
+ var blockManagerId: BlockManagerId,
+ var blockId: String,
+ var storageLevel: StorageLevel,
+ var deserializedSize: Long,
+ var size: Long)
+ extends ToBlockManagerMaster
+ with Externalizable {
+
+ def this() = this(null, null, null, 0, 0) // For deserialization only
+
+ override def writeExternal(out: ObjectOutput) {
+ blockManagerId.writeExternal(out)
+ out.writeUTF(blockId)
+ storageLevel.writeExternal(out)
+ out.writeInt(deserializedSize.toInt)
+ out.writeInt(size.toInt)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ blockManagerId = new BlockManagerId()
+ blockManagerId.readExternal(in)
+ blockId = in.readUTF()
+ storageLevel = new StorageLevel()
+ storageLevel.readExternal(in)
+ deserializedSize = in.readInt()
+ size = in.readInt()
+ }
+}
+
+object HeartBeat {
+ def apply(blockManagerId: BlockManagerId,
+ blockId: String,
+ storageLevel: StorageLevel,
+ deserializedSize: Long,
+ size: Long): HeartBeat = {
+ new HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size)
+ }
+
+
+ // For pattern-matching
+ def unapply(h: HeartBeat): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = {
+ Some((h.blockManagerId, h.blockId, h.storageLevel, h.deserializedSize, h.size))
+ }
+}
+
+case class GetLocations(
+ blockId: String)
+ extends ToBlockManagerMaster
+
+case class GetLocationsMultipleBlockIds(
+ blockIds: Array[String])
+ extends ToBlockManagerMaster
+
+case class GetPeers(
+ blockManagerId: BlockManagerId,
+ size: Int)
+ extends ToBlockManagerMaster
+
+case class RemoveHost(
+ host: String)
+ extends ToBlockManagerMaster
+
+class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging {
+ class BlockManagerInfo(
+ timeMs: Long,
+ maxMem: Long,
+ maxDisk: Long) {
+ private var lastSeenMs = timeMs
+ private var remainedMem = maxMem
+ private var remainedDisk = maxDisk
+ private val blocks = new HashMap[String, StorageLevel]
+
+ def updateLastSeenMs() {
+ lastSeenMs = System.currentTimeMillis() / 1000
+ }
+
+ def addBlock(blockId: String, storageLevel: StorageLevel, deserializedSize: Long, size: Long) =
+ synchronized {
+ updateLastSeenMs()
+
+ if (blocks.contains(blockId)) {
+ val oriLevel: StorageLevel = blocks(blockId)
+
+ if (oriLevel.deserialized) {
+ remainedMem += deserializedSize
+ }
+ if (oriLevel.useMemory) {
+ remainedMem += size
+ }
+ if (oriLevel.useDisk) {
+ remainedDisk += size
+ }
+ }
+
+ blocks += (blockId -> storageLevel)
+
+ if (storageLevel.deserialized) {
+ remainedMem -= deserializedSize
+ }
+ if (storageLevel.useMemory) {
+ remainedMem -= size
+ }
+ if (storageLevel.useDisk) {
+ remainedDisk -= size
+ }
+
+ if (!(storageLevel.deserialized || storageLevel.useMemory || storageLevel.useDisk)) {
+ blocks.remove(blockId)
+ }
+ }
+
+ def getLastSeenMs(): Long = {
+ return lastSeenMs
+ }
+
+ def getRemainedMem(): Long = {
+ return remainedMem
+ }
+
+ def getRemainedDisk(): Long = {
+ return remainedDisk
+ }
+
+ override def toString(): String = {
+ return "BlockManagerInfo " + timeMs + " " + remainedMem + " " + remainedDisk
+ }
+ }
+
+ private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerInfo]
+ private val blockIdMap = new HashMap[String, Pair[Int, HashSet[BlockManagerId]]]
+
+ initLogging()
+
+ def removeHost(host: String) {
+ logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.")
+ logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq)
+ val ip = host.split(":")(0)
+ val port = host.split(":")(1)
+ blockManagerInfo.remove(new BlockManagerId(ip, port.toInt))
+ logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq)
+ self.reply(true)
+ }
+
+ def receive = {
+ case RegisterBlockManager(blockManagerId, maxMemSize, maxDiskSize) =>
+ register(blockManagerId, maxMemSize, maxDiskSize)
+
+ case HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
+ heartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size)
+
+ case GetLocations(blockId) =>
+ getLocations(blockId)
+
+ case GetLocationsMultipleBlockIds(blockIds) =>
+ getLocationsMultipleBlockIds(blockIds)
+
+ case GetPeers(blockManagerId, size) =>
+ getPeers_Deterministic(blockManagerId, size)
+ /*getPeers(blockManagerId, size)*/
+
+ case RemoveHost(host) =>
+ removeHost(host)
+
+ case msg =>
+ logInfo("Got unknown msg: " + msg)
+ }
+
+ private def register(blockManagerId: BlockManagerId, maxMemSize: Long, maxDiskSize: Long) {
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " " + blockManagerId + " "
+ logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
+ logInfo("Got Register Msg from " + blockManagerId)
+ if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
+ logInfo("Got Register Msg from master node, don't register it")
+ } else {
+ blockManagerInfo += (blockManagerId -> new BlockManagerInfo(
+ System.currentTimeMillis() / 1000, maxMemSize, maxDiskSize))
+ }
+ logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs))
+ self.reply(true)
+ }
+
+ private def heartBeat(
+ blockManagerId: BlockManagerId,
+ blockId: String,
+ storageLevel: StorageLevel,
+ deserializedSize: Long,
+ size: Long) {
+
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " " + blockManagerId + " " + blockId + " "
+ logDebug("Got in heartBeat 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
+
+ if (blockId == null) {
+ blockManagerInfo(blockManagerId).updateLastSeenMs()
+ logDebug("Got in heartBeat 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
+ self.reply(true)
+ }
+
+ blockManagerInfo(blockManagerId).addBlock(blockId, storageLevel, deserializedSize, size)
+ logDebug("Got in heartBeat 2" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
+
+ var locations: HashSet[BlockManagerId] = null
+ if (blockIdMap.contains(blockId)) {
+ locations = blockIdMap(blockId)._2
+ } else {
+ locations = new HashSet[BlockManagerId]
+ blockIdMap += (blockId -> (storageLevel.replication, locations))
+ }
+ logDebug("Got in heartBeat 3" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
+
+ if (storageLevel.deserialized || storageLevel.useDisk || storageLevel.useMemory) {
+ locations += blockManagerId
+ } else {
+ locations.remove(blockManagerId)
+ }
+ logDebug("Got in heartBeat 4" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
+
+ if (locations.size == 0) {
+ blockIdMap.remove(blockId)
+ }
+
+ logDebug("Got in heartBeat 5" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
+ self.reply(true)
+ }
+
+ private def getLocations(blockId: String) {
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " " + blockId + " "
+ logDebug("Got in getLocations 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
+ if (blockIdMap.contains(blockId)) {
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ res.appendAll(blockIdMap(blockId)._2)
+ logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at "
+ + Utils.getUsedTimeMs(startTimeMs))
+ self.reply(res.toSeq)
+ } else {
+ logDebug("Got in getLocations 2" + tmp + Utils.getUsedTimeMs(startTimeMs))
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ self.reply(res)
+ }
+ }
+
+ private def getLocationsMultipleBlockIds(blockIds: Array[String]) {
+ def getLocations(blockId: String): Seq[BlockManagerId] = {
+ val tmp = blockId
+ logDebug("Got in getLocationsMultipleBlockIds Sub 0 " + tmp)
+ if (blockIdMap.contains(blockId)) {
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ res.appendAll(blockIdMap(blockId)._2)
+ logDebug("Got in getLocationsMultipleBlockIds Sub 1 " + tmp + " " + res.toSeq)
+ return res.toSeq
+ } else {
+ logDebug("Got in getLocationsMultipleBlockIds Sub 2 " + tmp)
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ return res.toSeq
+ }
+ }
+
+ logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq)
+ var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]]
+ for (blockId <- blockIds) {
+ res.append(getLocations(blockId))
+ }
+ logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq + " : " + res.toSeq)
+ self.reply(res.toSeq)
+ }
+
+ private def getPeers(blockManagerId: BlockManagerId, size: Int) {
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " " + blockManagerId + " "
+ logDebug("Got in getPeers 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
+ var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ res.appendAll(peers)
+ res -= blockManagerId
+ val rand = new Random(System.currentTimeMillis())
+ logDebug("Got in getPeers 1" + tmp + Utils.getUsedTimeMs(startTimeMs))
+ while (res.length > size) {
+ res.remove(rand.nextInt(res.length))
+ }
+ logDebug("Got in getPeers 2" + tmp + Utils.getUsedTimeMs(startTimeMs))
+ self.reply(res.toSeq)
+ }
+
+ private def getPeers_Deterministic(blockManagerId: BlockManagerId, size: Int) {
+ val startTimeMs = System.currentTimeMillis()
+ var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+
+ val peersWithIndices = peers.zipWithIndex
+ val selfIndex = peersWithIndices.find(_._1 == blockManagerId).map(_._2).getOrElse(-1)
+ if (selfIndex == -1) {
+ throw new Exception("Self index for " + blockManagerId + " not found")
+ }
+
+ var index = selfIndex
+ while (res.size < size) {
+ index += 1
+ if (index == selfIndex) {
+ throw new Exception("More peer expected than available")
+ }
+ res += peers(index % peers.size)
+ }
+ val resStr = res.map(_.toString).reduceLeft(_ + ", " + _)
+ logDebug("Got peers for " + blockManagerId + " as [" + resStr + "]")
+ self.reply(res.toSeq)
+ }
+}
+
+object BlockManagerMaster extends Logging {
+ initLogging()
+
+ val AKKA_ACTOR_NAME: String = "BlockMasterManager"
+ val REQUEST_RETRY_INTERVAL_MS = 100
+ val DEFAULT_MASTER_IP: String = System.getProperty("spark.master.host", "localhost")
+ val DEFAULT_MASTER_PORT: Int = System.getProperty("spark.master.port", "7077").toInt
+ val DEFAULT_MANAGER_IP: String = Utils.localHostName()
+ val DEFAULT_MANAGER_PORT: String = "10902"
+
+ implicit val TIME_OUT_SEC = Actor.Timeout(3000 millis)
+ var masterActor: ActorRef = null
+
+ def startBlockManagerMaster(isMaster: Boolean, isLocal: Boolean) {
+ if (isMaster) {
+ masterActor = actorOf(new BlockManagerMaster(isLocal))
+ remote.register(AKKA_ACTOR_NAME, masterActor)
+ logInfo("Registered BlockManagerMaster Actor: " + DEFAULT_MASTER_IP + ":" + DEFAULT_MASTER_PORT)
+ masterActor.start()
+ } else {
+ masterActor = remote.actorFor(AKKA_ACTOR_NAME, DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT)
+ }
+ }
+
+ def notifyADeadHost(host: String) {
+ (masterActor ? RemoveHost(host + ":" + DEFAULT_MANAGER_PORT)).as[Any] match {
+ case Some(true) =>
+ logInfo("Removed " + host + " successfully. @ notifyADeadHost")
+ case Some(oops) =>
+ logError("Failed @ notifyADeadHost: " + oops)
+ case None =>
+ logError("None @ notifyADeadHost.")
+ }
+ }
+
+ def mustRegisterBlockManager(msg: RegisterBlockManager) {
+ while (! syncRegisterBlockManager(msg)) {
+ logWarning("Failed to register " + msg)
+ Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+ }
+ }
+
+ def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = {
+ //val masterActor = RemoteActor.select(node, name)
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " msg " + msg + " "
+ logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+
+ (masterActor ? msg).as[Any] match {
+ case Some(true) =>
+ logInfo("BlockManager registered successfully @ syncRegisterBlockManager.")
+ logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+ return true
+ case Some(oops) =>
+ logError("Failed @ syncRegisterBlockManager: " + oops)
+ return false
+ case None =>
+ logError("None @ syncRegisterBlockManager.")
+ return false
+ }
+ }
+
+ def mustHeartBeat(msg: HeartBeat) {
+ while (! syncHeartBeat(msg)) {
+ logWarning("Failed to send heartbeat" + msg)
+ Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+ }
+ }
+
+ def syncHeartBeat(msg: HeartBeat): Boolean = {
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " msg " + msg + " "
+ logDebug("Got in syncHeartBeat " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs))
+
+ (masterActor ? msg).as[Any] match {
+ case Some(true) =>
+ logInfo("Heartbeat sent successfully.")
+ logDebug("Got in syncHeartBeat " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs))
+ return true
+ case Some(oops) =>
+ logError("Failed: " + oops)
+ return false
+ case None =>
+ logError("None.")
+ return false
+ }
+ }
+
+ def mustGetLocations(msg: GetLocations): Array[BlockManagerId] = {
+ var res: Array[BlockManagerId] = syncGetLocations(msg)
+ while (res == null) {
+ logInfo("Failed to get locations " + msg)
+ Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+ res = syncGetLocations(msg)
+ }
+ return res
+ }
+
+ def syncGetLocations(msg: GetLocations): Array[BlockManagerId] = {
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " msg " + msg + " "
+ logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+
+ (masterActor ? msg).as[Seq[BlockManagerId]] match {
+ case Some(arr) =>
+ logDebug("GetLocations successfully.")
+ logDebug("Got in syncGetLocations 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+ val res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ for (ele <- arr) {
+ res += ele
+ }
+ logDebug("Got in syncGetLocations 2 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+ return res.toArray
+ case None =>
+ logError("GetLocations call returned None.")
+ return null
+ }
+ }
+
+ def mustGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds):
+ Seq[Seq[BlockManagerId]] = {
+ var res: Seq[Seq[BlockManagerId]] = syncGetLocationsMultipleBlockIds(msg)
+ while (res == null) {
+ logWarning("Failed to GetLocationsMultipleBlockIds " + msg)
+ Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+ res = syncGetLocationsMultipleBlockIds(msg)
+ }
+ return res
+ }
+
+ def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds):
+ Seq[Seq[BlockManagerId]] = {
+ val startTimeMs = System.currentTimeMillis
+ val tmp = " msg " + msg + " "
+ logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+
+ (masterActor ? msg).as[Any] match {
+ case Some(arr: Seq[Seq[BlockManagerId]]) =>
+ logDebug("GetLocationsMultipleBlockIds successfully: " + arr)
+ logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+ return arr
+ case Some(oops) =>
+ logError("Failed: " + oops)
+ return null
+ case None =>
+ logInfo("None.")
+ return null
+ }
+ }
+
+ def mustGetPeers(msg: GetPeers): Array[BlockManagerId] = {
+ var res: Array[BlockManagerId] = syncGetPeers(msg)
+ while ((res == null) || (res.length != msg.size)) {
+ logInfo("Failed to get peers " + msg)
+ Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+ res = syncGetPeers(msg)
+ }
+
+ return res
+ }
+
+ def syncGetPeers(msg: GetPeers): Array[BlockManagerId] = {
+ val startTimeMs = System.currentTimeMillis
+ val tmp = " msg " + msg + " "
+ logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+
+ (masterActor ? msg).as[Seq[BlockManagerId]] match {
+ case Some(arr) =>
+ logDebug("Got in syncGetPeers 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+ val res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ logInfo("GetPeers successfully: " + arr.length)
+ res.appendAll(arr)
+ logDebug("Got in syncGetPeers 2 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+ return res.toArray
+ case None =>
+ logError("GetPeers call returned None.")
+ return null
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
new file mode 100644
index 0000000000..a4cdbd8ddd
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
@@ -0,0 +1,142 @@
+package spark.storage
+
+import java.nio._
+
+import scala.actors._
+import scala.actors.Actor._
+import scala.actors.remote._
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.util.Random
+
+import spark.Logging
+import spark.Utils
+import spark.SparkEnv
+import spark.network._
+
+/**
+ * This should be changed to use event model late.
+ */
+class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
+ initLogging()
+
+ blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive)
+
+ def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = {
+ logDebug("Handling message " + msg)
+ msg match {
+ case bufferMessage: BufferMessage => {
+ try {
+ logDebug("Handling as a buffer message " + bufferMessage)
+ val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage)
+ logDebug("Parsed as a block message array")
+ val responseMessages = blockMessages.map(processBlockMessage _).filter(_ != None).map(_.get)
+ /*logDebug("Processed block messages")*/
+ return Some(new BlockMessageArray(responseMessages).toBufferMessage)
+ } catch {
+ case e: Exception => logError("Exception handling buffer message: " + e.getMessage)
+ return None
+ }
+ }
+ case otherMessage: Any => {
+ logError("Unknown type message received: " + otherMessage)
+ return None
+ }
+ }
+ }
+
+ def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = {
+ blockMessage.getType() match {
+ case BlockMessage.TYPE_PUT_BLOCK => {
+ val pB = PutBlock(blockMessage.getId(), blockMessage.getData(), blockMessage.getLevel())
+ logInfo("Received [" + pB + "]")
+ putBlock(pB.id, pB.data, pB.level)
+ return None
+ }
+ case BlockMessage.TYPE_GET_BLOCK => {
+ val gB = new GetBlock(blockMessage.getId())
+ logInfo("Received [" + gB + "]")
+ val buffer = getBlock(gB.id)
+ if (buffer == null) {
+ return None
+ }
+ return Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer)))
+ }
+ case _ => return None
+ }
+ }
+
+ private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) {
+ val startTimeMs = System.currentTimeMillis()
+ logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes)
+ blockManager.putBytes(id, bytes, level)
+ logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
+ + " with data size: " + bytes.array().length)
+ }
+
+ private def getBlock(id: String): ByteBuffer = {
+ val startTimeMs = System.currentTimeMillis()
+ logDebug("Getblock " + id + " started from " + startTimeMs)
+ val block = blockManager.get(id)
+ val buffer = block match {
+ case Some(tValues) => {
+ val values = tValues.asInstanceOf[Iterator[Any]]
+ val buffer = blockManager.dataSerialize(values)
+ buffer
+ }
+ case None => {
+ null
+ }
+ }
+ logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
+ + " and got buffer " + buffer)
+ return buffer
+ }
+}
+
+object BlockManagerWorker extends Logging {
+ private var blockManagerWorker: BlockManagerWorker = null
+ private val DATA_TRANSFER_TIME_OUT_MS: Long = 500
+ private val REQUEST_RETRY_INTERVAL_MS: Long = 1000
+
+ initLogging()
+
+ def startBlockManagerWorker(manager: BlockManager) {
+ blockManagerWorker = new BlockManagerWorker(manager)
+ }
+
+ def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = {
+ val blockManager = blockManagerWorker.blockManager
+ val connectionManager = blockManager.connectionManager
+ val serializer = blockManager.serializer
+ val blockMessage = BlockMessage.fromPutBlock(msg)
+ val blockMessageArray = new BlockMessageArray(blockMessage)
+ val resultMessage = connectionManager.sendMessageReliablySync(
+ toConnManagerId, blockMessageArray.toBufferMessage())
+ return (resultMessage != None)
+ }
+
+ def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
+ val blockManager = blockManagerWorker.blockManager
+ val connectionManager = blockManager.connectionManager
+ val serializer = blockManager.serializer
+ val blockMessage = BlockMessage.fromGetBlock(msg)
+ val blockMessageArray = new BlockMessageArray(blockMessage)
+ val responseMessage = connectionManager.sendMessageReliablySync(
+ toConnManagerId, blockMessageArray.toBufferMessage())
+ responseMessage match {
+ case Some(message) => {
+ val bufferMessage = message.asInstanceOf[BufferMessage]
+ logDebug("Response message received " + bufferMessage)
+ BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => {
+ logDebug("Found " + blockMessage)
+ return blockMessage.getData
+ })
+ }
+ case None => logDebug("No response message received"); return null
+ }
+ return null
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala
new file mode 100644
index 0000000000..bb128dce7a
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockMessage.scala
@@ -0,0 +1,219 @@
+package spark.storage
+
+import java.nio._
+
+import scala.collection.mutable.StringBuilder
+import scala.collection.mutable.ArrayBuffer
+
+import spark._
+import spark.network._
+
+case class GetBlock(id: String)
+case class GotBlock(id: String, data: ByteBuffer)
+case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel)
+
+class BlockMessage() extends Logging{
+ // Un-initialized: typ = 0
+ // GetBlock: typ = 1
+ // GotBlock: typ = 2
+ // PutBlock: typ = 3
+ private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED
+ private var id: String = null
+ private var data: ByteBuffer = null
+ private var level: StorageLevel = null
+
+ initLogging()
+
+ def set(getBlock: GetBlock) {
+ typ = BlockMessage.TYPE_GET_BLOCK
+ id = getBlock.id
+ }
+
+ def set(gotBlock: GotBlock) {
+ typ = BlockMessage.TYPE_GOT_BLOCK
+ id = gotBlock.id
+ data = gotBlock.data
+ }
+
+ def set(putBlock: PutBlock) {
+ typ = BlockMessage.TYPE_PUT_BLOCK
+ id = putBlock.id
+ data = putBlock.data
+ level = putBlock.level
+ }
+
+ def set(buffer: ByteBuffer) {
+ val startTime = System.currentTimeMillis
+ /*
+ println()
+ println("BlockMessage: ")
+ while(buffer.remaining > 0) {
+ print(buffer.get())
+ }
+ buffer.rewind()
+ println()
+ println()
+ */
+ typ = buffer.getInt()
+ val idLength = buffer.getInt()
+ val idBuilder = new StringBuilder(idLength)
+ for (i <- 1 to idLength) {
+ idBuilder += buffer.getChar()
+ }
+ id = idBuilder.toString()
+
+ logDebug("Set from buffer Result: " + typ + " " + id)
+ logDebug("Buffer position is " + buffer.position)
+ if (typ == BlockMessage.TYPE_PUT_BLOCK) {
+
+ val booleanInt = buffer.getInt()
+ val replication = buffer.getInt()
+ level = new StorageLevel(booleanInt, replication)
+
+ val dataLength = buffer.getInt()
+ data = ByteBuffer.allocate(dataLength)
+ if (dataLength != buffer.remaining) {
+ throw new Exception("Error parsing buffer")
+ }
+ data.put(buffer)
+ data.flip()
+ logDebug("Set from buffer Result 2: " + level + " " + data)
+ } else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
+
+ val dataLength = buffer.getInt()
+ logDebug("Data length is "+ dataLength)
+ logDebug("Buffer position is " + buffer.position)
+ data = ByteBuffer.allocate(dataLength)
+ if (dataLength != buffer.remaining) {
+ throw new Exception("Error parsing buffer")
+ }
+ data.put(buffer)
+ data.flip()
+ logDebug("Set from buffer Result 3: " + data)
+ }
+
+ val finishTime = System.currentTimeMillis
+ logDebug("Converted " + id + " from bytebuffer in " + (finishTime - startTime) / 1000.0 + " s")
+ }
+
+ def set(bufferMsg: BufferMessage) {
+ val buffer = bufferMsg.buffers.apply(0)
+ buffer.clear()
+ set(buffer)
+ }
+
+ def getType(): Int = {
+ return typ
+ }
+
+ def getId(): String = {
+ return id
+ }
+
+ def getData(): ByteBuffer = {
+ return data
+ }
+
+ def getLevel(): StorageLevel = {
+ return level
+ }
+
+ def toBufferMessage(): BufferMessage = {
+ val startTime = System.currentTimeMillis
+ val buffers = new ArrayBuffer[ByteBuffer]()
+ var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2)
+ buffer.putInt(typ).putInt(id.length())
+ id.foreach((x: Char) => buffer.putChar(x))
+ buffer.flip()
+ buffers += buffer
+
+ if (typ == BlockMessage.TYPE_PUT_BLOCK) {
+ buffer = ByteBuffer.allocate(8).putInt(level.toInt()).putInt(level.replication)
+ buffer.flip()
+ buffers += buffer
+
+ buffer = ByteBuffer.allocate(4).putInt(data.remaining)
+ buffer.flip()
+ buffers += buffer
+
+ buffers += data
+ } else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
+ buffer = ByteBuffer.allocate(4).putInt(data.remaining)
+ buffer.flip()
+ buffers += buffer
+
+ buffers += data
+ }
+
+ logDebug("Start to log buffers.")
+ buffers.foreach((x: ByteBuffer) => logDebug("" + x))
+ /*
+ println()
+ println("BlockMessage: ")
+ buffers.foreach(b => {
+ while(b.remaining > 0) {
+ print(b.get())
+ }
+ b.rewind()
+ })
+ println()
+ println()
+ */
+ val finishTime = System.currentTimeMillis
+ logDebug("Converted " + id + " to buffer message in " + (finishTime - startTime) / 1000.0 + " s")
+ return Message.createBufferMessage(buffers)
+ }
+
+ override def toString(): String = {
+ "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level +
+ ", data = " + (if (data != null) data.remaining.toString else "null") + "]"
+ }
+}
+
+object BlockMessage {
+ val TYPE_NON_INITIALIZED: Int = 0
+ val TYPE_GET_BLOCK: Int = 1
+ val TYPE_GOT_BLOCK: Int = 2
+ val TYPE_PUT_BLOCK: Int = 3
+
+ def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(bufferMessage)
+ newBlockMessage
+ }
+
+ def fromByteBuffer(buffer: ByteBuffer): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(buffer)
+ newBlockMessage
+ }
+
+ def fromGetBlock(getBlock: GetBlock): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(getBlock)
+ newBlockMessage
+ }
+
+ def fromGotBlock(gotBlock: GotBlock): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(gotBlock)
+ newBlockMessage
+ }
+
+ def fromPutBlock(putBlock: PutBlock): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(putBlock)
+ newBlockMessage
+ }
+
+ def main(args: Array[String]) {
+ val B = new BlockMessage()
+ B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.DISK_AND_MEMORY_2))
+ val bMsg = B.toBufferMessage()
+ val C = new BlockMessage()
+ C.set(bMsg)
+
+ println(B.getId() + " " + B.getLevel())
+ println(C.getId() + " " + C.getLevel())
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala
new file mode 100644
index 0000000000..5f411d3488
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockMessageArray.scala
@@ -0,0 +1,140 @@
+package spark.storage
+import java.nio._
+
+import scala.collection.mutable.StringBuilder
+import scala.collection.mutable.ArrayBuffer
+
+import spark._
+import spark.network._
+
+class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging {
+
+ def this(bm: BlockMessage) = this(Array(bm))
+
+ def this() = this(null.asInstanceOf[Seq[BlockMessage]])
+
+ def apply(i: Int) = blockMessages(i)
+
+ def iterator = blockMessages.iterator
+
+ def length = blockMessages.length
+
+ initLogging()
+
+ def set(bufferMessage: BufferMessage) {
+ val startTime = System.currentTimeMillis
+ val newBlockMessages = new ArrayBuffer[BlockMessage]()
+ val buffer = bufferMessage.buffers(0)
+ buffer.clear()
+ /*
+ println()
+ println("BlockMessageArray: ")
+ while(buffer.remaining > 0) {
+ print(buffer.get())
+ }
+ buffer.rewind()
+ println()
+ println()
+ */
+ while(buffer.remaining() > 0) {
+ val size = buffer.getInt()
+ logDebug("Creating block message of size " + size + " bytes")
+ val newBuffer = buffer.slice()
+ newBuffer.clear()
+ newBuffer.limit(size)
+ logDebug("Trying to convert buffer " + newBuffer + " to block message")
+ val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer)
+ logDebug("Created " + newBlockMessage)
+ newBlockMessages += newBlockMessage
+ buffer.position(buffer.position() + size)
+ }
+ val finishTime = System.currentTimeMillis
+ logDebug("Converted block message array from buffer message in " + (finishTime - startTime) / 1000.0 + " s")
+ this.blockMessages = newBlockMessages
+ }
+
+ def toBufferMessage(): BufferMessage = {
+ val buffers = new ArrayBuffer[ByteBuffer]()
+
+ blockMessages.foreach(blockMessage => {
+ val bufferMessage = blockMessage.toBufferMessage
+ logDebug("Adding " + blockMessage)
+ val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size)
+ sizeBuffer.flip
+ buffers += sizeBuffer
+ buffers ++= bufferMessage.buffers
+ logDebug("Added " + bufferMessage)
+ })
+
+ logDebug("Buffer list:")
+ buffers.foreach((x: ByteBuffer) => logDebug("" + x))
+ /*
+ println()
+ println("BlockMessageArray: ")
+ buffers.foreach(b => {
+ while(b.remaining > 0) {
+ print(b.get())
+ }
+ b.rewind()
+ })
+ println()
+ println()
+ */
+ return Message.createBufferMessage(buffers)
+ }
+}
+
+object BlockMessageArray {
+
+ def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = {
+ val newBlockMessageArray = new BlockMessageArray()
+ newBlockMessageArray.set(bufferMessage)
+ newBlockMessageArray
+ }
+
+ def main(args: Array[String]) {
+ val blockMessages =
+ (0 until 10).map(i => {
+ if (i % 2 == 0) {
+ val buffer = ByteBuffer.allocate(100)
+ buffer.clear
+ BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY))
+ } else {
+ BlockMessage.fromGetBlock(GetBlock(i.toString))
+ }
+ })
+ val blockMessageArray = new BlockMessageArray(blockMessages)
+ println("Block message array created")
+
+ val bufferMessage = blockMessageArray.toBufferMessage
+ println("Converted to buffer message")
+
+ val totalSize = bufferMessage.size
+ val newBuffer = ByteBuffer.allocate(totalSize)
+ newBuffer.clear()
+ bufferMessage.buffers.foreach(buffer => {
+ newBuffer.put(buffer)
+ buffer.rewind()
+ })
+ newBuffer.flip
+ val newBufferMessage = Message.createBufferMessage(newBuffer)
+ println("Copied to new buffer message, size = " + newBufferMessage.size)
+
+ val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage)
+ println("Converted back to block message array")
+ newBlockMessageArray.foreach(blockMessage => {
+ blockMessage.getType() match {
+ case BlockMessage.TYPE_PUT_BLOCK => {
+ val pB = PutBlock(blockMessage.getId(), blockMessage.getData(), blockMessage.getLevel())
+ println(pB)
+ }
+ case BlockMessage.TYPE_GET_BLOCK => {
+ val gB = new GetBlock(blockMessage.getId())
+ println(gB)
+ }
+ }
+ })
+ }
+}
+
+
diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala
new file mode 100644
index 0000000000..0584cc2d4f
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockStore.scala
@@ -0,0 +1,282 @@
+package spark.storage
+
+import spark.{Utils, Logging, Serializer, SizeEstimator}
+
+import scala.collection.mutable.ArrayBuffer
+
+import java.io.{File, RandomAccessFile}
+import java.nio.ByteBuffer
+import java.nio.channels.FileChannel.MapMode
+import java.util.{UUID, LinkedHashMap}
+import java.util.concurrent.Executors
+
+import it.unimi.dsi.fastutil.io._
+
+/**
+ * Abstract class to store blocks
+ */
+abstract class BlockStore(blockManager: BlockManager) extends Logging {
+ initLogging()
+
+ def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel)
+
+ def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer]
+
+ def getBytes(blockId: String): Option[ByteBuffer]
+
+ def getValues(blockId: String): Option[Iterator[Any]]
+
+ def remove(blockId: String)
+
+ def dataSerialize(values: Iterator[Any]): ByteBuffer = blockManager.dataSerialize(values)
+
+ def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = blockManager.dataDeserialize(bytes)
+}
+
+/**
+ * Class to store blocks in memory
+ */
+class MemoryStore(blockManager: BlockManager, maxMemory: Long)
+ extends BlockStore(blockManager) {
+
+ class Entry(var value: Any, val size: Long, val deserialized: Boolean)
+
+ private val memoryStore = new LinkedHashMap[String, Entry](32, 0.75f, true)
+ private var currentMemory = 0L
+
+ private val blockDropper = Executors.newSingleThreadExecutor()
+
+ def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ if (level.deserialized) {
+ bytes.rewind()
+ val values = dataDeserialize(bytes)
+ val elements = new ArrayBuffer[Any]
+ elements ++= values
+ val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
+ ensureFreeSpace(sizeEstimate)
+ val entry = new Entry(elements, sizeEstimate, true)
+ memoryStore.synchronized { memoryStore.put(blockId, entry) }
+ currentMemory += sizeEstimate
+ logDebug("Block " + blockId + " stored as values to memory")
+ } else {
+ val entry = new Entry(bytes, bytes.array().length, false)
+ ensureFreeSpace(bytes.array.length)
+ memoryStore.synchronized { memoryStore.put(blockId, entry) }
+ currentMemory += bytes.array().length
+ logDebug("Block " + blockId + " stored as " + bytes.array().length + " bytes to memory")
+ }
+ }
+
+ def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = {
+ if (level.deserialized) {
+ val elements = new ArrayBuffer[Any]
+ elements ++= values
+ val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
+ ensureFreeSpace(sizeEstimate)
+ val entry = new Entry(elements, sizeEstimate, true)
+ memoryStore.synchronized { memoryStore.put(blockId, entry) }
+ currentMemory += sizeEstimate
+ logDebug("Block " + blockId + " stored as values to memory")
+ return Left(elements.iterator)
+ } else {
+ val bytes = dataSerialize(values)
+ ensureFreeSpace(bytes.array().length)
+ val entry = new Entry(bytes, bytes.array().length, false)
+ memoryStore.synchronized { memoryStore.put(blockId, entry) }
+ currentMemory += bytes.array().length
+ logDebug("Block " + blockId + " stored as " + bytes.array.length + " bytes to memory")
+ return Right(bytes)
+ }
+ }
+
+ def getBytes(blockId: String): Option[ByteBuffer] = {
+ throw new UnsupportedOperationException("Not implemented")
+ }
+
+ def getValues(blockId: String): Option[Iterator[Any]] = {
+ val entry = memoryStore.synchronized { memoryStore.get(blockId) }
+ if (entry == null) {
+ return None
+ }
+ if (entry.deserialized) {
+ return Some(entry.value.asInstanceOf[ArrayBuffer[Any]].toIterator)
+ } else {
+ return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer]))
+ }
+ }
+
+ def remove(blockId: String) {
+ memoryStore.synchronized {
+ val entry = memoryStore.get(blockId)
+ if (entry != null) {
+ memoryStore.remove(blockId)
+ currentMemory -= entry.size
+ logDebug("Block " + blockId + " of size " + entry.size + " dropped from memory")
+ } else {
+ logWarning("Block " + blockId + " could not be removed as it doesnt exist")
+ }
+ }
+ }
+
+ private def drop(blockId: String) {
+ blockDropper.submit(new Runnable() {
+ def run() {
+ blockManager.dropFromMemory(blockId)
+ }
+ })
+ }
+
+ private def ensureFreeSpace(space: Long) {
+ logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
+ space, currentMemory, maxMemory))
+
+ val droppedBlockIds = new ArrayBuffer[String]()
+ var droppedMemory = 0L
+
+ memoryStore.synchronized {
+ val iter = memoryStore.entrySet().iterator()
+ while (maxMemory - (currentMemory - droppedMemory) < space && iter.hasNext) {
+ val pair = iter.next()
+ val blockId = pair.getKey
+ droppedBlockIds += blockId
+ droppedMemory += pair.getValue.size
+ logDebug("Decided to drop " + blockId)
+ }
+ }
+
+ for (blockId <- droppedBlockIds) {
+ drop(blockId)
+ }
+
+ droppedBlockIds.clear
+ }
+}
+
+
+/**
+ * Class to store blocks in disk
+ */
+class DiskStore(blockManager: BlockManager, rootDirs: String)
+ extends BlockStore(blockManager) {
+
+ val MAX_DIR_CREATION_ATTEMPTS: Int = 10
+ val localDirs = createLocalDirs()
+ var lastLocalDirUsed = 0
+
+ addShutdownHook()
+
+ def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ logDebug("Attempting to put block " + blockId)
+ val startTime = System.currentTimeMillis
+ val file = createFile(blockId)
+ if (file != null) {
+ val channel = new RandomAccessFile(file, "rw").getChannel()
+ val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.array.length)
+ buffer.put(bytes.array)
+ channel.close()
+ val finishTime = System.currentTimeMillis
+ logDebug("Block " + blockId + " stored to file of " + bytes.array.length + " bytes to disk in " + (finishTime - startTime) + " ms")
+ } else {
+ logError("File not created for block " + blockId)
+ }
+ }
+
+ def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = {
+ val bytes = dataSerialize(values)
+ logDebug("Converted block " + blockId + " to " + bytes.array.length + " bytes")
+ putBytes(blockId, bytes, level)
+ return Right(bytes)
+ }
+
+ def getBytes(blockId: String): Option[ByteBuffer] = {
+ val file = getFile(blockId)
+ val length = file.length().toInt
+ val channel = new RandomAccessFile(file, "r").getChannel()
+ val bytes = ByteBuffer.allocate(length)
+ bytes.put(channel.map(MapMode.READ_WRITE, 0, length))
+ return Some(bytes)
+ }
+
+ def getValues(blockId: String): Option[Iterator[Any]] = {
+ val file = getFile(blockId)
+ val length = file.length().toInt
+ val channel = new RandomAccessFile(file, "r").getChannel()
+ val bytes = channel.map(MapMode.READ_ONLY, 0, length)
+ val buffer = dataDeserialize(bytes)
+ channel.close()
+ return Some(buffer)
+ }
+
+ def remove(blockId: String) {
+ throw new UnsupportedOperationException("Not implemented")
+ }
+
+ private def createFile(blockId: String): File = {
+ val file = getFile(blockId)
+ if (file == null) {
+ lastLocalDirUsed = (lastLocalDirUsed + 1) % localDirs.size
+ val newFile = new File(localDirs(lastLocalDirUsed), blockId)
+ newFile.getParentFile.mkdirs()
+ return newFile
+ } else {
+ logError("File for block " + blockId + " already exists on disk, " + file)
+ return null
+ }
+ }
+
+ private def getFile(blockId: String): File = {
+ logDebug("Getting file for block " + blockId)
+ // Search for the file in all the local directories, only one of them should have the file
+ val files = localDirs.map(localDir => new File(localDir, blockId)).filter(_.exists)
+ if (files.size > 1) {
+ throw new Exception("Multiple files for same block " + blockId + " exists: " +
+ files.map(_.toString).reduceLeft(_ + ", " + _))
+ return null
+ } else if (files.size == 0) {
+ return null
+ } else {
+ logDebug("Got file " + files(0) + " of size " + files(0).length + " bytes")
+ return files(0)
+ }
+ }
+
+ private def createLocalDirs(): Seq[File] = {
+ logDebug("Creating local directories at root dirs '" + rootDirs + "'")
+ rootDirs.split("[;,:]").map(rootDir => {
+ var foundLocalDir: Boolean = false
+ var localDir: File = null
+ var localDirUuid: UUID = null
+ var tries = 0
+ while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
+ tries += 1
+ try {
+ localDirUuid = UUID.randomUUID()
+ localDir = new File(rootDir, "spark-local-" + localDirUuid)
+ if (!localDir.exists) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
+ " attempts to create local dir in " + rootDir)
+ System.exit(1)
+ }
+ logDebug("Created local directory at " + localDir)
+ localDir
+ })
+ }
+
+ private def addShutdownHook() {
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
+ override def run() {
+ logDebug("Shutdown hook called")
+ localDirs.foreach(localDir => Utils.deleteRecursively(localDir))
+ }
+ })
+ }
+}
diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala
new file mode 100644
index 0000000000..a2833a7090
--- /dev/null
+++ b/core/src/main/scala/spark/storage/StorageLevel.scala
@@ -0,0 +1,78 @@
+package spark.storage
+
+import java.io._
+
+class StorageLevel(
+ var useDisk: Boolean,
+ var useMemory: Boolean,
+ var deserialized: Boolean,
+ var replication: Int = 1)
+ extends Externalizable {
+
+ // TODO: Also add fields for caching priority, dataset ID, and flushing.
+
+ def this(booleanInt: Int, replication: Int) {
+ this(((booleanInt & 4) != 0),
+ ((booleanInt & 2) != 0),
+ ((booleanInt & 1) != 0),
+ replication)
+ }
+
+ def this() = this(false, true, false) // For deserialization
+
+ override def clone(): StorageLevel = new StorageLevel(
+ this.useDisk, this.useMemory, this.deserialized, this.replication)
+
+ override def equals(other: Any): Boolean = other match {
+ case s: StorageLevel =>
+ s.useDisk == useDisk &&
+ s.useMemory == useMemory &&
+ s.deserialized == deserialized &&
+ s.replication == replication
+ case _ =>
+ false
+ }
+
+ def toInt(): Int = {
+ var ret = 0
+ if (useDisk) {
+ ret += 4
+ }
+ if (useMemory) {
+ ret += 2
+ }
+ if (deserialized) {
+ ret += 1
+ }
+ return ret
+ }
+
+ override def writeExternal(out: ObjectOutput) {
+ out.writeByte(toInt().toByte)
+ out.writeByte(replication.toByte)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ val flags = in.readByte()
+ useDisk = (flags & 4) != 0
+ useMemory = (flags & 2) != 0
+ deserialized = (flags & 1) != 0
+ replication = in.readByte()
+ }
+
+ override def toString(): String =
+ "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
+}
+
+object StorageLevel {
+ val NONE = new StorageLevel(false, false, false)
+ val DISK_ONLY = new StorageLevel(true, false, false)
+ val MEMORY_ONLY = new StorageLevel(false, true, false)
+ val MEMORY_ONLY_2 = new StorageLevel(false, true, false, 2)
+ val MEMORY_ONLY_DESER = new StorageLevel(false, true, true)
+ val MEMORY_ONLY_DESER_2 = new StorageLevel(false, true, true, 2)
+ val DISK_AND_MEMORY = new StorageLevel(true, true, false)
+ val DISK_AND_MEMORY_2 = new StorageLevel(true, true, false, 2)
+ val DISK_AND_MEMORY_DESER = new StorageLevel(true, true, true)
+ val DISK_AND_MEMORY_DESER_2 = new StorageLevel(true, true, true, 2)
+}
diff --git a/core/src/main/scala/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/spark/util/ByteBufferInputStream.scala
new file mode 100644
index 0000000000..abe2d99dd8
--- /dev/null
+++ b/core/src/main/scala/spark/util/ByteBufferInputStream.scala
@@ -0,0 +1,30 @@
+package spark.util
+
+import java.io.InputStream
+import java.nio.ByteBuffer
+
+class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream {
+ override def read(): Int = {
+ if (buffer.remaining() == 0) {
+ -1
+ } else {
+ buffer.get()
+ }
+ }
+
+ override def read(dest: Array[Byte]): Int = {
+ read(dest, 0, dest.length)
+ }
+
+ override def read(dest: Array[Byte], offset: Int, length: Int): Int = {
+ val amountToGet = math.min(buffer.remaining(), length)
+ buffer.get(dest, offset, amountToGet)
+ return amountToGet
+ }
+
+ override def skip(bytes: Long): Long = {
+ val amountToSkip = math.min(bytes, buffer.remaining).toInt
+ buffer.position(buffer.position + amountToSkip)
+ return amountToSkip
+ }
+}
diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala
new file mode 100644
index 0000000000..efb1ae7529
--- /dev/null
+++ b/core/src/main/scala/spark/util/StatCounter.scala
@@ -0,0 +1,89 @@
+package spark.util
+
+/**
+ * A class for tracking the statistics of a set of numbers (count, mean and variance) in a
+ * numerically robust way. Includes support for merging two StatCounters. Based on Welford and
+ * Chan's algorithms described at http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
+ */
+class StatCounter(values: TraversableOnce[Double]) {
+ private var n: Long = 0 // Running count of our values
+ private var mu: Double = 0 // Running mean of our values
+ private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2)
+
+ merge(values)
+
+ def this() = this(Nil)
+
+ def merge(value: Double): StatCounter = {
+ val delta = value - mu
+ n += 1
+ mu += delta / n
+ m2 += delta * (value - mu)
+ this
+ }
+
+ def merge(values: TraversableOnce[Double]): StatCounter = {
+ values.foreach(v => merge(v))
+ this
+ }
+
+ def merge(other: StatCounter): StatCounter = {
+ if (other == this) {
+ merge(other.copy()) // Avoid overwriting fields in a weird order
+ } else {
+ val delta = other.mu - mu
+ if (other.n * 10 < n) {
+ mu = mu + (delta * other.n) / (n + other.n)
+ } else if (n * 10 < other.n) {
+ mu = other.mu - (delta * n) / (n + other.n)
+ } else {
+ mu = (mu * n + other.mu * other.n) / (n + other.n)
+ }
+ m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
+ n += other.n
+ this
+ }
+ }
+
+ def copy(): StatCounter = {
+ val other = new StatCounter
+ other.n = n
+ other.mu = mu
+ other.m2 = m2
+ other
+ }
+
+ def count: Long = n
+
+ def mean: Double = mu
+
+ def sum: Double = n * mu
+
+ def variance: Double = {
+ if (n == 0)
+ Double.NaN
+ else
+ m2 / n
+ }
+
+ def sampleVariance: Double = {
+ if (n <= 1)
+ Double.NaN
+ else
+ m2 / (n - 1)
+ }
+
+ def stdev: Double = math.sqrt(variance)
+
+ def sampleStdev: Double = math.sqrt(sampleVariance)
+
+ override def toString: String = {
+ "(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev)
+ }
+}
+
+object StatCounter {
+ def apply(values: TraversableOnce[Double]) = new StatCounter(values)
+
+ def apply(values: Double*) = new StatCounter(values)
+}
diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala
index 60290d14ca..3d170a6e22 100644
--- a/core/src/test/scala/spark/CacheTrackerSuite.scala
+++ b/core/src/test/scala/spark/CacheTrackerSuite.scala
@@ -1,95 +1,103 @@
package spark
import org.scalatest.FunSuite
-import collection.mutable.HashMap
+
+import scala.collection.mutable.HashMap
+
+import akka.actor._
+import akka.actor.Actor
+import akka.actor.Actor._
class CacheTrackerSuite extends FunSuite {
test("CacheTrackerActor slave initialization & cache status") {
- System.setProperty("spark.master.port", "1345")
+ //System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
- val tracker = new CacheTrackerActor
+ val tracker = actorOf(new CacheTrackerActor)
tracker.start()
- tracker !? SlaveCacheStarted("host001", initialSize)
+ tracker !! SlaveCacheStarted("host001", initialSize)
- assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 0L)))
+ assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 0L)))
- tracker !? StopCacheTracker
+ tracker !! StopCacheTracker
}
test("RegisterRDD") {
- System.setProperty("spark.master.port", "1345")
+ //System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
- val tracker = new CacheTrackerActor
+ val tracker = actorOf(new CacheTrackerActor)
tracker.start()
- tracker !? SlaveCacheStarted("host001", initialSize)
+ tracker !! SlaveCacheStarted("host001", initialSize)
- tracker !? RegisterRDD(1, 3)
- tracker !? RegisterRDD(2, 1)
+ tracker !! RegisterRDD(1, 3)
+ tracker !! RegisterRDD(2, 1)
- assert(getCacheLocations(tracker) == Map(1 -> List(List(), List(), List()), 2 -> List(List())))
+ assert(getCacheLocations(tracker) === Map(1 -> List(List(), List(), List()), 2 -> List(List())))
- tracker !? StopCacheTracker
+ tracker !! StopCacheTracker
}
test("AddedToCache") {
- System.setProperty("spark.master.port", "1345")
+ //System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
- val tracker = new CacheTrackerActor
+ val tracker = actorOf(new CacheTrackerActor)
tracker.start()
- tracker !? SlaveCacheStarted("host001", initialSize)
+ tracker !! SlaveCacheStarted("host001", initialSize)
- tracker !? RegisterRDD(1, 2)
- tracker !? RegisterRDD(2, 1)
+ tracker !! RegisterRDD(1, 2)
+ tracker !! RegisterRDD(2, 1)
- tracker !? AddedToCache(1, 0, "host001", 2L << 15)
- tracker !? AddedToCache(1, 1, "host001", 2L << 11)
- tracker !? AddedToCache(2, 0, "host001", 3L << 10)
+ tracker !! AddedToCache(1, 0, "host001", 2L << 15)
+ tracker !! AddedToCache(1, 1, "host001", 2L << 11)
+ tracker !! AddedToCache(2, 0, "host001", 3L << 10)
- assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L)))
+ assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 72704L)))
- assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
+ assert(getCacheLocations(tracker) ===
+ Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
- tracker !? StopCacheTracker
+ tracker !! StopCacheTracker
}
test("DroppedFromCache") {
- System.setProperty("spark.master.port", "1345")
+ //System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
- val tracker = new CacheTrackerActor
+ val tracker = actorOf(new CacheTrackerActor)
tracker.start()
- tracker !? SlaveCacheStarted("host001", initialSize)
+ tracker !! SlaveCacheStarted("host001", initialSize)
- tracker !? RegisterRDD(1, 2)
- tracker !? RegisterRDD(2, 1)
+ tracker !! RegisterRDD(1, 2)
+ tracker !! RegisterRDD(2, 1)
- tracker !? AddedToCache(1, 0, "host001", 2L << 15)
- tracker !? AddedToCache(1, 1, "host001", 2L << 11)
- tracker !? AddedToCache(2, 0, "host001", 3L << 10)
+ tracker !! AddedToCache(1, 0, "host001", 2L << 15)
+ tracker !! AddedToCache(1, 1, "host001", 2L << 11)
+ tracker !! AddedToCache(2, 0, "host001", 3L << 10)
- assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L)))
- assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
+ assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 72704L)))
+ assert(getCacheLocations(tracker) ===
+ Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
- tracker !? DroppedFromCache(1, 1, "host001", 2L << 11)
+ tracker !! DroppedFromCache(1, 1, "host001", 2L << 11)
- assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 68608L)))
- assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
+ assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 68608L)))
+ assert(getCacheLocations(tracker) ===
+ Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
- tracker !? StopCacheTracker
+ tracker !! StopCacheTracker
}
/**
* Helper function to get cacheLocations from CacheTracker
*/
- def getCacheLocations(tracker: CacheTrackerActor) = tracker !? GetCacheLocations match {
+ def getCacheLocations(tracker: ActorRef) = (tracker ? GetCacheLocations).get match {
case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]].map {
case (i, arr) => (i -> arr.toList)
}
diff --git a/core/src/test/scala/spark/MesosSchedulerSuite.scala b/core/src/test/scala/spark/MesosSchedulerSuite.scala
index 0e6820cbdc..54421225d8 100644
--- a/core/src/test/scala/spark/MesosSchedulerSuite.scala
+++ b/core/src/test/scala/spark/MesosSchedulerSuite.scala
@@ -2,6 +2,8 @@ package spark
import org.scalatest.FunSuite
+import spark.scheduler.mesos.MesosScheduler
+
class MesosSchedulerSuite extends FunSuite {
test("memoryStringToMb"){
diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala
index f31251e509..1ac4737f04 100644
--- a/core/src/test/scala/spark/UtilsSuite.scala
+++ b/core/src/test/scala/spark/UtilsSuite.scala
@@ -2,7 +2,7 @@ package spark
import org.scalatest.FunSuite
import java.io.{ByteArrayOutputStream, ByteArrayInputStream}
-import util.Random
+import scala.util.Random
class UtilsSuite extends FunSuite {
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 08c5a990b4..a2faf7399c 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -33,6 +33,7 @@ object SparkBuild extends Build {
"org.scalatest" %% "scalatest" % "1.6.1" % "test",
"org.scala-tools.testing" %% "scalacheck" % "1.9" % "test"
),
+ parallelExecution in Test := false,
/* Workaround for issue #206 (fixed after SBT 0.11.0) */
watchTransitiveSources <<= Defaults.inDependencies[Task[Seq[File]]](watchSources.task,
const(std.TaskExtra.constant(Nil)), aggregate = true, includeRoot = true) apply { _.join.map(_.flatten) }
@@ -57,8 +58,12 @@ object SparkBuild extends Build {
"asm" % "asm-all" % "3.3.1",
"com.google.protobuf" % "protobuf-java" % "2.4.1",
"de.javakaffee" % "kryo-serializers" % "0.9",
+ "se.scalablesolutions.akka" % "akka-actor" % "1.3.1",
+ "se.scalablesolutions.akka" % "akka-remote" % "1.3.1",
+ "se.scalablesolutions.akka" % "akka-slf4j" % "1.3.1",
"org.jboss.netty" % "netty" % "3.2.6.Final",
- "it.unimi.dsi" % "fastutil" % "6.4.2"
+ "it.unimi.dsi" % "fastutil" % "6.4.4",
+ "colt" % "colt" % "1.2.0"
)
) ++ assemblySettings ++ Seq(test in assembly := {})
@@ -68,8 +73,7 @@ object SparkBuild extends Build {
) ++ assemblySettings ++ Seq(test in assembly := {})
def examplesSettings = sharedSettings ++ Seq(
- name := "spark-examples",
- libraryDependencies += "colt" % "colt" % "1.2.0"
+ name := "spark-examples"
)
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")