aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2012-06-17 14:27:45 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-06-17 14:27:45 -0700
commit94d77f83d3c6486e2eefd41dacb90ec0ed2633a3 (patch)
treecade164584d85d81243c88d385143f2b58ce4e50
parentf46e8672492d1f23ae2f12881cef52064164e38e (diff)
downloadspark-94d77f83d3c6486e2eefd41dacb90ec0ed2633a3.tar.gz
spark-94d77f83d3c6486e2eefd41dacb90ec0ed2633a3.tar.bz2
spark-94d77f83d3c6486e2eefd41dacb90ec0ed2633a3.zip
Revert "Merge branch 'master' into dev"
This reverts commit f58da6164eaf13dd986a39a40535975096b71b44, reversing changes made to 4449eb97834ed6191dc0937d255c475191895980.
-rw-r--r--bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala11
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala70
-rw-r--r--core/src/main/scala/spark/BoundedMemoryCache.scala3
-rw-r--r--core/src/main/scala/spark/CacheTracker.scala334
-rw-r--r--core/src/main/scala/spark/CoGroupedRDD.scala6
-rw-r--r--core/src/main/scala/spark/DAGScheduler.scala374
-rw-r--r--core/src/main/scala/spark/Dependency.scala2
-rw-r--r--core/src/main/scala/spark/DiskSpillingCache.scala75
-rw-r--r--core/src/main/scala/spark/DoubleRDDFunctions.scala39
-rw-r--r--core/src/main/scala/spark/Executor.scala26
-rw-r--r--core/src/main/scala/spark/FetchFailedException.scala8
-rw-r--r--core/src/main/scala/spark/JavaSerializer.scala39
-rw-r--r--core/src/main/scala/spark/Job.scala16
-rw-r--r--core/src/main/scala/spark/KryoSerializer.scala98
-rw-r--r--core/src/main/scala/spark/LocalScheduler.scala (renamed from core/src/main/scala/spark/scheduler/local/LocalScheduler.scala)43
-rw-r--r--core/src/main/scala/spark/Logging.scala7
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala108
-rw-r--r--core/src/main/scala/spark/MesosScheduler.scala (renamed from core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala)271
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala56
-rw-r--r--core/src/main/scala/spark/ParallelShuffleFetcher.scala119
-rw-r--r--core/src/main/scala/spark/Partitioner.scala1
-rw-r--r--core/src/main/scala/spark/PipedRDD.scala1
-rw-r--r--core/src/main/scala/spark/RDD.scala104
-rw-r--r--core/src/main/scala/spark/ResultTask.scala (renamed from core/src/main/scala/spark/scheduler/ResultTask.scala)15
-rw-r--r--core/src/main/scala/spark/Scheduler.scala27
-rw-r--r--core/src/main/scala/spark/SequenceFileRDDFunctions.scala2
-rw-r--r--core/src/main/scala/spark/Serializer.scala86
-rw-r--r--core/src/main/scala/spark/SerializingCache.scala26
-rw-r--r--core/src/main/scala/spark/ShuffleMapTask.scala56
-rw-r--r--core/src/main/scala/spark/ShuffledRDD.scala2
-rw-r--r--core/src/main/scala/spark/SimpleJob.scala (renamed from core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala)259
-rw-r--r--core/src/main/scala/spark/SimpleShuffleFetcher.scala46
-rw-r--r--core/src/main/scala/spark/SparkContext.scala83
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala79
-rw-r--r--core/src/main/scala/spark/Stage.scala41
-rw-r--r--core/src/main/scala/spark/Task.scala9
-rw-r--r--core/src/main/scala/spark/TaskContext.scala3
-rw-r--r--core/src/main/scala/spark/TaskEndReason.scala16
-rw-r--r--core/src/main/scala/spark/TaskResult.scala8
-rw-r--r--core/src/main/scala/spark/UnionRDD.scala3
-rw-r--r--core/src/main/scala/spark/Utils.scala35
-rw-r--r--core/src/main/scala/spark/network/Connection.scala364
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala468
-rw-r--r--core/src/main/scala/spark/network/ConnectionManagerTest.scala74
-rw-r--r--core/src/main/scala/spark/network/Message.scala219
-rw-r--r--core/src/main/scala/spark/network/ReceiverTest.scala20
-rw-r--r--core/src/main/scala/spark/network/SenderTest.scala53
-rw-r--r--core/src/main/scala/spark/partial/ApproximateActionListener.scala66
-rw-r--r--core/src/main/scala/spark/partial/ApproximateEvaluator.scala10
-rw-r--r--core/src/main/scala/spark/partial/BoundedDouble.scala8
-rw-r--r--core/src/main/scala/spark/partial/CountEvaluator.scala38
-rw-r--r--core/src/main/scala/spark/partial/GroupedCountEvaluator.scala62
-rw-r--r--core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala65
-rw-r--r--core/src/main/scala/spark/partial/GroupedSumEvaluator.scala72
-rw-r--r--core/src/main/scala/spark/partial/MeanEvaluator.scala41
-rw-r--r--core/src/main/scala/spark/partial/PartialResult.scala86
-rw-r--r--core/src/main/scala/spark/partial/StudentTCacher.scala26
-rw-r--r--core/src/main/scala/spark/partial/SumEvaluator.scala51
-rw-r--r--core/src/main/scala/spark/scheduler/ActiveJob.scala18
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala535
-rw-r--r--core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala30
-rw-r--r--core/src/main/scala/spark/scheduler/JobListener.scala11
-rw-r--r--core/src/main/scala/spark/scheduler/JobResult.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/JobWaiter.scala43
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala142
-rw-r--r--core/src/main/scala/spark/scheduler/Stage.scala86
-rw-r--r--core/src/main/scala/spark/scheduler/Task.scala11
-rw-r--r--core/src/main/scala/spark/scheduler/TaskResult.scala34
-rw-r--r--core/src/main/scala/spark/scheduler/TaskScheduler.scala27
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala16
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSet.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala364
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala32
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala588
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala517
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerWorker.scala142
-rw-r--r--core/src/main/scala/spark/storage/BlockMessage.scala219
-rw-r--r--core/src/main/scala/spark/storage/BlockMessageArray.scala140
-rw-r--r--core/src/main/scala/spark/storage/BlockStore.scala291
-rw-r--r--core/src/main/scala/spark/storage/StorageLevel.scala80
-rw-r--r--core/src/main/scala/spark/util/ByteBufferInputStream.scala30
-rw-r--r--core/src/main/scala/spark/util/StatCounter.scala89
-rw-r--r--core/src/test/scala/spark/CacheTrackerSuite.scala86
-rw-r--r--core/src/test/scala/spark/MesosSchedulerSuite.scala2
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala6
-rw-r--r--core/src/test/scala/spark/UtilsSuite.scala2
-rw-r--r--core/src/test/scala/spark/storage/BlockManagerSuite.scala212
-rw-r--r--project/SparkBuild.scala10
-rwxr-xr-xsbt/sbt2
-rw-r--r--sbt/sbt-launch-0.11.1.jarbin0 -> 1041757 bytes
-rw-r--r--sbt/sbt-launch-0.11.3-2.jarbin1096763 -> 0 bytes
91 files changed, 1413 insertions, 6700 deletions
diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
index ed8ace3a57..8ce7abd03f 100644
--- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
+++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
@@ -11,7 +11,6 @@ import scala.xml.{XML,NodeSeq}
import scala.collection.mutable.ArrayBuffer
import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream}
-import java.nio.ByteBuffer
object WikipediaPageRankStandalone {
def main(args: Array[String]) {
@@ -118,23 +117,23 @@ class WPRSerializer extends spark.Serializer {
}
class WPRSerializerInstance extends SerializerInstance {
- def serialize[T](t: T): ByteBuffer = {
+ def serialize[T](t: T): Array[Byte] = {
throw new UnsupportedOperationException()
}
- def deserialize[T](bytes: ByteBuffer): T = {
+ def deserialize[T](bytes: Array[Byte]): T = {
throw new UnsupportedOperationException()
}
- def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
+ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
throw new UnsupportedOperationException()
}
- def serializeStream(s: OutputStream): SerializationStream = {
+ def outputStream(s: OutputStream): SerializationStream = {
new WPRSerializationStream(s)
}
- def deserializeStream(s: InputStream): DeserializationStream = {
+ def inputStream(s: InputStream): DeserializationStream = {
new WPRDeserializationStream(s)
}
}
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
deleted file mode 100644
index e00a0d80fa..0000000000
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ /dev/null
@@ -1,70 +0,0 @@
-package spark
-
-import java.io.EOFException
-import java.net.URL
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-
-import spark.storage.BlockException
-import spark.storage.BlockManagerId
-
-import it.unimi.dsi.fastutil.io.FastBufferedInputStream
-
-
-class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
- def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
- logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
- val ser = SparkEnv.get.serializer.newInstance()
- val blockManager = SparkEnv.get.blockManager
-
- val startTime = System.currentTimeMillis
- val addresses = SparkEnv.get.mapOutputTracker.getServerAddresses(shuffleId)
- logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
- shuffleId, reduceId, System.currentTimeMillis - startTime))
-
- val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[Int]]
- for ((address, index) <- addresses.zipWithIndex) {
- splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += index
- }
-
- val blocksByAddress: Seq[(BlockManagerId, Seq[String])] = splitsByAddress.toSeq.map {
- case (address, splits) =>
- (address, splits.map(i => "shuffleid_%d_%d_%d".format(shuffleId, i, reduceId)))
- }
-
- try {
- val blockOptions = blockManager.get(blocksByAddress)
- logDebug("Fetching map output blocks for shuffle %d, reduce %d took %d ms".format(
- shuffleId, reduceId, System.currentTimeMillis - startTime))
- blockOptions.foreach(x => {
- val (blockId, blockOption) = x
- blockOption match {
- case Some(block) => {
- val values = block.asInstanceOf[Iterator[Any]]
- for(value <- values) {
- val v = value.asInstanceOf[(K, V)]
- func(v._1, v._2)
- }
- }
- case None => {
- throw new BlockException(blockId, "Did not get block " + blockId)
- }
- }
- })
- } catch {
- case be: BlockException => {
- val regex = "shuffledid_([0-9]*)_([0-9]*)_([0-9]]*)".r
- be.blockId match {
- case regex(sId, mId, rId) => {
- val address = addresses(mId.toInt)
- throw new FetchFailedException(address, sId.toInt, mId.toInt, rId.toInt, be)
- }
- case _ => {
- throw be
- }
- }
- }
- }
- }
-}
diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala
index fa5dcee7bb..1162e34ab0 100644
--- a/core/src/main/scala/spark/BoundedMemoryCache.scala
+++ b/core/src/main/scala/spark/BoundedMemoryCache.scala
@@ -90,8 +90,7 @@ class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
- // TODO: remove BoundedMemoryCache
- SparkEnv.get.cacheTracker.dropEntry(datasetId.asInstanceOf[(Int, Int)]._2, partition)
+ SparkEnv.get.cacheTracker.dropEntry(datasetId, partition)
}
}
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
index 64b4af0ae2..4867829c17 100644
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -1,17 +1,11 @@
package spark
-import akka.actor._
-import akka.actor.Actor
-import akka.actor.Actor._
-import akka.util.duration._
-
-import scala.collection.mutable.ArrayBuffer
+import scala.actors._
+import scala.actors.Actor._
+import scala.actors.remote._
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
-import spark.storage.BlockManager
-import spark.storage.StorageLevel
-
sealed trait CacheTrackerMessage
case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
@@ -24,8 +18,8 @@ case object GetCacheStatus extends CacheTrackerMessage
case object GetCacheLocations extends CacheTrackerMessage
case object StopCacheTracker extends CacheTrackerMessage
-class CacheTrackerActor extends Actor with Logging {
- // TODO: Should probably store (String, CacheType) tuples
+
+class CacheTrackerActor extends DaemonActor with Logging {
private val locs = new HashMap[Int, Array[List[String]]]
/**
@@ -34,93 +28,109 @@ class CacheTrackerActor extends Actor with Logging {
private val slaveCapacity = new HashMap[String, Long]
private val slaveUsage = new HashMap[String, Long]
+ // TODO: Should probably store (String, CacheType) tuples
+
private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
- def receive = {
- case SlaveCacheStarted(host: String, size: Long) =>
- logInfo("Started slave cache (size %s) on %s".format(
- Utils.memoryBytesToString(size), host))
- slaveCapacity.put(host, size)
- slaveUsage.put(host, 0)
- self.reply(true)
-
- case RegisterRDD(rddId: Int, numPartitions: Int) =>
- logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
- locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
- self.reply(true)
-
- case AddedToCache(rddId, partition, host, size) =>
- slaveUsage.put(host, getCacheUsage(host) + size)
- logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
- rddId, partition, host, Utils.memoryBytesToString(size),
- Utils.memoryBytesToString(getCacheAvailable(host))))
- locs(rddId)(partition) = host :: locs(rddId)(partition)
- self.reply(true)
+ def act() {
+ val port = System.getProperty("spark.master.port").toInt
+ RemoteActor.alive(port)
+ RemoteActor.register('CacheTracker, self)
+ logInfo("Registered actor on port " + port)
+
+ loop {
+ react {
+ case SlaveCacheStarted(host: String, size: Long) =>
+ logInfo("Started slave cache (size %s) on %s".format(
+ Utils.memoryBytesToString(size), host))
+ slaveCapacity.put(host, size)
+ slaveUsage.put(host, 0)
+ reply('OK)
+
+ case RegisterRDD(rddId: Int, numPartitions: Int) =>
+ logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
+ locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
+ reply('OK)
+
+ case AddedToCache(rddId, partition, host, size) =>
+ if (size > 0) {
+ slaveUsage.put(host, getCacheUsage(host) + size)
+ logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
+ rddId, partition, host, Utils.memoryBytesToString(size),
+ Utils.memoryBytesToString(getCacheAvailable(host))))
+ } else {
+ logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host))
+ }
+ locs(rddId)(partition) = host :: locs(rddId)(partition)
+ reply('OK)
+
+ case DroppedFromCache(rddId, partition, host, size) =>
+ if (size > 0) {
+ logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
+ rddId, partition, host, Utils.memoryBytesToString(size),
+ Utils.memoryBytesToString(getCacheAvailable(host))))
+ slaveUsage.put(host, getCacheUsage(host) - size)
+
+ // Do a sanity check to make sure usage is greater than 0.
+ val usage = getCacheUsage(host)
+ if (usage < 0) {
+ logError("Cache usage on %s is negative (%d)".format(host, usage))
+ }
+ } else {
+ logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host))
+ }
+ locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
+ reply('OK)
- case DroppedFromCache(rddId, partition, host, size) =>
- logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
- rddId, partition, host, Utils.memoryBytesToString(size),
- Utils.memoryBytesToString(getCacheAvailable(host))))
- slaveUsage.put(host, getCacheUsage(host) - size)
- // Do a sanity check to make sure usage is greater than 0.
- val usage = getCacheUsage(host)
- if (usage < 0) {
- logError("Cache usage on %s is negative (%d)".format(host, usage))
+ case MemoryCacheLost(host) =>
+ logInfo("Memory cache lost on " + host)
+ // TODO: Drop host from the memory locations list of all RDDs
+
+ case GetCacheLocations =>
+ logInfo("Asked for current cache locations")
+ reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())})
+
+ case GetCacheStatus =>
+ val status = slaveCapacity.map { case (host,capacity) =>
+ (host, capacity, getCacheUsage(host))
+ }.toSeq
+ reply(status)
+
+ case StopCacheTracker =>
+ reply('OK)
+ exit()
}
- locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
- self.reply(true)
+ }
+ }
+}
- case MemoryCacheLost(host) =>
- logInfo("Memory cache lost on " + host)
- for ((id, locations) <- locs) {
- for (i <- 0 until locations.length) {
- locations(i) = locations(i).filterNot(_ == host)
- }
- }
- self.reply(true)
- case GetCacheLocations =>
- logInfo("Asked for current cache locations")
- self.reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())})
+class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
+ // Tracker actor on the master, or remote reference to it on workers
+ var trackerActor: AbstractActor = null
- case GetCacheStatus =>
- val status = slaveCapacity.map { case (host, capacity) =>
- (host, capacity, getCacheUsage(host))
- }.toSeq
- self.reply(status)
+ val registeredRddIds = new HashSet[Int]
- case StopCacheTracker =>
- logInfo("CacheTrackerActor Server stopped!")
- self.reply(true)
- self.exit()
- }
-}
+ // Stores map results for various splits locally
+ val cache = theCache.newKeySpace()
-class CacheTracker(isMaster: Boolean, blockManager: BlockManager) extends Logging {
- // Tracker actor on the master, or remote reference to it on workers
- val ip: String = System.getProperty("spark.master.host", "localhost")
- val port: Int = System.getProperty("spark.master.port", "7077").toInt
- val aName: String = "CacheTracker"
-
if (isMaster) {
- }
-
- var trackerActor: ActorRef = if (isMaster) {
- val actor = actorOf(new CacheTrackerActor)
- remote.register(aName, actor)
- actor.start()
- logInfo("Registered CacheTrackerActor actor @ " + ip + ":" + port)
- actor
+ val tracker = new CacheTrackerActor
+ tracker.start()
+ trackerActor = tracker
} else {
- remote.actorFor(aName, ip, port)
+ val host = System.getProperty("spark.master.host")
+ val port = System.getProperty("spark.master.port").toInt
+ trackerActor = RemoteActor.select(Node(host, port), 'CacheTracker)
}
- val registeredRddIds = new HashSet[Int]
+ // Report the cache being started.
+ trackerActor !? SlaveCacheStarted(Utils.getHost, cache.getCapacity)
// Remembers which splits are currently being loaded (on worker nodes)
- val loading = new HashSet[String]
+ val loading = new HashSet[(Int, Int)]
// Registers an RDD (on master only)
def registerRDD(rddId: Int, numPartitions: Int) {
@@ -128,33 +138,24 @@ class CacheTracker(isMaster: Boolean, blockManager: BlockManager) extends Loggin
if (!registeredRddIds.contains(rddId)) {
logInfo("Registering RDD ID " + rddId + " with cache")
registeredRddIds += rddId
- (trackerActor ? RegisterRDD(rddId, numPartitions)).as[Any] match {
- case Some(true) =>
- logInfo("CacheTracker registerRDD " + RegisterRDD(rddId, numPartitions) + " successfully.")
- case Some(oops) =>
- logError("CacheTracker registerRDD" + RegisterRDD(rddId, numPartitions) + " failed: " + oops)
- case None =>
- logError("CacheTracker registerRDD None. " + RegisterRDD(rddId, numPartitions))
- throw new SparkException("Internal error: CacheTracker registerRDD None.")
- }
+ trackerActor !? RegisterRDD(rddId, numPartitions)
}
}
}
-
- // For BlockManager.scala only
- def cacheLost(host: String) {
- (trackerActor ? MemoryCacheLost(host)).as[Any] match {
- case Some(true) =>
- logInfo("CacheTracker successfully removed entries on " + host)
- case _ =>
- logError("CacheTracker did not reply to MemoryCacheLost")
+
+ // Get a snapshot of the currently known locations
+ def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
+ (trackerActor !? GetCacheLocations) match {
+ case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]]
+
+ case _ => throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap")
}
}
// Get the usage status of slave caches. Each tuple in the returned sequence
// is in the form of (host name, capacity, usage).
def getCacheStatus(): Seq[(String, Long, Long)] = {
- (trackerActor ? GetCacheStatus) match {
+ (trackerActor !? GetCacheStatus) match {
case h: Seq[(String, Long, Long)] => h.asInstanceOf[Seq[(String, Long, Long)]]
case _ =>
@@ -163,94 +164,75 @@ class CacheTracker(isMaster: Boolean, blockManager: BlockManager) extends Loggin
}
}
- // For BlockManager.scala only
- def notifyTheCacheTrackerFromBlockManager(t: AddedToCache) {
- (trackerActor ? t).as[Any] match {
- case Some(true) =>
- logInfo("CacheTracker notifyTheCacheTrackerFromBlockManager successfully.")
- case Some(oops) =>
- logError("CacheTracker notifyTheCacheTrackerFromBlockManager failed: " + oops)
- case None =>
- logError("CacheTracker notifyTheCacheTrackerFromBlockManager None.")
- }
- }
-
- // Get a snapshot of the currently known locations
- def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
- (trackerActor ? GetCacheLocations).as[Any] match {
- case Some(h: HashMap[_, _]) =>
- h.asInstanceOf[HashMap[Int, Array[List[String]]]]
-
- case _ =>
- throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap")
- }
- }
-
// Gets or computes an RDD split
- def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = {
- val key = "rdd:%d:%d".format(rdd.id, split.index)
- logInfo("Cache key is " + key)
- blockManager.get(key) match {
- case Some(cachedValues) =>
- // Split is in cache, so just return its values
- logInfo("Found partition in cache!")
- return cachedValues.asInstanceOf[Iterator[T]]
-
- case None =>
- // Mark the split as loading (unless someone else marks it first)
+ def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T]): Iterator[T] = {
+ logInfo("Looking for RDD partition %d:%d".format(rdd.id, split.index))
+ val cachedVal = cache.get(rdd.id, split.index)
+ if (cachedVal != null) {
+ // Split is in cache, so just return its values
+ logInfo("Found partition in cache!")
+ return cachedVal.asInstanceOf[Array[T]].iterator
+ } else {
+ // Mark the split as loading (unless someone else marks it first)
+ val key = (rdd.id, split.index)
+ loading.synchronized {
+ while (loading.contains(key)) {
+ // Someone else is loading it; let's wait for them
+ try { loading.wait() } catch { case _ => }
+ }
+ // See whether someone else has successfully loaded it. The main way this would fail
+ // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
+ // partition but we didn't want to make space for it. However, that case is unlikely
+ // because it's unlikely that two threads would work on the same RDD partition. One
+ // downside of the current code is that threads wait serially if this does happen.
+ val cachedVal = cache.get(rdd.id, split.index)
+ if (cachedVal != null) {
+ return cachedVal.asInstanceOf[Array[T]].iterator
+ }
+ // Nobody's loading it and it's not in the cache; let's load it ourselves
+ loading.add(key)
+ }
+ // If we got here, we have to load the split
+ // Tell the master that we're doing so
+
+ // TODO: fetch any remote copy of the split that may be available
+ logInfo("Computing partition " + split)
+ var array: Array[T] = null
+ var putResponse: CachePutResponse = null
+ try {
+ array = rdd.compute(split).toArray(m)
+ putResponse = cache.put(rdd.id, split.index, array)
+ } finally {
+ // Tell other threads that we've finished our attempt to load the key (whether or not
+ // we've actually succeeded to put it in the map)
loading.synchronized {
- if (loading.contains(key)) {
- logInfo("Loading contains " + key + ", waiting...")
- while (loading.contains(key)) {
- try {loading.wait()} catch {case _ =>}
- }
- logInfo("Loading no longer contains " + key + ", so returning cached result")
- // See whether someone else has successfully loaded it. The main way this would fail
- // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
- // partition but we didn't want to make space for it. However, that case is unlikely
- // because it's unlikely that two threads would work on the same RDD partition. One
- // downside of the current code is that threads wait serially if this does happen.
- blockManager.get(key) match {
- case Some(values) =>
- return values.asInstanceOf[Iterator[T]]
- case None =>
- logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
- loading.add(key)
- }
- } else {
- loading.add(key)
- }
+ loading.remove(key)
+ loading.notifyAll()
}
- // If we got here, we have to load the split
- // Tell the master that we're doing so
- //val host = System.getProperty("spark.hostname", Utils.localHostName)
- //val future = trackerActor !! AddedToCache(rdd.id, split.index, host)
- // TODO: fetch any remote copy of the split that may be available
- // TODO: also register a listener for when it unloads
- logInfo("Computing partition " + split)
- try {
- val values = new ArrayBuffer[Any]
- values ++= rdd.compute(split)
- blockManager.put(key, values.iterator, storageLevel, false)
- //future.apply() // Wait for the reply from the cache tracker
- return values.iterator.asInstanceOf[Iterator[T]]
- } finally {
- loading.synchronized {
- loading.remove(key)
- loading.notifyAll()
- }
+ }
+
+ putResponse match {
+ case CachePutSuccess(size) => {
+ // Tell the master that we added the entry. Don't return until it
+ // replies so it can properly schedule future tasks that use this RDD.
+ trackerActor !? AddedToCache(rdd.id, split.index, Utils.getHost, size)
}
+ case _ => null
+ }
+ return array.iterator
}
}
// Called by the Cache to report that an entry has been dropped from it
- def dropEntry(rddId: Int, partition: Int) {
- //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here.
- trackerActor !! DroppedFromCache(rddId, partition, Utils.localHostName())
+ def dropEntry(datasetId: Any, partition: Int) {
+ datasetId match {
+ //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here.
+ case (cache.keySpaceId, rddId: Int) => trackerActor !! DroppedFromCache(rddId, partition, Utils.getHost)
+ }
}
def stop() {
- trackerActor !! StopCacheTracker
+ trackerActor !? StopCacheTracker
registeredRddIds.clear()
trackerActor = null
}
diff --git a/core/src/main/scala/spark/CoGroupedRDD.scala b/core/src/main/scala/spark/CoGroupedRDD.scala
index 3543c8afa8..93f453bc5e 100644
--- a/core/src/main/scala/spark/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/CoGroupedRDD.scala
@@ -22,12 +22,11 @@ class CoGroupAggregator
{ (b1, b2) => b1 ++ b2 })
with Serializable
-class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
+class CoGroupedRDD[K](rdds: Seq[RDD[(_, _)]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging {
val aggr = new CoGroupAggregator
- @transient
override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
for ((rdd, index) <- rdds.zipWithIndex) {
@@ -68,10 +67,9 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit]
- val numRdds = split.deps.size
val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
- map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any]))
+ map.getOrElseUpdate(k, Array.fill(rdds.size)(new ArrayBuffer[Any]))
}
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, itsSplit) => {
diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala
new file mode 100644
index 0000000000..1b4af9d84c
--- /dev/null
+++ b/core/src/main/scala/spark/DAGScheduler.scala
@@ -0,0 +1,374 @@
+package spark
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map}
+
+/**
+ * A task created by the DAG scheduler. Knows its stage ID and map ouput tracker generation.
+ */
+abstract class DAGTask[T](val runId: Int, val stageId: Int) extends Task[T] {
+ val gen = SparkEnv.get.mapOutputTracker.getGeneration
+ override def generation: Option[Long] = Some(gen)
+}
+
+/**
+ * A completion event passed by the underlying task scheduler to the DAG scheduler.
+ */
+case class CompletionEvent(
+ task: DAGTask[_],
+ reason: TaskEndReason,
+ result: Any,
+ accumUpdates: Map[Long, Any])
+
+/**
+ * Various possible reasons why a DAG task ended. The underlying scheduler is supposed to retry
+ * tasks several times for "ephemeral" failures, and only report back failures that require some
+ * old stages to be resubmitted, such as shuffle map fetch failures.
+ */
+sealed trait TaskEndReason
+case object Success extends TaskEndReason
+case class FetchFailed(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
+case class ExceptionFailure(exception: Throwable) extends TaskEndReason
+case class OtherFailure(message: String) extends TaskEndReason
+
+/**
+ * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
+ * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
+ * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
+ * and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
+ */
+private trait DAGScheduler extends Scheduler with Logging {
+ // Must be implemented by subclasses to start running a set of tasks. The subclass should also
+ // attempt to run different sets of tasks in the order given by runId (lower values first).
+ def submitTasks(tasks: Seq[Task[_]], runId: Int): Unit
+
+ // Must be called by subclasses to report task completions or failures.
+ def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]) {
+ lock.synchronized {
+ val dagTask = task.asInstanceOf[DAGTask[_]]
+ eventQueues.get(dagTask.runId) match {
+ case Some(queue) =>
+ queue += CompletionEvent(dagTask, reason, result, accumUpdates)
+ lock.notifyAll()
+ case None =>
+ logInfo("Ignoring completion event for DAG job " + dagTask.runId + " because it's gone")
+ }
+ }
+ }
+
+ // The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
+ // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
+ // as more failure events come in
+ val RESUBMIT_TIMEOUT = 2000L
+
+ // The time, in millis, to wake up between polls of the completion queue in order to potentially
+ // resubmit failed stages
+ val POLL_TIMEOUT = 500L
+
+ private val lock = new Object // Used for access to the entire DAGScheduler
+
+ private val eventQueues = new HashMap[Int, Queue[CompletionEvent]] // Indexed by run ID
+
+ val nextRunId = new AtomicInteger(0)
+
+ val nextStageId = new AtomicInteger(0)
+
+ val idToStage = new HashMap[Int, Stage]
+
+ val shuffleToMapStage = new HashMap[Int, Stage]
+
+ var cacheLocs = new HashMap[Int, Array[List[String]]]
+
+ val env = SparkEnv.get
+ val cacheTracker = env.cacheTracker
+ val mapOutputTracker = env.mapOutputTracker
+
+ def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
+ cacheLocs(rdd.id)
+ }
+
+ def updateCacheLocs() {
+ cacheLocs = cacheTracker.getLocationsSnapshot()
+ }
+
+ def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = {
+ shuffleToMapStage.get(shuf.shuffleId) match {
+ case Some(stage) => stage
+ case None =>
+ val stage = newStage(shuf.rdd, Some(shuf))
+ shuffleToMapStage(shuf.shuffleId) = stage
+ stage
+ }
+ }
+
+ def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = {
+ // Kind of ugly: need to register RDDs with the cache and map output tracker here
+ // since we can't do it in the RDD constructor because # of splits is unknown
+ cacheTracker.registerRDD(rdd.id, rdd.splits.size)
+ if (shuffleDep != None) {
+ mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
+ }
+ val id = nextStageId.getAndIncrement()
+ val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd))
+ idToStage(id) = stage
+ stage
+ }
+
+ def getParentStages(rdd: RDD[_]): List[Stage] = {
+ val parents = new HashSet[Stage]
+ val visited = new HashSet[RDD[_]]
+ def visit(r: RDD[_]) {
+ if (!visited(r)) {
+ visited += r
+ // Kind of ugly: need to register RDDs with the cache here since
+ // we can't do it in its constructor because # of splits is unknown
+ cacheTracker.registerRDD(r.id, r.splits.size)
+ for (dep <- r.dependencies) {
+ dep match {
+ case shufDep: ShuffleDependency[_,_,_] =>
+ parents += getShuffleMapStage(shufDep)
+ case _ =>
+ visit(dep.rdd)
+ }
+ }
+ }
+ }
+ visit(rdd)
+ parents.toList
+ }
+
+ def getMissingParentStages(stage: Stage): List[Stage] = {
+ val missing = new HashSet[Stage]
+ val visited = new HashSet[RDD[_]]
+ def visit(rdd: RDD[_]) {
+ if (!visited(rdd)) {
+ visited += rdd
+ val locs = getCacheLocs(rdd)
+ for (p <- 0 until rdd.splits.size) {
+ if (locs(p) == Nil) {
+ for (dep <- rdd.dependencies) {
+ dep match {
+ case shufDep: ShuffleDependency[_,_,_] =>
+ val stage = getShuffleMapStage(shufDep)
+ if (!stage.isAvailable) {
+ missing += stage
+ }
+ case narrowDep: NarrowDependency[_] =>
+ visit(narrowDep.rdd)
+ }
+ }
+ }
+ }
+ }
+ }
+ visit(stage.rdd)
+ missing.toList
+ }
+
+ override def runJob[T, U](
+ finalRdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ allowLocal: Boolean)
+ (implicit m: ClassManifest[U]): Array[U] = {
+ lock.synchronized {
+ val runId = nextRunId.getAndIncrement()
+
+ val outputParts = partitions.toArray
+ val numOutputParts: Int = partitions.size
+ val finalStage = newStage(finalRdd, None)
+ val results = new Array[U](numOutputParts)
+ val finished = new Array[Boolean](numOutputParts)
+ var numFinished = 0
+
+ val waiting = new HashSet[Stage] // stages we need to run whose parents aren't done
+ val running = new HashSet[Stage] // stages we are running right now
+ val failed = new HashSet[Stage] // stages that must be resubmitted due to fetch failures
+ val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // missing tasks from each stage
+ var lastFetchFailureTime: Long = 0 // used to wait a bit to avoid repeated resubmits
+
+ SparkEnv.set(env)
+
+ updateCacheLocs()
+
+ logInfo("Final stage: " + finalStage)
+ logInfo("Parents of final stage: " + finalStage.parents)
+ logInfo("Missing parents: " + getMissingParentStages(finalStage))
+
+ // Optimization for short actions like first() and take() that can be computed locally
+ // without shipping tasks to the cluster.
+ if (allowLocal && finalStage.parents.size == 0 && numOutputParts == 1) {
+ logInfo("Computing the requested partition locally")
+ val split = finalRdd.splits(outputParts(0))
+ val taskContext = new TaskContext(finalStage.id, outputParts(0), 0)
+ return Array(func(taskContext, finalRdd.iterator(split)))
+ }
+
+ // Register the job ID so that we can get completion events for it
+ eventQueues(runId) = new Queue[CompletionEvent]
+
+ def submitStage(stage: Stage) {
+ if (!waiting(stage) && !running(stage)) {
+ val missing = getMissingParentStages(stage)
+ if (missing == Nil) {
+ logInfo("Submitting " + stage + ", which has no missing parents")
+ submitMissingTasks(stage)
+ running += stage
+ } else {
+ for (parent <- missing) {
+ submitStage(parent)
+ }
+ waiting += stage
+ }
+ }
+ }
+
+ def submitMissingTasks(stage: Stage) {
+ // Get our pending tasks and remember them in our pendingTasks entry
+ val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
+ var tasks = ArrayBuffer[Task[_]]()
+ if (stage == finalStage) {
+ for (id <- 0 until numOutputParts if (!finished(id))) {
+ val part = outputParts(id)
+ val locs = getPreferredLocs(finalRdd, part)
+ tasks += new ResultTask(runId, finalStage.id, finalRdd, func, part, locs, id)
+ }
+ } else {
+ for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
+ val locs = getPreferredLocs(stage.rdd, p)
+ tasks += new ShuffleMapTask(runId, stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
+ }
+ }
+ myPending ++= tasks
+ submitTasks(tasks, runId)
+ }
+
+ submitStage(finalStage)
+
+ while (numFinished != numOutputParts) {
+ val eventOption = waitForEvent(runId, POLL_TIMEOUT)
+ val time = System.currentTimeMillis // TODO: use a pluggable clock for testability
+
+ // If we got an event off the queue, mark the task done or react to a fetch failure
+ if (eventOption != None) {
+ val evt = eventOption.get
+ val stage = idToStage(evt.task.stageId)
+ pendingTasks(stage) -= evt.task
+ if (evt.reason == Success) {
+ // A task ended
+ logInfo("Completed " + evt.task)
+ Accumulators.add(evt.accumUpdates)
+ evt.task match {
+ case rt: ResultTask[_, _] =>
+ results(rt.outputId) = evt.result.asInstanceOf[U]
+ finished(rt.outputId) = true
+ numFinished += 1
+ case smt: ShuffleMapTask =>
+ val stage = idToStage(smt.stageId)
+ stage.addOutputLoc(smt.partition, evt.result.asInstanceOf[String])
+ if (running.contains(stage) && pendingTasks(stage).isEmpty) {
+ logInfo(stage + " finished; looking for newly runnable stages")
+ running -= stage
+ if (stage.shuffleDep != None) {
+ mapOutputTracker.registerMapOutputs(
+ stage.shuffleDep.get.shuffleId,
+ stage.outputLocs.map(_.head).toArray)
+ }
+ updateCacheLocs()
+ val newlyRunnable = new ArrayBuffer[Stage]
+ for (stage <- waiting if getMissingParentStages(stage) == Nil) {
+ newlyRunnable += stage
+ }
+ waiting --= newlyRunnable
+ running ++= newlyRunnable
+ for (stage <- newlyRunnable) {
+ submitMissingTasks(stage)
+ }
+ }
+ }
+ } else {
+ evt.reason match {
+ case FetchFailed(serverUri, shuffleId, mapId, reduceId) =>
+ // Mark the stage that the reducer was in as unrunnable
+ val failedStage = idToStage(evt.task.stageId)
+ running -= failedStage
+ failed += failedStage
+ // TODO: Cancel running tasks in the stage
+ logInfo("Marking " + failedStage + " for resubmision due to a fetch failure")
+ // Mark the map whose fetch failed as broken in the map stage
+ val mapStage = shuffleToMapStage(shuffleId)
+ mapStage.removeOutputLoc(mapId, serverUri)
+ mapOutputTracker.unregisterMapOutput(shuffleId, mapId, serverUri)
+ logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission")
+ failed += mapStage
+ // Remember that a fetch failed now; this is used to resubmit the broken
+ // stages later, after a small wait (to give other tasks the chance to fail)
+ lastFetchFailureTime = time
+ // TODO: If there are a lot of fetch failures on the same node, maybe mark all
+ // outputs on the node as dead.
+ case _ =>
+ // Non-fetch failure -- probably a bug in the job, so bail out
+ throw new SparkException("Task failed: " + evt.task + ", reason: " + evt.reason)
+ // TODO: Cancel all tasks that are still running
+ }
+ }
+ } // end if (evt != null)
+
+ // If fetches have failed recently and we've waited for the right timeout,
+ // resubmit all the failed stages
+ if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
+ logInfo("Resubmitting failed stages")
+ updateCacheLocs()
+ for (stage <- failed) {
+ submitStage(stage)
+ }
+ failed.clear()
+ }
+ }
+
+ eventQueues -= runId
+ return results
+ }
+ }
+
+ def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
+ // If the partition is cached, return the cache locations
+ val cached = getCacheLocs(rdd)(partition)
+ if (cached != Nil) {
+ return cached
+ }
+ // If the RDD has some placement preferences (as is the case for input RDDs), get those
+ val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList
+ if (rddPrefs != Nil) {
+ return rddPrefs
+ }
+ // If the RDD has narrow dependencies, pick the first partition of the first narrow dep
+ // that has any placement preferences. Ideally we would choose based on transfer sizes,
+ // but this will do for now.
+ rdd.dependencies.foreach(_ match {
+ case n: NarrowDependency[_] =>
+ for (inPart <- n.getParents(partition)) {
+ val locs = getPreferredLocs(n.rdd, inPart)
+ if (locs != Nil)
+ return locs;
+ }
+ case _ =>
+ })
+ return Nil
+ }
+
+ // Assumes that lock is held on entrance, but will release it to wait for the next event.
+ def waitForEvent(runId: Int, timeout: Long): Option[CompletionEvent] = {
+ val endTime = System.currentTimeMillis() + timeout // TODO: Use pluggable clock for testing
+ while (eventQueues(runId).isEmpty) {
+ val time = System.currentTimeMillis()
+ if (time >= endTime) {
+ return None
+ } else {
+ lock.wait(endTime - time)
+ }
+ }
+ return Some(eventQueues(runId).dequeue())
+ }
+}
diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala
index c0ff94acc6..d93c84924a 100644
--- a/core/src/main/scala/spark/Dependency.scala
+++ b/core/src/main/scala/spark/Dependency.scala
@@ -8,7 +8,7 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd, false) {
class ShuffleDependency[K, V, C](
val shuffleId: Int,
- @transient rdd: RDD[(K, V)],
+ rdd: RDD[(K, V)],
val aggregator: Aggregator[K, V, C],
val partitioner: Partitioner)
extends Dependency(rdd, true)
diff --git a/core/src/main/scala/spark/DiskSpillingCache.scala b/core/src/main/scala/spark/DiskSpillingCache.scala
new file mode 100644
index 0000000000..e11466eb64
--- /dev/null
+++ b/core/src/main/scala/spark/DiskSpillingCache.scala
@@ -0,0 +1,75 @@
+package spark
+
+import java.io.File
+import java.io.{FileOutputStream,FileInputStream}
+import java.io.IOException
+import java.util.LinkedHashMap
+import java.util.UUID
+
+// TODO: cache into a separate directory using Utils.createTempDir
+// TODO: clean up disk cache afterwards
+class DiskSpillingCache extends BoundedMemoryCache {
+ private val diskMap = new LinkedHashMap[(Any, Int), File](32, 0.75f, true)
+
+ override def get(datasetId: Any, partition: Int): Any = {
+ synchronized {
+ val ser = SparkEnv.get.serializer.newInstance()
+ super.get(datasetId, partition) match {
+ case bytes: Any => // found in memory
+ ser.deserialize(bytes.asInstanceOf[Array[Byte]])
+
+ case _ => diskMap.get((datasetId, partition)) match {
+ case file: Any => // found on disk
+ try {
+ val startTime = System.currentTimeMillis
+ val bytes = new Array[Byte](file.length.toInt)
+ new FileInputStream(file).read(bytes)
+ val timeTaken = System.currentTimeMillis - startTime
+ logInfo("Reading key (%s, %d) of size %d bytes from disk took %d ms".format(
+ datasetId, partition, file.length, timeTaken))
+ super.put(datasetId, partition, bytes)
+ ser.deserialize(bytes.asInstanceOf[Array[Byte]])
+ } catch {
+ case e: IOException =>
+ logWarning("Failed to read key (%s, %d) from disk at %s: %s".format(
+ datasetId, partition, file.getPath(), e.getMessage()))
+ diskMap.remove((datasetId, partition)) // remove dead entry
+ null
+ }
+
+ case _ => // not found
+ null
+ }
+ }
+ }
+ }
+
+ override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
+ var ser = SparkEnv.get.serializer.newInstance()
+ super.put(datasetId, partition, ser.serialize(value))
+ }
+
+ /**
+ * Spill the given entry to disk. Assumes that a lock is held on the
+ * DiskSpillingCache. Assumes that entry.value is a byte array.
+ */
+ override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
+ logInfo("Spilling key (%s, %d) of size %d to make space".format(
+ datasetId, partition, entry.size))
+ val cacheDir = System.getProperty(
+ "spark.diskSpillingCache.cacheDir",
+ System.getProperty("java.io.tmpdir"))
+ val file = new File(cacheDir, "spark-dsc-" + UUID.randomUUID.toString)
+ try {
+ val stream = new FileOutputStream(file)
+ stream.write(entry.value.asInstanceOf[Array[Byte]])
+ stream.close()
+ diskMap.put((datasetId, partition), file)
+ } catch {
+ case e: IOException =>
+ logWarning("Failed to spill key (%s, %d) to disk at %s: %s".format(
+ datasetId, partition, file.getPath(), e.getMessage()))
+ // Do nothing and let the entry be discarded
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/DoubleRDDFunctions.scala b/core/src/main/scala/spark/DoubleRDDFunctions.scala
deleted file mode 100644
index 1fbf66b7de..0000000000
--- a/core/src/main/scala/spark/DoubleRDDFunctions.scala
+++ /dev/null
@@ -1,39 +0,0 @@
-package spark
-
-import spark.partial.BoundedDouble
-import spark.partial.MeanEvaluator
-import spark.partial.PartialResult
-import spark.partial.SumEvaluator
-
-import spark.util.StatCounter
-
-/**
- * Extra functions available on RDDs of Doubles through an implicit conversion.
- */
-class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
- def sum(): Double = {
- self.reduce(_ + _)
- }
-
- def stats(): StatCounter = {
- self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b))
- }
-
- def mean(): Double = stats().mean
-
- def variance(): Double = stats().variance
-
- def stdev(): Double = stats().stdev
-
- def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
- val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
- val evaluator = new MeanEvaluator(self.splits.size, confidence)
- self.context.runApproximateJob(self, processPartition, evaluator, timeout)
- }
-
- def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
- val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
- val evaluator = new SumEvaluator(self.splits.size, confidence)
- self.context.runApproximateJob(self, processPartition, evaluator, timeout)
- }
-}
diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala
index af9eb9c878..c795b6c351 100644
--- a/core/src/main/scala/spark/Executor.scala
+++ b/core/src/main/scala/spark/Executor.scala
@@ -10,10 +10,9 @@ import scala.collection.mutable.ArrayBuffer
import com.google.protobuf.ByteString
import org.apache.mesos._
-import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
+import org.apache.mesos.Protos._
import spark.broadcast._
-import spark.scheduler._
/**
* The Mesos executor for Spark.
@@ -30,9 +29,6 @@ class Executor extends org.apache.mesos.Executor with Logging {
executorInfo: ExecutorInfo,
frameworkInfo: FrameworkInfo,
slaveInfo: SlaveInfo) {
- // Make sure the local hostname we report matches Mesos's name for this host
- Utils.setCustomHostname(slaveInfo.getHostname())
-
// Read spark.* system properties from executor arg
val props = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
for ((key, value) <- props) {
@@ -43,7 +39,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
RemoteActor.classLoader = getClass.getClassLoader
// Initialize Spark environment (using system properties read above)
- env = SparkEnv.createFromSystemProperties(false, false)
+ env = SparkEnv.createFromSystemProperties(false)
SparkEnv.set(env)
// Old stuff that isn't yet using env
Broadcast.initialize(false)
@@ -61,11 +57,11 @@ class Executor extends org.apache.mesos.Executor with Logging {
override def reregistered(d: ExecutorDriver, s: SlaveInfo) {}
- override def launchTask(d: ExecutorDriver, task: MTaskInfo) {
+ override def launchTask(d: ExecutorDriver, task: TaskInfo) {
threadPool.execute(new TaskRunner(task, d))
}
- class TaskRunner(info: MTaskInfo, d: ExecutorDriver)
+ class TaskRunner(info: TaskInfo, d: ExecutorDriver)
extends Runnable {
override def run() = {
val tid = info.getTaskId.getValue
@@ -78,11 +74,11 @@ class Executor extends org.apache.mesos.Executor with Logging {
.setState(TaskState.TASK_RUNNING)
.build())
try {
- SparkEnv.set(env)
- Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear
- val task = ser.deserialize[Task[Any]](info.getData.asReadOnlyByteBuffer, classLoader)
- env.mapOutputTracker.updateGeneration(task.generation)
+ val task = ser.deserialize[Task[Any]](info.getData.toByteArray, classLoader)
+ for (gen <- task.generation) {// Update generation if any is set
+ env.mapOutputTracker.updateGeneration(gen)
+ }
val value = task.run(tid.toInt)
val accumUpdates = Accumulators.values
val result = new TaskResult(value, accumUpdates)
@@ -109,11 +105,9 @@ class Executor extends org.apache.mesos.Executor with Logging {
.setData(ByteString.copyFrom(ser.serialize(reason)))
.build())
- // TODO: Should we exit the whole executor here? On the one hand, the failed task may
- // have left some weird state around depending on when the exception was thrown, but on
- // the other hand, maybe we could detect that when future tasks fail and exit then.
+ // TODO: Handle errors in tasks less dramatically
logError("Exception in task ID " + tid, t)
- //System.exit(1)
+ System.exit(1)
}
}
}
diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala
index 55512f4481..a3c4e7873d 100644
--- a/core/src/main/scala/spark/FetchFailedException.scala
+++ b/core/src/main/scala/spark/FetchFailedException.scala
@@ -1,9 +1,7 @@
package spark
-import spark.storage.BlockManagerId
-
class FetchFailedException(
- val bmAddress: BlockManagerId,
+ val serverUri: String,
val shuffleId: Int,
val mapId: Int,
val reduceId: Int,
@@ -11,10 +9,10 @@ class FetchFailedException(
extends Exception {
override def getMessage(): String =
- "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
+ "Fetch failed: %s %d %d %d".format(serverUri, shuffleId, mapId, reduceId)
override def getCause(): Throwable = cause
def toTaskEndReason: TaskEndReason =
- FetchFailed(bmAddress, shuffleId, mapId, reduceId)
+ FetchFailed(serverUri, shuffleId, mapId, reduceId)
}
diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala
index ec5c33d1df..80f615eeb0 100644
--- a/core/src/main/scala/spark/JavaSerializer.scala
+++ b/core/src/main/scala/spark/JavaSerializer.scala
@@ -1,7 +1,6 @@
package spark
import java.io._
-import java.nio.ByteBuffer
class JavaSerializationStream(out: OutputStream) extends SerializationStream {
val objOut = new ObjectOutputStream(out)
@@ -10,11 +9,10 @@ class JavaSerializationStream(out: OutputStream) extends SerializationStream {
def close() { objOut.close() }
}
-class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
-extends DeserializationStream {
+class JavaDeserializationStream(in: InputStream) extends DeserializationStream {
val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass) =
- Class.forName(desc.getName, false, loader)
+ Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
}
def readObject[T](): T = objIn.readObject().asInstanceOf[T]
@@ -22,36 +20,35 @@ extends DeserializationStream {
}
class JavaSerializerInstance extends SerializerInstance {
- def serialize[T](t: T): ByteBuffer = {
+ def serialize[T](t: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
- val out = serializeStream(bos)
+ val out = outputStream(bos)
out.writeObject(t)
out.close()
- ByteBuffer.wrap(bos.toByteArray)
+ bos.toByteArray
}
- def deserialize[T](bytes: ByteBuffer): T = {
- val bis = new ByteArrayInputStream(bytes.array())
- val in = deserializeStream(bis)
+ def deserialize[T](bytes: Array[Byte]): T = {
+ val bis = new ByteArrayInputStream(bytes)
+ val in = inputStream(bis)
in.readObject().asInstanceOf[T]
}
- def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
- val bis = new ByteArrayInputStream(bytes.array())
- val in = deserializeStream(bis, loader)
- in.readObject().asInstanceOf[T]
+ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
+ val bis = new ByteArrayInputStream(bytes)
+ val ois = new ObjectInputStream(bis) {
+ override def resolveClass(desc: ObjectStreamClass) =
+ Class.forName(desc.getName, false, loader)
+ }
+ return ois.readObject.asInstanceOf[T]
}
- def serializeStream(s: OutputStream): SerializationStream = {
+ def outputStream(s: OutputStream): SerializationStream = {
new JavaSerializationStream(s)
}
- def deserializeStream(s: InputStream): DeserializationStream = {
- new JavaDeserializationStream(s, currentThread.getContextClassLoader)
- }
-
- def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
- new JavaDeserializationStream(s, loader)
+ def inputStream(s: InputStream): DeserializationStream = {
+ new JavaDeserializationStream(s)
}
}
diff --git a/core/src/main/scala/spark/Job.scala b/core/src/main/scala/spark/Job.scala
new file mode 100644
index 0000000000..b7b0361c62
--- /dev/null
+++ b/core/src/main/scala/spark/Job.scala
@@ -0,0 +1,16 @@
+package spark
+
+import org.apache.mesos._
+import org.apache.mesos.Protos._
+
+/**
+ * Class representing a parallel job in MesosScheduler. Schedules the job by implementing various
+ * callbacks.
+ */
+abstract class Job(val runId: Int, val jobId: Int) {
+ def slaveOffer(s: Offer, availableCpus: Double): Option[TaskInfo]
+
+ def statusUpdate(t: TaskStatus): Unit
+
+ def error(message: String): Unit
+}
diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala
index 65d0532bd5..5693613d6d 100644
--- a/core/src/main/scala/spark/KryoSerializer.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -12,8 +12,6 @@ import com.esotericsoftware.kryo.{Serializer => KSerializer}
import com.esotericsoftware.kryo.serialize.ClassSerializer
import de.javakaffee.kryoserializers.KryoReflectionFactorySupport
-import spark.storage._
-
/**
* Zig-zag encoder used to write object sizes to serialization streams.
* Based on Kryo's integer encoder.
@@ -66,90 +64,57 @@ object ZigZag {
}
}
-class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream)
+class KryoSerializationStream(kryo: Kryo, buf: ByteBuffer, out: OutputStream)
extends SerializationStream {
val channel = Channels.newChannel(out)
def writeObject[T](t: T) {
- kryo.writeClassAndObject(threadBuffer, t)
- ZigZag.writeInt(threadBuffer.position(), out)
- threadBuffer.flip()
- channel.write(threadBuffer)
- threadBuffer.clear()
+ kryo.writeClassAndObject(buf, t)
+ ZigZag.writeInt(buf.position(), out)
+ buf.flip()
+ channel.write(buf)
+ buf.clear()
}
def flush() { out.flush() }
def close() { out.close() }
}
-class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream)
+class KryoDeserializationStream(buf: ObjectBuffer, in: InputStream)
extends DeserializationStream {
def readObject[T](): T = {
val len = ZigZag.readInt(in)
- objectBuffer.readClassAndObject(in, len).asInstanceOf[T]
+ buf.readClassAndObject(in, len).asInstanceOf[T]
}
def close() { in.close() }
}
class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
- val kryo = ks.kryo
- val threadBuffer = ks.threadBuffer.get()
- val objectBuffer = ks.objectBuffer.get()
+ val buf = ks.threadBuf.get()
- def serialize[T](t: T): ByteBuffer = {
- // Write it to our thread-local scratch buffer first to figure out the size, then return a new
- // ByteBuffer of the appropriate size
- threadBuffer.clear()
- kryo.writeClassAndObject(threadBuffer, t)
- val newBuf = ByteBuffer.allocate(threadBuffer.position)
- threadBuffer.flip()
- newBuf.put(threadBuffer)
- newBuf.flip()
- newBuf
+ def serialize[T](t: T): Array[Byte] = {
+ buf.writeClassAndObject(t)
}
- def deserialize[T](bytes: ByteBuffer): T = {
- kryo.readClassAndObject(bytes).asInstanceOf[T]
+ def deserialize[T](bytes: Array[Byte]): T = {
+ buf.readClassAndObject(bytes).asInstanceOf[T]
}
- def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
- val oldClassLoader = kryo.getClassLoader
- kryo.setClassLoader(loader)
- val obj = kryo.readClassAndObject(bytes).asInstanceOf[T]
- kryo.setClassLoader(oldClassLoader)
+ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
+ val oldClassLoader = ks.kryo.getClassLoader
+ ks.kryo.setClassLoader(loader)
+ val obj = buf.readClassAndObject(bytes).asInstanceOf[T]
+ ks.kryo.setClassLoader(oldClassLoader)
obj
}
- def serializeStream(s: OutputStream): SerializationStream = {
- threadBuffer.clear()
- new KryoSerializationStream(kryo, threadBuffer, s)
- }
-
- def deserializeStream(s: InputStream): DeserializationStream = {
- new KryoDeserializationStream(objectBuffer, s)
+ def outputStream(s: OutputStream): SerializationStream = {
+ new KryoSerializationStream(ks.kryo, ks.threadByteBuf.get(), s)
}
- override def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
- threadBuffer.clear()
- while (iterator.hasNext) {
- val element = iterator.next()
- // TODO: Do we also want to write the object's size? Doesn't seem necessary.
- kryo.writeClassAndObject(threadBuffer, element)
- }
- val newBuf = ByteBuffer.allocate(threadBuffer.position)
- threadBuffer.flip()
- newBuf.put(threadBuffer)
- newBuf.flip()
- newBuf
- }
-
- override def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
- buffer.rewind()
- new Iterator[Any] {
- override def hasNext: Boolean = buffer.remaining > 0
- override def next(): Any = kryo.readClassAndObject(buffer)
- }
+ def inputStream(s: InputStream): DeserializationStream = {
+ new KryoDeserializationStream(buf, s)
}
}
@@ -161,17 +126,20 @@ trait KryoRegistrator {
class KryoSerializer extends Serializer with Logging {
val kryo = createKryo()
- val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024
+ val bufferSize =
+ System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
- val objectBuffer = new ThreadLocal[ObjectBuffer] {
+ val threadBuf = new ThreadLocal[ObjectBuffer] {
override def initialValue = new ObjectBuffer(kryo, bufferSize)
}
- val threadBuffer = new ThreadLocal[ByteBuffer] {
+ val threadByteBuf = new ThreadLocal[ByteBuffer] {
override def initialValue = ByteBuffer.allocate(bufferSize)
}
def createKryo(): Kryo = {
+ // This is used so we can serialize/deserialize objects without a zero-arg
+ // constructor.
val kryo = new KryoReflectionFactorySupport()
// Register some commonly used classes
@@ -180,20 +148,14 @@ class KryoSerializer extends Serializer with Logging {
Array(1), Array(1.0), Array(1.0f), Array(1L), Array(""), Array(("", "")),
Array(new java.lang.Object), Array(1.toByte), Array(true), Array('c'),
// Specialized Tuple2s
- ("", ""), ("", 1), (1, 1), (1.0, 1.0), (1L, 1L),
+ ("", ""), (1, 1), (1.0, 1.0), (1L, 1L),
(1, 1.0), (1.0, 1), (1L, 1.0), (1.0, 1L), (1, 1L), (1L, 1),
// Scala collections
List(1), mutable.ArrayBuffer(1),
// Options and Either
Some(1), Left(1), Right(1),
// Higher-dimensional tuples
- (1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1),
- None,
- ByteBuffer.allocate(1),
- StorageLevel.MEMORY_ONLY_DESER,
- PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY_DESER),
- GotBlock("1", ByteBuffer.allocate(1)),
- GetBlock("1")
+ (1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1)
)
for (obj <- toRegister) {
kryo.register(obj.getClass)
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/LocalScheduler.scala
index 8339c0ae90..3910c7b09e 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/LocalScheduler.scala
@@ -1,21 +1,16 @@
-package spark.scheduler.local
+package spark
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger
-import spark._
-import spark.scheduler._
-
/**
- * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
- * the scheduler also allows each task to fail up to maxFailures times, which is useful for
- * testing fault recovery.
+ * A simple Scheduler implementation that runs tasks locally in a thread pool. Optionally the
+ * scheduler also allows each task to fail up to maxFailures times, which is useful for testing
+ * fault recovery.
*/
-class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with Logging {
+private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGScheduler with Logging {
var attemptId = new AtomicInteger(0)
var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
- val env = SparkEnv.get
- var listener: TaskSchedulerListener = null
// TODO: Need to take into account stage priority in scheduling
@@ -23,12 +18,7 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
override def waitForRegister() {}
- override def setListener(listener: TaskSchedulerListener) {
- this.listener = listener
- }
-
- override def submitTasks(taskSet: TaskSet) {
- val tasks = taskSet.tasks
+ override def submitTasks(tasks: Seq[Task[_]], runId: Int) {
val failCount = new Array[Int](tasks.size)
def submitTask(task: Task[_], idInJob: Int) {
@@ -48,14 +38,23 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
Accumulators.clear
- val bytes = Utils.serialize(task)
- logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes")
- val deserializedTask = Utils.deserialize[Task[_]](
- bytes, Thread.currentThread.getContextClassLoader)
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val startTime = System.currentTimeMillis
+ val bytes = ser.serialize(task)
+ val timeTaken = System.currentTimeMillis - startTime
+ logInfo("Size of task %d is %d bytes and took %d ms to serialize".format(
+ idInJob, bytes.size, timeTaken))
+ val deserializedTask = ser.deserialize[Task[_]](bytes, currentThread.getContextClassLoader)
val result: Any = deserializedTask.run(attemptId)
+
+ // Serialize and deserialize the result to emulate what the mesos
+ // executor does. This is useful to catch serialization errors early
+ // on in development (so when users move their local Spark programs
+ // to the cluster, they don't get surprised by serialization errors).
+ val resultToReturn = ser.deserialize[Any](ser.serialize(result))
val accumUpdates = Accumulators.values
logInfo("Finished task " + idInJob)
- listener.taskEnded(task, Success, result, accumUpdates)
+ taskEnded(task, Success, resultToReturn, accumUpdates)
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
@@ -65,7 +64,7 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
submitTask(task, idInJob)
} else {
// TODO: Do something nicer here to return all the way to the user
- listener.taskEnded(task, new ExceptionFailure(t), null, null)
+ taskEnded(task, new ExceptionFailure(t), null, null)
}
}
}
diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala
index 54bd57f6d3..0d11ab9cbd 100644
--- a/core/src/main/scala/spark/Logging.scala
+++ b/core/src/main/scala/spark/Logging.scala
@@ -28,11 +28,9 @@ trait Logging {
}
// Log methods that take only a String
- def logInfo(msg: => String) = if (log.isInfoEnabled /*&& msg.contains("job finished in")*/) log.info(msg)
+ def logInfo(msg: => String) = if (log.isInfoEnabled) log.info(msg)
def logDebug(msg: => String) = if (log.isDebugEnabled) log.debug(msg)
-
- def logTrace(msg: => String) = if (log.isTraceEnabled) log.trace(msg)
def logWarning(msg: => String) = if (log.isWarnEnabled) log.warn(msg)
@@ -45,9 +43,6 @@ trait Logging {
def logDebug(msg: => String, throwable: Throwable) =
if (log.isDebugEnabled) log.debug(msg)
- def logTrace(msg: => String, throwable: Throwable) =
- if (log.isTraceEnabled) log.trace(msg)
-
def logWarning(msg: => String, throwable: Throwable) =
if (log.isWarnEnabled) log.warn(msg, throwable)
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index d938a6eb62..a934c5a02f 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -2,80 +2,80 @@ package spark
import java.util.concurrent.ConcurrentHashMap
-import akka.actor._
-import akka.actor.Actor
-import akka.actor.Actor._
-import akka.util.duration._
-
+import scala.actors._
+import scala.actors.Actor._
+import scala.actors.remote._
import scala.collection.mutable.HashSet
-import spark.storage.BlockManagerId
-
sealed trait MapOutputTrackerMessage
case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage
case object StopMapOutputTracker extends MapOutputTrackerMessage
-class MapOutputTrackerActor(bmAddresses: ConcurrentHashMap[Int, Array[BlockManagerId]])
-extends Actor with Logging {
- def receive = {
- case GetMapOutputLocations(shuffleId: Int) =>
- logInfo("Asked to get map output locations for shuffle " + shuffleId)
- self.reply(bmAddresses.get(shuffleId))
-
- case StopMapOutputTracker =>
- logInfo("MapOutputTrackerActor stopped!")
- self.reply(true)
- self.exit()
+class MapOutputTrackerActor(serverUris: ConcurrentHashMap[Int, Array[String]])
+extends DaemonActor with Logging {
+ def act() {
+ val port = System.getProperty("spark.master.port").toInt
+ RemoteActor.alive(port)
+ RemoteActor.register('MapOutputTracker, self)
+ logInfo("Registered actor on port " + port)
+
+ loop {
+ react {
+ case GetMapOutputLocations(shuffleId: Int) =>
+ logInfo("Asked to get map output locations for shuffle " + shuffleId)
+ reply(serverUris.get(shuffleId))
+
+ case StopMapOutputTracker =>
+ reply('OK)
+ exit()
+ }
+ }
}
}
class MapOutputTracker(isMaster: Boolean) extends Logging {
- val ip: String = System.getProperty("spark.master.host", "localhost")
- val port: Int = System.getProperty("spark.master.port", "7077").toInt
- val aName: String = "MapOutputTracker"
+ var trackerActor: AbstractActor = null
- private var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
+ private var serverUris = new ConcurrentHashMap[Int, Array[String]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
private var generation: Long = 0
private var generationLock = new java.lang.Object
-
- var trackerActor: ActorRef = if (isMaster) {
- val actor = actorOf(new MapOutputTrackerActor(bmAddresses))
- remote.register(aName, actor)
- logInfo("Registered MapOutputTrackerActor actor @ " + ip + ":" + port)
- actor
+
+ if (isMaster) {
+ val tracker = new MapOutputTrackerActor(serverUris)
+ tracker.start()
+ trackerActor = tracker
} else {
- remote.actorFor(aName, ip, port)
+ val host = System.getProperty("spark.master.host")
+ val port = System.getProperty("spark.master.port").toInt
+ trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker)
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (bmAddresses.get(shuffleId) != null) {
+ if (serverUris.get(shuffleId) != null) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
- bmAddresses.put(shuffleId, new Array[BlockManagerId](numMaps))
+ serverUris.put(shuffleId, new Array[String](numMaps))
}
- def registerMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var array = bmAddresses.get(shuffleId)
+ def registerMapOutput(shuffleId: Int, mapId: Int, serverUri: String) {
+ var array = serverUris.get(shuffleId)
array.synchronized {
- array(mapId) = bmAddress
+ array(mapId) = serverUri
}
}
- def registerMapOutputs(shuffleId: Int, locs: Array[BlockManagerId], changeGeneration: Boolean = false) {
- bmAddresses.put(shuffleId, Array[BlockManagerId]() ++ locs)
- if (changeGeneration) {
- incrementGeneration()
- }
+ def registerMapOutputs(shuffleId: Int, locs: Array[String]) {
+ serverUris.put(shuffleId, Array[String]() ++ locs)
}
- def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var array = bmAddresses.get(shuffleId)
+ def unregisterMapOutput(shuffleId: Int, mapId: Int, serverUri: String) {
+ var array = serverUris.get(shuffleId)
if (array != null) {
array.synchronized {
- if (array(mapId) == bmAddress) {
+ if (array(mapId) == serverUri) {
array(mapId) = null
}
}
@@ -89,10 +89,10 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
val fetching = new HashSet[Int]
// Called on possibly remote nodes to get the server URIs for a given shuffle
- def getServerAddresses(shuffleId: Int): Array[BlockManagerId] = {
- val locs = bmAddresses.get(shuffleId)
+ def getServerUris(shuffleId: Int): Array[String] = {
+ val locs = serverUris.get(shuffleId)
if (locs == null) {
- logInfo("Don't have map outputs for shuffe " + shuffleId + ", fetching them")
+ logInfo("Don't have map outputs for " + shuffleId + ", fetching them")
fetching.synchronized {
if (fetching.contains(shuffleId)) {
// Someone else is fetching it; wait for them to be done
@@ -103,17 +103,15 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
case _ =>
}
}
- return bmAddresses.get(shuffleId)
+ return serverUris.get(shuffleId)
} else {
fetching += shuffleId
}
}
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
- val fetched = (trackerActor ? GetMapOutputLocations(shuffleId)).as[Array[BlockManagerId]].get
-
- logInfo("Got the output locations")
- bmAddresses.put(shuffleId, fetched)
+ val fetched = (trackerActor !? GetMapOutputLocations(shuffleId)).asInstanceOf[Array[String]]
+ serverUris.put(shuffleId, fetched)
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
@@ -123,10 +121,14 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
return locs
}
}
+
+ def getMapOutputUri(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int): String = {
+ "%s/shuffle/%s/%s/%s".format(serverUri, shuffleId, mapId, reduceId)
+ }
def stop() {
- trackerActor !! StopMapOutputTracker
- bmAddresses.clear()
+ trackerActor !? StopMapOutputTracker
+ serverUris.clear()
trackerActor = null
}
@@ -151,7 +153,7 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
- bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
+ serverUris = new ConcurrentHashMap[Int, Array[String]]
generation = newGen
}
}
diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala b/core/src/main/scala/spark/MesosScheduler.scala
index f72618c03f..a7711e0d35 100644
--- a/core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala
+++ b/core/src/main/scala/spark/MesosScheduler.scala
@@ -1,4 +1,4 @@
-package spark.scheduler.mesos
+package spark
import java.io.{File, FileInputStream, FileOutputStream}
import java.util.{ArrayList => JArrayList}
@@ -17,23 +17,20 @@ import com.google.protobuf.ByteString
import org.apache.mesos.{Scheduler => MScheduler}
import org.apache.mesos._
-import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
-
-import spark._
-import spark.scheduler._
+import org.apache.mesos.Protos._
/**
- * The main TaskScheduler implementation, which runs tasks on Mesos. Clients should first call
- * start(), then submit task sets through the runTasks method.
+ * The main Scheduler implementation, which runs jobs on Mesos. Clients should first call start(),
+ * then submit tasks through the runTasks method.
*/
-class MesosScheduler(
+private class MesosScheduler(
sc: SparkContext,
master: String,
frameworkName: String)
- extends TaskScheduler
- with MScheduler
+ extends MScheduler
+ with DAGScheduler
with Logging {
-
+
// Environment variables to pass to our executors
val ENV_VARS_TO_SEND_TO_EXECUTORS = Array(
"SPARK_MEM",
@@ -52,60 +49,55 @@ class MesosScheduler(
}
}
- // How often to check for speculative tasks
- val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
-
// Lock used to wait for scheduler to be registered
- var isRegistered = false
- val registeredLock = new Object()
+ private var isRegistered = false
+ private val registeredLock = new Object()
- val activeTaskSets = new HashMap[String, TaskSetManager]
- var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
+ private val activeJobs = new HashMap[Int, Job]
+ private var activeJobsQueue = new ArrayBuffer[Job]
- val taskIdToTaskSetId = new HashMap[String, String]
- val taskIdToSlaveId = new HashMap[String, String]
- val taskSetTaskIds = new HashMap[String, HashSet[String]]
+ private val taskIdToJobId = new HashMap[String, Int]
+ private val taskIdToSlaveId = new HashMap[String, String]
+ private val jobTasks = new HashMap[Int, HashSet[String]]
- // Incrementing Mesos task IDs
- var nextTaskId = 0
+ // Incrementing job and task IDs
+ private var nextJobId = 0
+ private var nextTaskId = 0
// Driver for talking to Mesos
var driver: SchedulerDriver = null
- // Which hosts in the cluster are alive (contains hostnames)
- val hostsAlive = new HashSet[String]
-
- // Which slave IDs we have executors on
- val slaveIdsWithExecutors = new HashSet[String]
-
- val slaveIdToHost = new HashMap[String, String]
+ // Which nodes we have executors on
+ private val slavesWithExecutors = new HashSet[String]
// JAR server, if any JARs were added by the user to the SparkContext
var jarServer: HttpServer = null
// URIs of JARs to pass to executor
var jarUris: String = ""
-
+
// Create an ExecutorInfo for our tasks
val executorInfo = createExecutorInfo()
- // Listener object to pass upcalls into
- var listener: TaskSchedulerListener = null
-
- val mapOutputTracker = SparkEnv.get.mapOutputTracker
-
- override def setListener(listener: TaskSchedulerListener) {
- this.listener = listener
+ // Sorts jobs in reverse order of run ID for use in our priority queue (so lower IDs run first)
+ private val jobOrdering = new Ordering[Job] {
+ override def compare(j1: Job, j2: Job): Int = j2.runId - j1.runId
+ }
+
+ def newJobId(): Int = this.synchronized {
+ val id = nextJobId
+ nextJobId += 1
+ return id
}
def newTaskId(): TaskID = {
- val id = TaskID.newBuilder().setValue("" + nextTaskId).build()
- nextTaskId += 1
- return id
+ val id = "" + nextTaskId;
+ nextTaskId += 1;
+ return TaskID.newBuilder().setValue(id).build()
}
override def start() {
- new Thread("MesosScheduler driver") {
+ new Thread("Spark scheduler") {
setDaemon(true)
override def run {
val sched = MesosScheduler.this
@@ -118,27 +110,12 @@ class MesosScheduler(
case e: Exception => logError("driver.run() failed", e)
}
}
- }.start()
- if (System.getProperty("spark.speculation", "false") == "true") {
- new Thread("MesosScheduler speculation check") {
- setDaemon(true)
- override def run {
- waitForRegister()
- while (true) {
- try {
- Thread.sleep(SPECULATION_INTERVAL)
- } catch { case e: InterruptedException => {} }
- checkSpeculatableTasks()
- }
- }
- }.start()
- }
+ }.start
}
def createExecutorInfo(): ExecutorInfo = {
val sparkHome = sc.getSparkHome match {
- case Some(path) =>
- path
+ case Some(path) => path
case None =>
throw new SparkException("Spark home is not set; set it through the spark.home system " +
"property, the SPARK_HOME environment variable or the SparkContext constructor")
@@ -174,26 +151,27 @@ class MesosScheduler(
.build()
}
- def submitTasks(taskSet: TaskSet) {
- val tasks = taskSet.tasks
- logInfo("Adding task set " + taskSet.id + " with " + tasks.size + " tasks")
+ def submitTasks(tasks: Seq[Task[_]], runId: Int) {
+ logInfo("Got a job with " + tasks.size + " tasks")
waitForRegister()
this.synchronized {
- val manager = new TaskSetManager(this, taskSet)
- activeTaskSets(taskSet.id) = manager
- activeTaskSetsQueue += manager
- taskSetTaskIds(taskSet.id) = new HashSet()
+ val jobId = newJobId()
+ val myJob = new SimpleJob(this, tasks, runId, jobId)
+ activeJobs(jobId) = myJob
+ activeJobsQueue += myJob
+ logInfo("Adding job with ID " + jobId)
+ jobTasks(jobId) = HashSet.empty[String]
}
- reviveOffers();
+ driver.reviveOffers();
}
- def taskSetFinished(manager: TaskSetManager) {
+ def jobFinished(job: Job) {
this.synchronized {
- activeTaskSets -= manager.taskSet.id
- activeTaskSetsQueue -= manager
- taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
- taskIdToSlaveId --= taskSetTaskIds(manager.taskSet.id)
- taskSetTaskIds.remove(manager.taskSet.id)
+ activeJobs -= job.jobId
+ activeJobsQueue -= job
+ taskIdToJobId --= jobTasks(job.jobId)
+ taskIdToSlaveId --= jobTasks(job.jobId)
+ jobTasks.remove(job.jobId)
}
}
@@ -218,40 +196,33 @@ class MesosScheduler(
override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {}
/**
- * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets
- * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that
- * tasks are balanced across the cluster.
+ * Method called by Mesos to offer resources on slaves. We resond by asking our active jobs for
+ * tasks in FIFO order. We fill each node with tasks in a round-robin manner so that tasks are
+ * balanced across the cluster.
*/
override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
synchronized {
- // Mark each slave as alive and remember its hostname
- for (o <- offers) {
- slaveIdToHost(o.getSlaveId.getValue) = o.getHostname
- hostsAlive += o.getHostname
- }
- // Build a list of tasks to assign to each slave
- val tasks = offers.map(o => new JArrayList[MTaskInfo])
+ val tasks = offers.map(o => new JArrayList[TaskInfo])
val availableCpus = offers.map(o => getResource(o.getResourcesList(), "cpus"))
val enoughMem = offers.map(o => {
val mem = getResource(o.getResourcesList(), "mem")
val slaveId = o.getSlaveId.getValue
- mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId)
+ mem >= EXECUTOR_MEMORY || slavesWithExecutors.contains(slaveId)
})
var launchedTask = false
- for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
+ for (job <- activeJobsQueue.sorted(jobOrdering)) {
do {
launchedTask = false
for (i <- 0 until offers.size if enoughMem(i)) {
- val sid = offers(i).getSlaveId.getValue
- val host = offers(i).getHostname
- manager.slaveOffer(sid, host, availableCpus(i)) match {
+ job.slaveOffer(offers(i), availableCpus(i)) match {
case Some(task) =>
tasks(i).add(task)
val tid = task.getTaskId.getValue
- taskIdToTaskSetId(tid) = manager.taskSet.id
- taskSetTaskIds(manager.taskSet.id) += tid
+ val sid = offers(i).getSlaveId.getValue
+ taskIdToJobId(tid) = job.jobId
+ jobTasks(job.jobId) += tid
taskIdToSlaveId(tid) = sid
- slaveIdsWithExecutors += sid
+ slavesWithExecutors += sid
availableCpus(i) -= getResource(task.getResourcesList(), "cpus")
launchedTask = true
@@ -285,74 +256,53 @@ class MesosScheduler(
}
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
- val tid = status.getTaskId.getValue
- var taskSetToUpdate: Option[TaskSetManager] = None
- var failedHost: Option[String] = None
- var taskFailed = false
+ var jobToUpdate: Option[Job] = None
synchronized {
try {
- if (status.getState == TaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) {
+ val tid = status.getTaskId.getValue
+ if (status.getState == TaskState.TASK_LOST
+ && taskIdToSlaveId.contains(tid)) {
// We lost the executor on this slave, so remember that it's gone
- val slaveId = taskIdToSlaveId(tid)
- val host = slaveIdToHost(slaveId)
- if (hostsAlive.contains(host)) {
- slaveIdsWithExecutors -= slaveId
- hostsAlive -= host
- activeTaskSetsQueue.foreach(_.hostLost(host))
- failedHost = Some(host)
- }
+ slavesWithExecutors -= taskIdToSlaveId(tid)
}
- taskIdToTaskSetId.get(tid) match {
- case Some(taskSetId) =>
- if (activeTaskSets.contains(taskSetId)) {
- //activeTaskSets(taskSetId).statusUpdate(status)
- taskSetToUpdate = Some(activeTaskSets(taskSetId))
+ taskIdToJobId.get(tid) match {
+ case Some(jobId) =>
+ if (activeJobs.contains(jobId)) {
+ jobToUpdate = Some(activeJobs(jobId))
}
if (isFinished(status.getState)) {
- taskIdToTaskSetId.remove(tid)
- if (taskSetTaskIds.contains(taskSetId)) {
- taskSetTaskIds(taskSetId) -= tid
+ taskIdToJobId.remove(tid)
+ if (jobTasks.contains(jobId)) {
+ jobTasks(jobId) -= tid
}
taskIdToSlaveId.remove(tid)
}
- if (status.getState == TaskState.TASK_FAILED) {
- taskFailed = true
- }
case None =>
- logInfo("Ignoring update from TID " + tid + " because its task set is gone")
+ logInfo("Ignoring update from TID " + tid + " because its job is gone")
}
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
}
}
- // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock
- if (taskSetToUpdate != None) {
- taskSetToUpdate.get.statusUpdate(status)
- }
- if (failedHost != None) {
- listener.hostLost(failedHost.get)
- reviveOffers();
- }
- if (taskFailed) {
- // Also revive offers if a task had failed for some reason other than host lost
- reviveOffers()
+ for (j <- jobToUpdate) {
+ j.statusUpdate(status)
}
}
override def error(d: SchedulerDriver, message: String) {
logError("Mesos error: " + message)
synchronized {
- if (activeTaskSets.size > 0) {
- // Have each task set throw a SparkException with the error
- for ((taskSetId, manager) <- activeTaskSets) {
+ if (activeJobs.size > 0) {
+ // Have each job throw a SparkException with the error
+ for ((jobId, activeJob) <- activeJobs) {
try {
- manager.error(message)
+ activeJob.error(message)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
}
} else {
- // No task sets are active but we still got an error. Just exit since this
+ // No jobs are active but we still got an error. Just exit since this
// must mean the error is during registration.
// It might be good to do something smarter here in the future.
System.exit(1)
@@ -423,68 +373,41 @@ class MesosScheduler(
return Utils.serialize(props.toArray)
}
- override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
+ override def frameworkMessage(
+ d: SchedulerDriver,
+ e: ExecutorID,
+ s: SlaveID,
+ b: Array[Byte]) {}
override def slaveLost(d: SchedulerDriver, s: SlaveID) {
- var failedHost: Option[String] = None
- synchronized {
- val slaveId = s.getValue
- val host = slaveIdToHost(slaveId)
- if (hostsAlive.contains(host)) {
- slaveIdsWithExecutors -= slaveId
- hostsAlive -= host
- activeTaskSetsQueue.foreach(_.hostLost(host))
- failedHost = Some(host)
- }
- }
- if (failedHost != None) {
- listener.hostLost(failedHost.get)
- reviveOffers();
- }
+ slavesWithExecutors.remove(s.getValue)
}
override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) {
- logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue))
- slaveLost(d, s)
+ slavesWithExecutors.remove(s.getValue)
}
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
-
- // Check for speculatable tasks in all our active jobs.
- def checkSpeculatableTasks() {
- var shouldRevive = false
- synchronized {
- for (ts <- activeTaskSetsQueue) {
- shouldRevive |= ts.checkSpeculatableTasks()
- }
- }
- if (shouldRevive) {
- reviveOffers()
- }
- }
-
- def reviveOffers() {
- driver.reviveOffers()
- }
}
object MesosScheduler {
/**
- * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
- * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM
+ * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
+ * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM
* environment variable.
*/
def memoryStringToMb(str: String): Int = {
val lower = str.toLowerCase
if (lower.endsWith("k")) {
- (lower.substring(0, lower.length-1).toLong / 1024).toInt
+ (lower.substring(0, lower.length - 1).toLong / 1024).toInt
} else if (lower.endsWith("m")) {
- lower.substring(0, lower.length-1).toInt
+ lower.substring(0, lower.length - 1).toInt
} else if (lower.endsWith("g")) {
- lower.substring(0, lower.length-1).toInt * 1024
+ lower.substring(0, lower.length - 1).toInt * 1024
} else if (lower.endsWith("t")) {
- lower.substring(0, lower.length-1).toInt * 1024 * 1024
- } else {// no suffix, so it's just a number in bytes
+ lower.substring(0, lower.length - 1).toInt * 1024 * 1024
+ } else {
+ // no suffix, so it's just a number in bytes
(lower.toLong / 1024 / 1024).toInt
}
}
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 270447712b..e880f9872f 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -4,14 +4,14 @@ import java.io.EOFException
import java.net.URL
import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong
-import java.util.{HashMap => JHashMap}
+import java.util.HashSet
+import java.util.Random
import java.util.Date
import java.text.SimpleDateFormat
-import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.Map
import scala.collection.mutable.HashMap
-import scala.collection.JavaConversions._
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.BytesWritable
@@ -34,9 +34,7 @@ import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob}
import org.apache.hadoop.mapreduce.TaskAttemptID
import org.apache.hadoop.mapreduce.TaskAttemptContext
-import spark.SparkContext._
-import spark.partial.BoundedDouble
-import spark.partial.PartialResult
+import SparkContext._
/**
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
@@ -45,6 +43,19 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
self: RDD[(K, V)])
extends Logging
with Serializable {
+
+ def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = {
+ def mergeMaps(m1: HashMap[K, V], m2: HashMap[K, V]): HashMap[K, V] = {
+ for ((k, v) <- m2) {
+ m1.get(k) match {
+ case None => m1(k) = v
+ case Some(w) => m1(k) = func(w, v)
+ }
+ }
+ return m1
+ }
+ self.map(pair => HashMap(pair)).reduce(mergeMaps)
+ }
def combineByKey[C](createCombiner: V => C,
mergeValue: (C, V) => C,
@@ -64,39 +75,6 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = {
combineByKey[V]((v: V) => v, func, func, partitioner)
}
-
- def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = {
- def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = {
- val map = new JHashMap[K, V]
- for ((k, v) <- iter) {
- val old = map.get(k)
- map.put(k, if (old == null) v else func(old, v))
- }
- Iterator(map)
- }
-
- def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = {
- for ((k, v) <- m2) {
- val old = m1.get(k)
- m1.put(k, if (old == null) v else func(old, v))
- }
- return m1
- }
-
- self.mapPartitions(reducePartition).reduce(mergeMaps)
- }
-
- // Alias for backwards compatibility
- def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func)
-
- // TODO: This should probably be a distributed version
- def countByKey(): Map[K, Long] = self.map(_._1).countByValue()
-
- // TODO: This should probably be a distributed version
- def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
- : PartialResult[Map[K, BoundedDouble]] = {
- self.map(_._1).countByValueApprox(timeout, confidence)
- }
def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = {
reduceByKey(new HashPartitioner(numSplits), func)
diff --git a/core/src/main/scala/spark/ParallelShuffleFetcher.scala b/core/src/main/scala/spark/ParallelShuffleFetcher.scala
new file mode 100644
index 0000000000..19eb288e84
--- /dev/null
+++ b/core/src/main/scala/spark/ParallelShuffleFetcher.scala
@@ -0,0 +1,119 @@
+package spark
+
+import java.io.ByteArrayInputStream
+import java.io.EOFException
+import java.net.URL
+import java.util.concurrent.LinkedBlockingQueue
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.atomic.AtomicReference
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+
+import it.unimi.dsi.fastutil.io.FastBufferedInputStream
+
+
+class ParallelShuffleFetcher extends ShuffleFetcher with Logging {
+ val parallelFetches = System.getProperty("spark.parallel.fetches", "3").toInt
+
+ def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
+ logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
+
+ // Figure out a list of input IDs (mapper IDs) for each server
+ val ser = SparkEnv.get.serializer.newInstance()
+ val inputsByUri = new HashMap[String, ArrayBuffer[Int]]
+ val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId)
+ for ((serverUri, index) <- serverUris.zipWithIndex) {
+ inputsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index
+ }
+
+ // Randomize them and put them in a LinkedBlockingQueue
+ val serverQueue = new LinkedBlockingQueue[(String, ArrayBuffer[Int])]
+ for (pair <- Utils.randomize(inputsByUri)) {
+ serverQueue.put(pair)
+ }
+
+ // Create a queue to hold the fetched data
+ val resultQueue = new LinkedBlockingQueue[Array[Byte]]
+
+ // Atomic variables to communicate failures and # of fetches done
+ var failure = new AtomicReference[FetchFailedException](null)
+
+ // Start multiple threads to do the fetching (TODO: may be possible to do it asynchronously)
+ for (i <- 0 until parallelFetches) {
+ new Thread("Fetch thread " + i + " for reduce " + reduceId) {
+ override def run() {
+ while (true) {
+ val pair = serverQueue.poll()
+ if (pair == null)
+ return
+ val (serverUri, inputIds) = pair
+ //logInfo("Pulled out server URI " + serverUri)
+ for (i <- inputIds) {
+ if (failure.get != null)
+ return
+ val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
+ logInfo("Starting HTTP request for " + url)
+ try {
+ val conn = new URL(url).openConnection()
+ conn.connect()
+ val len = conn.getContentLength()
+ if (len == -1) {
+ throw new SparkException("Content length was not specified by server")
+ }
+ val buf = new Array[Byte](len)
+ val in = new FastBufferedInputStream(conn.getInputStream())
+ var pos = 0
+ while (pos < len) {
+ val n = in.read(buf, pos, len-pos)
+ if (n == -1) {
+ throw new SparkException("EOF before reading the expected " + len + " bytes")
+ } else {
+ pos += n
+ }
+ }
+ // Done reading everything
+ resultQueue.put(buf)
+ in.close()
+ } catch {
+ case e: Exception =>
+ logError("Fetch failed from " + url, e)
+ failure.set(new FetchFailedException(serverUri, shuffleId, i, reduceId, e))
+ return
+ }
+ }
+ //logInfo("Done with server URI " + serverUri)
+ }
+ }
+ }.start()
+ }
+
+ // Wait for results from the threads (either a failure or all servers done)
+ var resultsDone = 0
+ var totalResults = inputsByUri.map{case (uri, inputs) => inputs.size}.sum
+ while (failure.get == null && resultsDone < totalResults) {
+ try {
+ val result = resultQueue.poll(100, TimeUnit.MILLISECONDS)
+ if (result != null) {
+ //logInfo("Pulled out a result")
+ val in = ser.inputStream(new ByteArrayInputStream(result))
+ try {
+ while (true) {
+ val pair = in.readObject().asInstanceOf[(K, V)]
+ func(pair._1, pair._2)
+ }
+ } catch {
+ case e: EOFException => {} // TODO: cleaner way to detect EOF, such as a sentinel
+ }
+ resultsDone += 1
+ //logInfo("Results done = " + resultsDone)
+ }
+ } catch { case e: InterruptedException => {} }
+ }
+ if (failure.get != null) {
+ throw failure.get
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala
index 0e45ebd35c..024a4580ac 100644
--- a/core/src/main/scala/spark/Partitioner.scala
+++ b/core/src/main/scala/spark/Partitioner.scala
@@ -71,3 +71,4 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
false
}
}
+
diff --git a/core/src/main/scala/spark/PipedRDD.scala b/core/src/main/scala/spark/PipedRDD.scala
index 9e0a01b5f9..8a5de3d7e9 100644
--- a/core/src/main/scala/spark/PipedRDD.scala
+++ b/core/src/main/scala/spark/PipedRDD.scala
@@ -3,7 +3,6 @@ package spark
import java.io.PrintWriter
import java.util.StringTokenizer
-import scala.collection.Map
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 1191523ccc..4c4b2ee30d 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -4,14 +4,11 @@ import java.io.EOFException
import java.net.URL
import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong
+import java.util.HashSet
import java.util.Random
import java.util.Date
-import java.util.{HashMap => JHashMap}
import scala.collection.mutable.ArrayBuffer
-import scala.collection.Map
-import scala.collection.mutable.HashMap
-import scala.collection.JavaConversions.mapAsScalaMap
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
@@ -25,14 +22,6 @@ import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.TextOutputFormat
-import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
-
-import spark.partial.BoundedDouble
-import spark.partial.CountEvaluator
-import spark.partial.GroupedCountEvaluator
-import spark.partial.PartialResult
-import spark.storage.StorageLevel
-
import SparkContext._
/**
@@ -72,32 +61,19 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
// Get a unique ID for this RDD
val id = sc.newRddId()
- // Variables relating to persistence
- private var storageLevel: StorageLevel = StorageLevel.NONE
+ // Variables relating to caching
+ private var shouldCache = false
- // Change this RDD's storage level
- def persist(newLevel: StorageLevel): RDD[T] = {
- // TODO: Handle changes of StorageLevel
- if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) {
- throw new UnsupportedOperationException(
- "Cannot change storage level of an RDD after it was already assigned a level")
- }
- storageLevel = newLevel
+ // Change this RDD's caching
+ def cache(): RDD[T] = {
+ shouldCache = true
this
}
-
- // Turn on the default caching level for this RDD
- def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY_DESER)
-
- // Turn on the default caching level for this RDD
- def cache(): RDD[T] = persist()
-
- def getStorageLevel = storageLevel
// Read this RDD; will read from cache if applicable, or otherwise compute
final def iterator(split: Split): Iterator[T] = {
- if (storageLevel != StorageLevel.NONE) {
- SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel)
+ if (shouldCache) {
+ SparkEnv.get.cacheTracker.getOrCompute[T](this, split)
} else {
compute(split)
}
@@ -186,8 +162,6 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
Array.concat(results: _*)
}
- def toArray(): Array[T] = collect()
-
def reduce(f: (T, T) => T): T = {
val cleanF = sc.clean(f)
val reducePartition: Iterator[T] => Option[T] = iter => {
@@ -248,67 +222,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
}).sum
}
- /**
- * Approximate version of count() that returns a potentially incomplete result after a timeout.
- */
- def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
- val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) =>
- var result = 0L
- while (iter.hasNext) {
- result += 1L
- iter.next
- }
- result
- }
- val evaluator = new CountEvaluator(splits.size, confidence)
- sc.runApproximateJob(this, countElements, evaluator, timeout)
- }
-
- /**
- * Count elements equal to each value, returning a map of (value, count) pairs. The final combine
- * step happens locally on the master, equivalent to running a single reduce task.
- *
- * TODO: This should perhaps be distributed by default.
- */
- def countByValue(): Map[T, Long] = {
- def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = {
- val map = new OLMap[T]
- while (iter.hasNext) {
- val v = iter.next()
- map.put(v, map.getLong(v) + 1L)
- }
- Iterator(map)
- }
- def mergeMaps(m1: OLMap[T], m2: OLMap[T]): OLMap[T] = {
- val iter = m2.object2LongEntrySet.fastIterator()
- while (iter.hasNext) {
- val entry = iter.next()
- m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue)
- }
- return m1
- }
- val myResult = mapPartitions(countPartition).reduce(mergeMaps)
- myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map
- }
-
- /**
- * Approximate version of countByValue().
- */
- def countByValueApprox(
- timeout: Long,
- confidence: Double = 0.95
- ): PartialResult[Map[T, BoundedDouble]] = {
- val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) =>
- val map = new OLMap[T]
- while (iter.hasNext) {
- val v = iter.next()
- map.put(v, map.getLong(v) + 1L)
- }
- map
- }
- val evaluator = new GroupedCountEvaluator[T](splits.size, confidence)
- sc.runApproximateJob(this, countPartition, evaluator, timeout)
- }
+ def toArray(): Array[T] = collect()
/**
* Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/ResultTask.scala
index d2fab55b5e..3952bf85b2 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/ResultTask.scala
@@ -1,15 +1,14 @@
-package spark.scheduler
-
-import spark._
+package spark
class ResultTask[T, U](
- stageId: Int,
- rdd: RDD[T],
+ runId: Int,
+ stageId: Int,
+ rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
- val partition: Int,
- @transient locs: Seq[String],
+ val partition: Int,
+ locs: Seq[String],
val outputId: Int)
- extends Task[U](stageId) {
+ extends DAGTask[U](runId, stageId) {
val split = rdd.splits(partition)
diff --git a/core/src/main/scala/spark/Scheduler.scala b/core/src/main/scala/spark/Scheduler.scala
new file mode 100644
index 0000000000..6c7e569313
--- /dev/null
+++ b/core/src/main/scala/spark/Scheduler.scala
@@ -0,0 +1,27 @@
+package spark
+
+/**
+ * Scheduler trait, implemented by both MesosScheduler and LocalScheduler.
+ */
+private trait Scheduler {
+ def start()
+
+ // Wait for registration with Mesos.
+ def waitForRegister()
+
+ /**
+ * Run a function on some partitions of an RDD, returning an array of results. The allowLocal
+ * flag specifies whether the scheduler is allowed to run the job on the master machine rather
+ * than shipping it to the cluster, for actions that create short jobs such as first() and take().
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ allowLocal: Boolean): Array[U]
+
+ def stop()
+
+ // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
+ def defaultParallelism(): Int
+}
diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
index 9da73c4b02..b213ca9dcb 100644
--- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
@@ -44,7 +44,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
}
// TODO: use something like WritableConverter to avoid reflection
}
- c.asInstanceOf[Class[_ <: Writable]]
+ c.asInstanceOf[Class[ _ <: Writable]]
}
def saveAsSequenceFile(path: String) {
diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/Serializer.scala
index 61a70beaf1..2429bbfeb9 100644
--- a/core/src/main/scala/spark/Serializer.scala
+++ b/core/src/main/scala/spark/Serializer.scala
@@ -1,12 +1,6 @@
package spark
-import java.io.{EOFException, InputStream, OutputStream}
-import java.nio.ByteBuffer
-import java.nio.channels.Channels
-
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-
-import spark.util.ByteBufferInputStream
+import java.io.{InputStream, OutputStream}
/**
* A serializer. Because some serialization libraries are not thread safe, this class is used to
@@ -20,31 +14,11 @@ trait Serializer {
* An instance of the serializer, for use by one thread at a time.
*/
trait SerializerInstance {
- def serialize[T](t: T): ByteBuffer
-
- def deserialize[T](bytes: ByteBuffer): T
-
- def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T
-
- def serializeStream(s: OutputStream): SerializationStream
-
- def deserializeStream(s: InputStream): DeserializationStream
-
- def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
- // Default implementation uses serializeStream
- val stream = new FastByteArrayOutputStream()
- serializeStream(stream).writeAll(iterator)
- val buffer = ByteBuffer.allocate(stream.position.toInt)
- buffer.put(stream.array, 0, stream.position.toInt)
- buffer.flip()
- buffer
- }
-
- def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
- // Default implementation uses deserializeStream
- buffer.rewind()
- deserializeStream(new ByteBufferInputStream(buffer)).toIterator
- }
+ def serialize[T](t: T): Array[Byte]
+ def deserialize[T](bytes: Array[Byte]): T
+ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T
+ def outputStream(s: OutputStream): SerializationStream
+ def inputStream(s: InputStream): DeserializationStream
}
/**
@@ -54,13 +28,6 @@ trait SerializationStream {
def writeObject[T](t: T): Unit
def flush(): Unit
def close(): Unit
-
- def writeAll[T](iter: Iterator[T]): SerializationStream = {
- while (iter.hasNext) {
- writeObject(iter.next())
- }
- this
- }
}
/**
@@ -69,45 +36,4 @@ trait SerializationStream {
trait DeserializationStream {
def readObject[T](): T
def close(): Unit
-
- /**
- * Read the elements of this stream through an iterator. This can only be called once, as
- * reading each element will consume data from the input source.
- */
- def toIterator: Iterator[Any] = new Iterator[Any] {
- var gotNext = false
- var finished = false
- var nextValue: Any = null
-
- private def getNext() {
- try {
- nextValue = readObject[Any]()
- } catch {
- case eof: EOFException =>
- finished = true
- }
- gotNext = true
- }
-
- override def hasNext: Boolean = {
- if (!gotNext) {
- getNext()
- }
- if (finished) {
- close()
- }
- !finished
- }
-
- override def next(): Any = {
- if (!gotNext) {
- getNext()
- }
- if (finished) {
- throw new NoSuchElementException("End of stream")
- }
- gotNext = false
- nextValue
- }
- }
}
diff --git a/core/src/main/scala/spark/SerializingCache.scala b/core/src/main/scala/spark/SerializingCache.scala
new file mode 100644
index 0000000000..3d192f2403
--- /dev/null
+++ b/core/src/main/scala/spark/SerializingCache.scala
@@ -0,0 +1,26 @@
+package spark
+
+import java.io._
+
+/**
+ * Wrapper around a BoundedMemoryCache that stores serialized objects as byte arrays in order to
+ * reduce storage cost and GC overhead
+ */
+class SerializingCache extends Cache with Logging {
+ val bmc = new BoundedMemoryCache
+
+ override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
+ val ser = SparkEnv.get.serializer.newInstance()
+ bmc.put(datasetId, partition, ser.serialize(value))
+ }
+
+ override def get(datasetId: Any, partition: Int): Any = {
+ val bytes = bmc.get(datasetId, partition)
+ if (bytes != null) {
+ val ser = SparkEnv.get.serializer.newInstance()
+ return ser.deserialize(bytes.asInstanceOf[Array[Byte]])
+ } else {
+ return null
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/ShuffleMapTask.scala b/core/src/main/scala/spark/ShuffleMapTask.scala
new file mode 100644
index 0000000000..5fc59af06c
--- /dev/null
+++ b/core/src/main/scala/spark/ShuffleMapTask.scala
@@ -0,0 +1,56 @@
+package spark
+
+import java.io.BufferedOutputStream
+import java.io.FileOutputStream
+import java.io.ObjectOutputStream
+import java.util.{HashMap => JHashMap}
+
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
+class ShuffleMapTask(
+ runId: Int,
+ stageId: Int,
+ rdd: RDD[_],
+ dep: ShuffleDependency[_,_,_],
+ val partition: Int,
+ locs: Seq[String])
+ extends DAGTask[String](runId, stageId)
+ with Logging {
+
+ val split = rdd.splits(partition)
+
+ override def run (attemptId: Int): String = {
+ val numOutputSplits = dep.partitioner.numPartitions
+ val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
+ val partitioner = dep.partitioner.asInstanceOf[Partitioner]
+ val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any])
+ for (elem <- rdd.iterator(split)) {
+ val (k, v) = elem.asInstanceOf[(Any, Any)]
+ var bucketId = partitioner.getPartition(k)
+ val bucket = buckets(bucketId)
+ var existing = bucket.get(k)
+ if (existing == null) {
+ bucket.put(k, aggregator.createCombiner(v))
+ } else {
+ bucket.put(k, aggregator.mergeValue(existing, v))
+ }
+ }
+ val ser = SparkEnv.get.serializer.newInstance()
+ for (i <- 0 until numOutputSplits) {
+ val file = SparkEnv.get.shuffleManager.getOutputFile(dep.shuffleId, partition, i)
+ val out = ser.outputStream(new FastBufferedOutputStream(new FileOutputStream(file)))
+ val iter = buckets(i).entrySet().iterator()
+ while (iter.hasNext()) {
+ val entry = iter.next()
+ out.writeObject((entry.getKey, entry.getValue))
+ }
+ // TODO: have some kind of EOF marker
+ out.close()
+ }
+ return SparkEnv.get.shuffleManager.getServerUri
+ }
+
+ override def preferredLocations: Seq[String] = locs
+
+ override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
+}
diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala
index 5434197eca..5efc8cf50b 100644
--- a/core/src/main/scala/spark/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/ShuffledRDD.scala
@@ -8,7 +8,7 @@ class ShuffledRDDSplit(val idx: Int) extends Split {
}
class ShuffledRDD[K, V, C](
- @transient parent: RDD[(K, V)],
+ parent: RDD[(K, V)],
aggregator: Aggregator[K, V, C],
part : Partitioner)
extends RDD[(K, C)](parent.context) {
diff --git a/core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala b/core/src/main/scala/spark/SimpleJob.scala
index 535c17d9d4..01c7efff1e 100644
--- a/core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala
+++ b/core/src/main/scala/spark/SimpleJob.scala
@@ -1,32 +1,28 @@
-package spark.scheduler.mesos
+package spark
-import java.util.Arrays
import java.util.{HashMap => JHashMap}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.math.max
-import scala.math.min
import com.google.protobuf.ByteString
import org.apache.mesos._
-import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
-
-import spark._
-import spark.scheduler._
+import org.apache.mesos.Protos._
/**
- * Schedules the tasks within a single TaskSet in the MesosScheduler.
+ * A Job that runs a set of tasks with no interdependencies.
*/
-class TaskSetManager(
+class SimpleJob(
sched: MesosScheduler,
- val taskSet: TaskSet)
- extends Logging {
+ tasksSeq: Seq[Task[_]],
+ runId: Int,
+ jobId: Int)
+ extends Job(runId, jobId)
+ with Logging {
// Maximum time to wait to run a task in a preferred location (in ms)
- val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
+ val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "5000").toLong
// CPUs to request per task
val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
@@ -34,20 +30,18 @@ class TaskSetManager(
// Maximum times a task is allowed to fail before failing the job
val MAX_TASK_FAILURES = 4
- // Quantile of tasks at which to start speculation
- val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
- val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
-
// Serializer for closures and tasks.
val ser = SparkEnv.get.closureSerializer.newInstance()
- val priority = taskSet.priority
- val tasks = taskSet.tasks
+ val callingThread = Thread.currentThread
+ val tasks = tasksSeq.toArray
val numTasks = tasks.length
- val copiesRunning = new Array[Int](numTasks)
+ val launched = new Array[Boolean](numTasks)
val finished = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
- val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
+ val tidToIndex = HashMap[String, Int]()
+
+ var tasksLaunched = 0
var tasksFinished = 0
// Last time when we launched a preferred task (for delay scheduling)
@@ -68,13 +62,6 @@ class TaskSetManager(
// List containing all pending tasks (also used as a stack, as above)
val allPendingTasks = new ArrayBuffer[Int]
- // Tasks that can be specualted. Since these will be a small fraction of total
- // tasks, we'll just hold them in a HaskSet.
- val speculatableTasks = new HashSet[Int]
-
- // Task index, start and finish time for each task attempt (indexed by task ID)
- val taskInfos = new HashMap[String, TaskInfo]
-
// Did the job fail?
var failed = false
var causeOfFailure = ""
@@ -89,12 +76,6 @@ class TaskSetManager(
// exceptions automatically.
val recentExceptions = HashMap[String, (Int, Long)]()
- // Figure out the current map output tracker generation and set it on all tasks
- val generation = sched.mapOutputTracker.getGeneration
- for (t <- tasks) {
- t.generation = generation
- }
-
// Add all our tasks to the pending lists. We do this in reverse order
// of task index so that tasks with low indices get launched first.
for (i <- (0 until numTasks).reverse) {
@@ -103,7 +84,7 @@ class TaskSetManager(
// Add a task to all the pending-task lists that it should be on.
def addPendingTask(index: Int) {
- val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
+ val locations = tasks(index).preferredLocations
if (locations.size == 0) {
pendingTasksWithNoPrefs += index
} else {
@@ -129,37 +110,13 @@ class TaskSetManager(
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
- if (copiesRunning(index) == 0 && !finished(index)) {
+ if (!launched(index) && !finished(index)) {
return Some(index)
}
}
return None
}
- // Return a speculative task for a given host if any are available. The task should not have an
- // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
- // task must have a preference for this host (or no preferred locations at all).
- def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
- speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
- val localTask = speculatableTasks.find { index =>
- val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
- val attemptLocs = taskAttempts(index).map(_.host)
- (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host)
- }
- if (localTask != None) {
- speculatableTasks -= localTask.get
- return localTask
- }
- if (!localOnly && speculatableTasks.size > 0) {
- val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.host).contains(host))
- if (nonLocalTask != None) {
- speculatableTasks -= nonLocalTask.get
- return nonLocalTask
- }
- }
- return None
- }
-
// Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well.
def findTask(host: String, localOnly: Boolean): Option[Int] = {
@@ -172,13 +129,10 @@ class TaskSetManager(
return noPrefTask
}
if (!localOnly) {
- val nonLocalTask = findTaskFromList(allPendingTasks)
- if (nonLocalTask != None) {
- return nonLocalTask
- }
+ return findTaskFromList(allPendingTasks) // Look for non-local task
+ } else {
+ return None
}
- // Finally, if all else has failed, find a speculative task
- return findSpeculativeTask(host, localOnly)
}
// Does a host count as a preferred location for a task? This is true if
@@ -190,11 +144,11 @@ class TaskSetManager(
}
// Respond to an offer of a single slave from the scheduler by finding a task
- def slaveOffer(slaveId: String, host: String, availableCpus: Double): Option[MTaskInfo] = {
- if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
+ def slaveOffer(offer: Offer, availableCpus: Double): Option[TaskInfo] = {
+ if (tasksLaunched < numTasks && availableCpus >= CPUS_PER_TASK) {
val time = System.currentTimeMillis
- var localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
-
+ val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
+ val host = offer.getHostname
findTask(host, localOnly) match {
case Some(index) => {
// Found a task; do some bookkeeping and return a Mesos task for it
@@ -202,17 +156,17 @@ class TaskSetManager(
val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch
val preferred = isPreferredLocation(task, host)
- val prefStr = if (preferred) "preferred" else "non-preferred"
- logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
- taskSet.id, index, taskId.getValue, slaveId, host, prefStr))
+ val prefStr = if(preferred) "preferred" else "non-preferred"
+ val message =
+ "Starting task %d:%d as TID %s on slave %s: %s (%s)".format(
+ jobId, index, taskId.getValue, offer.getSlaveId.getValue, host, prefStr)
+ logInfo(message)
// Do various bookkeeping
- copiesRunning(index) += 1
- val info = new TaskInfo(taskId.getValue, index, time, host)
- taskInfos(taskId.getValue) = info
- taskAttempts(index) = info :: taskAttempts(index)
- if (preferred) {
+ tidToIndex(taskId.getValue) = index
+ launched(index) = true
+ tasksLaunched += 1
+ if (preferred)
lastPreferredLaunchTime = time
- }
// Create and return the Mesos task object
val cpuRes = Resource.newBuilder()
.setName("cpus")
@@ -224,13 +178,13 @@ class TaskSetManager(
val serializedTask = ser.serialize(task)
val timeTaken = System.currentTimeMillis - startTime
- logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
- taskSet.id, index, serializedTask.limit, timeTaken))
+ logInfo("Size of task %d:%d is %d bytes and took %d ms to serialize by %s"
+ .format(jobId, index, serializedTask.size, timeTaken, ser.getClass.getName))
- val taskName = "task %s:%d".format(taskSet.id, index)
- return Some(MTaskInfo.newBuilder()
+ val taskName = "task %d:%d".format(jobId, index)
+ return Some(TaskInfo.newBuilder()
.setTaskId(taskId)
- .setSlaveId(SlaveID.newBuilder().setValue(slaveId))
+ .setSlaveId(offer.getSlaveId)
.setExecutor(sched.executorInfo)
.setName(taskName)
.addResources(cpuRes)
@@ -259,21 +213,18 @@ class TaskSetManager(
def taskFinished(status: TaskStatus) {
val tid = status.getTaskId.getValue
- val info = taskInfos(tid)
- val index = info.index
- info.markSuccessful()
+ val index = tidToIndex(tid)
if (!finished(index)) {
tasksFinished += 1
- logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
- tid, info.duration, tasksFinished, numTasks))
- // Deserialize task result and pass it to the scheduler
- val result = ser.deserialize[TaskResult[_]](status.getData.asReadOnlyByteBuffer)
- sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
+ logInfo("Finished TID %s (progress: %d/%d)".format(tid, tasksFinished, numTasks))
+ // Deserialize task result
+ val result = ser.deserialize[TaskResult[_]](
+ status.getData.toByteArray, getClass.getClassLoader)
+ sched.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
// Mark finished and stop if we've finished all the tasks
finished(index) = true
- if (tasksFinished == numTasks) {
- sched.taskSetFinished(this)
- }
+ if (tasksFinished == numTasks)
+ sched.jobFinished(this)
} else {
logInfo("Ignoring task-finished event for TID " + tid +
" because task " + index + " is already finished")
@@ -282,29 +233,30 @@ class TaskSetManager(
def taskLost(status: TaskStatus) {
val tid = status.getTaskId.getValue
- val info = taskInfos(tid)
- val index = info.index
- info.markFailed()
+ val index = tidToIndex(tid)
if (!finished(index)) {
- logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
- copiesRunning(index) -= 1
+ logInfo("Lost TID %s (task %d:%d)".format(tid, jobId, index))
+ launched(index) = false
+ tasksLaunched -= 1
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
if (status.getData != null && status.getData.size > 0) {
- val reason = ser.deserialize[TaskEndReason](status.getData.asReadOnlyByteBuffer)
+ val reason = ser.deserialize[TaskEndReason](
+ status.getData.toByteArray, getClass.getClassLoader)
reason match {
case fetchFailed: FetchFailed =>
- logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.listener.taskEnded(tasks(index), fetchFailed, null, null)
+ logInfo("Loss was due to fetch failure from " + fetchFailed.serverUri)
+ sched.taskEnded(tasks(index), fetchFailed, null, null)
finished(index) = true
tasksFinished += 1
- sched.taskSetFinished(this)
+ if (tasksFinished == numTasks) {
+ sched.jobFinished(this)
+ }
return
-
case ef: ExceptionFailure =>
val key = ef.exception.toString
val now = System.currentTimeMillis
- val (printFull, dupCount) = {
+ val (printFull, dupCount) =
if (recentExceptions.contains(key)) {
val (dupCount, printTime) = recentExceptions(key)
if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
@@ -315,28 +267,32 @@ class TaskSetManager(
(false, dupCount + 1)
}
} else {
- recentExceptions(key) = (0, now)
+ recentExceptions += Tuple(key, (0, now))
(true, 0)
}
- }
+
if (printFull) {
- val locs = ef.exception.getStackTrace.map(loc => "\tat %s".format(loc.toString))
- logInfo("Loss was due to %s\n%s".format(ef.exception.toString, locs.mkString("\n")))
+ val stackTrace =
+ for (elem <- ef.exception.getStackTrace)
+ yield "\tat %s".format(elem.toString)
+ logInfo("Loss was due to %s\n%s".format(
+ ef.exception.toString, stackTrace.mkString("\n")))
} else {
- logInfo("Loss was due to %s [duplicate %d]".format(ef.exception.toString, dupCount))
+ logInfo("Loss was due to %s [duplicate %d]".format(
+ ef.exception.toString, dupCount))
}
-
case _ => {}
}
}
- // On non-fetch failures, re-enqueue the task as pending for a max number of retries
+ // On other failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
- // Count failed attempts only on FAILED and LOST state (not on KILLED)
- if (status.getState == TaskState.TASK_FAILED || status.getState == TaskState.TASK_LOST) {
+ // Count attempts only on FAILED and LOST state (not on KILLED)
+ if (status.getState == TaskState.TASK_FAILED ||
+ status.getState == TaskState.TASK_LOST) {
numFailures(index) += 1
if (numFailures(index) > MAX_TASK_FAILURES) {
- logError("Task %s:%d failed more than %d times; aborting job".format(
- taskSet.id, index, MAX_TASK_FAILURES))
+ logError("Task %d:%d failed more than %d times; aborting job".format(
+ jobId, index, MAX_TASK_FAILURES))
abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES))
}
}
@@ -355,71 +311,6 @@ class TaskSetManager(
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.taskSetFinished(this)
- }
-
- def hostLost(hostname: String) {
- logInfo("Re-queueing tasks for " + hostname)
- // If some task has preferred locations only on hostname, put it in the no-prefs list
- // to avoid the wait from delay scheduling
- for (index <- getPendingTasksForHost(hostname)) {
- val newLocs = tasks(index).preferredLocations.toSet & sched.hostsAlive
- if (newLocs.isEmpty) {
- pendingTasksWithNoPrefs += index
- }
- }
- // Also re-enqueue any tasks that ran on the failed host if this is a shuffle map stage
- if (tasks(0).isInstanceOf[ShuffleMapTask]) {
- for ((tid, info) <- taskInfos if info.host == hostname) {
- val index = taskInfos(tid).index
- if (finished(index)) {
- finished(index) = false
- copiesRunning(index) -= 1
- tasksFinished -= 1
- addPendingTask(index)
- // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
- // stage finishes when a total of tasks.size tasks finish.
- sched.listener.taskEnded(tasks(index), Resubmitted, null, null)
- }
- }
- }
- }
-
- /**
- * Check for tasks to be speculated and return true if there are any. This is called periodically
- * by the MesosScheduler.
- *
- * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
- * we don't scan the whole task set. It might also help to make this sorted by launch time.
- */
- def checkSpeculatableTasks(): Boolean = {
- // Can't speculate if we only have one task, or if all tasks have finished.
- if (numTasks == 1 || tasksFinished == numTasks) {
- return false
- }
- var foundTasks = false
- val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
- logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
- if (tasksFinished >= minFinishedForSpeculation) {
- val time = System.currentTimeMillis()
- val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
- Arrays.sort(durations)
- val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
- val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
- // TODO: Threshold should also look at standard deviation of task durations and have a lower
- // bound based on that.
- logDebug("Task length threshold for speculation: " + threshold)
- for ((tid, info) <- taskInfos) {
- val index = info.index
- if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
- !speculatableTasks.contains(index)) {
- logInfo("Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
- taskSet.id, index, info.host, threshold))
- speculatableTasks += index
- foundTasks = true
- }
- }
- }
- return foundTasks
+ sched.jobFinished(this)
}
}
diff --git a/core/src/main/scala/spark/SimpleShuffleFetcher.scala b/core/src/main/scala/spark/SimpleShuffleFetcher.scala
new file mode 100644
index 0000000000..196c64cf1f
--- /dev/null
+++ b/core/src/main/scala/spark/SimpleShuffleFetcher.scala
@@ -0,0 +1,46 @@
+package spark
+
+import java.io.EOFException
+import java.net.URL
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+
+import it.unimi.dsi.fastutil.io.FastBufferedInputStream
+
+class SimpleShuffleFetcher extends ShuffleFetcher with Logging {
+ def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
+ logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
+ val ser = SparkEnv.get.serializer.newInstance()
+ val splitsByUri = new HashMap[String, ArrayBuffer[Int]]
+ val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId)
+ for ((serverUri, index) <- serverUris.zipWithIndex) {
+ splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index
+ }
+ for ((serverUri, inputIds) <- Utils.randomize(splitsByUri)) {
+ for (i <- inputIds) {
+ try {
+ val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
+ // TODO: multithreaded fetch
+ // TODO: would be nice to retry multiple times
+ val inputStream = ser.inputStream(
+ new FastBufferedInputStream(new URL(url).openStream()))
+ try {
+ while (true) {
+ val pair = inputStream.readObject().asInstanceOf[(K, V)]
+ func(pair._1, pair._2)
+ }
+ } finally {
+ inputStream.close()
+ }
+ } catch {
+ case e: EOFException => {} // We currently assume EOF means we read the whole thing
+ case other: Exception => {
+ logError("Fetch failed", other)
+ throw new FetchFailedException(serverUri, shuffleId, i, reduceId, other)
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index b43aca2b97..6e019d6e7f 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -3,9 +3,6 @@ package spark
import java.io._
import java.util.concurrent.atomic.AtomicInteger
-import akka.actor.Actor
-import akka.actor.Actor._
-
import scala.actors.remote.RemoteActor
import scala.collection.mutable.ArrayBuffer
@@ -35,17 +32,6 @@ import org.apache.mesos.MesosNativeLibrary
import spark.broadcast._
-import spark.partial.ApproximateEvaluator
-import spark.partial.PartialResult
-
-import spark.scheduler.ShuffleMapTask
-import spark.scheduler.DAGScheduler
-import spark.scheduler.TaskScheduler
-import spark.scheduler.local.LocalScheduler
-import spark.scheduler.mesos.MesosScheduler
-import spark.scheduler.mesos.CoarseMesosScheduler
-import spark.storage.BlockManagerMaster
-
class SparkContext(
master: String,
frameworkName: String,
@@ -68,19 +54,14 @@ class SparkContext(
if (RemoteActor.classLoader == null) {
RemoteActor.classLoader = getClass.getClassLoader
}
-
- remote.start(System.getProperty("spark.master.host"),
- System.getProperty("spark.master.port").toInt)
- private val isLocal = master.startsWith("local") // TODO: better check for local
-
// Create the Spark execution environment (cache, map output tracker, etc)
- val env = SparkEnv.createFromSystemProperties(true, isLocal)
+ val env = SparkEnv.createFromSystemProperties(true)
SparkEnv.set(env)
Broadcast.initialize(true)
// Create and start the scheduler
- private var taskScheduler: TaskScheduler = {
+ private var scheduler: Scheduler = {
// Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
@@ -93,17 +74,13 @@ class SparkContext(
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
new LocalScheduler(threads.toInt, maxFailures.toInt)
case _ =>
- System.loadLibrary("mesos")
- if (System.getProperty("spark.mesos.coarse", "false") == "true") {
- new CoarseMesosScheduler(this, master, frameworkName)
- } else {
- new MesosScheduler(this, master, frameworkName)
- }
+ MesosNativeLibrary.load()
+ new MesosScheduler(this, master, frameworkName)
}
}
- taskScheduler.start()
+ scheduler.start()
- private var dagScheduler = new DAGScheduler(taskScheduler)
+ private val isLocal = scheduler.isInstanceOf[LocalScheduler]
// Methods for creating RDDs
@@ -260,25 +237,19 @@ class SparkContext(
// Stop the SparkContext
def stop() {
- remote.shutdownServerModule()
- dagScheduler.stop()
- dagScheduler = null
- taskScheduler = null
+ scheduler.stop()
+ scheduler = null
// TODO: Broadcast.stop(), Cache.stop()?
env.mapOutputTracker.stop()
env.cacheTracker.stop()
env.shuffleFetcher.stop()
env.shuffleManager.stop()
- env.blockManager.stop()
- BlockManagerMaster.stopBlockManagerMaster()
- env.connectionManager.stop()
SparkEnv.set(null)
- ShuffleMapTask.clearCache()
}
- // Wait for the scheduler to be registered with the cluster manager
+ // Wait for the scheduler to be registered
def waitForRegister() {
- taskScheduler.waitForRegister()
+ scheduler.waitForRegister()
}
// Get Spark's home location from either a value set through the constructor,
@@ -310,7 +281,7 @@ class SparkContext(
): Array[U] = {
logInfo("Starting job...")
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, func, partitions, allowLocal)
+ val result = scheduler.runJob(rdd, func, partitions, allowLocal)
logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
result
}
@@ -335,22 +306,6 @@ class SparkContext(
runJob(rdd, func, 0 until rdd.splits.size, false)
}
- /**
- * Run a job that can return approximate results.
- */
- def runApproximateJob[T, U, R](
- rdd: RDD[T],
- func: (TaskContext, Iterator[T]) => U,
- evaluator: ApproximateEvaluator[U, R],
- timeout: Long
- ): PartialResult[R] = {
- logInfo("Starting job...")
- val start = System.nanoTime
- val result = dagScheduler.runApproximateJob(rdd, func, evaluator, timeout)
- logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
- result
- }
-
// Clean a closure to make it ready to serialized and send to tasks
// (removes unreferenced variables in $outer's, updates REPL variables)
private[spark] def clean[F <: AnyRef](f: F): F = {
@@ -359,7 +314,7 @@ class SparkContext(
}
// Default level of parallelism to use when not given by user (e.g. for reduce tasks)
- def defaultParallelism: Int = taskScheduler.defaultParallelism
+ def defaultParallelism: Int = scheduler.defaultParallelism
// Default min number of splits for Hadoop RDDs when not given by user
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
@@ -394,23 +349,15 @@ object SparkContext {
}
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
-
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
-
- implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
- rdd: RDD[(K, V)]) =
+
+ implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
- implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
- rdd: RDD[(K, V)]) =
+ implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
new OrderedRDDFunctions(rdd)
- implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd)
-
- implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) =
- new DoubleRDDFunctions(rdd.map(x => num.toDouble(x)))
-
// Implicit conversions to common Writable types, for saveAsSequenceFile
implicit def intToIntWritable(i: Int) = new IntWritable(i)
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 897a5ef82d..cd752f8b65 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -1,26 +1,14 @@
package spark
-import akka.actor.Actor
-
-import spark.storage.BlockManager
-import spark.storage.BlockManagerMaster
-import spark.network.ConnectionManager
-
class SparkEnv (
- val cache: Cache,
- val serializer: Serializer,
- val closureSerializer: Serializer,
- val cacheTracker: CacheTracker,
- val mapOutputTracker: MapOutputTracker,
- val shuffleFetcher: ShuffleFetcher,
- val shuffleManager: ShuffleManager,
- val blockManager: BlockManager,
- val connectionManager: ConnectionManager
- ) {
-
- /** No-parameter constructor for unit tests. */
- def this() = this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null)
-}
+ val cache: Cache,
+ val serializer: Serializer,
+ val closureSerializer: Serializer,
+ val cacheTracker: CacheTracker,
+ val mapOutputTracker: MapOutputTracker,
+ val shuffleFetcher: ShuffleFetcher,
+ val shuffleManager: ShuffleManager
+)
object SparkEnv {
private val env = new ThreadLocal[SparkEnv]
@@ -33,55 +21,36 @@ object SparkEnv {
env.get()
}
- def createFromSystemProperties(isMaster: Boolean, isLocal: Boolean): SparkEnv = {
- val serializerClass = System.getProperty("spark.serializer", "spark.KryoSerializer")
- val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer]
-
- BlockManagerMaster.startBlockManagerMaster(isMaster, isLocal)
-
- var blockManager = new BlockManager(serializer)
-
- val connectionManager = blockManager.connectionManager
+ def createFromSystemProperties(isMaster: Boolean): SparkEnv = {
+ val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache")
+ val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
- val shuffleManager = new ShuffleManager()
+ val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer")
+ val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer]
val closureSerializerClass =
System.getProperty("spark.closure.serializer", "spark.JavaSerializer")
val closureSerializer =
Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer]
- val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache")
- val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
- val cacheTracker = new CacheTracker(isMaster, blockManager)
- blockManager.cacheTracker = cacheTracker
+ val cacheTracker = new CacheTracker(isMaster, cache)
val mapOutputTracker = new MapOutputTracker(isMaster)
val shuffleFetcherClass =
- System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
+ System.getProperty("spark.shuffle.fetcher", "spark.SimpleShuffleFetcher")
val shuffleFetcher =
Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher]
- /*
- if (System.getProperty("spark.stream.distributed", "false") == "true") {
- val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]]
- if (isLocal || !isMaster) {
- (new Thread() {
- override def run() {
- println("Wait started")
- Thread.sleep(60000)
- println("Wait ended")
- val receiverClass = Class.forName("spark.stream.TestStreamReceiver4")
- val constructor = receiverClass.getConstructor(blockManagerClass)
- val receiver = constructor.newInstance(blockManager)
- receiver.asInstanceOf[Thread].start()
- }
- }).start()
- }
- }
- */
+ val shuffleMgr = new ShuffleManager()
- new SparkEnv(cache, serializer, closureSerializer, cacheTracker, mapOutputTracker, shuffleFetcher,
- shuffleManager, blockManager, connectionManager)
+ new SparkEnv(
+ cache,
+ serializer,
+ closureSerializer,
+ cacheTracker,
+ mapOutputTracker,
+ shuffleFetcher,
+ shuffleMgr)
}
}
diff --git a/core/src/main/scala/spark/Stage.scala b/core/src/main/scala/spark/Stage.scala
new file mode 100644
index 0000000000..9452ea3a8e
--- /dev/null
+++ b/core/src/main/scala/spark/Stage.scala
@@ -0,0 +1,41 @@
+package spark
+
+class Stage(
+ val id: Int,
+ val rdd: RDD[_],
+ val shuffleDep: Option[ShuffleDependency[_,_,_]],
+ val parents: List[Stage]) {
+
+ val isShuffleMap = shuffleDep != None
+ val numPartitions = rdd.splits.size
+ val outputLocs = Array.fill[List[String]](numPartitions)(Nil)
+ var numAvailableOutputs = 0
+
+ def isAvailable: Boolean = {
+ if (parents.size == 0 && !isShuffleMap) {
+ true
+ } else {
+ numAvailableOutputs == numPartitions
+ }
+ }
+
+ def addOutputLoc(partition: Int, host: String) {
+ val prevList = outputLocs(partition)
+ outputLocs(partition) = host :: prevList
+ if (prevList == Nil)
+ numAvailableOutputs += 1
+ }
+
+ def removeOutputLoc(partition: Int, host: String) {
+ val prevList = outputLocs(partition)
+ val newList = prevList.filterNot(_ == host)
+ outputLocs(partition) = newList
+ if (prevList != Nil && newList == Nil) {
+ numAvailableOutputs -= 1
+ }
+ }
+
+ override def toString = "Stage " + id
+
+ override def hashCode(): Int = id
+}
diff --git a/core/src/main/scala/spark/Task.scala b/core/src/main/scala/spark/Task.scala
new file mode 100644
index 0000000000..bc3b374344
--- /dev/null
+++ b/core/src/main/scala/spark/Task.scala
@@ -0,0 +1,9 @@
+package spark
+
+class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) extends Serializable
+
+abstract class Task[T] extends Serializable {
+ def run(id: Int): T
+ def preferredLocations: Seq[String] = Nil
+ def generation: Option[Long] = None
+}
diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala
deleted file mode 100644
index 7a6214aab6..0000000000
--- a/core/src/main/scala/spark/TaskContext.scala
+++ /dev/null
@@ -1,3 +0,0 @@
-package spark
-
-class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) extends Serializable
diff --git a/core/src/main/scala/spark/TaskEndReason.scala b/core/src/main/scala/spark/TaskEndReason.scala
deleted file mode 100644
index 6e4eb25ed4..0000000000
--- a/core/src/main/scala/spark/TaskEndReason.scala
+++ /dev/null
@@ -1,16 +0,0 @@
-package spark
-
-import spark.storage.BlockManagerId
-
-/**
- * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry
- * tasks several times for "ephemeral" failures, and only report back failures that require some
- * old stages to be resubmitted, such as shuffle map fetch failures.
- */
-sealed trait TaskEndReason
-
-case object Success extends TaskEndReason
-case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it
-case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
-case class ExceptionFailure(exception: Throwable) extends TaskEndReason
-case class OtherFailure(message: String) extends TaskEndReason
diff --git a/core/src/main/scala/spark/TaskResult.scala b/core/src/main/scala/spark/TaskResult.scala
new file mode 100644
index 0000000000..2b7fd1a4b2
--- /dev/null
+++ b/core/src/main/scala/spark/TaskResult.scala
@@ -0,0 +1,8 @@
+package spark
+
+import scala.collection.mutable.Map
+
+// Task result. Also contains updates to accumulator variables.
+// TODO: Use of distributed cache to return result is a hack to get around
+// what seems to be a bug with messages over 60KB in libprocess; fix it
+private class TaskResult[T](val value: T, val accumUpdates: Map[Long, Any]) extends Serializable
diff --git a/core/src/main/scala/spark/UnionRDD.scala b/core/src/main/scala/spark/UnionRDD.scala
index 17522e2bbb..4c0f255e6b 100644
--- a/core/src/main/scala/spark/UnionRDD.scala
+++ b/core/src/main/scala/spark/UnionRDD.scala
@@ -33,8 +33,7 @@ class UnionRDD[T: ClassManifest](
override def splits = splits_
- @transient
- override val dependencies = {
+ @transient override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for ((rdd, index) <- rdds.zipWithIndex) {
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 89624eb370..68ccab24db 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -124,23 +124,6 @@ object Utils {
* Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4).
*/
def localIpAddress(): String = InetAddress.getLocalHost.getHostAddress
-
- private var customHostname: Option[String] = None
-
- /**
- * Allow setting a custom host name because when we run on Mesos we need to use the same
- * hostname it reports to the master.
- */
- def setCustomHostname(hostname: String) {
- customHostname = Some(hostname)
- }
-
- /**
- * Get the local machine's hostname
- */
- def localHostName(): String = {
- customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
- }
/**
* Returns a standard ThreadFactory except all threads are daemons.
@@ -165,14 +148,6 @@ object Utils {
return threadPool
}
-
- /**
- * Return the string to tell how long has passed in seconds. The passing parameter should be in
- * millisecond.
- */
- def getUsedTimeMs(startTimeMs: Long): String = {
- return " " + (System.currentTimeMillis - startTimeMs) + " ms "
- }
/**
* Wrapper over newFixedThreadPool.
@@ -186,6 +161,16 @@ object Utils {
}
/**
+ * Get the local machine's hostname.
+ */
+ def localHostName(): String = InetAddress.getLocalHost.getHostName
+
+ /**
+ * Get current host
+ */
+ def getHost = System.getProperty("spark.hostname", localHostName())
+
+ /**
* Delete a file or directory and its contents recursively.
*/
def deleteRecursively(file: File) {
diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
deleted file mode 100644
index 4546dfa0fa..0000000000
--- a/core/src/main/scala/spark/network/Connection.scala
+++ /dev/null
@@ -1,364 +0,0 @@
-package spark.network
-
-import spark._
-
-import scala.collection.mutable.{HashMap, Queue, ArrayBuffer}
-
-import java.io._
-import java.nio._
-import java.nio.channels._
-import java.nio.channels.spi._
-import java.net._
-
-
-abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging {
-
- channel.configureBlocking(false)
- channel.socket.setTcpNoDelay(true)
- channel.socket.setReuseAddress(true)
- channel.socket.setKeepAlive(true)
- /*channel.socket.setReceiveBufferSize(32768) */
-
- var onCloseCallback: Connection => Unit = null
- var onExceptionCallback: (Connection, Exception) => Unit = null
- var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
-
- lazy val remoteAddress = getRemoteAddress()
- lazy val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
-
- def key() = channel.keyFor(selector)
-
- def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
-
- def read() {
- throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString)
- }
-
- def write() {
- throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString)
- }
-
- def close() {
- key.cancel()
- channel.close()
- callOnCloseCallback()
- }
-
- def onClose(callback: Connection => Unit) {onCloseCallback = callback}
-
- def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback}
-
- def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback}
-
- def callOnExceptionCallback(e: Exception) {
- if (onExceptionCallback != null) {
- onExceptionCallback(this, e)
- } else {
- logError("Error in connection to " + remoteConnectionManagerId +
- " and OnExceptionCallback not registered", e)
- }
- }
-
- def callOnCloseCallback() {
- if (onCloseCallback != null) {
- onCloseCallback(this)
- } else {
- logWarning("Connection to " + remoteConnectionManagerId +
- " closed and OnExceptionCallback not registered")
- }
-
- }
-
- def changeConnectionKeyInterest(ops: Int) {
- if (onKeyInterestChangeCallback != null) {
- onKeyInterestChangeCallback(this, ops)
- } else {
- throw new Exception("OnKeyInterestChangeCallback not registered")
- }
- }
-
- def printRemainingBuffer(buffer: ByteBuffer) {
- val bytes = new Array[Byte](buffer.remaining)
- val curPosition = buffer.position
- buffer.get(bytes)
- bytes.foreach(x => print(x + " "))
- buffer.position(curPosition)
- print(" (" + bytes.size + ")")
- }
-
- def printBuffer(buffer: ByteBuffer, position: Int, length: Int) {
- val bytes = new Array[Byte](length)
- val curPosition = buffer.position
- buffer.position(position)
- buffer.get(bytes)
- bytes.foreach(x => print(x + " "))
- print(" (" + position + ", " + length + ")")
- buffer.position(curPosition)
- }
-
-}
-
-
-class SendingConnection(val address: InetSocketAddress, selector_ : Selector)
-extends Connection(SocketChannel.open, selector_) {
-
- class Outbox(fair: Int = 0) {
- val messages = new Queue[Message]()
- val defaultChunkSize = 65536 //32768 //16384
- var nextMessageToBeUsed = 0
-
- def addMessage(message: Message): Unit = {
- messages.synchronized{
- /*messages += message*/
- messages.enqueue(message)
- logInfo("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]")
- }
- }
-
- def getChunk(): Option[MessageChunk] = {
- fair match {
- case 0 => getChunkFIFO()
- case 1 => getChunkRR()
- case _ => throw new Exception("Unexpected fairness policy in outbox")
- }
- }
-
- private def getChunkFIFO(): Option[MessageChunk] = {
- /*logInfo("Using FIFO")*/
- messages.synchronized {
- while (!messages.isEmpty) {
- val message = messages(0)
- val chunk = message.getChunkForSending(defaultChunkSize)
- if (chunk.isDefined) {
- messages += message // this is probably incorrect, it wont work as fifo
- if (!message.started) logDebug("Starting to send [" + message + "]")
- message.started = true
- return chunk
- }
- /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
- logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
- }
- }
- None
- }
-
- private def getChunkRR(): Option[MessageChunk] = {
- messages.synchronized {
- while (!messages.isEmpty) {
- /*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
- /*val message = messages(nextMessageToBeUsed)*/
- val message = messages.dequeue
- val chunk = message.getChunkForSending(defaultChunkSize)
- if (chunk.isDefined) {
- messages.enqueue(message)
- nextMessageToBeUsed = nextMessageToBeUsed + 1
- if (!message.started) {
- logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]")
- message.started = true
- message.startTime = System.currentTimeMillis
- }
- logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]")
- return chunk
- }
- /*messages -= message*/
- message.finishTime = System.currentTimeMillis
- logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
- }
- }
- None
- }
- }
-
- val outbox = new Outbox(1)
- val currentBuffers = new ArrayBuffer[ByteBuffer]()
-
- /*channel.socket.setSendBufferSize(256 * 1024)*/
-
- override def getRemoteAddress() = address
-
- def send(message: Message) {
- outbox.synchronized {
- outbox.addMessage(message)
- if (channel.isConnected) {
- changeConnectionKeyInterest(SelectionKey.OP_WRITE)
- }
- }
- }
-
- def connect() {
- try{
- channel.connect(address)
- channel.register(selector, SelectionKey.OP_CONNECT)
- logInfo("Initiating connection to [" + address + "]")
- } catch {
- case e: Exception => {
- logError("Error connecting to " + address, e)
- callOnExceptionCallback(e)
- }
- }
- }
-
- def finishConnect() {
- try {
- channel.finishConnect
- changeConnectionKeyInterest(SelectionKey.OP_WRITE)
- logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
- } catch {
- case e: Exception => {
- logWarning("Error finishing connection to " + address, e)
- callOnExceptionCallback(e)
- }
- }
- }
-
- override def write() {
- try{
- while(true) {
- if (currentBuffers.size == 0) {
- outbox.synchronized {
- outbox.getChunk match {
- case Some(chunk) => {
- currentBuffers ++= chunk.buffers
- }
- case None => {
- changeConnectionKeyInterest(0)
- /*key.interestOps(0)*/
- return
- }
- }
- }
- }
-
- if (currentBuffers.size > 0) {
- val buffer = currentBuffers(0)
- val remainingBytes = buffer.remaining
- val writtenBytes = channel.write(buffer)
- if (buffer.remaining == 0) {
- currentBuffers -= buffer
- }
- if (writtenBytes < remainingBytes) {
- return
- }
- }
- }
- } catch {
- case e: Exception => {
- logWarning("Error writing in connection to " + remoteConnectionManagerId, e)
- callOnExceptionCallback(e)
- close()
- }
- }
- }
-}
-
-
-class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
-extends Connection(channel_, selector_) {
-
- class Inbox() {
- val messages = new HashMap[Int, BufferMessage]()
-
- def getChunk(header: MessageChunkHeader): Option[MessageChunk] = {
-
- def createNewMessage: BufferMessage = {
- val newMessage = Message.create(header).asInstanceOf[BufferMessage]
- newMessage.started = true
- newMessage.startTime = System.currentTimeMillis
- logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]")
- messages += ((newMessage.id, newMessage))
- newMessage
- }
-
- val message = messages.getOrElseUpdate(header.id, createNewMessage)
- logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]")
- message.getChunkForReceiving(header.chunkSize)
- }
-
- def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = {
- messages.get(chunk.header.id)
- }
-
- def removeMessage(message: Message) {
- messages -= message.id
- }
- }
-
- val inbox = new Inbox()
- val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
- var onReceiveCallback: (Connection , Message) => Unit = null
- var currentChunk: MessageChunk = null
-
- channel.register(selector, SelectionKey.OP_READ)
-
- override def read() {
- try {
- while (true) {
- if (currentChunk == null) {
- val headerBytesRead = channel.read(headerBuffer)
- if (headerBytesRead == -1) {
- close()
- return
- }
- if (headerBuffer.remaining > 0) {
- return
- }
- headerBuffer.flip
- if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
- throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
- }
- val header = MessageChunkHeader.create(headerBuffer)
- headerBuffer.clear()
- header.typ match {
- case Message.BUFFER_MESSAGE => {
- if (header.totalSize == 0) {
- if (onReceiveCallback != null) {
- onReceiveCallback(this, Message.create(header))
- }
- currentChunk = null
- return
- } else {
- currentChunk = inbox.getChunk(header).orNull
- }
- }
- case _ => throw new Exception("Message of unknown type received")
- }
- }
-
- if (currentChunk == null) throw new Exception("No message chunk to receive data")
-
- val bytesRead = channel.read(currentChunk.buffer)
- if (bytesRead == 0) {
- return
- } else if (bytesRead == -1) {
- close()
- return
- }
-
- /*logDebug("Read " + bytesRead + " bytes for the buffer")*/
-
- if (currentChunk.buffer.remaining == 0) {
- /*println("Filled buffer at " + System.currentTimeMillis)*/
- val bufferMessage = inbox.getMessageForChunk(currentChunk).get
- if (bufferMessage.isCompletelyReceived) {
- bufferMessage.flip
- bufferMessage.finishTime = System.currentTimeMillis
- logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken)
- if (onReceiveCallback != null) {
- onReceiveCallback(this, bufferMessage)
- }
- inbox.removeMessage(bufferMessage)
- }
- currentChunk = null
- }
- }
- } catch {
- case e: Exception => {
- logWarning("Error reading from connection to " + remoteConnectionManagerId, e)
- callOnExceptionCallback(e)
- close()
- }
- }
- }
-
- def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
-}
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
deleted file mode 100644
index 3222187990..0000000000
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ /dev/null
@@ -1,468 +0,0 @@
-package spark.network
-
-import spark._
-
-import scala.actors.Future
-import scala.actors.Futures.future
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.SynchronizedMap
-import scala.collection.mutable.SynchronizedQueue
-import scala.collection.mutable.Queue
-import scala.collection.mutable.ArrayBuffer
-
-import java.io._
-import java.nio._
-import java.nio.channels._
-import java.nio.channels.spi._
-import java.net._
-import java.util.concurrent.Executors
-
-case class ConnectionManagerId(val host: String, val port: Int) {
- def toSocketAddress() = new InetSocketAddress(host, port)
-}
-
-object ConnectionManagerId {
- def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
- new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
- }
-}
-
-class ConnectionManager(port: Int) extends Logging {
-
- case class MessageStatus(message: Message, connectionManagerId: ConnectionManagerId) {
- var ackMessage: Option[Message] = None
- var attempted = false
- var acked = false
- }
-
- val selector = SelectorProvider.provider.openSelector()
- val handleMessageExecutor = Executors.newFixedThreadPool(4)
- val serverChannel = ServerSocketChannel.open()
- val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
- val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
- val messageStatuses = new HashMap[Int, MessageStatus]
- val connectionRequests = new SynchronizedQueue[SendingConnection]
- val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
- val sendMessageRequests = new Queue[(Message, SendingConnection)]
-
- var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
-
- serverChannel.configureBlocking(false)
- serverChannel.socket.setReuseAddress(true)
- serverChannel.socket.setReceiveBufferSize(256 * 1024)
-
- serverChannel.socket.bind(new InetSocketAddress(port))
- serverChannel.register(selector, SelectionKey.OP_ACCEPT)
-
- val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
- logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
-
- val thisInstance = this
- val selectorThread = new Thread("connection-manager-thread") {
- override def run() {
- thisInstance.run()
- }
- }
- selectorThread.setDaemon(true)
- selectorThread.start()
-
- def run() {
- try {
- var interrupted = false
- while(!interrupted) {
- while(!connectionRequests.isEmpty) {
- val sendingConnection = connectionRequests.dequeue
- sendingConnection.connect()
- addConnection(sendingConnection)
- }
- sendMessageRequests.synchronized {
- while(!sendMessageRequests.isEmpty) {
- val (message, connection) = sendMessageRequests.dequeue
- connection.send(message)
- }
- }
-
- while(!keyInterestChangeRequests.isEmpty) {
- val (key, ops) = keyInterestChangeRequests.dequeue
- val connection = connectionsByKey(key)
- val lastOps = key.interestOps()
- key.interestOps(ops)
-
- def intToOpStr(op: Int): String = {
- val opStrs = ArrayBuffer[String]()
- if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
- if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
- if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
- if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
- if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
- }
-
- logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId +
- "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
-
- }
-
- val selectedKeysCount = selector.select()
- if (selectedKeysCount == 0) logInfo("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
-
- interrupted = selectorThread.isInterrupted
-
- val selectedKeys = selector.selectedKeys().iterator()
- while (selectedKeys.hasNext()) {
- val key = selectedKeys.next.asInstanceOf[SelectionKey]
- selectedKeys.remove()
- if (key.isValid) {
- if (key.isAcceptable) {
- acceptConnection(key)
- } else
- if (key.isConnectable) {
- connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect()
- } else
- if (key.isReadable) {
- connectionsByKey(key).read()
- } else
- if (key.isWritable) {
- connectionsByKey(key).write()
- }
- }
- }
- }
- } catch {
- case e: Exception => logError("Error in select loop", e)
- }
- }
-
- def acceptConnection(key: SelectionKey) {
- val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
- val newChannel = serverChannel.accept()
- val newConnection = new ReceivingConnection(newChannel, selector)
- newConnection.onReceive(receiveMessage)
- newConnection.onClose(removeConnection)
- addConnection(newConnection)
- logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
- }
-
- def addConnection(connection: Connection) {
- connectionsByKey += ((connection.key, connection))
- if (connection.isInstanceOf[SendingConnection]) {
- val sendingConnection = connection.asInstanceOf[SendingConnection]
- connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection))
- }
- connection.onKeyInterestChange(changeConnectionKeyInterest)
- connection.onException(handleConnectionError)
- connection.onClose(removeConnection)
- }
-
- def removeConnection(connection: Connection) {
- /*logInfo("Removing connection")*/
- connectionsByKey -= connection.key
- if (connection.isInstanceOf[SendingConnection]) {
- val sendingConnection = connection.asInstanceOf[SendingConnection]
- val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId
- logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
-
- connectionsById -= sendingConnectionManagerId
-
- messageStatuses.synchronized {
- messageStatuses
- .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
- logInfo("Notifying " + status)
- status.synchronized {
- status.attempted = true
- status.acked = false
- status.notifyAll
- }
- })
-
- messageStatuses.retain((i, status) => {
- status.connectionManagerId != sendingConnectionManagerId
- })
- }
- } else if (connection.isInstanceOf[ReceivingConnection]) {
- val receivingConnection = connection.asInstanceOf[ReceivingConnection]
- val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId
- logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
-
- val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull
- if (sendingConnectionManagerId == null) {
- logError("Corresponding SendingConnectionManagerId not found")
- return
- }
- logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId)
-
- val sendingConnection = connectionsById(sendingConnectionManagerId)
- sendingConnection.close()
- connectionsById -= sendingConnectionManagerId
-
- messageStatuses.synchronized {
- messageStatuses
- .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
- logInfo("Notifying " + status)
- status.synchronized {
- status.attempted = true
- status.acked = false
- status.notifyAll
- }
- })
-
- messageStatuses.retain((i, status) => {
- status.connectionManagerId != sendingConnectionManagerId
- })
- }
- }
- }
-
- def handleConnectionError(connection: Connection, e: Exception) {
- logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId)
- removeConnection(connection)
- }
-
- def changeConnectionKeyInterest(connection: Connection, ops: Int) {
- keyInterestChangeRequests += ((connection.key, ops))
- }
-
- def receiveMessage(connection: Connection, message: Message) {
- val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
- logInfo("Received [" + message + "] from [" + connectionManagerId + "]")
- val runnable = new Runnable() {
- val creationTime = System.currentTimeMillis
- def run() {
- logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
- handleMessage(connectionManagerId, message)
- logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
- }
- }
- handleMessageExecutor.execute(runnable)
- /*handleMessage(connection, message)*/
- }
-
- private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) {
- logInfo("Handling [" + message + "] from [" + connectionManagerId + "]")
- message match {
- case bufferMessage: BufferMessage => {
- if (bufferMessage.hasAckId) {
- val sentMessageStatus = messageStatuses.synchronized {
- messageStatuses.get(bufferMessage.ackId) match {
- case Some(status) => {
- messageStatuses -= bufferMessage.ackId
- status
- }
- case None => {
- throw new Exception("Could not find reference for received ack message " + message.id)
- null
- }
- }
- }
- sentMessageStatus.synchronized {
- sentMessageStatus.ackMessage = Some(message)
- sentMessageStatus.attempted = true
- sentMessageStatus.acked = true
- sentMessageStatus.notifyAll
- }
- } else {
- val ackMessage = if (onReceiveCallback != null) {
- logDebug("Calling back")
- onReceiveCallback(bufferMessage, connectionManagerId)
- } else {
- logWarning("Not calling back as callback is null")
- None
- }
-
- if (ackMessage.isDefined) {
- if (!ackMessage.get.isInstanceOf[BufferMessage]) {
- logWarning("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
- } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
- logWarning("Response to " + bufferMessage + " does not have ack id set")
- ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
- }
- }
-
- sendMessage(connectionManagerId, ackMessage.getOrElse {
- Message.createBufferMessage(bufferMessage.id)
- })
- }
- }
- case _ => throw new Exception("Unknown type message received")
- }
- }
-
- private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
- def startNewConnection(): SendingConnection = {
- val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
- val newConnection = new SendingConnection(inetSocketAddress, selector)
- connectionRequests += newConnection
- newConnection
- }
- val connection = connectionsById.getOrElse(connectionManagerId, startNewConnection)
- message.senderAddress = id.toSocketAddress()
- logInfo("Sending [" + message + "] to [" + connectionManagerId + "]")
- /*connection.send(message)*/
- sendMessageRequests.synchronized {
- sendMessageRequests += ((message, connection))
- }
- selector.wakeup()
- }
-
- def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message): Future[Option[Message]] = {
- val messageStatus = new MessageStatus(message, connectionManagerId)
- messageStatuses.synchronized {
- messageStatuses += ((message.id, messageStatus))
- }
- sendMessage(connectionManagerId, message)
- future {
- messageStatus.synchronized {
- if (!messageStatus.attempted) {
- logTrace("Waiting, " + messageStatuses.size + " statuses" )
- messageStatus.wait()
- logTrace("Done waiting")
- }
- }
- messageStatus.ackMessage
- }
- }
-
- def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = {
- sendMessageReliably(connectionManagerId, message)()
- }
-
- def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {
- onReceiveCallback = callback
- }
-
- def stop() {
- if (!selectorThread.isAlive) {
- selectorThread.interrupt()
- selectorThread.join()
- selector.close()
- val connections = connectionsByKey.values
- connections.foreach(_.close())
- if (connectionsByKey.size != 0) {
- logWarning("All connections not cleaned up")
- }
- handleMessageExecutor.shutdown()
- logInfo("ConnectionManager stopped")
- }
- }
-}
-
-
-object ConnectionManager {
-
- def main(args: Array[String]) {
-
- val manager = new ConnectionManager(9999)
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- println("Received [" + msg + "] from [" + id + "]")
- None
- })
-
- /*testSequentialSending(manager)*/
- /*System.gc()*/
-
- /*testParallelSending(manager)*/
- /*System.gc()*/
-
- /*testParallelDecreasingSending(manager)*/
- /*System.gc()*/
-
- testContinuousSending(manager)
- System.gc()
- }
-
- def testSequentialSending(manager: ConnectionManager) {
- println("--------------------------")
- println("Sequential Sending")
- println("--------------------------")
- val size = 10 * 1024 * 1024
- val count = 10
-
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
-
- (0 until count).map(i => {
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- manager.sendMessageReliablySync(manager.id, bufferMessage)
- })
- println("--------------------------")
- println()
- }
-
- def testParallelSending(manager: ConnectionManager) {
- println("--------------------------")
- println("Parallel Sending")
- println("--------------------------")
- val size = 10 * 1024 * 1024
- val count = 10
-
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
-
- val startTime = System.currentTimeMillis
- (0 until count).map(i => {
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- manager.sendMessageReliably(manager.id, bufferMessage)
- }).foreach(f => {if (!f().isDefined) println("Failed")})
- val finishTime = System.currentTimeMillis
-
- val mb = size * count / 1024.0 / 1024.0
- val ms = finishTime - startTime
- val tput = mb * 1000.0 / ms
- println("--------------------------")
- println("Started at " + startTime + ", finished at " + finishTime)
- println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)")
- println("--------------------------")
- println()
- }
-
- def testParallelDecreasingSending(manager: ConnectionManager) {
- println("--------------------------")
- println("Parallel Decreasing Sending")
- println("--------------------------")
- val size = 10 * 1024 * 1024
- val count = 10
- val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte)))
- buffers.foreach(_.flip)
- val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0
-
- val startTime = System.currentTimeMillis
- (0 until count).map(i => {
- val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
- manager.sendMessageReliably(manager.id, bufferMessage)
- }).foreach(f => {if (!f().isDefined) println("Failed")})
- val finishTime = System.currentTimeMillis
-
- val ms = finishTime - startTime
- val tput = mb * 1000.0 / ms
- println("--------------------------")
- /*println("Started at " + startTime + ", finished at " + finishTime) */
- println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
- println("--------------------------")
- println()
- }
-
- def testContinuousSending(manager: ConnectionManager) {
- println("--------------------------")
- println("Continuous Sending")
- println("--------------------------")
- val size = 10 * 1024 * 1024
- val count = 10
-
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
-
- val startTime = System.currentTimeMillis
- while(true) {
- (0 until count).map(i => {
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- manager.sendMessageReliably(manager.id, bufferMessage)
- }).foreach(f => {if (!f().isDefined) println("Failed")})
- val finishTime = System.currentTimeMillis
- Thread.sleep(1000)
- val mb = size * count / 1024.0 / 1024.0
- val ms = finishTime - startTime
- val tput = mb * 1000.0 / ms
- println("--------------------------")
- println()
- }
- }
-}
diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala
deleted file mode 100644
index 5d21bb793f..0000000000
--- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala
+++ /dev/null
@@ -1,74 +0,0 @@
-package spark.network
-
-import spark._
-import spark.SparkContext._
-
-import scala.io.Source
-
-import java.nio.ByteBuffer
-import java.net.InetAddress
-
-object ConnectionManagerTest extends Logging{
- def main(args: Array[String]) {
- if (args.length < 2) {
- println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>")
- System.exit(1)
- }
-
- if (args(0).startsWith("local")) {
- println("This runs only on a mesos cluster")
- }
-
- val sc = new SparkContext(args(0), "ConnectionManagerTest")
- val slavesFile = Source.fromFile(args(1))
- val slaves = slavesFile.mkString.split("\n")
- slavesFile.close()
-
- /*println("Slaves")*/
- /*slaves.foreach(println)*/
-
- val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map(
- i => SparkEnv.get.connectionManager.id).collect()
- println("\nSlave ConnectionManagerIds")
- slaveConnManagerIds.foreach(println)
- println
-
- val count = 10
- (0 until count).foreach(i => {
- val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => {
- val connManager = SparkEnv.get.connectionManager
- val thisConnManagerId = connManager.id
- connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- logInfo("Received [" + msg + "] from [" + id + "]")
- None
- })
-
- val size = 100 * 1024 * 1024
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
-
- val startTime = System.currentTimeMillis
- val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => {
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
- connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
- })
- val results = futures.map(f => f())
- val finishTime = System.currentTimeMillis
- Thread.sleep(5000)
-
- val mb = size * results.size / 1024.0 / 1024.0
- val ms = finishTime - startTime
- val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
- logInfo(resultStr)
- resultStr
- }).collect()
-
- println("---------------------")
- println("Run " + i)
- resultStrs.foreach(println)
- println("---------------------")
- })
- }
-}
-
diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala
deleted file mode 100644
index 2e85803679..0000000000
--- a/core/src/main/scala/spark/network/Message.scala
+++ /dev/null
@@ -1,219 +0,0 @@
-package spark.network
-
-import spark._
-
-import scala.collection.mutable.ArrayBuffer
-
-import java.nio.ByteBuffer
-import java.net.InetAddress
-import java.net.InetSocketAddress
-
-class MessageChunkHeader(
- val typ: Long,
- val id: Int,
- val totalSize: Int,
- val chunkSize: Int,
- val other: Int,
- val address: InetSocketAddress) {
- lazy val buffer = {
- val ip = address.getAddress.getAddress()
- val port = address.getPort()
- ByteBuffer.
- allocate(MessageChunkHeader.HEADER_SIZE).
- putLong(typ).
- putInt(id).
- putInt(totalSize).
- putInt(chunkSize).
- putInt(other).
- putInt(ip.size).
- put(ip).
- putInt(port).
- position(MessageChunkHeader.HEADER_SIZE).
- flip.asInstanceOf[ByteBuffer]
- }
-
- override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
- " and sizes " + totalSize + " / " + chunkSize + " bytes"
-}
-
-class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
- val size = if (buffer == null) 0 else buffer.remaining
- lazy val buffers = {
- val ab = new ArrayBuffer[ByteBuffer]()
- ab += header.buffer
- if (buffer != null) {
- ab += buffer
- }
- ab
- }
-
- override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
-}
-
-abstract class Message(val typ: Long, val id: Int) {
- var senderAddress: InetSocketAddress = null
- var started = false
- var startTime = -1L
- var finishTime = -1L
-
- def size: Int
-
- def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
-
- def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
-
- def timeTaken(): String = (finishTime - startTime).toString + " ms"
-
- override def toString = "" + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
-}
-
-class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
-extends Message(Message.BUFFER_MESSAGE, id_) {
-
- val initialSize = currentSize()
- var gotChunkForSendingOnce = false
-
- def size = initialSize
-
- def currentSize() = {
- if (buffers == null || buffers.isEmpty) {
- 0
- } else {
- buffers.map(_.remaining).reduceLeft(_ + _)
- }
- }
-
- def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
- if (maxChunkSize <= 0) {
- throw new Exception("Max chunk size is " + maxChunkSize)
- }
-
- if (size == 0 && gotChunkForSendingOnce == false) {
- val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
- gotChunkForSendingOnce = true
- return Some(newChunk)
- }
-
- while(!buffers.isEmpty) {
- val buffer = buffers(0)
- if (buffer.remaining == 0) {
- buffers -= buffer
- } else {
- val newBuffer = if (buffer.remaining <= maxChunkSize) {
- buffer.duplicate
- } else {
- buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
- }
- buffer.position(buffer.position + newBuffer.remaining)
- val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
- gotChunkForSendingOnce = true
- return Some(newChunk)
- }
- }
- None
- }
-
- def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
- // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
- if (buffers.size > 1) {
- throw new Exception("Attempting to get chunk from message with multiple data buffers")
- }
- val buffer = buffers(0)
- if (buffer.remaining > 0) {
- if (buffer.remaining < chunkSize) {
- throw new Exception("Not enough space in data buffer for receiving chunk")
- }
- val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
- buffer.position(buffer.position + newBuffer.remaining)
- val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
- return Some(newChunk)
- }
- None
- }
-
- def flip() {
- buffers.foreach(_.flip)
- }
-
- def hasAckId() = (ackId != 0)
-
- def isCompletelyReceived() = !buffers(0).hasRemaining
-
- override def toString = {
- if (hasAckId) {
- "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
- } else {
- "BufferMessage(id = " + id + ", size = " + size + ")"
- }
-
- }
-}
-
-object MessageChunkHeader {
- val HEADER_SIZE = 40
-
- def create(buffer: ByteBuffer): MessageChunkHeader = {
- if (buffer.remaining != HEADER_SIZE) {
- throw new IllegalArgumentException("Cannot convert buffer data to Message")
- }
- val typ = buffer.getLong()
- val id = buffer.getInt()
- val totalSize = buffer.getInt()
- val chunkSize = buffer.getInt()
- val other = buffer.getInt()
- val ipSize = buffer.getInt()
- val ipBytes = new Array[Byte](ipSize)
- buffer.get(ipBytes)
- val ip = InetAddress.getByAddress(ipBytes)
- val port = buffer.getInt()
- new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
- }
-}
-
-object Message {
- val BUFFER_MESSAGE = 1111111111L
-
- var lastId = 1
-
- def getNewId() = synchronized {
- lastId += 1
- if (lastId == 0) lastId += 1
- lastId
- }
-
- def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
- if (dataBuffers == null) {
- return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
- }
- if (dataBuffers.exists(_ == null)) {
- throw new Exception("Attempting to create buffer message with null buffer")
- }
- return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId)
- }
-
- def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
- createBufferMessage(dataBuffers, 0)
-
- def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
- if (dataBuffer == null) {
- return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
- } else {
- return createBufferMessage(Array(dataBuffer), ackId)
- }
- }
-
- def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
- createBufferMessage(dataBuffer, 0)
-
- def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId)
-
- def create(header: MessageChunkHeader): Message = {
- val newMessage: Message = header.typ match {
- case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
- }
- newMessage.senderAddress = header.address
- newMessage
- }
-}
diff --git a/core/src/main/scala/spark/network/ReceiverTest.scala b/core/src/main/scala/spark/network/ReceiverTest.scala
deleted file mode 100644
index e1ba7c06c0..0000000000
--- a/core/src/main/scala/spark/network/ReceiverTest.scala
+++ /dev/null
@@ -1,20 +0,0 @@
-package spark.network
-
-import java.nio.ByteBuffer
-import java.net.InetAddress
-
-object ReceiverTest {
-
- def main(args: Array[String]) {
- val manager = new ConnectionManager(9999)
- println("Started connection manager with id = " + manager.id)
-
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- /*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/
- val buffer = ByteBuffer.wrap("response".getBytes())
- Some(Message.createBufferMessage(buffer, msg.id))
- })
- Thread.currentThread.join()
- }
-}
-
diff --git a/core/src/main/scala/spark/network/SenderTest.scala b/core/src/main/scala/spark/network/SenderTest.scala
deleted file mode 100644
index 4ab6dd3414..0000000000
--- a/core/src/main/scala/spark/network/SenderTest.scala
+++ /dev/null
@@ -1,53 +0,0 @@
-package spark.network
-
-import java.nio.ByteBuffer
-import java.net.InetAddress
-
-object SenderTest {
-
- def main(args: Array[String]) {
-
- if (args.length < 2) {
- println("Usage: SenderTest <target host> <target port>")
- System.exit(1)
- }
-
- val targetHost = args(0)
- val targetPort = args(1).toInt
- val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort)
-
- val manager = new ConnectionManager(0)
- println("Started connection manager with id = " + manager.id)
-
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- println("Received [" + msg + "] from [" + id + "]")
- None
- })
-
- val size = 100 * 1024 * 1024
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
-
- val targetServer = args(0)
-
- val count = 100
- (0 until count).foreach(i => {
- val dataMessage = Message.createBufferMessage(buffer.duplicate)
- val startTime = System.currentTimeMillis
- /*println("Started timer at " + startTime)*/
- val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) match {
- case Some(response) =>
- val buffer = response.asInstanceOf[BufferMessage].buffers(0)
- new String(buffer.array)
- case None => "none"
- }
- val finishTime = System.currentTimeMillis
- val mb = size / 1024.0 / 1024.0
- val ms = finishTime - startTime
- /*val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"*/
- val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" + (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr
- println(resultStr)
- })
- }
-}
-
diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
deleted file mode 100644
index 260547902b..0000000000
--- a/core/src/main/scala/spark/partial/ApproximateActionListener.scala
+++ /dev/null
@@ -1,66 +0,0 @@
-package spark.partial
-
-import spark._
-import spark.scheduler.JobListener
-
-/**
- * A JobListener for an approximate single-result action, such as count() or non-parallel reduce().
- * This listener waits up to timeout milliseconds and will return a partial answer even if the
- * complete answer is not available by then.
- *
- * This class assumes that the action is performed on an entire RDD[T] via a function that computes
- * a result of type U for each partition, and that the action returns a partial or complete result
- * of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt).
- */
-class ApproximateActionListener[T, U, R](
- rdd: RDD[T],
- func: (TaskContext, Iterator[T]) => U,
- evaluator: ApproximateEvaluator[U, R],
- timeout: Long)
- extends JobListener {
-
- val startTime = System.currentTimeMillis()
- val totalTasks = rdd.splits.size
- var finishedTasks = 0
- var failure: Option[Exception] = None // Set if the job has failed (permanently)
- var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult
-
- override def taskSucceeded(index: Int, result: Any): Unit = synchronized {
- evaluator.merge(index, result.asInstanceOf[U])
- finishedTasks += 1
- if (finishedTasks == totalTasks) {
- // If we had already returned a PartialResult, set its final value
- resultObject.foreach(r => r.setFinalValue(evaluator.currentResult()))
- // Notify any waiting thread that may have called getResult
- this.notifyAll()
- }
- }
-
- override def jobFailed(exception: Exception): Unit = synchronized {
- failure = Some(exception)
- this.notifyAll()
- }
-
- /**
- * Waits for up to timeout milliseconds since the listener was created and then returns a
- * PartialResult with the result so far. This may be complete if the whole job is done.
- */
- def getResult(): PartialResult[R] = synchronized {
- val finishTime = startTime + timeout
- while (true) {
- val time = System.currentTimeMillis()
- if (failure != None) {
- throw failure.get
- } else if (finishedTasks == totalTasks) {
- return new PartialResult(evaluator.currentResult(), true)
- } else if (time >= finishTime) {
- resultObject = Some(new PartialResult(evaluator.currentResult(), false))
- return resultObject.get
- } else {
- this.wait(finishTime - time)
- }
- }
- // Should never be reached, but required to keep the compiler happy
- return null
- }
-}
diff --git a/core/src/main/scala/spark/partial/ApproximateEvaluator.scala b/core/src/main/scala/spark/partial/ApproximateEvaluator.scala
deleted file mode 100644
index 4772e43ef0..0000000000
--- a/core/src/main/scala/spark/partial/ApproximateEvaluator.scala
+++ /dev/null
@@ -1,10 +0,0 @@
-package spark.partial
-
-/**
- * An object that computes a function incrementally by merging in results of type U from multiple
- * tasks. Allows partial evaluation at any point by calling currentResult().
- */
-trait ApproximateEvaluator[U, R] {
- def merge(outputId: Int, taskResult: U): Unit
- def currentResult(): R
-}
diff --git a/core/src/main/scala/spark/partial/BoundedDouble.scala b/core/src/main/scala/spark/partial/BoundedDouble.scala
deleted file mode 100644
index 463c33d6e2..0000000000
--- a/core/src/main/scala/spark/partial/BoundedDouble.scala
+++ /dev/null
@@ -1,8 +0,0 @@
-package spark.partial
-
-/**
- * A Double with error bars on it.
- */
-class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
- override def toString(): String = "[%.3f, %.3f]".format(low, high)
-}
diff --git a/core/src/main/scala/spark/partial/CountEvaluator.scala b/core/src/main/scala/spark/partial/CountEvaluator.scala
deleted file mode 100644
index 1bc90d6b39..0000000000
--- a/core/src/main/scala/spark/partial/CountEvaluator.scala
+++ /dev/null
@@ -1,38 +0,0 @@
-package spark.partial
-
-import cern.jet.stat.Probability
-
-/**
- * An ApproximateEvaluator for counts.
- *
- * TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might
- * be best to make this a special case of GroupedCountEvaluator with one group.
- */
-class CountEvaluator(totalOutputs: Int, confidence: Double)
- extends ApproximateEvaluator[Long, BoundedDouble] {
-
- var outputsMerged = 0
- var sum: Long = 0
-
- override def merge(outputId: Int, taskResult: Long) {
- outputsMerged += 1
- sum += taskResult
- }
-
- override def currentResult(): BoundedDouble = {
- if (outputsMerged == totalOutputs) {
- new BoundedDouble(sum, 1.0, sum, sum)
- } else if (outputsMerged == 0) {
- new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
- } else {
- val p = outputsMerged.toDouble / totalOutputs
- val mean = (sum + 1 - p) / p
- val variance = (sum + 1) * (1 - p) / (p * p)
- val stdev = math.sqrt(variance)
- val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
- val low = mean - confFactor * stdev
- val high = mean + confFactor * stdev
- new BoundedDouble(mean, confidence, low, high)
- }
- }
-}
diff --git a/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala
deleted file mode 100644
index 3e631c0efc..0000000000
--- a/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala
+++ /dev/null
@@ -1,62 +0,0 @@
-package spark.partial
-
-import java.util.{HashMap => JHashMap}
-import java.util.{Map => JMap}
-
-import scala.collection.Map
-import scala.collection.mutable.HashMap
-import scala.collection.JavaConversions.mapAsScalaMap
-
-import cern.jet.stat.Probability
-
-import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
-
-/**
- * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval.
- */
-class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double)
- extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] {
-
- var outputsMerged = 0
- var sums = new OLMap[T] // Sum of counts for each key
-
- override def merge(outputId: Int, taskResult: OLMap[T]) {
- outputsMerged += 1
- val iter = taskResult.object2LongEntrySet.fastIterator()
- while (iter.hasNext) {
- val entry = iter.next()
- sums.put(entry.getKey, sums.getLong(entry.getKey) + entry.getLongValue)
- }
- }
-
- override def currentResult(): Map[T, BoundedDouble] = {
- if (outputsMerged == totalOutputs) {
- val result = new JHashMap[T, BoundedDouble](sums.size)
- val iter = sums.object2LongEntrySet.fastIterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val sum = entry.getLongValue()
- result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum)
- }
- result
- } else if (outputsMerged == 0) {
- new HashMap[T, BoundedDouble]
- } else {
- val p = outputsMerged.toDouble / totalOutputs
- val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
- val result = new JHashMap[T, BoundedDouble](sums.size)
- val iter = sums.object2LongEntrySet.fastIterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val sum = entry.getLongValue
- val mean = (sum + 1 - p) / p
- val variance = (sum + 1) * (1 - p) / (p * p)
- val stdev = math.sqrt(variance)
- val low = mean - confFactor * stdev
- val high = mean + confFactor * stdev
- result(entry.getKey) = new BoundedDouble(mean, confidence, low, high)
- }
- result
- }
- }
-}
diff --git a/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala
deleted file mode 100644
index 2a9ccba205..0000000000
--- a/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala
+++ /dev/null
@@ -1,65 +0,0 @@
-package spark.partial
-
-import java.util.{HashMap => JHashMap}
-import java.util.{Map => JMap}
-
-import scala.collection.mutable.HashMap
-import scala.collection.Map
-import scala.collection.JavaConversions.mapAsScalaMap
-
-import spark.util.StatCounter
-
-/**
- * An ApproximateEvaluator for means by key. Returns a map of key to confidence interval.
- */
-class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double)
- extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
-
- var outputsMerged = 0
- var sums = new JHashMap[T, StatCounter] // Sum of counts for each key
-
- override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) {
- outputsMerged += 1
- val iter = taskResult.entrySet.iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val old = sums.get(entry.getKey)
- if (old != null) {
- old.merge(entry.getValue)
- } else {
- sums.put(entry.getKey, entry.getValue)
- }
- }
- }
-
- override def currentResult(): Map[T, BoundedDouble] = {
- if (outputsMerged == totalOutputs) {
- val result = new JHashMap[T, BoundedDouble](sums.size)
- val iter = sums.entrySet.iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val mean = entry.getValue.mean
- result(entry.getKey) = new BoundedDouble(mean, 1.0, mean, mean)
- }
- result
- } else if (outputsMerged == 0) {
- new HashMap[T, BoundedDouble]
- } else {
- val p = outputsMerged.toDouble / totalOutputs
- val studentTCacher = new StudentTCacher(confidence)
- val result = new JHashMap[T, BoundedDouble](sums.size)
- val iter = sums.entrySet.iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val counter = entry.getValue
- val mean = counter.mean
- val stdev = math.sqrt(counter.sampleVariance / counter.count)
- val confFactor = studentTCacher.get(counter.count)
- val low = mean - confFactor * stdev
- val high = mean + confFactor * stdev
- result(entry.getKey) = new BoundedDouble(mean, confidence, low, high)
- }
- result
- }
- }
-}
diff --git a/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala
deleted file mode 100644
index 6a2ec7a7bd..0000000000
--- a/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala
+++ /dev/null
@@ -1,72 +0,0 @@
-package spark.partial
-
-import java.util.{HashMap => JHashMap}
-import java.util.{Map => JMap}
-
-import scala.collection.mutable.HashMap
-import scala.collection.Map
-import scala.collection.JavaConversions.mapAsScalaMap
-
-import spark.util.StatCounter
-
-/**
- * An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval.
- */
-class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double)
- extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
-
- var outputsMerged = 0
- var sums = new JHashMap[T, StatCounter] // Sum of counts for each key
-
- override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) {
- outputsMerged += 1
- val iter = taskResult.entrySet.iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val old = sums.get(entry.getKey)
- if (old != null) {
- old.merge(entry.getValue)
- } else {
- sums.put(entry.getKey, entry.getValue)
- }
- }
- }
-
- override def currentResult(): Map[T, BoundedDouble] = {
- if (outputsMerged == totalOutputs) {
- val result = new JHashMap[T, BoundedDouble](sums.size)
- val iter = sums.entrySet.iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val sum = entry.getValue.sum
- result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum)
- }
- result
- } else if (outputsMerged == 0) {
- new HashMap[T, BoundedDouble]
- } else {
- val p = outputsMerged.toDouble / totalOutputs
- val studentTCacher = new StudentTCacher(confidence)
- val result = new JHashMap[T, BoundedDouble](sums.size)
- val iter = sums.entrySet.iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val counter = entry.getValue
- val meanEstimate = counter.mean
- val meanVar = counter.sampleVariance / counter.count
- val countEstimate = (counter.count + 1 - p) / p
- val countVar = (counter.count + 1) * (1 - p) / (p * p)
- val sumEstimate = meanEstimate * countEstimate
- val sumVar = (meanEstimate * meanEstimate * countVar) +
- (countEstimate * countEstimate * meanVar) +
- (meanVar * countVar)
- val sumStdev = math.sqrt(sumVar)
- val confFactor = studentTCacher.get(counter.count)
- val low = sumEstimate - confFactor * sumStdev
- val high = sumEstimate + confFactor * sumStdev
- result(entry.getKey) = new BoundedDouble(sumEstimate, confidence, low, high)
- }
- result
- }
- }
-}
diff --git a/core/src/main/scala/spark/partial/MeanEvaluator.scala b/core/src/main/scala/spark/partial/MeanEvaluator.scala
deleted file mode 100644
index b8c7cb8863..0000000000
--- a/core/src/main/scala/spark/partial/MeanEvaluator.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-package spark.partial
-
-import cern.jet.stat.Probability
-
-import spark.util.StatCounter
-
-/**
- * An ApproximateEvaluator for means.
- */
-class MeanEvaluator(totalOutputs: Int, confidence: Double)
- extends ApproximateEvaluator[StatCounter, BoundedDouble] {
-
- var outputsMerged = 0
- var counter = new StatCounter
-
- override def merge(outputId: Int, taskResult: StatCounter) {
- outputsMerged += 1
- counter.merge(taskResult)
- }
-
- override def currentResult(): BoundedDouble = {
- if (outputsMerged == totalOutputs) {
- new BoundedDouble(counter.mean, 1.0, counter.mean, counter.mean)
- } else if (outputsMerged == 0) {
- new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
- } else {
- val mean = counter.mean
- val stdev = math.sqrt(counter.sampleVariance / counter.count)
- val confFactor = {
- if (counter.count > 100) {
- Probability.normalInverse(1 - (1 - confidence) / 2)
- } else {
- Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
- }
- }
- val low = mean - confFactor * stdev
- val high = mean + confFactor * stdev
- new BoundedDouble(mean, confidence, low, high)
- }
- }
-}
diff --git a/core/src/main/scala/spark/partial/PartialResult.scala b/core/src/main/scala/spark/partial/PartialResult.scala
deleted file mode 100644
index 7095bc8ca1..0000000000
--- a/core/src/main/scala/spark/partial/PartialResult.scala
+++ /dev/null
@@ -1,86 +0,0 @@
-package spark.partial
-
-class PartialResult[R](initialVal: R, isFinal: Boolean) {
- private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None
- private var failure: Option[Exception] = None
- private var completionHandler: Option[R => Unit] = None
- private var failureHandler: Option[Exception => Unit] = None
-
- def initialValue: R = initialVal
-
- def isInitialValueFinal: Boolean = isFinal
-
- /**
- * Blocking method to wait for and return the final value.
- */
- def getFinalValue(): R = synchronized {
- while (finalValue == None && failure == None) {
- this.wait()
- }
- if (finalValue != None) {
- return finalValue.get
- } else {
- throw failure.get
- }
- }
-
- /**
- * Set a handler to be called when this PartialResult completes. Only one completion handler
- * is supported per PartialResult.
- */
- def onComplete(handler: R => Unit): PartialResult[R] = synchronized {
- if (completionHandler != None) {
- throw new UnsupportedOperationException("onComplete cannot be called twice")
- }
- completionHandler = Some(handler)
- if (finalValue != None) {
- // We already have a final value, so let's call the handler
- handler(finalValue.get)
- }
- return this
- }
-
- /**
- * Set a handler to be called if this PartialResult's job fails. Only one failure handler
- * is supported per PartialResult.
- */
- def onFail(handler: Exception => Unit): Unit = synchronized {
- if (failureHandler != None) {
- throw new UnsupportedOperationException("onFail cannot be called twice")
- }
- failureHandler = Some(handler)
- if (failure != None) {
- // We already have a failure, so let's call the handler
- handler(failure.get)
- }
- }
-
- private[spark] def setFinalValue(value: R): Unit = synchronized {
- if (finalValue != None) {
- throw new UnsupportedOperationException("setFinalValue called twice on a PartialResult")
- }
- finalValue = Some(value)
- // Call the completion handler if it was set
- completionHandler.foreach(h => h(value))
- // Notify any threads that may be calling getFinalValue()
- this.notifyAll()
- }
-
- private[spark] def setFailure(exception: Exception): Unit = synchronized {
- if (failure != None) {
- throw new UnsupportedOperationException("setFailure called twice on a PartialResult")
- }
- failure = Some(exception)
- // Call the failure handler if it was set
- failureHandler.foreach(h => h(exception))
- // Notify any threads that may be calling getFinalValue()
- this.notifyAll()
- }
-
- override def toString: String = synchronized {
- finalValue match {
- case Some(value) => "(final: " + value + ")"
- case None => "(partial: " + initialValue + ")"
- }
- }
-}
diff --git a/core/src/main/scala/spark/partial/StudentTCacher.scala b/core/src/main/scala/spark/partial/StudentTCacher.scala
deleted file mode 100644
index 6263ee3518..0000000000
--- a/core/src/main/scala/spark/partial/StudentTCacher.scala
+++ /dev/null
@@ -1,26 +0,0 @@
-package spark.partial
-
-import cern.jet.stat.Probability
-
-/**
- * A utility class for caching Student's T distribution values for a given confidence level
- * and various sample sizes. This is used by the MeanEvaluator to efficiently calculate
- * confidence intervals for many keys.
- */
-class StudentTCacher(confidence: Double) {
- val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation
- val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2)
- val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0)
-
- def get(sampleSize: Long): Double = {
- if (sampleSize >= NORMAL_APPROX_SAMPLE_SIZE) {
- normalApprox
- } else {
- val size = sampleSize.toInt
- if (cache(size) < 0) {
- cache(size) = Probability.studentTInverse(1 - confidence, size - 1)
- }
- cache(size)
- }
- }
-}
diff --git a/core/src/main/scala/spark/partial/SumEvaluator.scala b/core/src/main/scala/spark/partial/SumEvaluator.scala
deleted file mode 100644
index 0357a6bff8..0000000000
--- a/core/src/main/scala/spark/partial/SumEvaluator.scala
+++ /dev/null
@@ -1,51 +0,0 @@
-package spark.partial
-
-import cern.jet.stat.Probability
-
-import spark.util.StatCounter
-
-/**
- * An ApproximateEvaluator for sums. It estimates the mean and the cont and multiplies them
- * together, then uses the formula for the variance of two independent random variables to get
- * a variance for the result and compute a confidence interval.
- */
-class SumEvaluator(totalOutputs: Int, confidence: Double)
- extends ApproximateEvaluator[StatCounter, BoundedDouble] {
-
- var outputsMerged = 0
- var counter = new StatCounter
-
- override def merge(outputId: Int, taskResult: StatCounter) {
- outputsMerged += 1
- counter.merge(taskResult)
- }
-
- override def currentResult(): BoundedDouble = {
- if (outputsMerged == totalOutputs) {
- new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum)
- } else if (outputsMerged == 0) {
- new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
- } else {
- val p = outputsMerged.toDouble / totalOutputs
- val meanEstimate = counter.mean
- val meanVar = counter.sampleVariance / counter.count
- val countEstimate = (counter.count + 1 - p) / p
- val countVar = (counter.count + 1) * (1 - p) / (p * p)
- val sumEstimate = meanEstimate * countEstimate
- val sumVar = (meanEstimate * meanEstimate * countVar) +
- (countEstimate * countEstimate * meanVar) +
- (meanVar * countVar)
- val sumStdev = math.sqrt(sumVar)
- val confFactor = {
- if (counter.count > 100) {
- Probability.normalInverse(1 - (1 - confidence) / 2)
- } else {
- Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
- }
- }
- val low = sumEstimate - confFactor * sumStdev
- val high = sumEstimate + confFactor * sumStdev
- new BoundedDouble(sumEstimate, confidence, low, high)
- }
- }
-}
diff --git a/core/src/main/scala/spark/scheduler/ActiveJob.scala b/core/src/main/scala/spark/scheduler/ActiveJob.scala
deleted file mode 100644
index 0ecff9ce77..0000000000
--- a/core/src/main/scala/spark/scheduler/ActiveJob.scala
+++ /dev/null
@@ -1,18 +0,0 @@
-package spark.scheduler
-
-import spark.TaskContext
-
-/**
- * Tracks information about an active job in the DAGScheduler.
- */
-class ActiveJob(
- val runId: Int,
- val finalStage: Stage,
- val func: (TaskContext, Iterator[_]) => _,
- val partitions: Array[Int],
- val listener: JobListener) {
-
- val numPartitions = partitions.length
- val finished = Array.fill[Boolean](numPartitions)(false)
- var numFinished = 0
-}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
deleted file mode 100644
index f9d53d3b5d..0000000000
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ /dev/null
@@ -1,535 +0,0 @@
-package spark.scheduler
-
-import java.net.URI
-import java.util.concurrent.atomic.AtomicInteger
-import java.util.concurrent.Future
-import java.util.concurrent.LinkedBlockingQueue
-import java.util.concurrent.TimeUnit
-
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map}
-
-import spark._
-import spark.partial.ApproximateActionListener
-import spark.partial.ApproximateEvaluator
-import spark.partial.PartialResult
-import spark.storage.BlockManagerMaster
-import spark.storage.BlockManagerId
-
-/**
- * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
- * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
- * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
- * and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
- */
-class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
- taskSched.setListener(this)
-
- // Called by TaskScheduler to report task completions or failures.
- override def taskEnded(
- task: Task[_],
- reason: TaskEndReason,
- result: Any,
- accumUpdates: Map[Long, Any]) {
- eventQueue.put(CompletionEvent(task, reason, result, accumUpdates))
- }
-
- // Called by TaskScheduler when a host fails.
- override def hostLost(host: String) {
- eventQueue.put(HostLost(host))
- }
-
- // The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
- // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
- // as more failure events come in
- val RESUBMIT_TIMEOUT = 50L
-
- // The time, in millis, to wake up between polls of the completion queue in order to potentially
- // resubmit failed stages
- val POLL_TIMEOUT = 10L
-
- private val lock = new Object // Used for access to the entire DAGScheduler
-
- private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
-
- val nextRunId = new AtomicInteger(0)
-
- val nextStageId = new AtomicInteger(0)
-
- val idToStage = new HashMap[Int, Stage]
-
- val shuffleToMapStage = new HashMap[Int, Stage]
-
- var cacheLocs = new HashMap[Int, Array[List[String]]]
-
- val env = SparkEnv.get
- val cacheTracker = env.cacheTracker
- val mapOutputTracker = env.mapOutputTracker
-
- val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back;
- // that's not going to be a realistic assumption in general
-
- val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
- val running = new HashSet[Stage] // Stages we are running right now
- val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
- val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage
- var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits
-
- val activeJobs = new HashSet[ActiveJob]
- val resultStageToJob = new HashMap[Stage, ActiveJob]
-
- // Start a thread to run the DAGScheduler event loop
- new Thread("DAGScheduler") {
- setDaemon(true)
- override def run() {
- DAGScheduler.this.run()
- }
- }.start()
-
- def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
- cacheLocs(rdd.id)
- }
-
- def updateCacheLocs() {
- cacheLocs = cacheTracker.getLocationsSnapshot()
- }
-
- /**
- * Get or create a shuffle map stage for the given shuffle dependency's map side.
- * The priority value passed in will be used if the stage doesn't already exist with
- * a lower priority (we assume that priorities always increase across jobs for now).
- */
- def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_,_], priority: Int): Stage = {
- shuffleToMapStage.get(shuffleDep.shuffleId) match {
- case Some(stage) => stage
- case None =>
- val stage = newStage(shuffleDep.rdd, Some(shuffleDep), priority)
- shuffleToMapStage(shuffleDep.shuffleId) = stage
- stage
- }
- }
-
- /**
- * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or
- * as a result stage for the final RDD used directly in an action. The stage will also be given
- * the provided priority.
- */
- def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]], priority: Int): Stage = {
- // Kind of ugly: need to register RDDs with the cache and map output tracker here
- // since we can't do it in the RDD constructor because # of splits is unknown
- logInfo("Registering RDD " + rdd.id + ": " + rdd)
- cacheTracker.registerRDD(rdd.id, rdd.splits.size)
- if (shuffleDep != None) {
- mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
- }
- val id = nextStageId.getAndIncrement()
- val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority)
- idToStage(id) = stage
- stage
- }
-
- /**
- * Get or create the list of parent stages for a given RDD. The stages will be assigned the
- * provided priority if they haven't already been created with a lower priority.
- */
- def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
- val parents = new HashSet[Stage]
- val visited = new HashSet[RDD[_]]
- def visit(r: RDD[_]) {
- if (!visited(r)) {
- visited += r
- // Kind of ugly: need to register RDDs with the cache here since
- // we can't do it in its constructor because # of splits is unknown
- logInfo("Registering parent RDD " + r.id + ": " + r)
- cacheTracker.registerRDD(r.id, r.splits.size)
- for (dep <- r.dependencies) {
- dep match {
- case shufDep: ShuffleDependency[_,_,_] =>
- parents += getShuffleMapStage(shufDep, priority)
- case _ =>
- visit(dep.rdd)
- }
- }
- }
- }
- visit(rdd)
- parents.toList
- }
-
- def getMissingParentStages(stage: Stage): List[Stage] = {
- val missing = new HashSet[Stage]
- val visited = new HashSet[RDD[_]]
- def visit(rdd: RDD[_]) {
- if (!visited(rdd)) {
- visited += rdd
- val locs = getCacheLocs(rdd)
- for (p <- 0 until rdd.splits.size) {
- if (locs(p) == Nil) {
- for (dep <- rdd.dependencies) {
- dep match {
- case shufDep: ShuffleDependency[_,_,_] =>
- val mapStage = getShuffleMapStage(shufDep, stage.priority)
- if (!mapStage.isAvailable) {
- missing += mapStage
- }
- case narrowDep: NarrowDependency[_] =>
- visit(narrowDep.rdd)
- }
- }
- }
- }
- }
- }
- visit(stage.rdd)
- missing.toList
- }
-
- def runJob[T, U](
- finalRdd: RDD[T],
- func: (TaskContext, Iterator[T]) => U,
- partitions: Seq[Int],
- allowLocal: Boolean)
- (implicit m: ClassManifest[U]): Array[U] =
- {
- if (partitions.size == 0) {
- return new Array[U](0)
- }
- val waiter = new JobWaiter(partitions.size)
- val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, waiter))
- waiter.getResult() match {
- case JobSucceeded(results: Seq[_]) =>
- return results.asInstanceOf[Seq[U]].toArray
- case JobFailed(exception: Exception) =>
- throw exception
- }
- }
-
- def runApproximateJob[T, U, R](
- rdd: RDD[T],
- func: (TaskContext, Iterator[T]) => U,
- evaluator: ApproximateEvaluator[U, R],
- timeout: Long
- ): PartialResult[R] =
- {
- val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
- val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- val partitions = (0 until rdd.splits.size).toArray
- eventQueue.put(JobSubmitted(rdd, func2, partitions, false, listener))
- return listener.getResult() // Will throw an exception if the job fails
- }
-
- /**
- * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
- * events and responds by launching tasks. This runs in a dedicated thread and receives events
- * via the eventQueue.
- */
- def run() = {
- SparkEnv.set(env)
-
- while (true) {
- val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
- val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
- if (event != null) {
- logDebug("Got event of type " + event.getClass.getName)
- }
-
- event match {
- case JobSubmitted(finalRDD, func, partitions, allowLocal, listener) =>
- val runId = nextRunId.getAndIncrement()
- val finalStage = newStage(finalRDD, None, runId)
- val job = new ActiveJob(runId, finalStage, func, partitions, listener)
- updateCacheLocs()
- logInfo("Got job " + job.runId + " with " + partitions.length + " output partitions")
- logInfo("Final stage: " + finalStage)
- logInfo("Parents of final stage: " + finalStage.parents)
- logInfo("Missing parents: " + getMissingParentStages(finalStage))
- if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
- // Compute very short actions like first() or take() with no parent stages locally.
- runLocally(job)
- } else {
- activeJobs += job
- resultStageToJob(finalStage) = job
- submitStage(finalStage)
- }
-
- case HostLost(host) =>
- handleHostLost(host)
-
- case completion: CompletionEvent =>
- handleTaskCompletion(completion)
-
- case null =>
- // queue.poll() timed out, ignore it
- }
-
- // Periodically resubmit failed stages if some map output fetches have failed and we have
- // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
- // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
- // the same time, so we want to make sure we've identified all the reduce tasks that depend
- // on the failed node.
- if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
- logInfo("Resubmitting failed stages")
- updateCacheLocs()
- val failed2 = failed.toArray
- failed.clear()
- for (stage <- failed2.sortBy(_.priority)) {
- submitStage(stage)
- }
- } else {
- // TODO: We might want to run this less often, when we are sure that something has become
- // runnable that wasn't before.
- logDebug("Checking for newly runnable parent stages")
- logDebug("running: " + running)
- logDebug("waiting: " + waiting)
- logDebug("failed: " + failed)
- val waiting2 = waiting.toArray
- waiting.clear()
- for (stage <- waiting2.sortBy(_.priority)) {
- submitStage(stage)
- }
- }
- }
- }
-
- /**
- * Run a job on an RDD locally, assuming it has only a single partition and no dependencies.
- * We run the operation in a separate thread just in case it takes a bunch of time, so that we
- * don't block the DAGScheduler event loop or other concurrent jobs.
- */
- def runLocally(job: ActiveJob) {
- logInfo("Computing the requested partition locally")
- new Thread("Local computation of job " + job.runId) {
- override def run() {
- try {
- SparkEnv.set(env)
- val rdd = job.finalStage.rdd
- val split = rdd.splits(job.partitions(0))
- val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
- val result = job.func(taskContext, rdd.iterator(split))
- job.listener.taskSucceeded(0, result)
- } catch {
- case e: Exception =>
- job.listener.jobFailed(e)
- }
- }
- }.start()
- }
-
- def submitStage(stage: Stage) {
- logDebug("submitStage(" + stage + ")")
- if (!waiting(stage) && !running(stage) && !failed(stage)) {
- val missing = getMissingParentStages(stage).sortBy(_.id)
- logDebug("missing: " + missing)
- if (missing == Nil) {
- logInfo("Submitting " + stage + ", which has no missing parents")
- submitMissingTasks(stage)
- running += stage
- } else {
- for (parent <- missing) {
- submitStage(parent)
- }
- waiting += stage
- }
- }
- }
-
- def submitMissingTasks(stage: Stage) {
- logDebug("submitMissingTasks(" + stage + ")")
- // Get our pending tasks and remember them in our pendingTasks entry
- val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
- myPending.clear()
- var tasks = ArrayBuffer[Task[_]]()
- if (stage.isShuffleMap) {
- for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
- val locs = getPreferredLocs(stage.rdd, p)
- tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
- }
- } else {
- // This is a final stage; figure out its job's missing partitions
- val job = resultStageToJob(stage)
- for (id <- 0 until job.numPartitions if (!job.finished(id))) {
- val partition = job.partitions(id)
- val locs = getPreferredLocs(stage.rdd, partition)
- tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id)
- }
- }
- if (tasks.size > 0) {
- logInfo("Submitting " + tasks.size + " missing tasks from " + stage)
- myPending ++= tasks
- logDebug("New pending tasks: " + myPending)
- taskSched.submitTasks(
- new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority))
- } else {
- logDebug("Stage " + stage + " is actually done; %b %d %d".format(
- stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
- running -= stage
- }
- }
-
- /**
- * Responds to a task finishing. This is called inside the event loop so it assumes that it can
- * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
- */
- def handleTaskCompletion(event: CompletionEvent) {
- val task = event.task
- val stage = idToStage(task.stageId)
- event.reason match {
- case Success =>
- logInfo("Completed " + task)
- if (event.accumUpdates != null) {
- Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
- }
- pendingTasks(stage) -= task
- task match {
- case rt: ResultTask[_, _] =>
- resultStageToJob.get(stage) match {
- case Some(job) =>
- if (!job.finished(rt.outputId)) {
- job.finished(rt.outputId) = true
- job.numFinished += 1
- job.listener.taskSucceeded(rt.outputId, event.result)
- // If the whole job has finished, remove it
- if (job.numFinished == job.numPartitions) {
- activeJobs -= job
- resultStageToJob -= stage
- running -= stage
- }
- }
- case None =>
- logInfo("Ignoring result from " + rt + " because its job has finished")
- }
-
- case smt: ShuffleMapTask =>
- val stage = idToStage(smt.stageId)
- val bmAddress = event.result.asInstanceOf[BlockManagerId]
- val host = bmAddress.ip
- logInfo("ShuffleMapTask finished with host " + host)
- if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos
- stage.addOutputLoc(smt.partition, bmAddress)
- }
- if (running.contains(stage) && pendingTasks(stage).isEmpty) {
- logInfo(stage + " finished; looking for newly runnable stages")
- running -= stage
- logInfo("running: " + running)
- logInfo("waiting: " + waiting)
- logInfo("failed: " + failed)
- if (stage.shuffleDep != None) {
- mapOutputTracker.registerMapOutputs(
- stage.shuffleDep.get.shuffleId,
- stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray)
- }
- updateCacheLocs()
- if (stage.outputLocs.count(_ == Nil) != 0) {
- // Some tasks had failed; let's resubmit this stage
- // TODO: Lower-level scheduler should also deal with this
- logInfo("Resubmitting " + stage + " because some of its tasks had failed: " +
- stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", "))
- submitStage(stage)
- } else {
- val newlyRunnable = new ArrayBuffer[Stage]
- for (stage <- waiting) {
- logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage))
- }
- for (stage <- waiting if getMissingParentStages(stage) == Nil) {
- newlyRunnable += stage
- }
- waiting --= newlyRunnable
- running ++= newlyRunnable
- for (stage <- newlyRunnable.sortBy(_.id)) {
- submitMissingTasks(stage)
- }
- }
- }
- }
-
- case Resubmitted =>
- logInfo("Resubmitted " + task + ", so marking it as still running")
- pendingTasks(stage) += task
-
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
- // Mark the stage that the reducer was in as unrunnable
- val failedStage = idToStage(task.stageId)
- running -= failedStage
- failed += failedStage
- // TODO: Cancel running tasks in the stage
- logInfo("Marking " + failedStage + " for resubmision due to a fetch failure")
- // Mark the map whose fetch failed as broken in the map stage
- val mapStage = shuffleToMapStage(shuffleId)
- mapStage.removeOutputLoc(mapId, bmAddress)
- mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
- logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission")
- failed += mapStage
- // Remember that a fetch failed now; this is used to resubmit the broken
- // stages later, after a small wait (to give other tasks the chance to fail)
- lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock
- // TODO: mark the host as failed only if there were lots of fetch failures on it
- if (bmAddress != null) {
- handleHostLost(bmAddress.ip)
- }
-
- case _ =>
- // Non-fetch failure -- probably a bug in the job, so bail out
- // TODO: Cancel all tasks that are still running
- resultStageToJob.get(stage) match {
- case Some(job) =>
- val error = new SparkException("Task failed: " + task + ", reason: " + event.reason)
- job.listener.jobFailed(error)
- activeJobs -= job
- resultStageToJob -= stage
- case None =>
- logInfo("Ignoring result from " + task + " because its job has finished")
- }
- }
- }
-
- /**
- * Responds to a host being lost. This is called inside the event loop so it assumes that it can
- * modify the scheduler's internal state. Use hostLost() to post a host lost event from outside.
- */
- def handleHostLost(host: String) {
- if (!deadHosts.contains(host)) {
- logInfo("Host lost: " + host)
- deadHosts += host
- BlockManagerMaster.notifyADeadHost(host)
- // TODO: This will be really slow if we keep accumulating shuffle map stages
- for ((shuffleId, stage) <- shuffleToMapStage) {
- stage.removeOutputsOnHost(host)
- val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
- mapOutputTracker.registerMapOutputs(shuffleId, locs, true)
- }
- cacheTracker.cacheLost(host)
- updateCacheLocs()
- }
- }
-
- def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
- // If the partition is cached, return the cache locations
- val cached = getCacheLocs(rdd)(partition)
- if (cached != Nil) {
- return cached
- }
- // If the RDD has some placement preferences (as is the case for input RDDs), get those
- val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList
- if (rddPrefs != Nil) {
- return rddPrefs
- }
- // If the RDD has narrow dependencies, pick the first partition of the first narrow dep
- // that has any placement preferences. Ideally we would choose based on transfer sizes,
- // but this will do for now.
- rdd.dependencies.foreach(_ match {
- case n: NarrowDependency[_] =>
- for (inPart <- n.getParents(partition)) {
- val locs = getPreferredLocs(n.rdd, inPart)
- if (locs != Nil)
- return locs;
- }
- case _ =>
- })
- return Nil
- }
-
- def stop() {
- // TODO: Put a stop event on our queue and break the event loop
- taskSched.stop()
- }
-}
diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
deleted file mode 100644
index c10abc9202..0000000000
--- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
+++ /dev/null
@@ -1,30 +0,0 @@
-package spark.scheduler
-
-import scala.collection.mutable.Map
-
-import spark._
-
-/**
- * Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue
- * architecture where any thread can post an event (e.g. a task finishing or a new job being
- * submitted) but there is a single "logic" thread that reads these events and takes decisions.
- * This greatly simplifies synchronization.
- */
-sealed trait DAGSchedulerEvent
-
-case class JobSubmitted(
- finalRDD: RDD[_],
- func: (TaskContext, Iterator[_]) => _,
- partitions: Array[Int],
- allowLocal: Boolean,
- listener: JobListener)
- extends DAGSchedulerEvent
-
-case class CompletionEvent(
- task: Task[_],
- reason: TaskEndReason,
- result: Any,
- accumUpdates: Map[Long, Any])
- extends DAGSchedulerEvent
-
-case class HostLost(host: String) extends DAGSchedulerEvent
diff --git a/core/src/main/scala/spark/scheduler/JobListener.scala b/core/src/main/scala/spark/scheduler/JobListener.scala
deleted file mode 100644
index d4dd536a7d..0000000000
--- a/core/src/main/scala/spark/scheduler/JobListener.scala
+++ /dev/null
@@ -1,11 +0,0 @@
-package spark.scheduler
-
-/**
- * Interface used to listen for job completion or failure events after submitting a job to the
- * DAGScheduler. The listener is notified each time a task succeeds, as well as if the whole
- * job fails (and no further taskSucceeded events will happen).
- */
-trait JobListener {
- def taskSucceeded(index: Int, result: Any)
- def jobFailed(exception: Exception)
-}
diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala
deleted file mode 100644
index 62b458eccb..0000000000
--- a/core/src/main/scala/spark/scheduler/JobResult.scala
+++ /dev/null
@@ -1,9 +0,0 @@
-package spark.scheduler
-
-/**
- * A result of a job in the DAGScheduler.
- */
-sealed trait JobResult
-
-case class JobSucceeded(results: Seq[_]) extends JobResult
-case class JobFailed(exception: Exception) extends JobResult
diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala
deleted file mode 100644
index be8ec9bd7b..0000000000
--- a/core/src/main/scala/spark/scheduler/JobWaiter.scala
+++ /dev/null
@@ -1,43 +0,0 @@
-package spark.scheduler
-
-import scala.collection.mutable.ArrayBuffer
-
-/**
- * An object that waits for a DAGScheduler job to complete.
- */
-class JobWaiter(totalTasks: Int) extends JobListener {
- private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null)
- private var finishedTasks = 0
-
- private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
- private var jobResult: JobResult = null // If the job is finished, this will be its result
-
- override def taskSucceeded(index: Int, result: Any) = synchronized {
- if (jobFinished) {
- throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
- }
- taskResults(index) = result
- finishedTasks += 1
- if (finishedTasks == totalTasks) {
- jobFinished = true
- jobResult = JobSucceeded(taskResults)
- this.notifyAll()
- }
- }
-
- override def jobFailed(exception: Exception) = synchronized {
- if (jobFinished) {
- throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter")
- }
- jobFinished = true
- jobResult = JobFailed(exception)
- this.notifyAll()
- }
-
- def getResult(): JobResult = synchronized {
- while (!jobFinished) {
- this.wait()
- }
- return jobResult
- }
-}
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
deleted file mode 100644
index 79cca0f294..0000000000
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ /dev/null
@@ -1,142 +0,0 @@
-package spark.scheduler
-
-import java.io._
-import java.util.HashMap
-import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-
-import scala.collection.mutable.ArrayBuffer
-
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
-import com.ning.compress.lzf.LZFInputStream
-import com.ning.compress.lzf.LZFOutputStream
-
-import spark._
-import spark.storage._
-
-object ShuffleMapTask {
- val serializedInfoCache = new HashMap[Int, Array[Byte]]
- val deserializedInfoCache = new HashMap[Int, (RDD[_], ShuffleDependency[_,_,_])]
-
- def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = {
- synchronized {
- val old = serializedInfoCache.get(stageId)
- if (old != null) {
- return old
- } else {
- val out = new ByteArrayOutputStream
- val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
- objOut.writeObject(rdd)
- objOut.writeObject(dep)
- objOut.close()
- val bytes = out.toByteArray
- serializedInfoCache.put(stageId, bytes)
- return bytes
- }
- }
- }
-
- def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = {
- synchronized {
- val old = deserializedInfoCache.get(stageId)
- if (old != null) {
- return old
- } else {
- val loader = currentThread.getContextClassLoader
- val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
- val objIn = new ObjectInputStream(in) {
- override def resolveClass(desc: ObjectStreamClass) =
- Class.forName(desc.getName, false, loader)
- }
- val rdd = objIn.readObject().asInstanceOf[RDD[_]]
- val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]]
- val tuple = (rdd, dep)
- deserializedInfoCache.put(stageId, tuple)
- return tuple
- }
- }
- }
-
- def clearCache() {
- synchronized {
- serializedInfoCache.clear()
- deserializedInfoCache.clear()
- }
- }
-}
-
-class ShuffleMapTask(
- stageId: Int,
- var rdd: RDD[_],
- var dep: ShuffleDependency[_,_,_],
- var partition: Int,
- @transient var locs: Seq[String])
- extends Task[BlockManagerId](stageId)
- with Externalizable
- with Logging {
-
- def this() = this(0, null, null, 0, null)
-
- var split = if (rdd == null) {
- null
- } else {
- rdd.splits(partition)
- }
-
- override def writeExternal(out: ObjectOutput) {
- out.writeInt(stageId)
- val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
- out.writeInt(bytes.length)
- out.write(bytes)
- out.writeInt(partition)
- out.writeObject(split)
- }
-
- override def readExternal(in: ObjectInput) {
- val stageId = in.readInt()
- val numBytes = in.readInt()
- val bytes = new Array[Byte](numBytes)
- in.readFully(bytes)
- val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
- rdd = rdd_
- dep = dep_
- partition = in.readInt()
- split = in.readObject().asInstanceOf[Split]
- }
-
- override def run(attemptId: Int): BlockManagerId = {
- val numOutputSplits = dep.partitioner.numPartitions
- val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
- val partitioner = dep.partitioner.asInstanceOf[Partitioner]
- val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
- for (elem <- rdd.iterator(split)) {
- val (k, v) = elem.asInstanceOf[(Any, Any)]
- var bucketId = partitioner.getPartition(k)
- val bucket = buckets(bucketId)
- var existing = bucket.get(k)
- if (existing == null) {
- bucket.put(k, aggregator.createCombiner(v))
- } else {
- bucket.put(k, aggregator.mergeValue(existing, v))
- }
- }
- val ser = SparkEnv.get.serializer.newInstance()
- val blockManager = SparkEnv.get.blockManager
- for (i <- 0 until numOutputSplits) {
- val blockId = "shuffleid_" + dep.shuffleId + "_" + partition + "_" + i
- val arr = new ArrayBuffer[Any]
- val iter = buckets(i).entrySet().iterator()
- while (iter.hasNext()) {
- val entry = iter.next()
- arr += ((entry.getKey(), entry.getValue()))
- }
- // TODO: This should probably be DISK_ONLY
- blockManager.put(blockId, arr.iterator, StorageLevel.MEMORY_ONLY, false)
- }
- return SparkEnv.get.blockManager.blockManagerId
- }
-
- override def preferredLocations: Seq[String] = locs
-
- override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
-}
diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala
deleted file mode 100644
index cd660c9085..0000000000
--- a/core/src/main/scala/spark/scheduler/Stage.scala
+++ /dev/null
@@ -1,86 +0,0 @@
-package spark.scheduler
-
-import java.net.URI
-
-import spark._
-import spark.storage.BlockManagerId
-
-/**
- * A stage is a set of independent tasks all computing the same function that need to run as part
- * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run
- * by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the
- * DAGScheduler runs these stages in topological order.
- *
- * Each Stage can either be a shuffle map stage, in which case its tasks' results are input for
- * another stage, or a result stage, in which case its tasks directly compute the action that
- * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes
- * that each output partition is on.
- *
- * Each Stage also has a priority, which is (by default) based on the job it was submitted in.
- * This allows Stages from earlier jobs to be computed first or recovered faster on failure.
- */
-class Stage(
- val id: Int,
- val rdd: RDD[_],
- val shuffleDep: Option[ShuffleDependency[_,_,_]], // Output shuffle if stage is a map stage
- val parents: List[Stage],
- val priority: Int)
- extends Logging {
-
- val isShuffleMap = shuffleDep != None
- val numPartitions = rdd.splits.size
- val outputLocs = Array.fill[List[BlockManagerId]](numPartitions)(Nil)
- var numAvailableOutputs = 0
-
- private var nextAttemptId = 0
-
- def isAvailable: Boolean = {
- if (/*parents.size == 0 &&*/ !isShuffleMap) {
- true
- } else {
- numAvailableOutputs == numPartitions
- }
- }
-
- def addOutputLoc(partition: Int, bmAddress: BlockManagerId) {
- val prevList = outputLocs(partition)
- outputLocs(partition) = bmAddress :: prevList
- if (prevList == Nil)
- numAvailableOutputs += 1
- }
-
- def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) {
- val prevList = outputLocs(partition)
- val newList = prevList.filterNot(_ == bmAddress)
- outputLocs(partition) = newList
- if (prevList != Nil && newList == Nil) {
- numAvailableOutputs -= 1
- }
- }
-
- def removeOutputsOnHost(host: String) {
- var becameUnavailable = false
- for (partition <- 0 until numPartitions) {
- val prevList = outputLocs(partition)
- val newList = prevList.filterNot(_.ip == host)
- outputLocs(partition) = newList
- if (prevList != Nil && newList == Nil) {
- becameUnavailable = true
- numAvailableOutputs -= 1
- }
- }
- if (becameUnavailable) {
- logInfo("%s is now unavailable on %s (%d/%d, %s)".format(this, host, numAvailableOutputs, numPartitions, isAvailable))
- }
- }
-
- def newAttemptId(): Int = {
- val id = nextAttemptId
- nextAttemptId += 1
- return id
- }
-
- override def toString = "Stage " + id // + ": [RDD = " + rdd.id + ", isShuffle = " + isShuffleMap + "]"
-
- override def hashCode(): Int = id
-}
diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala
deleted file mode 100644
index 42325956ba..0000000000
--- a/core/src/main/scala/spark/scheduler/Task.scala
+++ /dev/null
@@ -1,11 +0,0 @@
-package spark.scheduler
-
-/**
- * A task to execute on a worker node.
- */
-abstract class Task[T](val stageId: Int) extends Serializable {
- def run(attemptId: Int): T
- def preferredLocations: Seq[String] = Nil
-
- var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler.
-}
diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala
deleted file mode 100644
index 868ddb237c..0000000000
--- a/core/src/main/scala/spark/scheduler/TaskResult.scala
+++ /dev/null
@@ -1,34 +0,0 @@
-package spark.scheduler
-
-import java.io._
-
-import scala.collection.mutable.Map
-
-// Task result. Also contains updates to accumulator variables.
-// TODO: Use of distributed cache to return result is a hack to get around
-// what seems to be a bug with messages over 60KB in libprocess; fix it
-class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Externalizable {
- def this() = this(null.asInstanceOf[T], null)
-
- override def writeExternal(out: ObjectOutput) {
- out.writeObject(value)
- out.writeInt(accumUpdates.size)
- for ((key, value) <- accumUpdates) {
- out.writeLong(key)
- out.writeObject(value)
- }
- }
-
- override def readExternal(in: ObjectInput) {
- value = in.readObject().asInstanceOf[T]
- val numUpdates = in.readInt
- if (numUpdates == 0) {
- accumUpdates = null
- } else {
- accumUpdates = Map()
- for (i <- 0 until numUpdates) {
- accumUpdates(in.readLong()) = in.readObject()
- }
- }
- }
-}
diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
deleted file mode 100644
index cb7c375d97..0000000000
--- a/core/src/main/scala/spark/scheduler/TaskScheduler.scala
+++ /dev/null
@@ -1,27 +0,0 @@
-package spark.scheduler
-
-/**
- * Low-level task scheduler interface, implemented by both MesosScheduler and LocalScheduler.
- * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
- * and are responsible for sending the tasks to the cluster, running them, retrying if there
- * are failures, and mitigating stragglers. They return events to the DAGScheduler through
- * the TaskSchedulerListener interface.
- */
-trait TaskScheduler {
- def start(): Unit
-
- // Wait for registration with Mesos.
- def waitForRegister(): Unit
-
- // Disconnect from the cluster.
- def stop(): Unit
-
- // Submit a sequence of tasks to run.
- def submitTasks(taskSet: TaskSet): Unit
-
- // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
- def setListener(listener: TaskSchedulerListener): Unit
-
- // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
- def defaultParallelism(): Int
-}
diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
deleted file mode 100644
index a647eec9e4..0000000000
--- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
+++ /dev/null
@@ -1,16 +0,0 @@
-package spark.scheduler
-
-import scala.collection.mutable.Map
-
-import spark.TaskEndReason
-
-/**
- * Interface for getting events back from the TaskScheduler.
- */
-trait TaskSchedulerListener {
- // A task has finished or failed.
- def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit
-
- // A node was lost from the cluster.
- def hostLost(host: String): Unit
-}
diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala
deleted file mode 100644
index 6f29dd2e9d..0000000000
--- a/core/src/main/scala/spark/scheduler/TaskSet.scala
+++ /dev/null
@@ -1,9 +0,0 @@
-package spark.scheduler
-
-/**
- * A set of tasks submitted together to the low-level TaskScheduler, usually representing
- * missing partitions of a particular stage.
- */
-class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) {
- val id: String = stageId + "." + attempt
-}
diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala
deleted file mode 100644
index 8182901ce3..0000000000
--- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala
+++ /dev/null
@@ -1,364 +0,0 @@
-package spark.scheduler.mesos
-
-import java.io.{File, FileInputStream, FileOutputStream}
-import java.util.{ArrayList => JArrayList}
-import java.util.{List => JList}
-import java.util.{HashMap => JHashMap}
-import java.util.concurrent._
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.collection.mutable.Map
-import scala.collection.mutable.PriorityQueue
-import scala.collection.JavaConversions._
-import scala.math.Ordering
-
-import akka.actor._
-import akka.actor.Actor
-import akka.actor.Actor._
-import akka.actor.Channel
-import akka.serialization.RemoteActorSerialization._
-
-import com.google.protobuf.ByteString
-
-import org.apache.mesos.{Scheduler => MScheduler}
-import org.apache.mesos._
-import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
-
-import spark._
-import spark.scheduler._
-
-sealed trait CoarseMesosSchedulerMessage
-case class RegisterSlave(slaveId: String, host: String, port: Int) extends CoarseMesosSchedulerMessage
-case class StatusUpdate(slaveId: String, status: TaskStatus) extends CoarseMesosSchedulerMessage
-case class LaunchTask(slaveId: String, task: MTaskInfo) extends CoarseMesosSchedulerMessage
-case class ReviveOffers() extends CoarseMesosSchedulerMessage
-
-case class FakeOffer(slaveId: String, host: String, cores: Int)
-
-/**
- * Mesos scheduler that uses coarse-grained tasks and does its own fine-grained scheduling inside
- * them using Akka actors for messaging. Clients should first call start(), then submit task sets
- * through the runTasks method.
- *
- * TODO: This is a pretty big hack for now.
- */
-class CoarseMesosScheduler(
- sc: SparkContext,
- master: String,
- frameworkName: String)
- extends MesosScheduler(sc, master, frameworkName) {
-
- val CORES_PER_SLAVE = System.getProperty("spark.coarseMesosScheduler.coresPerSlave", "4").toInt
-
- class MasterActor extends Actor {
- val slaveActor = new HashMap[String, ActorRef]
- val slaveHost = new HashMap[String, String]
- val freeCores = new HashMap[String, Int]
-
- def receive = {
- case RegisterSlave(slaveId, host, port) =>
- slaveActor(slaveId) = remote.actorFor("WorkerActor", host, port)
- logInfo("Slave actor: " + slaveActor(slaveId))
- slaveHost(slaveId) = host
- freeCores(slaveId) = CORES_PER_SLAVE
- makeFakeOffers()
-
- case StatusUpdate(slaveId, status) =>
- fakeStatusUpdate(status)
- if (isFinished(status.getState)) {
- freeCores(slaveId) += 1
- makeFakeOffers(slaveId)
- }
-
- case LaunchTask(slaveId, task) =>
- freeCores(slaveId) -= 1
- slaveActor(slaveId) ! LaunchTask(slaveId, task)
-
- case ReviveOffers() =>
- logInfo("Reviving offers")
- makeFakeOffers()
- }
-
- // Make fake resource offers for all slaves
- def makeFakeOffers() {
- fakeResourceOffers(slaveHost.toSeq.map{case (id, host) => FakeOffer(id, host, freeCores(id))})
- }
-
- // Make fake resource offers for all slaves
- def makeFakeOffers(slaveId: String) {
- fakeResourceOffers(Seq(FakeOffer(slaveId, slaveHost(slaveId), freeCores(slaveId))))
- }
- }
-
- val masterActor: ActorRef = actorOf(new MasterActor)
- remote.register("MasterActor", masterActor)
- masterActor.start()
-
- val taskIdsOnSlave = new HashMap[String, HashSet[String]]
-
- /**
- * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets
- * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that
- * tasks are balanced across the cluster.
- */
- override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
- synchronized {
- val tasks = offers.map(o => new JArrayList[MTaskInfo])
- for (i <- 0 until offers.size) {
- val o = offers.get(i)
- val slaveId = o.getSlaveId.getValue
- if (!slaveIdToHost.contains(slaveId)) {
- slaveIdToHost(slaveId) = o.getHostname
- hostsAlive += o.getHostname
- taskIdsOnSlave(slaveId) = new HashSet[String]
- // Launch an infinite task on the node that will talk to the MasterActor to get fake tasks
- val cpuRes = Resource.newBuilder()
- .setName("cpus")
- .setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(1).build())
- .build()
- val task = new WorkerTask(slaveId, o.getHostname)
- val serializedTask = Utils.serialize(task)
- tasks(i).add(MTaskInfo.newBuilder()
- .setTaskId(newTaskId())
- .setSlaveId(o.getSlaveId)
- .setExecutor(executorInfo)
- .setName("worker task")
- .addResources(cpuRes)
- .setData(ByteString.copyFrom(serializedTask))
- .build())
- }
- }
- val filters = Filters.newBuilder().setRefuseSeconds(10).build()
- for (i <- 0 until offers.size) {
- d.launchTasks(offers(i).getId(), tasks(i), filters)
- }
- }
- }
-
- override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
- val tid = status.getTaskId.getValue
- var taskSetToUpdate: Option[TaskSetManager] = None
- var taskFailed = false
- synchronized {
- try {
- taskIdToTaskSetId.get(tid) match {
- case Some(taskSetId) =>
- if (activeTaskSets.contains(taskSetId)) {
- //activeTaskSets(taskSetId).statusUpdate(status)
- taskSetToUpdate = Some(activeTaskSets(taskSetId))
- }
- if (isFinished(status.getState)) {
- taskIdToTaskSetId.remove(tid)
- if (taskSetTaskIds.contains(taskSetId)) {
- taskSetTaskIds(taskSetId) -= tid
- }
- val slaveId = taskIdToSlaveId(tid)
- taskIdToSlaveId -= tid
- taskIdsOnSlave(slaveId) -= tid
- }
- if (status.getState == TaskState.TASK_FAILED) {
- taskFailed = true
- }
- case None =>
- logInfo("Ignoring update from TID " + tid + " because its task set is gone")
- }
- } catch {
- case e: Exception => logError("Exception in statusUpdate", e)
- }
- }
- // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock
- if (taskSetToUpdate != None) {
- taskSetToUpdate.get.statusUpdate(status)
- }
- if (taskFailed) {
- // Revive offers if a task had failed for some reason other than host lost
- reviveOffers()
- }
- }
-
- override def slaveLost(d: SchedulerDriver, s: SlaveID) {
- logInfo("Slave lost: " + s.getValue)
- var failedHost: Option[String] = None
- var lostTids: Option[HashSet[String]] = None
- synchronized {
- val slaveId = s.getValue
- val host = slaveIdToHost(slaveId)
- if (hostsAlive.contains(host)) {
- slaveIdsWithExecutors -= slaveId
- hostsAlive -= host
- failedHost = Some(host)
- lostTids = Some(taskIdsOnSlave(slaveId))
- logInfo("failedHost: " + host)
- logInfo("lostTids: " + lostTids)
- taskIdsOnSlave -= slaveId
- activeTaskSetsQueue.foreach(_.hostLost(host))
- }
- }
- if (failedHost != None) {
- // Report all the tasks on the failed host as lost, without holding a lock on this
- for (tid <- lostTids.get; taskSetId <- taskIdToTaskSetId.get(tid)) {
- // TODO: Maybe call our statusUpdate() instead to clean our internal data structures
- activeTaskSets(taskSetId).statusUpdate(TaskStatus.newBuilder()
- .setTaskId(TaskID.newBuilder().setValue(tid).build())
- .setState(TaskState.TASK_LOST)
- .build())
- }
- // Also report the loss to the DAGScheduler
- listener.hostLost(failedHost.get)
- reviveOffers();
- }
- }
-
- override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
-
- // Check for speculatable tasks in all our active jobs.
- override def checkSpeculatableTasks() {
- var shouldRevive = false
- synchronized {
- for (ts <- activeTaskSetsQueue) {
- shouldRevive |= ts.checkSpeculatableTasks()
- }
- }
- if (shouldRevive) {
- reviveOffers()
- }
- }
-
-
- val lock2 = new Object
- var firstWait = true
-
- override def waitForRegister() {
- lock2.synchronized {
- if (firstWait) {
- super.waitForRegister()
- Thread.sleep(5000)
- firstWait = false
- }
- }
- }
-
- def fakeStatusUpdate(status: TaskStatus) {
- statusUpdate(driver, status)
- }
-
- def fakeResourceOffers(offers: Seq[FakeOffer]) {
- logDebug("fakeResourceOffers: " + offers)
- val availableCpus = offers.map(_.cores.toDouble).toArray
- var launchedTask = false
- for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
- do {
- launchedTask = false
- for (i <- 0 until offers.size if hostsAlive.contains(offers(i).host)) {
- manager.slaveOffer(offers(i).slaveId, offers(i).host, availableCpus(i)) match {
- case Some(task) =>
- val tid = task.getTaskId.getValue
- val sid = offers(i).slaveId
- taskIdToTaskSetId(tid) = manager.taskSet.id
- taskSetTaskIds(manager.taskSet.id) += tid
- taskIdToSlaveId(tid) = sid
- taskIdsOnSlave(sid) += tid
- slaveIdsWithExecutors += sid
- availableCpus(i) -= getResource(task.getResourcesList(), "cpus")
- launchedTask = true
- masterActor ! LaunchTask(sid, task)
-
- case None => {}
- }
- }
- } while (launchedTask)
- }
- }
-
- override def reviveOffers() {
- masterActor ! ReviveOffers()
- }
-}
-
-class WorkerTask(slaveId: String, host: String) extends Task[Unit](-1) {
- generation = 0
-
- def run(id: Int): Unit = {
- val actor = actorOf(new WorkerActor(slaveId, host))
- if (!remote.isRunning) {
- remote.start(Utils.localIpAddress, 7078)
- }
- remote.register("WorkerActor", actor)
- actor.start()
- while (true) {
- Thread.sleep(10000)
- }
- }
-}
-
-class WorkerActor(slaveId: String, host: String) extends Actor with Logging {
- val env = SparkEnv.get
- val classLoader = currentThread.getContextClassLoader
- val threadPool = new ThreadPoolExecutor(
- 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
-
- val masterIp: String = System.getProperty("spark.master.host", "localhost")
- val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt
- val masterActor = remote.actorFor("MasterActor", masterIp, masterPort)
-
- class TaskRunner(desc: MTaskInfo)
- extends Runnable {
- override def run() = {
- val tid = desc.getTaskId.getValue
- logInfo("Running task ID " + tid)
- try {
- SparkEnv.set(env)
- Thread.currentThread.setContextClassLoader(classLoader)
- Accumulators.clear
- val task = Utils.deserialize[Task[Any]](desc.getData.toByteArray, classLoader)
- env.mapOutputTracker.updateGeneration(task.generation)
- val value = task.run(tid.toInt)
- val accumUpdates = Accumulators.values
- val result = new TaskResult(value, accumUpdates)
- masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
- .setTaskId(desc.getTaskId)
- .setState(TaskState.TASK_FINISHED)
- .setData(ByteString.copyFrom(Utils.serialize(result)))
- .build())
- logInfo("Finished task ID " + tid)
- } catch {
- case ffe: FetchFailedException => {
- val reason = ffe.toTaskEndReason
- masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
- .setTaskId(desc.getTaskId)
- .setState(TaskState.TASK_FAILED)
- .setData(ByteString.copyFrom(Utils.serialize(reason)))
- .build())
- }
- case t: Throwable => {
- val reason = ExceptionFailure(t)
- masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
- .setTaskId(desc.getTaskId)
- .setState(TaskState.TASK_FAILED)
- .setData(ByteString.copyFrom(Utils.serialize(reason)))
- .build())
-
- // TODO: Should we exit the whole executor here? On the one hand, the failed task may
- // have left some weird state around depending on when the exception was thrown, but on
- // the other hand, maybe we could detect that when future tasks fail and exit then.
- logError("Exception in task ID " + tid, t)
- //System.exit(1)
- }
- }
- }
- }
-
- override def preStart {
- val ref = toRemoteActorRefProtocol(self).toByteArray
- logInfo("Registering with master")
- masterActor ! RegisterSlave(slaveId, host, remote.address.getPort)
- }
-
- override def receive = {
- case LaunchTask(slaveId, task) =>
- threadPool.execute(new TaskRunner(task))
- }
-}
diff --git a/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala b/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala
deleted file mode 100644
index af2f80ea66..0000000000
--- a/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-package spark.scheduler.mesos
-
-/**
- * Information about a running task attempt.
- */
-class TaskInfo(val taskId: String, val index: Int, val launchTime: Long, val host: String) {
- var finishTime: Long = 0
- var failed = false
-
- def markSuccessful(time: Long = System.currentTimeMillis) {
- finishTime = time
- }
-
- def markFailed(time: Long = System.currentTimeMillis) {
- finishTime = time
- failed = true
- }
-
- def finished: Boolean = finishTime != 0
-
- def successful: Boolean = finished && !failed
-
- def duration: Long = {
- if (!finished) {
- throw new UnsupportedOperationException("duration() called on unfinished tasks")
- } else {
- finishTime - launchTime
- }
- }
-
- def timeRunning(currentTime: Long): Long = currentTime - launchTime
-}
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
deleted file mode 100644
index 9e4816f7ce..0000000000
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ /dev/null
@@ -1,588 +0,0 @@
-package spark.storage
-
-import java.io._
-import java.nio._
-import java.nio.channels.FileChannel.MapMode
-import java.util.{HashMap => JHashMap}
-import java.util.LinkedHashMap
-import java.util.UUID
-import java.util.Collections
-
-import scala.actors._
-import scala.actors.Actor._
-import scala.actors.Future
-import scala.actors.Futures.future
-import scala.actors.remote._
-import scala.actors.remote.RemoteActor._
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.JavaConversions._
-
-import it.unimi.dsi.fastutil.io._
-
-import spark.CacheTracker
-import spark.Logging
-import spark.Serializer
-import spark.SizeEstimator
-import spark.SparkEnv
-import spark.SparkException
-import spark.Utils
-import spark.util.ByteBufferInputStream
-import spark.network._
-
-class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
- def this() = this(null, 0)
-
- override def writeExternal(out: ObjectOutput) {
- out.writeUTF(ip)
- out.writeInt(port)
- }
-
- override def readExternal(in: ObjectInput) {
- ip = in.readUTF()
- port = in.readInt()
- }
-
- override def toString = "BlockManagerId(" + ip + ", " + port + ")"
-
- override def hashCode = ip.hashCode * 41 + port
-
- override def equals(that: Any) = that match {
- case id: BlockManagerId => port == id.port && ip == id.ip
- case _ => false
- }
-}
-
-
-case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message)
-
-
-class BlockLocker(numLockers: Int) {
- private val hashLocker = Array.fill(numLockers)(new Object())
-
- def getLock(blockId: String): Object = {
- return hashLocker(Math.abs(blockId.hashCode % numLockers))
- }
-}
-
-
-
-class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging {
-
- case class BlockInfo(level: StorageLevel, tellMaster: Boolean)
-
- private val NUM_LOCKS = 337
- private val locker = new BlockLocker(NUM_LOCKS)
-
- private val blockInfo = Collections.synchronizedMap(new JHashMap[String, BlockInfo])
- private val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
- private val diskStore: BlockStore = new DiskStore(this,
- System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
-
- val connectionManager = new ConnectionManager(0)
-
- val connectionManagerId = connectionManager.id
- val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port)
-
- // TODO: This will be removed after cacheTracker is removed from the code base.
- var cacheTracker: CacheTracker = null
-
- initLogging()
-
- initialize()
-
- /**
- * Construct a BlockManager with a memory limit set based on system properties.
- */
- def this(serializer: Serializer) =
- this(BlockManager.getMaxMemoryFromSystemProperties(), serializer)
-
- /**
- * Initialize the BlockManager. Register to the BlockManagerMaster, and start the
- * BlockManagerWorker actor.
- */
- private def initialize() {
- BlockManagerMaster.mustRegisterBlockManager(
- RegisterBlockManager(blockManagerId, maxMemory, maxMemory))
- BlockManagerWorker.startBlockManagerWorker(this)
- }
-
- /**
- * Get storage level of local block. If no info exists for the block, then returns null.
- */
- def getLevel(blockId: String): StorageLevel = {
- val info = blockInfo.get(blockId)
- if (info != null) info.level else null
- }
-
- /**
- * Change storage level for a local block and tell master is necesary.
- * If new level is invalid, then block info (if it exists) will be silently removed.
- */
- def setLevel(blockId: String, level: StorageLevel, tellMaster: Boolean = true) {
- if (level == null) {
- throw new IllegalArgumentException("Storage level is null")
- }
-
- // If there was earlier info about the block, then use earlier tellMaster
- val oldInfo = blockInfo.get(blockId)
- val newTellMaster = if (oldInfo != null) oldInfo.tellMaster else tellMaster
- if (oldInfo != null && oldInfo.tellMaster != tellMaster) {
- logWarning("Ignoring tellMaster setting as it is different from earlier setting")
- }
-
- // If level is valid, store the block info, else remove the block info
- if (level.isValid) {
- blockInfo.put(blockId, new BlockInfo(level, newTellMaster))
- logDebug("Info for block " + blockId + " updated with new level as " + level)
- } else {
- blockInfo.remove(blockId)
- logDebug("Info for block " + blockId + " removed as new level is null or invalid")
- }
-
- // Tell master if necessary
- if (newTellMaster) {
- logDebug("Told master about block " + blockId)
- notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0))
- } else {
- logDebug("Did not tell master about block " + blockId)
- }
- }
-
- /**
- * Get locations of the block.
- */
- def getLocations(blockId: String): Seq[String] = {
- val startTimeMs = System.currentTimeMillis
- var managers: Array[BlockManagerId] = BlockManagerMaster.mustGetLocations(GetLocations(blockId))
- val locations = managers.map((manager: BlockManagerId) => { manager.ip }).toSeq
- logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs))
- return locations
- }
-
- /**
- * Get locations of an array of blocks.
- */
- def getLocations(blockIds: Array[String]): Array[Seq[String]] = {
- val startTimeMs = System.currentTimeMillis
- val locations = BlockManagerMaster.mustGetLocationsMultipleBlockIds(
- GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray
- logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
- return locations
- }
-
- /**
- * Get block from local block manager.
- */
- def getLocal(blockId: String): Option[Iterator[Any]] = {
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- logDebug("Getting local block " + blockId)
- locker.getLock(blockId).synchronized {
-
- // Check storage level of block
- val level = getLevel(blockId)
- if (level != null) {
- logDebug("Level for block " + blockId + " is " + level + " on local machine")
-
- // Look for the block in memory
- if (level.useMemory) {
- logDebug("Getting block " + blockId + " from memory")
- memoryStore.getValues(blockId) match {
- case Some(iterator) => {
- logDebug("Block " + blockId + " found in memory")
- return Some(iterator)
- }
- case None => {
- logDebug("Block " + blockId + " not found in memory")
- }
- }
- } else {
- logDebug("Not getting block " + blockId + " from memory")
- }
-
- // Look for block in disk
- if (level.useDisk) {
- logDebug("Getting block " + blockId + " from disk")
- diskStore.getValues(blockId) match {
- case Some(iterator) => {
- logDebug("Block " + blockId + " found in disk")
- return Some(iterator)
- }
- case None => {
- throw new Exception("Block " + blockId + " not found in disk")
- return None
- }
- }
- } else {
- logDebug("Not getting block " + blockId + " from disk")
- }
-
- } else {
- logDebug("Level for block " + blockId + " not found")
- }
- }
- return None
- }
-
- /**
- * Get block from remote block managers.
- */
- def getRemote(blockId: String): Option[Iterator[Any]] = {
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- logDebug("Getting remote block " + blockId)
- // Get locations of block
- val locations = BlockManagerMaster.mustGetLocations(GetLocations(blockId))
-
- // Get block from remote locations
- for (loc <- locations) {
- logDebug("Getting remote block " + blockId + " from " + loc)
- val data = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port))
- if (data != null) {
- logDebug("Data is not null: " + data)
- return Some(dataDeserialize(data))
- }
- logDebug("Data is null")
- }
- logDebug("Data not found")
- return None
- }
-
- /**
- * Get a block from the block manager (either local or remote).
- */
- def get(blockId: String): Option[Iterator[Any]] = {
- getLocal(blockId).orElse(getRemote(blockId))
- }
-
- /**
- * Get many blocks from local and remote block manager using their BlockManagerIds.
- */
- def get(blocksByAddress: Seq[(BlockManagerId, Seq[String])]): HashMap[String, Option[Iterator[Any]]] = {
- if (blocksByAddress == null) {
- throw new IllegalArgumentException("BlocksByAddress is null")
- }
- logDebug("Getting " + blocksByAddress.map(_._2.size).sum + " blocks")
- var startTime = System.currentTimeMillis
- val blocks = new HashMap[String,Option[Iterator[Any]]]()
- val localBlockIds = new ArrayBuffer[String]()
- val remoteBlockIds = new ArrayBuffer[String]()
- val remoteBlockIdsPerLocation = new HashMap[BlockManagerId, Seq[String]]()
-
- // Split local and remote blocks
- for ((address, blockIds) <- blocksByAddress) {
- if (address == blockManagerId) {
- localBlockIds ++= blockIds
- } else {
- remoteBlockIds ++= blockIds
- remoteBlockIdsPerLocation(address) = blockIds
- }
- }
-
- // Start getting remote blocks
- val remoteBlockFutures = remoteBlockIdsPerLocation.toSeq.map { case (bmId, bIds) =>
- val cmId = ConnectionManagerId(bmId.ip, bmId.port)
- val blockMessages = bIds.map(bId => BlockMessage.fromGetBlock(GetBlock(bId)))
- val blockMessageArray = new BlockMessageArray(blockMessages)
- val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
- (cmId, future)
- }
- logDebug("Started remote gets for " + remoteBlockIds.size + " blocks in " +
- Utils.getUsedTimeMs(startTime) + " ms")
-
- // Get the local blocks while remote blocks are being fetched
- startTime = System.currentTimeMillis
- localBlockIds.foreach(id => {
- get(id) match {
- case Some(block) => {
- blocks.update(id, Some(block))
- logDebug("Got local block " + id)
- }
- case None => {
- throw new BlockException(id, "Could not get block " + id + " from local machine")
- }
- }
- })
- logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
-
- // wait for and gather all the remote blocks
- for ((cmId, future) <- remoteBlockFutures) {
- var count = 0
- val oneBlockId = remoteBlockIdsPerLocation(new BlockManagerId(cmId.host, cmId.port)).first
- future() match {
- case Some(message) => {
- val bufferMessage = message.asInstanceOf[BufferMessage]
- val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
- blockMessageArray.foreach(blockMessage => {
- if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
- throw new BlockException(oneBlockId, "Unexpected message received from " + cmId)
- }
- val buffer = blockMessage.getData()
- val blockId = blockMessage.getId()
- val block = dataDeserialize(buffer)
- blocks.update(blockId, Some(block))
- logDebug("Got remote block " + blockId + " in " + Utils.getUsedTimeMs(startTime))
- count += 1
- })
- }
- case None => {
- throw new BlockException(oneBlockId, "Could not get blocks from " + cmId)
- }
- }
- logDebug("Got remote " + count + " blocks from " + cmId.host + " in " +
- Utils.getUsedTimeMs(startTime) + " ms")
- }
-
- logDebug("Got all blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
- return blocks
- }
-
- /**
- * Put a new block of values to the block manager.
- */
- def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean = true) {
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- if (values == null) {
- throw new IllegalArgumentException("Values is null")
- }
- if (level == null || !level.isValid) {
- throw new IllegalArgumentException("Storage level is null or invalid")
- }
-
- val startTimeMs = System.currentTimeMillis
- var bytes: ByteBuffer = null
-
- locker.getLock(blockId).synchronized {
- logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
- + " to get into synchronized block")
-
- // Check and warn if block with same id already exists
- if (getLevel(blockId) != null) {
- logWarning("Block " + blockId + " already exists in local machine")
- return
- }
-
- if (level.useMemory && level.useDisk) {
- // If saving to both memory and disk, then serialize only once
- memoryStore.putValues(blockId, values, level) match {
- case Left(newValues) =>
- diskStore.putValues(blockId, newValues, level) match {
- case Right(newBytes) => bytes = newBytes
- case _ => throw new Exception("Unexpected return value")
- }
- case Right(newBytes) =>
- bytes = newBytes
- diskStore.putBytes(blockId, newBytes, level)
- }
- } else if (level.useMemory) {
- // If only save to memory
- memoryStore.putValues(blockId, values, level) match {
- case Right(newBytes) => bytes = newBytes
- case _ =>
- }
- } else {
- // If only save to disk
- diskStore.putValues(blockId, values, level) match {
- case Right(newBytes) => bytes = newBytes
- case _ => throw new Exception("Unexpected return value")
- }
- }
-
- // Store the storage level
- setLevel(blockId, level, tellMaster)
- }
- logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
-
- // Replicate block if required
- if (level.replication > 1) {
- if (bytes == null) {
- bytes = dataSerialize(values) // serialize the block if not already done
- }
- replicate(blockId, bytes, level)
- }
-
- // TODO: This code will be removed when CacheTracker is gone.
- if (blockId.startsWith("rdd")) {
- notifyTheCacheTracker(blockId)
- }
- logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs))
- }
-
-
- /**
- * Put a new block of serialized bytes to the block manager.
- */
- def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- if (bytes == null) {
- throw new IllegalArgumentException("Bytes is null")
- }
- if (level == null || !level.isValid) {
- throw new IllegalArgumentException("Storage level is null or invalid")
- }
-
- val startTimeMs = System.currentTimeMillis
-
- // Initiate the replication before storing it locally. This is faster as
- // data is already serialized and ready for sending
- val replicationFuture = if (level.replication > 1) {
- future {
- replicate(blockId, bytes, level)
- }
- } else {
- null
- }
-
- locker.getLock(blockId).synchronized {
- logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
- + " to get into synchronized block")
- if (getLevel(blockId) != null) {
- logWarning("Block " + blockId + " already exists")
- return
- }
-
- if (level.useMemory) {
- memoryStore.putBytes(blockId, bytes, level)
- }
- if (level.useDisk) {
- diskStore.putBytes(blockId, bytes, level)
- }
-
- // Store the storage level
- setLevel(blockId, level, tellMaster)
- }
-
- // TODO: This code will be removed when CacheTracker is gone.
- if (blockId.startsWith("rdd")) {
- notifyTheCacheTracker(blockId)
- }
-
- // If replication had started, then wait for it to finish
- if (level.replication > 1) {
- if (replicationFuture == null) {
- throw new Exception("Unexpected")
- }
- replicationFuture()
- }
-
- val finishTime = System.currentTimeMillis
- if (level.replication > 1) {
- logDebug("PutBytes for block " + blockId + " with replication took " +
- Utils.getUsedTimeMs(startTimeMs))
- } else {
- logDebug("PutBytes for block " + blockId + " without replication took " +
- Utils.getUsedTimeMs(startTimeMs))
- }
- }
-
- /**
- * Replicate block to another node.
- */
-
- private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
- val tLevel: StorageLevel =
- new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
- var peers: Array[BlockManagerId] = BlockManagerMaster.mustGetPeers(
- GetPeers(blockManagerId, level.replication - 1))
- for (peer: BlockManagerId <- peers) {
- val start = System.nanoTime
- logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is "
- + data.array().length + " Bytes. To node: " + peer)
- if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel),
- new ConnectionManagerId(peer.ip, peer.port))) {
- logError("Failed to call syncPutBlock to " + peer)
- }
- logDebug("Replicated BlockId " + blockId + " once used " +
- (System.nanoTime - start) / 1e6 + " s; The size of the data is " +
- data.array().length + " bytes.")
- }
- }
-
- // TODO: This code will be removed when CacheTracker is gone.
- private def notifyTheCacheTracker(key: String) {
- val rddInfo = key.split(":")
- val rddId: Int = rddInfo(1).toInt
- val splitIndex: Int = rddInfo(2).toInt
- val host = System.getProperty("spark.hostname", Utils.localHostName)
- cacheTracker.notifyTheCacheTrackerFromBlockManager(spark.AddedToCache(rddId, splitIndex, host))
- }
-
- /**
- * Read a block consisting of a single object.
- */
- def getSingle(blockId: String): Option[Any] = {
- get(blockId).map(_.next)
- }
-
- /**
- * Write a block consisting of a single object.
- */
- def putSingle(blockId: String, value: Any, level: StorageLevel, tellMaster: Boolean = true) {
- put(blockId, Iterator(value), level, tellMaster)
- }
-
- /**
- * Drop block from memory (called when memory store has reached it limit)
- */
- def dropFromMemory(blockId: String) {
- locker.getLock(blockId).synchronized {
- val level = getLevel(blockId)
- if (level == null) {
- logWarning("Block " + blockId + " cannot be removed from memory as it does not exist")
- return
- }
- if (!level.useMemory) {
- logWarning("Block " + blockId + " cannot be removed from memory as it is not in memory")
- return
- }
- memoryStore.remove(blockId)
- val newLevel = new StorageLevel(level.useDisk, false, level.deserialized, level.replication)
- setLevel(blockId, newLevel)
- }
- }
-
- def dataSerialize(values: Iterator[Any]): ByteBuffer = {
- /*serializer.newInstance().serializeMany(values)*/
- val byteStream = new FastByteArrayOutputStream(4096)
- serializer.newInstance().serializeStream(byteStream).writeAll(values).close()
- byteStream.trim()
- ByteBuffer.wrap(byteStream.array)
- }
-
- def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = {
- /*serializer.newInstance().deserializeMany(bytes)*/
- val ser = serializer.newInstance()
- bytes.rewind()
- return ser.deserializeStream(new ByteBufferInputStream(bytes)).toIterator
- }
-
- private def notifyMaster(heartBeat: HeartBeat) {
- BlockManagerMaster.mustHeartBeat(heartBeat)
- }
-
- def stop() {
- connectionManager.stop()
- blockInfo.clear()
- memoryStore.clear()
- diskStore.clear()
- logInfo("BlockManager stopped")
- }
-}
-
-
-object BlockManager extends Logging {
- def getMaxMemoryFromSystemProperties(): Long = {
- val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble
- val bytes = (Runtime.getRuntime.totalMemory * memoryFraction).toLong
- logInfo("Maximum memory to use: " + bytes)
- bytes
- }
-}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
deleted file mode 100644
index d8400a1f65..0000000000
--- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala
+++ /dev/null
@@ -1,517 +0,0 @@
-package spark.storage
-
-import java.io._
-import java.util.{HashMap => JHashMap}
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.util.Random
-
-import akka.actor._
-import akka.actor.Actor
-import akka.actor.Actor._
-import akka.util.duration._
-
-import spark.Logging
-import spark.Utils
-
-sealed trait ToBlockManagerMaster
-
-case class RegisterBlockManager(
- blockManagerId: BlockManagerId,
- maxMemSize: Long,
- maxDiskSize: Long)
- extends ToBlockManagerMaster
-
-class HeartBeat(
- var blockManagerId: BlockManagerId,
- var blockId: String,
- var storageLevel: StorageLevel,
- var deserializedSize: Long,
- var size: Long)
- extends ToBlockManagerMaster
- with Externalizable {
-
- def this() = this(null, null, null, 0, 0) // For deserialization only
-
- override def writeExternal(out: ObjectOutput) {
- blockManagerId.writeExternal(out)
- out.writeUTF(blockId)
- storageLevel.writeExternal(out)
- out.writeInt(deserializedSize.toInt)
- out.writeInt(size.toInt)
- }
-
- override def readExternal(in: ObjectInput) {
- blockManagerId = new BlockManagerId()
- blockManagerId.readExternal(in)
- blockId = in.readUTF()
- storageLevel = new StorageLevel()
- storageLevel.readExternal(in)
- deserializedSize = in.readInt()
- size = in.readInt()
- }
-}
-
-object HeartBeat {
- def apply(blockManagerId: BlockManagerId,
- blockId: String,
- storageLevel: StorageLevel,
- deserializedSize: Long,
- size: Long): HeartBeat = {
- new HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size)
- }
-
-
- // For pattern-matching
- def unapply(h: HeartBeat): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = {
- Some((h.blockManagerId, h.blockId, h.storageLevel, h.deserializedSize, h.size))
- }
-}
-
-case class GetLocations(
- blockId: String)
- extends ToBlockManagerMaster
-
-case class GetLocationsMultipleBlockIds(
- blockIds: Array[String])
- extends ToBlockManagerMaster
-
-case class GetPeers(
- blockManagerId: BlockManagerId,
- size: Int)
- extends ToBlockManagerMaster
-
-case class RemoveHost(
- host: String)
- extends ToBlockManagerMaster
-
-
-class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging {
-
- class BlockManagerInfo(
- timeMs: Long,
- maxMem: Long,
- maxDisk: Long) {
- private var lastSeenMs = timeMs
- private var remainedMem = maxMem
- private var remainedDisk = maxDisk
- private val blocks = new JHashMap[String, StorageLevel]
-
- def updateLastSeenMs() {
- lastSeenMs = System.currentTimeMillis() / 1000
- }
-
- def addBlock(blockId: String, storageLevel: StorageLevel, deserializedSize: Long, size: Long) =
- synchronized {
- updateLastSeenMs()
-
- if (blocks.containsKey(blockId)) {
- val oriLevel: StorageLevel = blocks.get(blockId)
-
- if (oriLevel.deserialized) {
- remainedMem += deserializedSize
- }
- if (oriLevel.useMemory) {
- remainedMem += size
- }
- if (oriLevel.useDisk) {
- remainedDisk += size
- }
- }
-
- if (storageLevel.isValid) {
- blocks.put(blockId, storageLevel)
- if (storageLevel.deserialized) {
- remainedMem -= deserializedSize
- }
- if (storageLevel.useMemory) {
- remainedMem -= size
- }
- if (storageLevel.useDisk) {
- remainedDisk -= size
- }
- } else {
- blocks.remove(blockId)
- }
- }
-
- def getLastSeenMs(): Long = {
- return lastSeenMs
- }
-
- def getRemainedMem(): Long = {
- return remainedMem
- }
-
- def getRemainedDisk(): Long = {
- return remainedDisk
- }
-
- override def toString(): String = {
- return "BlockManagerInfo " + timeMs + " " + remainedMem + " " + remainedDisk
- }
-
- def clear() {
- blocks.clear()
- }
- }
-
- private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerInfo]
- private val blockInfo = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]]
-
- initLogging()
-
- def removeHost(host: String) {
- logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.")
- logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq)
- val ip = host.split(":")(0)
- val port = host.split(":")(1)
- blockManagerInfo.remove(new BlockManagerId(ip, port.toInt))
- logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq)
- self.reply(true)
- }
-
- def receive = {
- case RegisterBlockManager(blockManagerId, maxMemSize, maxDiskSize) =>
- register(blockManagerId, maxMemSize, maxDiskSize)
-
- case HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
- heartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size)
-
- case GetLocations(blockId) =>
- getLocations(blockId)
-
- case GetLocationsMultipleBlockIds(blockIds) =>
- getLocationsMultipleBlockIds(blockIds)
-
- case GetPeers(blockManagerId, size) =>
- getPeers_Deterministic(blockManagerId, size)
- /*getPeers(blockManagerId, size)*/
-
- case RemoveHost(host) =>
- removeHost(host)
-
- case msg =>
- logInfo("Got unknown msg: " + msg)
- }
-
- private def register(blockManagerId: BlockManagerId, maxMemSize: Long, maxDiskSize: Long) {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " " + blockManagerId + " "
- logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
- logInfo("Got Register Msg from " + blockManagerId)
- if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
- logInfo("Got Register Msg from master node, don't register it")
- } else {
- blockManagerInfo += (blockManagerId -> new BlockManagerInfo(
- System.currentTimeMillis() / 1000, maxMemSize, maxDiskSize))
- }
- logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs))
- self.reply(true)
- }
-
- private def heartBeat(
- blockManagerId: BlockManagerId,
- blockId: String,
- storageLevel: StorageLevel,
- deserializedSize: Long,
- size: Long) {
-
- val startTimeMs = System.currentTimeMillis()
- val tmp = " " + blockManagerId + " " + blockId + " "
-
- if (blockId == null) {
- blockManagerInfo(blockManagerId).updateLastSeenMs()
- logDebug("Got in heartBeat 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
- self.reply(true)
- }
-
- blockManagerInfo(blockManagerId).addBlock(blockId, storageLevel, deserializedSize, size)
-
- var locations: HashSet[BlockManagerId] = null
- if (blockInfo.containsKey(blockId)) {
- locations = blockInfo.get(blockId)._2
- } else {
- locations = new HashSet[BlockManagerId]
- blockInfo.put(blockId, (storageLevel.replication, locations))
- }
-
- if (storageLevel.isValid) {
- locations += blockManagerId
- } else {
- locations.remove(blockManagerId)
- }
-
- if (locations.size == 0) {
- blockInfo.remove(blockId)
- }
- self.reply(true)
- }
-
- private def getLocations(blockId: String) {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " " + blockId + " "
- logDebug("Got in getLocations 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
- if (blockInfo.containsKey(blockId)) {
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- res.appendAll(blockInfo.get(blockId)._2)
- logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at "
- + Utils.getUsedTimeMs(startTimeMs))
- self.reply(res.toSeq)
- } else {
- logDebug("Got in getLocations 2" + tmp + Utils.getUsedTimeMs(startTimeMs))
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- self.reply(res)
- }
- }
-
- private def getLocationsMultipleBlockIds(blockIds: Array[String]) {
- def getLocations(blockId: String): Seq[BlockManagerId] = {
- val tmp = blockId
- logDebug("Got in getLocationsMultipleBlockIds Sub 0 " + tmp)
- if (blockInfo.containsKey(blockId)) {
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- res.appendAll(blockInfo.get(blockId)._2)
- logDebug("Got in getLocationsMultipleBlockIds Sub 1 " + tmp + " " + res.toSeq)
- return res.toSeq
- } else {
- logDebug("Got in getLocationsMultipleBlockIds Sub 2 " + tmp)
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- return res.toSeq
- }
- }
-
- logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq)
- var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]]
- for (blockId <- blockIds) {
- res.append(getLocations(blockId))
- }
- logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq + " : " + res.toSeq)
- self.reply(res.toSeq)
- }
-
- private def getPeers(blockManagerId: BlockManagerId, size: Int) {
- var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- res.appendAll(peers)
- res -= blockManagerId
- val rand = new Random(System.currentTimeMillis())
- while (res.length > size) {
- res.remove(rand.nextInt(res.length))
- }
- self.reply(res.toSeq)
- }
-
- private def getPeers_Deterministic(blockManagerId: BlockManagerId, size: Int) {
- var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
-
- val peersWithIndices = peers.zipWithIndex
- val selfIndex = peersWithIndices.find(_._1 == blockManagerId).map(_._2).getOrElse(-1)
- if (selfIndex == -1) {
- throw new Exception("Self index for " + blockManagerId + " not found")
- }
-
- var index = selfIndex
- while (res.size < size) {
- index += 1
- if (index == selfIndex) {
- throw new Exception("More peer expected than available")
- }
- res += peers(index % peers.size)
- }
- val resStr = res.map(_.toString).reduceLeft(_ + ", " + _)
- self.reply(res.toSeq)
- }
-}
-
-object BlockManagerMaster extends Logging {
- initLogging()
-
- val AKKA_ACTOR_NAME: String = "BlockMasterManager"
- val REQUEST_RETRY_INTERVAL_MS = 100
- val DEFAULT_MASTER_IP: String = System.getProperty("spark.master.host", "localhost")
- val DEFAULT_MASTER_PORT: Int = System.getProperty("spark.master.port", "7077").toInt
- val DEFAULT_MANAGER_IP: String = Utils.localHostName()
- val DEFAULT_MANAGER_PORT: String = "10902"
-
- implicit val TIME_OUT_SEC = Actor.Timeout(3000 millis)
- var masterActor: ActorRef = null
-
- def startBlockManagerMaster(isMaster: Boolean, isLocal: Boolean) {
- if (isMaster) {
- masterActor = actorOf(new BlockManagerMaster(isLocal))
- remote.register(AKKA_ACTOR_NAME, masterActor)
- logInfo("Registered BlockManagerMaster Actor: " + DEFAULT_MASTER_IP + ":" + DEFAULT_MASTER_PORT)
- masterActor.start()
- } else {
- masterActor = remote.actorFor(AKKA_ACTOR_NAME, DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT)
- }
- }
-
- def stopBlockManagerMaster() {
- if (masterActor != null) {
- masterActor.stop()
- masterActor = null
- logInfo("BlockManagerMaster stopped")
- }
- }
-
- def notifyADeadHost(host: String) {
- (masterActor ? RemoveHost(host + ":" + DEFAULT_MANAGER_PORT)).as[Any] match {
- case Some(true) =>
- logInfo("Removed " + host + " successfully. @ notifyADeadHost")
- case Some(oops) =>
- logError("Failed @ notifyADeadHost: " + oops)
- case None =>
- logError("None @ notifyADeadHost.")
- }
- }
-
- def mustRegisterBlockManager(msg: RegisterBlockManager) {
- while (! syncRegisterBlockManager(msg)) {
- logWarning("Failed to register " + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
- }
- }
-
- def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = {
- //val masterActor = RemoteActor.select(node, name)
- val startTimeMs = System.currentTimeMillis()
- val tmp = " msg " + msg + " "
- logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
-
- (masterActor ? msg).as[Any] match {
- case Some(true) =>
- logInfo("BlockManager registered successfully @ syncRegisterBlockManager.")
- logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
- return true
- case Some(oops) =>
- logError("Failed @ syncRegisterBlockManager: " + oops)
- return false
- case None =>
- logError("None @ syncRegisterBlockManager.")
- return false
- }
- }
-
- def mustHeartBeat(msg: HeartBeat) {
- while (! syncHeartBeat(msg)) {
- logWarning("Failed to send heartbeat" + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
- }
- }
-
- def syncHeartBeat(msg: HeartBeat): Boolean = {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " msg " + msg + " "
- logDebug("Got in syncHeartBeat " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs))
-
- (masterActor ? msg).as[Any] match {
- case Some(true) =>
- logInfo("Heartbeat sent successfully.")
- logDebug("Got in syncHeartBeat " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs))
- return true
- case Some(oops) =>
- logError("Failed: " + oops)
- return false
- case None =>
- logError("None.")
- return false
- }
- }
-
- def mustGetLocations(msg: GetLocations): Array[BlockManagerId] = {
- var res: Array[BlockManagerId] = syncGetLocations(msg)
- while (res == null) {
- logInfo("Failed to get locations " + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
- res = syncGetLocations(msg)
- }
- return res
- }
-
- def syncGetLocations(msg: GetLocations): Array[BlockManagerId] = {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " msg " + msg + " "
- logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
-
- (masterActor ? msg).as[Seq[BlockManagerId]] match {
- case Some(arr) =>
- logDebug("GetLocations successfully.")
- logDebug("Got in syncGetLocations 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
- val res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- for (ele <- arr) {
- res += ele
- }
- logDebug("Got in syncGetLocations 2 " + tmp + Utils.getUsedTimeMs(startTimeMs))
- return res.toArray
- case None =>
- logError("GetLocations call returned None.")
- return null
- }
- }
-
- def mustGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds):
- Seq[Seq[BlockManagerId]] = {
- var res: Seq[Seq[BlockManagerId]] = syncGetLocationsMultipleBlockIds(msg)
- while (res == null) {
- logWarning("Failed to GetLocationsMultipleBlockIds " + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
- res = syncGetLocationsMultipleBlockIds(msg)
- }
- return res
- }
-
- def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds):
- Seq[Seq[BlockManagerId]] = {
- val startTimeMs = System.currentTimeMillis
- val tmp = " msg " + msg + " "
- logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
-
- (masterActor ? msg).as[Any] match {
- case Some(arr: Seq[Seq[BlockManagerId]]) =>
- logDebug("GetLocationsMultipleBlockIds successfully: " + arr)
- logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
- return arr
- case Some(oops) =>
- logError("Failed: " + oops)
- return null
- case None =>
- logInfo("None.")
- return null
- }
- }
-
- def mustGetPeers(msg: GetPeers): Array[BlockManagerId] = {
- var res: Array[BlockManagerId] = syncGetPeers(msg)
- while ((res == null) || (res.length != msg.size)) {
- logInfo("Failed to get peers " + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
- res = syncGetPeers(msg)
- }
-
- return res
- }
-
- def syncGetPeers(msg: GetPeers): Array[BlockManagerId] = {
- val startTimeMs = System.currentTimeMillis
- val tmp = " msg " + msg + " "
- logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
-
- (masterActor ? msg).as[Seq[BlockManagerId]] match {
- case Some(arr) =>
- logDebug("Got in syncGetPeers 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
- val res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- logInfo("GetPeers successfully: " + arr.length)
- res.appendAll(arr)
- logDebug("Got in syncGetPeers 2 " + tmp + Utils.getUsedTimeMs(startTimeMs))
- return res.toArray
- case None =>
- logError("GetPeers call returned None.")
- return null
- }
- }
-}
diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
deleted file mode 100644
index 3a8574a815..0000000000
--- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala
+++ /dev/null
@@ -1,142 +0,0 @@
-package spark.storage
-
-import java.nio._
-
-import scala.actors._
-import scala.actors.Actor._
-import scala.actors.remote._
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.util.Random
-
-import spark.Logging
-import spark.Utils
-import spark.SparkEnv
-import spark.network._
-
-/**
- * This should be changed to use event model late.
- */
-class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
- initLogging()
-
- blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive)
-
- def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = {
- logDebug("Handling message " + msg)
- msg match {
- case bufferMessage: BufferMessage => {
- try {
- logDebug("Handling as a buffer message " + bufferMessage)
- val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage)
- logDebug("Parsed as a block message array")
- val responseMessages = blockMessages.map(processBlockMessage _).filter(_ != None).map(_.get)
- /*logDebug("Processed block messages")*/
- return Some(new BlockMessageArray(responseMessages).toBufferMessage)
- } catch {
- case e: Exception => logError("Exception handling buffer message: " + e.getMessage)
- return None
- }
- }
- case otherMessage: Any => {
- logError("Unknown type message received: " + otherMessage)
- return None
- }
- }
- }
-
- def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = {
- blockMessage.getType() match {
- case BlockMessage.TYPE_PUT_BLOCK => {
- val pB = PutBlock(blockMessage.getId(), blockMessage.getData(), blockMessage.getLevel())
- logInfo("Received [" + pB + "]")
- putBlock(pB.id, pB.data, pB.level)
- return None
- }
- case BlockMessage.TYPE_GET_BLOCK => {
- val gB = new GetBlock(blockMessage.getId())
- logInfo("Received [" + gB + "]")
- val buffer = getBlock(gB.id)
- if (buffer == null) {
- return None
- }
- return Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer)))
- }
- case _ => return None
- }
- }
-
- private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) {
- val startTimeMs = System.currentTimeMillis()
- logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes)
- blockManager.putBytes(id, bytes, level)
- logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
- + " with data size: " + bytes.array().length)
- }
-
- private def getBlock(id: String): ByteBuffer = {
- val startTimeMs = System.currentTimeMillis()
- logDebug("Getblock " + id + " started from " + startTimeMs)
- val block = blockManager.getLocal(id)
- val buffer = block match {
- case Some(tValues) => {
- val values = tValues.asInstanceOf[Iterator[Any]]
- val buffer = blockManager.dataSerialize(values)
- buffer
- }
- case None => {
- null
- }
- }
- logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
- + " and got buffer " + buffer)
- return buffer
- }
-}
-
-object BlockManagerWorker extends Logging {
- private var blockManagerWorker: BlockManagerWorker = null
- private val DATA_TRANSFER_TIME_OUT_MS: Long = 500
- private val REQUEST_RETRY_INTERVAL_MS: Long = 1000
-
- initLogging()
-
- def startBlockManagerWorker(manager: BlockManager) {
- blockManagerWorker = new BlockManagerWorker(manager)
- }
-
- def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = {
- val blockManager = blockManagerWorker.blockManager
- val connectionManager = blockManager.connectionManager
- val serializer = blockManager.serializer
- val blockMessage = BlockMessage.fromPutBlock(msg)
- val blockMessageArray = new BlockMessageArray(blockMessage)
- val resultMessage = connectionManager.sendMessageReliablySync(
- toConnManagerId, blockMessageArray.toBufferMessage())
- return (resultMessage != None)
- }
-
- def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
- val blockManager = blockManagerWorker.blockManager
- val connectionManager = blockManager.connectionManager
- val serializer = blockManager.serializer
- val blockMessage = BlockMessage.fromGetBlock(msg)
- val blockMessageArray = new BlockMessageArray(blockMessage)
- val responseMessage = connectionManager.sendMessageReliablySync(
- toConnManagerId, blockMessageArray.toBufferMessage())
- responseMessage match {
- case Some(message) => {
- val bufferMessage = message.asInstanceOf[BufferMessage]
- logDebug("Response message received " + bufferMessage)
- BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => {
- logDebug("Found " + blockMessage)
- return blockMessage.getData
- })
- }
- case None => logDebug("No response message received"); return null
- }
- return null
- }
-}
diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala
deleted file mode 100644
index bb128dce7a..0000000000
--- a/core/src/main/scala/spark/storage/BlockMessage.scala
+++ /dev/null
@@ -1,219 +0,0 @@
-package spark.storage
-
-import java.nio._
-
-import scala.collection.mutable.StringBuilder
-import scala.collection.mutable.ArrayBuffer
-
-import spark._
-import spark.network._
-
-case class GetBlock(id: String)
-case class GotBlock(id: String, data: ByteBuffer)
-case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel)
-
-class BlockMessage() extends Logging{
- // Un-initialized: typ = 0
- // GetBlock: typ = 1
- // GotBlock: typ = 2
- // PutBlock: typ = 3
- private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED
- private var id: String = null
- private var data: ByteBuffer = null
- private var level: StorageLevel = null
-
- initLogging()
-
- def set(getBlock: GetBlock) {
- typ = BlockMessage.TYPE_GET_BLOCK
- id = getBlock.id
- }
-
- def set(gotBlock: GotBlock) {
- typ = BlockMessage.TYPE_GOT_BLOCK
- id = gotBlock.id
- data = gotBlock.data
- }
-
- def set(putBlock: PutBlock) {
- typ = BlockMessage.TYPE_PUT_BLOCK
- id = putBlock.id
- data = putBlock.data
- level = putBlock.level
- }
-
- def set(buffer: ByteBuffer) {
- val startTime = System.currentTimeMillis
- /*
- println()
- println("BlockMessage: ")
- while(buffer.remaining > 0) {
- print(buffer.get())
- }
- buffer.rewind()
- println()
- println()
- */
- typ = buffer.getInt()
- val idLength = buffer.getInt()
- val idBuilder = new StringBuilder(idLength)
- for (i <- 1 to idLength) {
- idBuilder += buffer.getChar()
- }
- id = idBuilder.toString()
-
- logDebug("Set from buffer Result: " + typ + " " + id)
- logDebug("Buffer position is " + buffer.position)
- if (typ == BlockMessage.TYPE_PUT_BLOCK) {
-
- val booleanInt = buffer.getInt()
- val replication = buffer.getInt()
- level = new StorageLevel(booleanInt, replication)
-
- val dataLength = buffer.getInt()
- data = ByteBuffer.allocate(dataLength)
- if (dataLength != buffer.remaining) {
- throw new Exception("Error parsing buffer")
- }
- data.put(buffer)
- data.flip()
- logDebug("Set from buffer Result 2: " + level + " " + data)
- } else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
-
- val dataLength = buffer.getInt()
- logDebug("Data length is "+ dataLength)
- logDebug("Buffer position is " + buffer.position)
- data = ByteBuffer.allocate(dataLength)
- if (dataLength != buffer.remaining) {
- throw new Exception("Error parsing buffer")
- }
- data.put(buffer)
- data.flip()
- logDebug("Set from buffer Result 3: " + data)
- }
-
- val finishTime = System.currentTimeMillis
- logDebug("Converted " + id + " from bytebuffer in " + (finishTime - startTime) / 1000.0 + " s")
- }
-
- def set(bufferMsg: BufferMessage) {
- val buffer = bufferMsg.buffers.apply(0)
- buffer.clear()
- set(buffer)
- }
-
- def getType(): Int = {
- return typ
- }
-
- def getId(): String = {
- return id
- }
-
- def getData(): ByteBuffer = {
- return data
- }
-
- def getLevel(): StorageLevel = {
- return level
- }
-
- def toBufferMessage(): BufferMessage = {
- val startTime = System.currentTimeMillis
- val buffers = new ArrayBuffer[ByteBuffer]()
- var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2)
- buffer.putInt(typ).putInt(id.length())
- id.foreach((x: Char) => buffer.putChar(x))
- buffer.flip()
- buffers += buffer
-
- if (typ == BlockMessage.TYPE_PUT_BLOCK) {
- buffer = ByteBuffer.allocate(8).putInt(level.toInt()).putInt(level.replication)
- buffer.flip()
- buffers += buffer
-
- buffer = ByteBuffer.allocate(4).putInt(data.remaining)
- buffer.flip()
- buffers += buffer
-
- buffers += data
- } else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
- buffer = ByteBuffer.allocate(4).putInt(data.remaining)
- buffer.flip()
- buffers += buffer
-
- buffers += data
- }
-
- logDebug("Start to log buffers.")
- buffers.foreach((x: ByteBuffer) => logDebug("" + x))
- /*
- println()
- println("BlockMessage: ")
- buffers.foreach(b => {
- while(b.remaining > 0) {
- print(b.get())
- }
- b.rewind()
- })
- println()
- println()
- */
- val finishTime = System.currentTimeMillis
- logDebug("Converted " + id + " to buffer message in " + (finishTime - startTime) / 1000.0 + " s")
- return Message.createBufferMessage(buffers)
- }
-
- override def toString(): String = {
- "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level +
- ", data = " + (if (data != null) data.remaining.toString else "null") + "]"
- }
-}
-
-object BlockMessage {
- val TYPE_NON_INITIALIZED: Int = 0
- val TYPE_GET_BLOCK: Int = 1
- val TYPE_GOT_BLOCK: Int = 2
- val TYPE_PUT_BLOCK: Int = 3
-
- def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = {
- val newBlockMessage = new BlockMessage()
- newBlockMessage.set(bufferMessage)
- newBlockMessage
- }
-
- def fromByteBuffer(buffer: ByteBuffer): BlockMessage = {
- val newBlockMessage = new BlockMessage()
- newBlockMessage.set(buffer)
- newBlockMessage
- }
-
- def fromGetBlock(getBlock: GetBlock): BlockMessage = {
- val newBlockMessage = new BlockMessage()
- newBlockMessage.set(getBlock)
- newBlockMessage
- }
-
- def fromGotBlock(gotBlock: GotBlock): BlockMessage = {
- val newBlockMessage = new BlockMessage()
- newBlockMessage.set(gotBlock)
- newBlockMessage
- }
-
- def fromPutBlock(putBlock: PutBlock): BlockMessage = {
- val newBlockMessage = new BlockMessage()
- newBlockMessage.set(putBlock)
- newBlockMessage
- }
-
- def main(args: Array[String]) {
- val B = new BlockMessage()
- B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.DISK_AND_MEMORY_2))
- val bMsg = B.toBufferMessage()
- val C = new BlockMessage()
- C.set(bMsg)
-
- println(B.getId() + " " + B.getLevel())
- println(C.getId() + " " + C.getLevel())
- }
-}
diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala
deleted file mode 100644
index 5f411d3488..0000000000
--- a/core/src/main/scala/spark/storage/BlockMessageArray.scala
+++ /dev/null
@@ -1,140 +0,0 @@
-package spark.storage
-import java.nio._
-
-import scala.collection.mutable.StringBuilder
-import scala.collection.mutable.ArrayBuffer
-
-import spark._
-import spark.network._
-
-class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging {
-
- def this(bm: BlockMessage) = this(Array(bm))
-
- def this() = this(null.asInstanceOf[Seq[BlockMessage]])
-
- def apply(i: Int) = blockMessages(i)
-
- def iterator = blockMessages.iterator
-
- def length = blockMessages.length
-
- initLogging()
-
- def set(bufferMessage: BufferMessage) {
- val startTime = System.currentTimeMillis
- val newBlockMessages = new ArrayBuffer[BlockMessage]()
- val buffer = bufferMessage.buffers(0)
- buffer.clear()
- /*
- println()
- println("BlockMessageArray: ")
- while(buffer.remaining > 0) {
- print(buffer.get())
- }
- buffer.rewind()
- println()
- println()
- */
- while(buffer.remaining() > 0) {
- val size = buffer.getInt()
- logDebug("Creating block message of size " + size + " bytes")
- val newBuffer = buffer.slice()
- newBuffer.clear()
- newBuffer.limit(size)
- logDebug("Trying to convert buffer " + newBuffer + " to block message")
- val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer)
- logDebug("Created " + newBlockMessage)
- newBlockMessages += newBlockMessage
- buffer.position(buffer.position() + size)
- }
- val finishTime = System.currentTimeMillis
- logDebug("Converted block message array from buffer message in " + (finishTime - startTime) / 1000.0 + " s")
- this.blockMessages = newBlockMessages
- }
-
- def toBufferMessage(): BufferMessage = {
- val buffers = new ArrayBuffer[ByteBuffer]()
-
- blockMessages.foreach(blockMessage => {
- val bufferMessage = blockMessage.toBufferMessage
- logDebug("Adding " + blockMessage)
- val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size)
- sizeBuffer.flip
- buffers += sizeBuffer
- buffers ++= bufferMessage.buffers
- logDebug("Added " + bufferMessage)
- })
-
- logDebug("Buffer list:")
- buffers.foreach((x: ByteBuffer) => logDebug("" + x))
- /*
- println()
- println("BlockMessageArray: ")
- buffers.foreach(b => {
- while(b.remaining > 0) {
- print(b.get())
- }
- b.rewind()
- })
- println()
- println()
- */
- return Message.createBufferMessage(buffers)
- }
-}
-
-object BlockMessageArray {
-
- def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = {
- val newBlockMessageArray = new BlockMessageArray()
- newBlockMessageArray.set(bufferMessage)
- newBlockMessageArray
- }
-
- def main(args: Array[String]) {
- val blockMessages =
- (0 until 10).map(i => {
- if (i % 2 == 0) {
- val buffer = ByteBuffer.allocate(100)
- buffer.clear
- BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY))
- } else {
- BlockMessage.fromGetBlock(GetBlock(i.toString))
- }
- })
- val blockMessageArray = new BlockMessageArray(blockMessages)
- println("Block message array created")
-
- val bufferMessage = blockMessageArray.toBufferMessage
- println("Converted to buffer message")
-
- val totalSize = bufferMessage.size
- val newBuffer = ByteBuffer.allocate(totalSize)
- newBuffer.clear()
- bufferMessage.buffers.foreach(buffer => {
- newBuffer.put(buffer)
- buffer.rewind()
- })
- newBuffer.flip
- val newBufferMessage = Message.createBufferMessage(newBuffer)
- println("Copied to new buffer message, size = " + newBufferMessage.size)
-
- val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage)
- println("Converted back to block message array")
- newBlockMessageArray.foreach(blockMessage => {
- blockMessage.getType() match {
- case BlockMessage.TYPE_PUT_BLOCK => {
- val pB = PutBlock(blockMessage.getId(), blockMessage.getData(), blockMessage.getLevel())
- println(pB)
- }
- case BlockMessage.TYPE_GET_BLOCK => {
- val gB = new GetBlock(blockMessage.getId())
- println(gB)
- }
- }
- })
- }
-}
-
-
diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala
deleted file mode 100644
index 8672a5376e..0000000000
--- a/core/src/main/scala/spark/storage/BlockStore.scala
+++ /dev/null
@@ -1,291 +0,0 @@
-package spark.storage
-
-import spark.{Utils, Logging, Serializer, SizeEstimator}
-
-import scala.collection.mutable.ArrayBuffer
-
-import java.io.{File, RandomAccessFile}
-import java.nio.ByteBuffer
-import java.nio.channels.FileChannel.MapMode
-import java.util.{UUID, LinkedHashMap}
-import java.util.concurrent.Executors
-
-import it.unimi.dsi.fastutil.io._
-
-/**
- * Abstract class to store blocks
- */
-abstract class BlockStore(blockManager: BlockManager) extends Logging {
- initLogging()
-
- def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel)
-
- def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer]
-
- def getBytes(blockId: String): Option[ByteBuffer]
-
- def getValues(blockId: String): Option[Iterator[Any]]
-
- def remove(blockId: String)
-
- def dataSerialize(values: Iterator[Any]): ByteBuffer = blockManager.dataSerialize(values)
-
- def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = blockManager.dataDeserialize(bytes)
-
- def clear() { }
-}
-
-/**
- * Class to store blocks in memory
- */
-class MemoryStore(blockManager: BlockManager, maxMemory: Long)
- extends BlockStore(blockManager) {
-
- class Entry(var value: Any, val size: Long, val deserialized: Boolean)
-
- private val memoryStore = new LinkedHashMap[String, Entry](32, 0.75f, true)
- private var currentMemory = 0L
-
- private val blockDropper = Executors.newSingleThreadExecutor()
-
- def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
- if (level.deserialized) {
- bytes.rewind()
- val values = dataDeserialize(bytes)
- val elements = new ArrayBuffer[Any]
- elements ++= values
- val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
- ensureFreeSpace(sizeEstimate)
- val entry = new Entry(elements, sizeEstimate, true)
- memoryStore.synchronized { memoryStore.put(blockId, entry) }
- currentMemory += sizeEstimate
- logDebug("Block " + blockId + " stored as values to memory")
- } else {
- val entry = new Entry(bytes, bytes.array().length, false)
- ensureFreeSpace(bytes.array.length)
- memoryStore.synchronized { memoryStore.put(blockId, entry) }
- currentMemory += bytes.array().length
- logDebug("Block " + blockId + " stored as " + bytes.array().length + " bytes to memory")
- }
- }
-
- def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = {
- if (level.deserialized) {
- val elements = new ArrayBuffer[Any]
- elements ++= values
- val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
- ensureFreeSpace(sizeEstimate)
- val entry = new Entry(elements, sizeEstimate, true)
- memoryStore.synchronized { memoryStore.put(blockId, entry) }
- currentMemory += sizeEstimate
- logDebug("Block " + blockId + " stored as values to memory")
- return Left(elements.iterator)
- } else {
- val bytes = dataSerialize(values)
- ensureFreeSpace(bytes.array().length)
- val entry = new Entry(bytes, bytes.array().length, false)
- memoryStore.synchronized { memoryStore.put(blockId, entry) }
- currentMemory += bytes.array().length
- logDebug("Block " + blockId + " stored as " + bytes.array.length + " bytes to memory")
- return Right(bytes)
- }
- }
-
- def getBytes(blockId: String): Option[ByteBuffer] = {
- throw new UnsupportedOperationException("Not implemented")
- }
-
- def getValues(blockId: String): Option[Iterator[Any]] = {
- val entry = memoryStore.synchronized { memoryStore.get(blockId) }
- if (entry == null) {
- return None
- }
- if (entry.deserialized) {
- return Some(entry.value.asInstanceOf[ArrayBuffer[Any]].toIterator)
- } else {
- return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer]))
- }
- }
-
- def remove(blockId: String) {
- memoryStore.synchronized {
- val entry = memoryStore.get(blockId)
- if (entry != null) {
- memoryStore.remove(blockId)
- currentMemory -= entry.size
- logDebug("Block " + blockId + " of size " + entry.size + " dropped from memory")
- } else {
- logWarning("Block " + blockId + " could not be removed as it doesnt exist")
- }
- }
- }
-
- override def clear() {
- memoryStore.synchronized {
- memoryStore.clear()
- }
- blockDropper.shutdown()
- logInfo("MemoryStore cleared")
- }
-
- private def drop(blockId: String) {
- blockDropper.submit(new Runnable() {
- def run() {
- blockManager.dropFromMemory(blockId)
- }
- })
- }
-
- private def ensureFreeSpace(space: Long) {
- logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
- space, currentMemory, maxMemory))
-
- val droppedBlockIds = new ArrayBuffer[String]()
- var droppedMemory = 0L
-
- memoryStore.synchronized {
- val iter = memoryStore.entrySet().iterator()
- while (maxMemory - (currentMemory - droppedMemory) < space && iter.hasNext) {
- val pair = iter.next()
- val blockId = pair.getKey
- droppedBlockIds += blockId
- droppedMemory += pair.getValue.size
- logDebug("Decided to drop " + blockId)
- }
- }
-
- for (blockId <- droppedBlockIds) {
- drop(blockId)
- }
- droppedBlockIds.clear()
- }
-}
-
-
-/**
- * Class to store blocks in disk
- */
-class DiskStore(blockManager: BlockManager, rootDirs: String)
- extends BlockStore(blockManager) {
-
- val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- val localDirs = createLocalDirs()
- var lastLocalDirUsed = 0
-
- addShutdownHook()
-
- def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
- logDebug("Attempting to put block " + blockId)
- val startTime = System.currentTimeMillis
- val file = createFile(blockId)
- if (file != null) {
- val channel = new RandomAccessFile(file, "rw").getChannel()
- val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.array.length)
- buffer.put(bytes.array)
- channel.close()
- val finishTime = System.currentTimeMillis
- logDebug("Block " + blockId + " stored to file of " + bytes.array.length + " bytes to disk in " + (finishTime - startTime) + " ms")
- } else {
- logError("File not created for block " + blockId)
- }
- }
-
- def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = {
- val bytes = dataSerialize(values)
- logDebug("Converted block " + blockId + " to " + bytes.array.length + " bytes")
- putBytes(blockId, bytes, level)
- return Right(bytes)
- }
-
- def getBytes(blockId: String): Option[ByteBuffer] = {
- val file = getFile(blockId)
- val length = file.length().toInt
- val channel = new RandomAccessFile(file, "r").getChannel()
- val bytes = ByteBuffer.allocate(length)
- bytes.put(channel.map(MapMode.READ_WRITE, 0, length))
- return Some(bytes)
- }
-
- def getValues(blockId: String): Option[Iterator[Any]] = {
- val file = getFile(blockId)
- val length = file.length().toInt
- val channel = new RandomAccessFile(file, "r").getChannel()
- val bytes = channel.map(MapMode.READ_ONLY, 0, length)
- val buffer = dataDeserialize(bytes)
- channel.close()
- return Some(buffer)
- }
-
- def remove(blockId: String) {
- throw new UnsupportedOperationException("Not implemented")
- }
-
- private def createFile(blockId: String): File = {
- val file = getFile(blockId)
- if (file == null) {
- lastLocalDirUsed = (lastLocalDirUsed + 1) % localDirs.size
- val newFile = new File(localDirs(lastLocalDirUsed), blockId)
- newFile.getParentFile.mkdirs()
- return newFile
- } else {
- logError("File for block " + blockId + " already exists on disk, " + file)
- return null
- }
- }
-
- private def getFile(blockId: String): File = {
- logDebug("Getting file for block " + blockId)
- // Search for the file in all the local directories, only one of them should have the file
- val files = localDirs.map(localDir => new File(localDir, blockId)).filter(_.exists)
- if (files.size > 1) {
- throw new Exception("Multiple files for same block " + blockId + " exists: " +
- files.map(_.toString).reduceLeft(_ + ", " + _))
- return null
- } else if (files.size == 0) {
- return null
- } else {
- logDebug("Got file " + files(0) + " of size " + files(0).length + " bytes")
- return files(0)
- }
- }
-
- private def createLocalDirs(): Seq[File] = {
- logDebug("Creating local directories at root dirs '" + rootDirs + "'")
- rootDirs.split("[;,:]").map(rootDir => {
- var foundLocalDir: Boolean = false
- var localDir: File = null
- var localDirUuid: UUID = null
- var tries = 0
- while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
- tries += 1
- try {
- localDirUuid = UUID.randomUUID()
- localDir = new File(rootDir, "spark-local-" + localDirUuid)
- if (!localDir.exists) {
- localDir.mkdirs()
- foundLocalDir = true
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir failed", e)
- }
- }
- if (!foundLocalDir) {
- logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
- " attempts to create local dir in " + rootDir)
- System.exit(1)
- }
- logDebug("Created local directory at " + localDir)
- localDir
- })
- }
-
- private def addShutdownHook() {
- Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
- override def run() {
- logDebug("Shutdown hook called")
- localDirs.foreach(localDir => Utils.deleteRecursively(localDir))
- }
- })
- }
-}
diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala
deleted file mode 100644
index 693a679c4e..0000000000
--- a/core/src/main/scala/spark/storage/StorageLevel.scala
+++ /dev/null
@@ -1,80 +0,0 @@
-package spark.storage
-
-import java.io._
-
-class StorageLevel(
- var useDisk: Boolean,
- var useMemory: Boolean,
- var deserialized: Boolean,
- var replication: Int = 1)
- extends Externalizable {
-
- // TODO: Also add fields for caching priority, dataset ID, and flushing.
-
- def this(booleanInt: Int, replication: Int) {
- this(((booleanInt & 4) != 0),
- ((booleanInt & 2) != 0),
- ((booleanInt & 1) != 0),
- replication)
- }
-
- def this() = this(false, true, false) // For deserialization
-
- override def clone(): StorageLevel = new StorageLevel(
- this.useDisk, this.useMemory, this.deserialized, this.replication)
-
- override def equals(other: Any): Boolean = other match {
- case s: StorageLevel =>
- s.useDisk == useDisk &&
- s.useMemory == useMemory &&
- s.deserialized == deserialized &&
- s.replication == replication
- case _ =>
- false
- }
-
- def isValid() = ((useMemory || useDisk) && (replication > 0))
-
- def toInt(): Int = {
- var ret = 0
- if (useDisk) {
- ret += 4
- }
- if (useMemory) {
- ret += 2
- }
- if (deserialized) {
- ret += 1
- }
- return ret
- }
-
- override def writeExternal(out: ObjectOutput) {
- out.writeByte(toInt().toByte)
- out.writeByte(replication.toByte)
- }
-
- override def readExternal(in: ObjectInput) {
- val flags = in.readByte()
- useDisk = (flags & 4) != 0
- useMemory = (flags & 2) != 0
- deserialized = (flags & 1) != 0
- replication = in.readByte()
- }
-
- override def toString(): String =
- "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
-}
-
-object StorageLevel {
- val NONE = new StorageLevel(false, false, false)
- val DISK_ONLY = new StorageLevel(true, false, false)
- val MEMORY_ONLY = new StorageLevel(false, true, false)
- val MEMORY_ONLY_2 = new StorageLevel(false, true, false, 2)
- val MEMORY_ONLY_DESER = new StorageLevel(false, true, true)
- val MEMORY_ONLY_DESER_2 = new StorageLevel(false, true, true, 2)
- val DISK_AND_MEMORY = new StorageLevel(true, true, false)
- val DISK_AND_MEMORY_2 = new StorageLevel(true, true, false, 2)
- val DISK_AND_MEMORY_DESER = new StorageLevel(true, true, true)
- val DISK_AND_MEMORY_DESER_2 = new StorageLevel(true, true, true, 2)
-}
diff --git a/core/src/main/scala/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/spark/util/ByteBufferInputStream.scala
deleted file mode 100644
index abe2d99dd8..0000000000
--- a/core/src/main/scala/spark/util/ByteBufferInputStream.scala
+++ /dev/null
@@ -1,30 +0,0 @@
-package spark.util
-
-import java.io.InputStream
-import java.nio.ByteBuffer
-
-class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream {
- override def read(): Int = {
- if (buffer.remaining() == 0) {
- -1
- } else {
- buffer.get()
- }
- }
-
- override def read(dest: Array[Byte]): Int = {
- read(dest, 0, dest.length)
- }
-
- override def read(dest: Array[Byte], offset: Int, length: Int): Int = {
- val amountToGet = math.min(buffer.remaining(), length)
- buffer.get(dest, offset, amountToGet)
- return amountToGet
- }
-
- override def skip(bytes: Long): Long = {
- val amountToSkip = math.min(bytes, buffer.remaining).toInt
- buffer.position(buffer.position + amountToSkip)
- return amountToSkip
- }
-}
diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala
deleted file mode 100644
index efb1ae7529..0000000000
--- a/core/src/main/scala/spark/util/StatCounter.scala
+++ /dev/null
@@ -1,89 +0,0 @@
-package spark.util
-
-/**
- * A class for tracking the statistics of a set of numbers (count, mean and variance) in a
- * numerically robust way. Includes support for merging two StatCounters. Based on Welford and
- * Chan's algorithms described at http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
- */
-class StatCounter(values: TraversableOnce[Double]) {
- private var n: Long = 0 // Running count of our values
- private var mu: Double = 0 // Running mean of our values
- private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2)
-
- merge(values)
-
- def this() = this(Nil)
-
- def merge(value: Double): StatCounter = {
- val delta = value - mu
- n += 1
- mu += delta / n
- m2 += delta * (value - mu)
- this
- }
-
- def merge(values: TraversableOnce[Double]): StatCounter = {
- values.foreach(v => merge(v))
- this
- }
-
- def merge(other: StatCounter): StatCounter = {
- if (other == this) {
- merge(other.copy()) // Avoid overwriting fields in a weird order
- } else {
- val delta = other.mu - mu
- if (other.n * 10 < n) {
- mu = mu + (delta * other.n) / (n + other.n)
- } else if (n * 10 < other.n) {
- mu = other.mu - (delta * n) / (n + other.n)
- } else {
- mu = (mu * n + other.mu * other.n) / (n + other.n)
- }
- m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
- n += other.n
- this
- }
- }
-
- def copy(): StatCounter = {
- val other = new StatCounter
- other.n = n
- other.mu = mu
- other.m2 = m2
- other
- }
-
- def count: Long = n
-
- def mean: Double = mu
-
- def sum: Double = n * mu
-
- def variance: Double = {
- if (n == 0)
- Double.NaN
- else
- m2 / n
- }
-
- def sampleVariance: Double = {
- if (n <= 1)
- Double.NaN
- else
- m2 / (n - 1)
- }
-
- def stdev: Double = math.sqrt(variance)
-
- def sampleStdev: Double = math.sqrt(sampleVariance)
-
- override def toString: String = {
- "(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev)
- }
-}
-
-object StatCounter {
- def apply(values: TraversableOnce[Double]) = new StatCounter(values)
-
- def apply(values: Double*) = new StatCounter(values)
-}
diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala
index 3d170a6e22..60290d14ca 100644
--- a/core/src/test/scala/spark/CacheTrackerSuite.scala
+++ b/core/src/test/scala/spark/CacheTrackerSuite.scala
@@ -1,103 +1,95 @@
package spark
import org.scalatest.FunSuite
-
-import scala.collection.mutable.HashMap
-
-import akka.actor._
-import akka.actor.Actor
-import akka.actor.Actor._
+import collection.mutable.HashMap
class CacheTrackerSuite extends FunSuite {
test("CacheTrackerActor slave initialization & cache status") {
- //System.setProperty("spark.master.port", "1345")
+ System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
- val tracker = actorOf(new CacheTrackerActor)
+ val tracker = new CacheTrackerActor
tracker.start()
- tracker !! SlaveCacheStarted("host001", initialSize)
+ tracker !? SlaveCacheStarted("host001", initialSize)
- assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 0L)))
+ assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 0L)))
- tracker !! StopCacheTracker
+ tracker !? StopCacheTracker
}
test("RegisterRDD") {
- //System.setProperty("spark.master.port", "1345")
+ System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
- val tracker = actorOf(new CacheTrackerActor)
+ val tracker = new CacheTrackerActor
tracker.start()
- tracker !! SlaveCacheStarted("host001", initialSize)
+ tracker !? SlaveCacheStarted("host001", initialSize)
- tracker !! RegisterRDD(1, 3)
- tracker !! RegisterRDD(2, 1)
+ tracker !? RegisterRDD(1, 3)
+ tracker !? RegisterRDD(2, 1)
- assert(getCacheLocations(tracker) === Map(1 -> List(List(), List(), List()), 2 -> List(List())))
+ assert(getCacheLocations(tracker) == Map(1 -> List(List(), List(), List()), 2 -> List(List())))
- tracker !! StopCacheTracker
+ tracker !? StopCacheTracker
}
test("AddedToCache") {
- //System.setProperty("spark.master.port", "1345")
+ System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
- val tracker = actorOf(new CacheTrackerActor)
+ val tracker = new CacheTrackerActor
tracker.start()
- tracker !! SlaveCacheStarted("host001", initialSize)
+ tracker !? SlaveCacheStarted("host001", initialSize)
- tracker !! RegisterRDD(1, 2)
- tracker !! RegisterRDD(2, 1)
+ tracker !? RegisterRDD(1, 2)
+ tracker !? RegisterRDD(2, 1)
- tracker !! AddedToCache(1, 0, "host001", 2L << 15)
- tracker !! AddedToCache(1, 1, "host001", 2L << 11)
- tracker !! AddedToCache(2, 0, "host001", 3L << 10)
+ tracker !? AddedToCache(1, 0, "host001", 2L << 15)
+ tracker !? AddedToCache(1, 1, "host001", 2L << 11)
+ tracker !? AddedToCache(2, 0, "host001", 3L << 10)
- assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 72704L)))
+ assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L)))
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
+ assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
- tracker !! StopCacheTracker
+ tracker !? StopCacheTracker
}
test("DroppedFromCache") {
- //System.setProperty("spark.master.port", "1345")
+ System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
- val tracker = actorOf(new CacheTrackerActor)
+ val tracker = new CacheTrackerActor
tracker.start()
- tracker !! SlaveCacheStarted("host001", initialSize)
+ tracker !? SlaveCacheStarted("host001", initialSize)
- tracker !! RegisterRDD(1, 2)
- tracker !! RegisterRDD(2, 1)
+ tracker !? RegisterRDD(1, 2)
+ tracker !? RegisterRDD(2, 1)
- tracker !! AddedToCache(1, 0, "host001", 2L << 15)
- tracker !! AddedToCache(1, 1, "host001", 2L << 11)
- tracker !! AddedToCache(2, 0, "host001", 3L << 10)
+ tracker !? AddedToCache(1, 0, "host001", 2L << 15)
+ tracker !? AddedToCache(1, 1, "host001", 2L << 11)
+ tracker !? AddedToCache(2, 0, "host001", 3L << 10)
- assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 72704L)))
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
+ assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L)))
+ assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
- tracker !! DroppedFromCache(1, 1, "host001", 2L << 11)
+ tracker !? DroppedFromCache(1, 1, "host001", 2L << 11)
- assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 68608L)))
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
+ assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 68608L)))
+ assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
- tracker !! StopCacheTracker
+ tracker !? StopCacheTracker
}
/**
* Helper function to get cacheLocations from CacheTracker
*/
- def getCacheLocations(tracker: ActorRef) = (tracker ? GetCacheLocations).get match {
+ def getCacheLocations(tracker: CacheTrackerActor) = tracker !? GetCacheLocations match {
case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]].map {
case (i, arr) => (i -> arr.toList)
}
diff --git a/core/src/test/scala/spark/MesosSchedulerSuite.scala b/core/src/test/scala/spark/MesosSchedulerSuite.scala
index 54421225d8..0e6820cbdc 100644
--- a/core/src/test/scala/spark/MesosSchedulerSuite.scala
+++ b/core/src/test/scala/spark/MesosSchedulerSuite.scala
@@ -2,8 +2,6 @@ package spark
import org.scalatest.FunSuite
-import spark.scheduler.mesos.MesosScheduler
-
class MesosSchedulerSuite extends FunSuite {
test("memoryStringToMb"){
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 00b24464a6..c61cb90f82 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -48,7 +48,7 @@ class ShuffleSuite extends FunSuite {
assert(valuesFor2.toList.sorted === List(1))
sc.stop()
}
-
+
test("groupByKey with many output partitions") {
val sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
@@ -189,7 +189,7 @@ class ShuffleSuite extends FunSuite {
))
sc.stop()
}
-
+
test("zero-partition RDD") {
val sc = new SparkContext("local", "test")
val emptyDir = Files.createTempDir()
@@ -199,5 +199,5 @@ class ShuffleSuite extends FunSuite {
// Test that a shuffle on the file works, because this used to be a bug
assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
sc.stop()
- }
+ }
}
diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala
index 1ac4737f04..f31251e509 100644
--- a/core/src/test/scala/spark/UtilsSuite.scala
+++ b/core/src/test/scala/spark/UtilsSuite.scala
@@ -2,7 +2,7 @@ package spark
import org.scalatest.FunSuite
import java.io.{ByteArrayOutputStream, ByteArrayInputStream}
-import scala.util.Random
+import util.Random
class UtilsSuite extends FunSuite {
diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
deleted file mode 100644
index 63501f0613..0000000000
--- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala
+++ /dev/null
@@ -1,212 +0,0 @@
-package spark.storage
-
-import spark.KryoSerializer
-
-import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
-
-class BlockManagerSuite extends FunSuite with BeforeAndAfter{
- before {
- BlockManagerMaster.startBlockManagerMaster(true, true)
- }
-
- test("manager-master interaction") {
- val store = new BlockManager(2000, new KryoSerializer)
- val a1 = new Array[Byte](400)
- val a2 = new Array[Byte](400)
- val a3 = new Array[Byte](400)
-
- // Putting a1, a2 and a3 in memory and telling master only about a1 and a2
- store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_DESER)
- store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_DESER)
- store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY_DESER, false)
-
- // Checking whether blocks are in memory
- assert(store.getSingle("a1") != None, "a1 was not in store")
- assert(store.getSingle("a2") != None, "a2 was not in store")
- assert(store.getSingle("a3") != None, "a3 was not in store")
-
- // Checking whether master knows about the blocks or not
- assert(BlockManagerMaster.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1")
- assert(BlockManagerMaster.mustGetLocations(GetLocations("a2")).size > 0, "master was not told about a2")
- assert(BlockManagerMaster.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3")
-
- // Setting storage level of a1 and a2 to invalid; they should be removed from store and master
- store.setLevel("a1", new StorageLevel(false, false, false, 1))
- store.setLevel("a2", new StorageLevel(true, false, false, 0))
- assert(store.getSingle("a1") === None, "a1 not removed from store")
- assert(store.getSingle("a2") === None, "a2 not removed from store")
- assert(BlockManagerMaster.mustGetLocations(GetLocations("a1")).size === 0, "master did not remove a1")
- assert(BlockManagerMaster.mustGetLocations(GetLocations("a2")).size === 0, "master did not remove a2")
- }
-
- test("in-memory LRU storage") {
- val store = new BlockManager(1000, new KryoSerializer)
- val a1 = new Array[Byte](400)
- val a2 = new Array[Byte](400)
- val a3 = new Array[Byte](400)
- store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_DESER)
- store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_DESER)
- store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY_DESER)
- assert(store.getSingle("a2") != None, "a2 was not in store")
- assert(store.getSingle("a3") != None, "a3 was not in store")
- Thread.sleep(100)
- assert(store.getSingle("a1") === None, "a1 was in store")
- assert(store.getSingle("a2") != None, "a2 was not in store")
- // At this point a2 was gotten last, so LRU will getSingle rid of a3
- store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_DESER)
- assert(store.getSingle("a1") != None, "a1 was not in store")
- assert(store.getSingle("a2") != None, "a2 was not in store")
- Thread.sleep(100)
- assert(store.getSingle("a3") === None, "a3 was in store")
- }
-
- test("in-memory LRU storage with serialization") {
- val store = new BlockManager(1000, new KryoSerializer)
- val a1 = new Array[Byte](400)
- val a2 = new Array[Byte](400)
- val a3 = new Array[Byte](400)
- store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
- store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY)
- store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY)
- Thread.sleep(100)
- assert(store.getSingle("a2") != None, "a2 was not in store")
- assert(store.getSingle("a3") != None, "a3 was not in store")
- assert(store.getSingle("a1") === None, "a1 was in store")
- assert(store.getSingle("a2") != None, "a2 was not in store")
- // At this point a2 was gotten last, so LRU will getSingle rid of a3
- store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_DESER)
- Thread.sleep(100)
- assert(store.getSingle("a1") != None, "a1 was not in store")
- assert(store.getSingle("a2") != None, "a2 was not in store")
- assert(store.getSingle("a3") === None, "a1 was in store")
- }
-
- test("on-disk storage") {
- val store = new BlockManager(1000, new KryoSerializer)
- val a1 = new Array[Byte](400)
- val a2 = new Array[Byte](400)
- val a3 = new Array[Byte](400)
- store.putSingle("a1", a1, StorageLevel.DISK_ONLY)
- store.putSingle("a2", a2, StorageLevel.DISK_ONLY)
- store.putSingle("a3", a3, StorageLevel.DISK_ONLY)
- assert(store.getSingle("a2") != None, "a2 was not in store")
- assert(store.getSingle("a3") != None, "a3 was not in store")
- assert(store.getSingle("a1") != None, "a1 was not in store")
- }
-
- test("disk and memory storage") {
- val store = new BlockManager(1000, new KryoSerializer)
- val a1 = new Array[Byte](400)
- val a2 = new Array[Byte](400)
- val a3 = new Array[Byte](400)
- store.putSingle("a1", a1, StorageLevel.DISK_AND_MEMORY_DESER)
- store.putSingle("a2", a2, StorageLevel.DISK_AND_MEMORY_DESER)
- store.putSingle("a3", a3, StorageLevel.DISK_AND_MEMORY_DESER)
- Thread.sleep(100)
- assert(store.getSingle("a2") != None, "a2 was not in store")
- assert(store.getSingle("a3") != None, "a3 was not in store")
- assert(store.getSingle("a1") != None, "a1 was not in store")
- }
-
- test("disk and memory storage with serialization") {
- val store = new BlockManager(1000, new KryoSerializer)
- val a1 = new Array[Byte](400)
- val a2 = new Array[Byte](400)
- val a3 = new Array[Byte](400)
- store.putSingle("a1", a1, StorageLevel.DISK_AND_MEMORY)
- store.putSingle("a2", a2, StorageLevel.DISK_AND_MEMORY)
- store.putSingle("a3", a3, StorageLevel.DISK_AND_MEMORY)
- Thread.sleep(100)
- assert(store.getSingle("a2") != None, "a2 was not in store")
- assert(store.getSingle("a3") != None, "a3 was not in store")
- assert(store.getSingle("a1") != None, "a1 was not in store")
- }
-
- test("LRU with mixed storage levels") {
- val store = new BlockManager(1000, new KryoSerializer)
- val a1 = new Array[Byte](400)
- val a2 = new Array[Byte](400)
- val a3 = new Array[Byte](400)
- val a4 = new Array[Byte](400)
- // First store a1 and a2, both in memory, and a3, on disk only
- store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
- store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY)
- store.putSingle("a3", a3, StorageLevel.DISK_ONLY)
- // At this point LRU should not kick in because a3 is only on disk
- assert(store.getSingle("a1") != None, "a2 was not in store")
- assert(store.getSingle("a2") != None, "a3 was not in store")
- assert(store.getSingle("a3") != None, "a1 was not in store")
- assert(store.getSingle("a1") != None, "a2 was not in store")
- assert(store.getSingle("a2") != None, "a3 was not in store")
- assert(store.getSingle("a3") != None, "a1 was not in store")
- // Now let's add in a4, which uses both disk and memory; a1 should drop out
- store.putSingle("a4", a4, StorageLevel.DISK_AND_MEMORY)
- Thread.sleep(100)
- assert(store.getSingle("a1") == None, "a1 was in store")
- assert(store.getSingle("a2") != None, "a2 was not in store")
- assert(store.getSingle("a3") != None, "a3 was not in store")
- assert(store.getSingle("a4") != None, "a4 was not in store")
- }
-
- test("in-memory LRU with streams") {
- val store = new BlockManager(1000, new KryoSerializer)
- val list1 = List(new Array[Byte](200), new Array[Byte](200))
- val list2 = List(new Array[Byte](200), new Array[Byte](200))
- val list3 = List(new Array[Byte](200), new Array[Byte](200))
- store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_DESER)
- store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY_DESER)
- store.put("list3", list3.iterator, StorageLevel.MEMORY_ONLY_DESER)
- Thread.sleep(100)
- assert(store.get("list2") != None, "list2 was not in store")
- assert(store.get("list2").get.size == 2)
- assert(store.get("list3") != None, "list3 was not in store")
- assert(store.get("list3").get.size == 2)
- assert(store.get("list1") === None, "list1 was in store")
- assert(store.get("list2") != None, "list2 was not in store")
- assert(store.get("list2").get.size == 2)
- // At this point list2 was gotten last, so LRU will getSingle rid of list3
- store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_DESER)
- Thread.sleep(100)
- assert(store.get("list1") != None, "list1 was not in store")
- assert(store.get("list1").get.size == 2)
- assert(store.get("list2") != None, "list2 was not in store")
- assert(store.get("list2").get.size == 2)
- assert(store.get("list3") === None, "list1 was in store")
- }
-
- test("LRU with mixed storage levels and streams") {
- val store = new BlockManager(1000, new KryoSerializer)
- val list1 = List(new Array[Byte](200), new Array[Byte](200))
- val list2 = List(new Array[Byte](200), new Array[Byte](200))
- val list3 = List(new Array[Byte](200), new Array[Byte](200))
- val list4 = List(new Array[Byte](200), new Array[Byte](200))
- // First store list1 and list2, both in memory, and list3, on disk only
- store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY)
- store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY)
- store.put("list3", list3.iterator, StorageLevel.DISK_ONLY)
- Thread.sleep(100)
- // At this point LRU should not kick in because list3 is only on disk
- assert(store.get("list1") != None, "list2 was not in store")
- assert(store.get("list1").get.size === 2)
- assert(store.get("list2") != None, "list3 was not in store")
- assert(store.get("list2").get.size === 2)
- assert(store.get("list3") != None, "list1 was not in store")
- assert(store.get("list3").get.size === 2)
- assert(store.get("list1") != None, "list2 was not in store")
- assert(store.get("list1").get.size === 2)
- assert(store.get("list2") != None, "list3 was not in store")
- assert(store.get("list2").get.size === 2)
- assert(store.get("list3") != None, "list1 was not in store")
- assert(store.get("list3").get.size === 2)
- // Now let's add in list4, which uses both disk and memory; list1 should drop out
- store.put("list4", list4.iterator, StorageLevel.DISK_AND_MEMORY)
- assert(store.get("list1") === None, "list1 was in store")
- assert(store.get("list2") != None, "list3 was not in store")
- assert(store.get("list2").get.size === 2)
- assert(store.get("list3") != None, "list1 was not in store")
- assert(store.get("list3").get.size === 2)
- assert(store.get("list4") != None, "list4 was not in store")
- assert(store.get("list4").get.size === 2)
- }
-}
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 3ce6a086c1..caaf5ebc68 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -33,7 +33,6 @@ object SparkBuild extends Build {
"org.scalatest" %% "scalatest" % "1.6.1" % "test",
"org.scala-tools.testing" %% "scalacheck" % "1.9" % "test"
),
- parallelExecution in Test := false,
/* Workaround for issue #206 (fixed after SBT 0.11.0) */
watchTransitiveSources <<= Defaults.inDependencies[Task[Seq[File]]](watchSources.task,
const(std.TaskExtra.constant(Nil)), aggregate = true, includeRoot = true) apply { _.join.map(_.flatten) }
@@ -58,12 +57,8 @@ object SparkBuild extends Build {
"asm" % "asm-all" % "3.3.1",
"com.google.protobuf" % "protobuf-java" % "2.4.1",
"de.javakaffee" % "kryo-serializers" % "0.9",
- "se.scalablesolutions.akka" % "akka-actor" % "1.3.1",
- "se.scalablesolutions.akka" % "akka-remote" % "1.3.1",
- "se.scalablesolutions.akka" % "akka-slf4j" % "1.3.1",
"org.jboss.netty" % "netty" % "3.2.6.Final",
- "it.unimi.dsi" % "fastutil" % "6.4.4",
- "colt" % "colt" % "1.2.0"
+ "it.unimi.dsi" % "fastutil" % "6.4.2"
)
) ++ assemblySettings ++ Seq(test in assembly := {})
@@ -73,7 +68,8 @@ object SparkBuild extends Build {
) ++ assemblySettings ++ Seq(test in assembly := {})
def examplesSettings = sharedSettings ++ Seq(
- name := "spark-examples"
+ name := "spark-examples",
+ libraryDependencies += "colt" % "colt" % "1.2.0"
)
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")
diff --git a/sbt/sbt b/sbt/sbt
index fab9967286..714e3d15d7 100755
--- a/sbt/sbt
+++ b/sbt/sbt
@@ -4,4 +4,4 @@ if [ "$MESOS_HOME" != "" ]; then
EXTRA_ARGS="-Djava.library.path=$MESOS_HOME/lib/java"
fi
export SPARK_HOME=$(cd "$(dirname $0)/.."; pwd)
-java -Xmx1200M -XX:MaxPermSize=200m $EXTRA_ARGS -jar $SPARK_HOME/sbt/sbt-launch-*.jar "$@"
+java -Xmx800M -XX:MaxPermSize=150m $EXTRA_ARGS -jar $SPARK_HOME/sbt/sbt-launch-*.jar "$@"
diff --git a/sbt/sbt-launch-0.11.1.jar b/sbt/sbt-launch-0.11.1.jar
new file mode 100644
index 0000000000..59d325ecfe
--- /dev/null
+++ b/sbt/sbt-launch-0.11.1.jar
Binary files differ
diff --git a/sbt/sbt-launch-0.11.3-2.jar b/sbt/sbt-launch-0.11.3-2.jar
deleted file mode 100644
index 23e5c3f311..0000000000
--- a/sbt/sbt-launch-0.11.3-2.jar
+++ /dev/null
Binary files differ