aboutsummaryrefslogtreecommitdiff
path: root/core/src/main
diff options
context:
space:
mode:
authorStephen Haberman <stephen@exigencecorp.com>2013-03-20 15:37:10 -0500
committerStephen Haberman <stephen@exigencecorp.com>2013-03-20 15:37:10 -0500
commit4f4215311a4bef65eb705798a0748d270371bee5 (patch)
tree9c0658787581d642b0b44e99e010467adcbb9ecf /core/src/main
parent6415c2bb6046b080a040ca9e3f3015079712cb5e (diff)
parentca4d083ec825aa674fdd7d1dcd52a99ef8dcdf8b (diff)
downloadspark-4f4215311a4bef65eb705798a0748d270371bee5.tar.gz
spark-4f4215311a4bef65eb705798a0748d270371bee5.tar.bz2
spark-4f4215311a4bef65eb705798a0748d270371bee5.zip
Merge branch 'master' into volatile
Diffstat (limited to 'core/src/main')
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala26
-rw-r--r--core/src/main/scala/spark/KryoSerializer.scala29
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala18
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala44
-rw-r--r--core/src/main/scala/spark/Partitioner.scala5
-rw-r--r--core/src/main/scala/spark/RDD.scala18
-rw-r--r--core/src/main/scala/spark/ShuffleFetcher.scala4
-rw-r--r--core/src/main/scala/spark/SparkContext.scala33
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala31
-rw-r--r--core/src/main/scala/spark/TaskContext.scala9
-rw-r--r--core/src/main/scala/spark/api/java/JavaPairRDD.scala24
-rw-r--r--core/src/main/scala/spark/api/java/JavaSparkContext.scala4
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala4
-rw-r--r--core/src/main/scala/spark/deploy/DeployMessage.scala2
-rw-r--r--core/src/main/scala/spark/deploy/client/Client.scala5
-rw-r--r--core/src/main/scala/spark/deploy/master/Master.scala21
-rw-r--r--core/src/main/scala/spark/deploy/worker/Worker.scala14
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala12
-rw-r--r--core/src/main/scala/spark/executor/TaskMetrics.scala78
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala3
-rw-r--r--core/src/main/scala/spark/rdd/CoalescedRDD.scala4
-rw-r--r--core/src/main/scala/spark/rdd/HadoopRDD.scala40
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala8
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/SubtractedRDD.scala68
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala61
-rw-r--r--core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/ResultTask.scala1
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala8
-rw-r--r--core/src/main/scala/spark/scheduler/SparkListener.scala146
-rw-r--r--core/src/main/scala/spark/scheduler/StageInfo.scala12
-rw-r--r--core/src/main/scala/spark/scheduler/Task.scala7
-rw-r--r--core/src/main/scala/spark/scheduler/TaskResult.scala7
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala5
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala3
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala19
-rw-r--r--core/src/main/scala/spark/serializer/Serializer.scala32
-rw-r--r--core/src/main/scala/spark/storage/BlockFetchTracker.scala10
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala476
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala24
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMessages.scala8
-rw-r--r--core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala12
-rw-r--r--core/src/main/scala/spark/storage/MemoryStore.scala4
-rw-r--r--core/src/main/scala/spark/storage/ThreadingTest.scala5
-rw-r--r--core/src/main/scala/spark/util/AkkaUtils.scala6
-rw-r--r--core/src/main/scala/spark/util/CompletionIterator.scala25
-rw-r--r--core/src/main/scala/spark/util/Distribution.scala65
-rw-r--r--core/src/main/scala/spark/util/NextIterator.scala71
-rw-r--r--core/src/main/scala/spark/util/TimedIterator.scala32
51 files changed, 1114 insertions, 452 deletions
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
index 86432d0127..c27ed36406 100644
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -1,20 +1,22 @@
package spark
+import executor.{ShuffleReadMetrics, TaskMetrics}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
-import spark.storage.BlockManagerId
+import spark.storage.{DelegateBlockFetchTracker, BlockManagerId}
+import util.{CompletionIterator, TimedIterator}
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
- override def fetch[K, V](shuffleId: Int, reduceId: Int) = {
+ override def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) = {
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
-
+
val startTime = System.currentTimeMillis
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
-
+
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
for (((address, size), index) <- statuses.zipWithIndex) {
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
@@ -45,6 +47,20 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
}
}
}
- blockManager.getMultiple(blocksByAddress).flatMap(unpackBlock)
+
+ val blockFetcherItr = blockManager.getMultiple(blocksByAddress)
+ val itr = new TimedIterator(blockFetcherItr.flatMap(unpackBlock)) with DelegateBlockFetchTracker
+ itr.setDelegate(blockFetcherItr)
+ CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
+ val shuffleMetrics = new ShuffleReadMetrics
+ shuffleMetrics.shuffleReadMillis = itr.getNetMillis
+ shuffleMetrics.remoteFetchTime = itr.remoteFetchTime
+ shuffleMetrics.fetchWaitTime = itr.fetchWaitTime
+ shuffleMetrics.remoteBytesRead = itr.remoteBytesRead
+ shuffleMetrics.totalBlocksFetched = itr.totalBlocks
+ shuffleMetrics.localBlocksFetched = itr.numLocalBlocks
+ shuffleMetrics.remoteBlocksFetched = itr.numRemoteBlocks
+ metrics.shuffleReadMetrics = Some(shuffleMetrics)
+ })
}
}
diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala
index 0bd73e936b..d723ab7b1e 100644
--- a/core/src/main/scala/spark/KryoSerializer.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -157,27 +157,34 @@ class KryoSerializer extends spark.serializer.Serializer with Logging {
// Register maps with a special serializer since they have complex internal structure
class ScalaMapSerializer(buildMap: Array[(Any, Any)] => scala.collection.Map[Any, Any])
- extends KSerializer[Array[(Any, Any)] => scala.collection.Map[Any, Any]] {
+ extends KSerializer[Array[(Any, Any)] => scala.collection.Map[Any, Any]] {
+
+ //hack, look at https://groups.google.com/forum/#!msg/kryo-users/Eu5V4bxCfws/k-8UQ22y59AJ
+ private final val FAKE_REFERENCE = new Object()
override def write(
- kryo: Kryo,
- output: KryoOutput,
- obj: Array[(Any, Any)] => scala.collection.Map[Any, Any]) {
+ kryo: Kryo,
+ output: KryoOutput,
+ obj: Array[(Any, Any)] => scala.collection.Map[Any, Any]) {
val map = obj.asInstanceOf[scala.collection.Map[Any, Any]]
- kryo.writeObject(output, map.size.asInstanceOf[java.lang.Integer])
+ output.writeInt(map.size)
for ((k, v) <- map) {
kryo.writeClassAndObject(output, k)
kryo.writeClassAndObject(output, v)
}
}
override def read (
- kryo: Kryo,
- input: KryoInput,
- cls: Class[Array[(Any, Any)] => scala.collection.Map[Any, Any]])
+ kryo: Kryo,
+ input: KryoInput,
+ cls: Class[Array[(Any, Any)] => scala.collection.Map[Any, Any]])
: Array[(Any, Any)] => scala.collection.Map[Any, Any] = {
- val size = kryo.readObject(input, classOf[java.lang.Integer]).intValue
+ kryo.reference(FAKE_REFERENCE)
+ val size = input.readInt()
val elems = new Array[(Any, Any)](size)
- for (i <- 0 until size)
- elems(i) = (kryo.readClassAndObject(input), kryo.readClassAndObject(input))
+ for (i <- 0 until size) {
+ val k = kryo.readClassAndObject(input)
+ val v = kryo.readClassAndObject(input)
+ elems(i)=(k,v)
+ }
buildMap(elems).asInstanceOf[Array[(Any, Any)] => scala.collection.Map[Any, Any]]
}
}
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 4735207585..866d630a6d 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -38,9 +38,10 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
}
}
-private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolean) extends Logging {
+private[spark] class MapOutputTracker extends Logging {
- val timeout = 10.seconds
+ // Set to the MapOutputTrackerActor living on the driver
+ var trackerActor: ActorRef = _
var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
@@ -53,24 +54,13 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
var cacheGeneration = generation
val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
- val actorName: String = "MapOutputTracker"
- var trackerActor: ActorRef = if (isDriver) {
- val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
- logInfo("Registered MapOutputTrackerActor actor")
- actor
- } else {
- val ip = System.getProperty("spark.driver.host", "localhost")
- val port = System.getProperty("spark.driver.port", "7077").toInt
- val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
- actorSystem.actorFor(url)
- }
-
val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
try {
+ val timeout = 10.seconds
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
} catch {
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index e7408e4352..07efba9e8d 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -89,6 +89,33 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
/**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = {
+ combineByKey[V]({v: V => func(zeroValue, v)}, func, func, partitioner)
+ }
+
+ /**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = {
+ foldByKey(zeroValue, new HashPartitioner(numPartitions))(func)
+ }
+
+ /**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = {
+ foldByKey(zeroValue, defaultPartitioner(self))(func)
+ }
+
+ /**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce.
@@ -441,6 +468,23 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
/**
+ * Return an RDD with the pairs from `this` whose keys are not in `other`.
+ *
+ * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
+ * RDD will be <= us.
+ */
+ def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] =
+ subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size)))
+
+ /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
+ def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] =
+ subtractByKey(other, new HashPartitioner(numPartitions))
+
+ /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
+ def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] =
+ new SubtractedRDD[K, V, W](self, other, p)
+
+ /**
* Return the list of values in the RDD for key `key`. This operation is done efficiently if the
* RDD has a known partitioner by only searching the partition that the key maps to.
*/
diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala
index eec0e8dd79..6f8cd17c88 100644
--- a/core/src/main/scala/spark/Partitioner.scala
+++ b/core/src/main/scala/spark/Partitioner.scala
@@ -10,9 +10,6 @@ abstract class Partitioner extends Serializable {
}
object Partitioner {
-
- private val useDefaultParallelism = System.getProperty("spark.default.parallelism") != null
-
/**
* Choose a partitioner to use for a cogroup-like operation between a number of RDDs.
*
@@ -33,7 +30,7 @@ object Partitioner {
for (r <- bySize if r.partitioner != None) {
return r.partitioner.get
}
- if (useDefaultParallelism) {
+ if (System.getProperty("spark.default.parallelism") != null) {
return new HashPartitioner(rdd.context.defaultParallelism)
} else {
return new HashPartitioner(bySize.head.partitions.size)
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 584efa8adf..9bd8a0f98d 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -420,7 +420,23 @@ abstract class RDD[T: ClassManifest](
/**
* Return an RDD with the elements from `this` that are not in `other`.
*/
- def subtract(other: RDD[T], p: Partitioner): RDD[T] = new SubtractedRDD[T](this, other, p)
+ def subtract(other: RDD[T], p: Partitioner): RDD[T] = {
+ if (partitioner == Some(p)) {
+ // Our partitioner knows how to handle T (which, since we have a partitioner, is
+ // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples
+ val p2 = new Partitioner() {
+ override def numPartitions = p.numPartitions
+ override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1)
+ }
+ // Unfortunately, since we're making a new p2, we'll get ShuffleDependencies
+ // anyway, and when calling .keys, will not have a partitioner set, even though
+ // the SubtractedRDD will, thanks to p2's de-tupled partitioning, already be
+ // partitioned by the right/real keys (e.g. p).
+ this.map(x => (x, null)).subtractByKey(other.map((_, null)), p2).keys
+ } else {
+ this.map(x => (x, null)).subtractByKey(other.map((_, null)), p).keys
+ }
+ }
/**
* Reduces the elements of this RDD using the specified commutative and associative binary operator.
diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala
index d9a94d4021..442e9f0269 100644
--- a/core/src/main/scala/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/spark/ShuffleFetcher.scala
@@ -1,11 +1,13 @@
package spark
+import executor.TaskMetrics
+
private[spark] abstract class ShuffleFetcher {
/**
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
- def fetch[K, V](shuffleId: Int, reduceId: Int) : Iterator[(K, V)]
+ def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) : Iterator[(K,V)]
/** Stop the fetcher */
def stop() {}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index df23710d46..4957a54c1b 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -1,19 +1,15 @@
package spark
import java.io._
-import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger
-import java.net.{URI, URLClassLoader}
-import java.lang.ref.WeakReference
+import java.net.URI
import scala.collection.Map
import scala.collection.generic.Growable
-import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
-import akka.actor.Actor
-import akka.actor.Actor._
-import org.apache.hadoop.fs.{FileUtil, Path}
+import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.SequenceFileInputFormat
@@ -33,20 +29,19 @@ import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
-import org.apache.mesos.{Scheduler, MesosNativeLibrary}
+import org.apache.mesos.MesosNativeLibrary
-import spark.broadcast._
import spark.deploy.LocalSparkCluster
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
-import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
-import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler}
+import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
+import spark.scheduler._
import spark.scheduler.local.LocalScheduler
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import storage.BlockManagerUI
-import util.{MetadataCleaner, TimeStampedHashMap}
-import storage.{StorageStatus, StorageUtils, RDDInfo}
+import spark.storage.BlockManagerUI
+import spark.util.{MetadataCleaner, TimeStampedHashMap}
+import spark.storage.{StorageStatus, StorageUtils, RDDInfo}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -64,7 +59,7 @@ class SparkContext(
val appName: String,
val sparkHome: String = null,
val jars: Seq[String] = Nil,
- environment: Map[String, String] = Map())
+ val environment: Map[String, String] = Map())
extends Logging {
// Ensure logging is initialized before we spawn any threads
@@ -466,6 +461,10 @@ class SparkContext(
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
+ def addSparkListener(listener: SparkListener) {
+ dagScheduler.sparkListeners += listener
+ }
+
/**
* Return a map from the slave to the max memory available for caching and the remaining
* memory available for caching.
@@ -484,6 +483,10 @@ class SparkContext(
StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
}
+ def getStageInfo: Map[Stage,StageInfo] = {
+ dagScheduler.stageToInfos
+ }
+
/**
* Return information about blocks stored in all of the slaves
*/
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index d2193ae72b..7157fd2688 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -1,7 +1,6 @@
package spark
-import akka.actor.ActorSystem
-import akka.actor.ActorSystemImpl
+import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
import akka.remote.RemoteActorRefProvider
import serializer.Serializer
@@ -83,11 +82,23 @@ object SparkEnv extends Logging {
}
val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
+
+ def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
+ if (isDriver) {
+ logInfo("Registering " + name)
+ actorSystem.actorOf(Props(newActor), name = name)
+ } else {
+ val driverIp: String = System.getProperty("spark.driver.host", "localhost")
+ val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
+ val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, name)
+ logInfo("Connecting to " + name + ": " + url)
+ actorSystem.actorFor(url)
+ }
+ }
- val driverIp: String = System.getProperty("spark.driver.host", "localhost")
- val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
- val blockManagerMaster = new BlockManagerMaster(
- actorSystem, isDriver, isLocal, driverIp, driverPort)
+ val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
+ "BlockManagerMaster",
+ new spark.storage.BlockManagerMasterActor(isLocal)))
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer)
val connectionManager = blockManager.connectionManager
@@ -99,7 +110,12 @@ object SparkEnv extends Logging {
val cacheManager = new CacheManager(blockManager)
- val mapOutputTracker = new MapOutputTracker(actorSystem, isDriver)
+ // Have to assign trackerActor after initialization as MapOutputTrackerActor
+ // requires the MapOutputTracker itself
+ val mapOutputTracker = new MapOutputTracker()
+ mapOutputTracker.trackerActor = registerOrLookup(
+ "MapOutputTracker",
+ new MapOutputTrackerActor(mapOutputTracker))
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
@@ -137,4 +153,5 @@ object SparkEnv extends Logging {
httpFileServer,
sparkFilesDir)
}
+
}
diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala
index eab85f85a2..dd0609026a 100644
--- a/core/src/main/scala/spark/TaskContext.scala
+++ b/core/src/main/scala/spark/TaskContext.scala
@@ -1,9 +1,14 @@
package spark
+import executor.TaskMetrics
import scala.collection.mutable.ArrayBuffer
-
-class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable {
+class TaskContext(
+ val stageId: Int,
+ val splitId: Int,
+ val attemptId: Long,
+ val taskMetrics: TaskMetrics = TaskMetrics.empty()
+) extends Serializable {
@transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
index c1bd13c49a..49aaabf835 100644
--- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -161,6 +161,30 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
/**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V, partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
+ fromRDD(rdd.foldByKey(zeroValue, partitioner)(func))
+
+ /**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V, numPartitions: Int, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
+ fromRDD(rdd.foldByKey(zeroValue, numPartitions)(func))
+
+ /**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
+ fromRDD(rdd.foldByKey(zeroValue)(func))
+
+ /**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions.
diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
index f75fc27c7b..5f18b1e15b 100644
--- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
@@ -31,8 +31,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param appName A name for your application, to display on the cluster web UI
* @param sparkHome The SPARK_HOME directory on the slave nodes
- * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
- * system or HDFS, HTTP, HTTPS, or FTP URLs.
+ * @param jarFile JAR file to send to the cluster. This can be a path on the local file system
+ * or an HDFS, HTTP, HTTPS, or FTP URL.
*/
def this(master: String, appName: String, sparkHome: String, jarFile: String) =
this(new SparkContext(master, appName, sparkHome, Seq(jarFile)))
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 8c73477384..9b4d54ab4e 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -51,7 +51,7 @@ private[spark] class PythonRDD[T: ClassManifest](
val env = SparkEnv.get
// Start a thread to print the process's stderr to ours
- new Thread("stderr reader for " + command) {
+ new Thread("stderr reader for " + pythonExec) {
override def run() {
for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
System.err.println(line)
@@ -60,7 +60,7 @@ private[spark] class PythonRDD[T: ClassManifest](
}.start()
// Start a thread to feed the process input from our parent's iterator
- new Thread("stdin writer for " + command) {
+ new Thread("stdin writer for " + pythonExec) {
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala
index 3cbf4fdd98..8a3e64e4c2 100644
--- a/core/src/main/scala/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/spark/deploy/DeployMessage.scala
@@ -65,7 +65,7 @@ case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String
exitStatus: Option[Int])
private[spark]
-case class appKilled(message: String)
+case class ApplicationRemoved(message: String)
// Internal message in Client
diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala
index 1a95524cf9..2fc5e657f9 100644
--- a/core/src/main/scala/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/spark/deploy/client/Client.scala
@@ -54,6 +54,11 @@ private[spark] class Client(
appId = appId_
listener.connected(appId)
+ case ApplicationRemoved(message) =>
+ logError("Master removed our application: %s; stopping client".format(message))
+ markDisconnected()
+ context.stop(self)
+
case ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) =>
val fullId = appId + "/" + id
logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, host, cores))
diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala
index b7f167425f..71b9d0801d 100644
--- a/core/src/main/scala/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/spark/deploy/master/Master.scala
@@ -43,7 +43,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
// As a temporary workaround before better ways of configuring memory, we allow users to set
// a flag that will perform round-robin scheduling across the nodes (spreading out each app
// among all the nodes) instead of trying to consolidate each app onto a small # of nodes.
- val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "false").toBoolean
+ val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean
override def preStart() {
logInfo("Starting Spark master at spark://" + ip + ":" + port)
@@ -107,7 +107,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
} else {
logError("Application %s with ID %s failed %d times, removing it".format(
appInfo.desc.name, appInfo.id, appInfo.retryCount))
- removeApplication(appInfo)
+ removeApplication(appInfo, ApplicationState.FAILED)
}
}
}
@@ -129,19 +129,19 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
// The disconnected actor could've been either a worker or an app; remove whichever of
// those we have an entry for in the corresponding actor hashmap
actorToWorker.get(actor).foreach(removeWorker)
- actorToApp.get(actor).foreach(removeApplication)
+ actorToApp.get(actor).foreach(finishApplication)
}
case RemoteClientDisconnected(transport, address) => {
// The disconnected client could've been either a worker or an app; remove whichever it was
addressToWorker.get(address).foreach(removeWorker)
- addressToApp.get(address).foreach(removeApplication)
+ addressToApp.get(address).foreach(finishApplication)
}
case RemoteClientShutdown(transport, address) => {
// The disconnected client could've been either a worker or an app; remove whichever it was
addressToWorker.get(address).foreach(removeWorker)
- addressToApp.get(address).foreach(removeApplication)
+ addressToApp.get(address).foreach(finishApplication)
}
case RequestMasterState => {
@@ -257,20 +257,25 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
return app
}
- def removeApplication(app: ApplicationInfo) {
+ def finishApplication(app: ApplicationInfo) {
+ removeApplication(app, ApplicationState.FINISHED)
+ }
+
+ def removeApplication(app: ApplicationInfo, state: ApplicationState.Value) {
if (apps.contains(app)) {
logInfo("Removing app " + app.id)
apps -= app
idToApp -= app.id
actorToApp -= app.driver
- addressToWorker -= app.driver.path.address
+ addressToApp -= app.driver.path.address
completedApps += app // Remember it in our history
waitingApps -= app
for (exec <- app.executors.values) {
exec.worker.removeExecutor(exec)
exec.worker.actor ! KillExecutor(exec.application.id, exec.id)
}
- app.markFinished(ApplicationState.FINISHED) // TODO: Mark it as FAILED if it failed
+ app.markFinished(state)
+ app.driver ! ApplicationRemoved(state.toString)
schedule()
}
}
diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala
index 2bbc931316..da3f4f636c 100644
--- a/core/src/main/scala/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/spark/deploy/worker/Worker.scala
@@ -74,16 +74,10 @@ private[spark] class Worker(
def connectToMaster() {
logInfo("Connecting to master " + masterUrl)
- try {
- master = context.actorFor(Master.toAkkaUrl(masterUrl))
- master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
- } catch {
- case e: Exception =>
- logError("Failed to connect to master", e)
- System.exit(1)
- }
+ master = context.actorFor(Master.toAkkaUrl(masterUrl))
+ master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
+ context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.watch(master) // Doesn't work with remote actors, but useful for testing
}
def startWebUi() {
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index b1d1d30283..3e7407b58d 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -80,6 +80,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
extends Runnable {
override def run() {
+ val startTime = System.currentTimeMillis()
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(urlClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
@@ -93,9 +94,18 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
logInfo("Its generation is " + task.generation)
env.mapOutputTracker.updateGeneration(task.generation)
+ val taskStart = System.currentTimeMillis()
val value = task.run(taskId.toInt)
+ val taskFinish = System.currentTimeMillis()
+ task.metrics.foreach{ m =>
+ m.executorDeserializeTime = (taskStart - startTime).toInt
+ m.executorRunTime = (taskFinish - taskStart).toInt
+ }
+ //TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c
+ // we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could
+ // just change the relevants bytes in the byte buffer
val accumUpdates = Accumulators.values
- val result = new TaskResult(value, accumUpdates)
+ val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null))
val serializedResult = ser.serialize(result)
logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit)
context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala
new file mode 100644
index 0000000000..93bbb6b458
--- /dev/null
+++ b/core/src/main/scala/spark/executor/TaskMetrics.scala
@@ -0,0 +1,78 @@
+package spark.executor
+
+class TaskMetrics extends Serializable {
+ /**
+ * Time taken on the executor to deserialize this task
+ */
+ var executorDeserializeTime: Int = _
+
+ /**
+ * Time the executor spends actually running the task (including fetching shuffle data)
+ */
+ var executorRunTime:Int = _
+
+ /**
+ * The number of bytes this task transmitted back to the driver as the TaskResult
+ */
+ var resultSize: Long = _
+
+ /**
+ * If this task reads from shuffle output, metrics on getting shuffle data will be collected here
+ */
+ var shuffleReadMetrics: Option[ShuffleReadMetrics] = None
+
+ /**
+ * If this task writes to shuffle output, metrics on the written shuffle data will be collected here
+ */
+ var shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None
+}
+
+object TaskMetrics {
+ private[spark] def empty(): TaskMetrics = new TaskMetrics
+}
+
+
+class ShuffleReadMetrics extends Serializable {
+ /**
+ * Total number of blocks fetched in a shuffle (remote or local)
+ */
+ var totalBlocksFetched : Int = _
+
+ /**
+ * Number of remote blocks fetched in a shuffle
+ */
+ var remoteBlocksFetched: Int = _
+
+ /**
+ * Local blocks fetched in a shuffle
+ */
+ var localBlocksFetched: Int = _
+
+ /**
+ * Total time to read shuffle data
+ */
+ var shuffleReadMillis: Long = _
+
+ /**
+ * Total time that is spent blocked waiting for shuffle to fetch data
+ */
+ var fetchWaitTime: Long = _
+
+ /**
+ * The total amount of time for all the shuffle fetches. This adds up time from overlapping
+ * shuffles, so can be longer than task time
+ */
+ var remoteFetchTime: Long = _
+
+ /**
+ * Total number of remote bytes read from a shuffle
+ */
+ var remoteBytesRead: Long = _
+}
+
+class ShuffleWriteMetrics extends Serializable {
+ /**
+ * Number of bytes written for a shuffle
+ */
+ var shuffleBytesWritten: Long = _
+}
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 5200fb6b65..65b4621b87 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -102,7 +102,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
- for ((k, vs) <- fetcher.fetch[K, Seq[Any]](shuffleId, split.index)) {
+ val fetchItr = fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics)
+ for ((k, vs) <- fetchItr) {
getSeq(k)(depNum) ++= vs
}
}
diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
index 0d16cf6e85..6d862c0c28 100644
--- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -37,8 +37,8 @@ class CoalescedRDD[T: ClassManifest](
prevSplits.map(_.index).map{idx => new CoalescedRDDPartition(idx, prev, Array(idx)) }
} else {
(0 until maxPartitions).map { i =>
- val rangeStart = (i * prevSplits.length) / maxPartitions
- val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions
+ val rangeStart = ((i.toLong * prevSplits.length) / maxPartitions).toInt
+ val rangeEnd = (((i.toLong + 1) * prevSplits.length) / maxPartitions).toInt
new CoalescedRDDPartition(i, prev, (rangeStart until rangeEnd).toArray)
}.toArray
}
diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index 78097502bc..cbf5512e24 100644
--- a/core/src/main/scala/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -16,6 +16,8 @@ import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
import spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, TaskContext}
+import spark.util.NextIterator
+import org.apache.hadoop.conf.Configurable
/**
@@ -49,6 +51,9 @@ class HadoopRDD[K, V](
override def getPartitions: Array[Partition] = {
val inputFormat = createInputFormat(conf)
+ if (inputFormat.isInstanceOf[Configurable]) {
+ inputFormat.asInstanceOf[Configurable].setConf(conf)
+ }
val inputSplits = inputFormat.getSplits(conf, minSplits)
val array = new Array[Partition](inputSplits.size)
for (i <- 0 until inputSplits.size) {
@@ -62,47 +67,34 @@ class HadoopRDD[K, V](
.asInstanceOf[InputFormat[K, V]]
}
- override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
+ override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] {
val split = theSplit.asInstanceOf[HadoopPartition]
var reader: RecordReader[K, V] = null
val conf = confBroadcast.value.value
val fmt = createInputFormat(conf)
+ if (fmt.isInstanceOf[Configurable]) {
+ fmt.asInstanceOf[Configurable].setConf(conf)
+ }
reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback{ () => close() }
+ context.addOnCompleteCallback{ () => closeIfNeeded() }
val key: K = reader.createKey()
val value: V = reader.createValue()
- var gotNext = false
- var finished = false
-
- override def hasNext: Boolean = {
- if (!gotNext) {
- try {
- finished = !reader.next(key, value)
- } catch {
- case eof: EOFException =>
- finished = true
- }
- gotNext = true
- }
- !finished
- }
- override def next: (K, V) = {
- if (!gotNext) {
+ override def getNext() = {
+ try {
finished = !reader.next(key, value)
+ } catch {
+ case eof: EOFException =>
+ finished = true
}
- if (finished) {
- throw new NoSuchElementException("End of stream")
- }
- gotNext = false
(key, value)
}
- private def close() {
+ override def close() {
try {
reader.close()
} catch {
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index df2361025c..bdd974590a 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -3,7 +3,7 @@ package spark.rdd
import java.text.SimpleDateFormat
import java.util.Date
-import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
@@ -42,6 +42,9 @@ class NewHadoopRDD[K, V](
override def getPartitions: Array[Partition] = {
val inputFormat = inputFormatClass.newInstance
+ if (inputFormat.isInstanceOf[Configurable]) {
+ inputFormat.asInstanceOf[Configurable].setConf(conf)
+ }
val jobContext = newJobContext(conf, jobId)
val rawSplits = inputFormat.getSplits(jobContext).toArray
val result = new Array[Partition](rawSplits.size)
@@ -57,6 +60,9 @@ class NewHadoopRDD[K, V](
val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
+ if (format.isInstanceOf[Configurable]) {
+ format.asInstanceOf[Configurable].setConf(conf)
+ }
val reader = format.createRecordReader(
split.serializableHadoopSplit.value, hadoopAttemptContext)
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index c2f118305f..51f02409b6 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -28,6 +28,6 @@ class ShuffledRDD[K, V](
override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index)
+ SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics)
}
}
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index daf9cc993c..0a02561062 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -1,7 +1,8 @@
package spark.rdd
-import java.util.{HashSet => JHashSet}
+import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
import spark.RDD
import spark.Partitioner
import spark.Dependency
@@ -27,10 +28,10 @@ import spark.OneToOneDependency
* you can use `rdd1`'s partitioner/partition size and not worry about running
* out of memory because of the size of `rdd2`.
*/
-private[spark] class SubtractedRDD[T: ClassManifest](
- @transient var rdd1: RDD[T],
- @transient var rdd2: RDD[T],
- part: Partitioner) extends RDD[T](rdd1.context, Nil) {
+private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest](
+ @transient var rdd1: RDD[(K, V)],
+ @transient var rdd2: RDD[(K, W)],
+ part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) {
override def getDependencies: Seq[Dependency[_]] = {
Seq(rdd1, rdd2).map { rdd =>
@@ -39,26 +40,7 @@ private[spark] class SubtractedRDD[T: ClassManifest](
new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
- val mapSideCombinedRDD = rdd.mapPartitions(i => {
- val set = new JHashSet[T]()
- while (i.hasNext) {
- set.add(i.next)
- }
- set.iterator
- }, true)
- // ShuffleDependency requires a tuple (k, v), which it will partition by k.
- // We need this to partition to map to the same place as the k for
- // OneToOneDependency, which means:
- // - for already-tupled RDD[(A, B)], into getPartition(a)
- // - for non-tupled RDD[C], into getPartition(c)
- val part2 = new Partitioner() {
- def numPartitions = part.numPartitions
- def getPartition(key: Any) = key match {
- case (k, v) => part.getPartition(k)
- case k => part.getPartition(k)
- }
- }
- new ShuffleDependency(mapSideCombinedRDD.map((_, null)), part2)
+ new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part)
}
}
}
@@ -81,22 +63,32 @@ private[spark] class SubtractedRDD[T: ClassManifest](
override val partitioner = Some(part)
- override def compute(p: Partition, context: TaskContext): Iterator[T] = {
+ override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
- val set = new JHashSet[T]
- def integrate(dep: CoGroupSplitDep, op: T => Unit) = dep match {
+ val map = new JHashMap[K, ArrayBuffer[V]]
+ def getSeq(k: K): ArrayBuffer[V] = {
+ val seq = map.get(k)
+ if (seq != null) {
+ seq
+ } else {
+ val seq = new ArrayBuffer[V]()
+ map.put(k, seq)
+ seq
+ }
+ }
+ def integrate(dep: CoGroupSplitDep, op: ((K, V)) => Unit) = dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
- for (k <- rdd.iterator(itsSplit, context))
- op(k.asInstanceOf[T])
+ for (t <- rdd.iterator(itsSplit, context))
+ op(t.asInstanceOf[(K, V)])
case ShuffleCoGroupSplitDep(shuffleId) =>
- for ((k, _) <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index))
- op(k.asInstanceOf[T])
+ for (t <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
+ op(t.asInstanceOf[(K, V)])
}
- // the first dep is rdd1; add all keys to the set
- integrate(partition.deps(0), set.add)
- // the second dep is rdd2; remove all of its keys from the set
- integrate(partition.deps(1), set.remove)
- set.iterator
+ // the first dep is rdd1; add all values to the map
+ integrate(partition.deps(0), t => getSeq(t._1) += t._2)
+ // the second dep is rdd2; remove all of its keys
+ integrate(partition.deps(1), t => map.remove(t._1))
+ map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten
}
override def clearDependencies() {
@@ -105,4 +97,4 @@ private[spark] class SubtractedRDD[T: ClassManifest](
rdd2 = null
}
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index bf0837c066..c54dce51d7 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -1,20 +1,19 @@
package spark.scheduler
-import java.net.URI
+import cluster.TaskInfo
import java.util.concurrent.atomic.AtomicInteger
-import java.util.concurrent.Future
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import spark._
+import spark.executor.TaskMetrics
import spark.partial.ApproximateActionListener
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
import spark.storage.BlockManagerMaster
-import spark.storage.BlockManagerId
-import util.{MetadataCleaner, TimeStampedHashMap}
+import spark.util.{MetadataCleaner, TimeStampedHashMap}
/**
* A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
@@ -40,8 +39,10 @@ class DAGScheduler(
task: Task[_],
reason: TaskEndReason,
result: Any,
- accumUpdates: Map[Long, Any]) {
- eventQueue.put(CompletionEvent(task, reason, result, accumUpdates))
+ accumUpdates: Map[Long, Any],
+ taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics) {
+ eventQueue.put(CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
}
// Called by TaskScheduler when an executor fails.
@@ -73,6 +74,10 @@ class DAGScheduler(
val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
+ private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
+
+ private[spark] val sparkListeners = ArrayBuffer[SparkListener]()
+
var cacheLocs = new HashMap[Int, Array[List[String]]]
// For tracking failed nodes, we use the MapOutputTracker's generation number, which is
@@ -148,6 +153,7 @@ class DAGScheduler(
val id = nextStageId.getAndIncrement()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority)
idToStage(id) = stage
+ stageToInfos(stage) = StageInfo(stage)
stage
}
@@ -379,29 +385,34 @@ class DAGScheduler(
* We run the operation in a separate thread just in case it takes a bunch of time, so that we
* don't block the DAGScheduler event loop or other concurrent jobs.
*/
- private def runLocally(job: ActiveJob) {
+ protected def runLocally(job: ActiveJob) {
logInfo("Computing the requested partition locally")
new Thread("Local computation of job " + job.runId) {
override def run() {
- try {
- SparkEnv.set(env)
- val rdd = job.finalStage.rdd
- val split = rdd.partitions(job.partitions(0))
- val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
- try {
- val result = job.func(taskContext, rdd.iterator(split, taskContext))
- job.listener.taskSucceeded(0, result)
- } finally {
- taskContext.executeOnCompleteCallbacks()
- }
- } catch {
- case e: Exception =>
- job.listener.jobFailed(e)
- }
+ runLocallyWithinThread(job)
}
}.start()
}
+ // Broken out for easier testing in DAGSchedulerSuite.
+ protected def runLocallyWithinThread(job: ActiveJob) {
+ try {
+ SparkEnv.set(env)
+ val rdd = job.finalStage.rdd
+ val split = rdd.partitions(job.partitions(0))
+ val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
+ try {
+ val result = job.func(taskContext, rdd.iterator(split, taskContext))
+ job.listener.taskSucceeded(0, result)
+ } finally {
+ taskContext.executeOnCompleteCallbacks()
+ }
+ } catch {
+ case e: Exception =>
+ job.listener.jobFailed(e)
+ }
+ }
+
/** Submits stage, but first recursively submits any missing parents. */
private def submitStage(stage: Stage) {
logDebug("submitStage(" + stage + ")")
@@ -472,6 +483,8 @@ class DAGScheduler(
case _ => "Unkown"
}
logInfo("%s (%s) finished in %s s".format(stage, stage.origin, serviceTime))
+ val stageComp = StageCompleted(stageToInfos(stage))
+ sparkListeners.foreach{_.onStageCompleted(stageComp)}
running -= stage
}
event.reason match {
@@ -481,6 +494,7 @@ class DAGScheduler(
Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
}
pendingTasks(stage) -= task
+ stageToInfos(stage).taskInfos += event.taskInfo -> event.taskMetrics
task match {
case rt: ResultTask[_, _] =>
resultStageToJob.get(stage) match {
@@ -501,7 +515,6 @@ class DAGScheduler(
}
case smt: ShuffleMapTask =>
- val stage = idToStage(smt.stageId)
val status = event.result.asInstanceOf[MapStatus]
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
index b34fa78c07..ed0b9bf178 100644
--- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
@@ -1,8 +1,10 @@
package spark.scheduler
+import spark.scheduler.cluster.TaskInfo
import scala.collection.mutable.Map
import spark._
+import spark.executor.TaskMetrics
/**
* Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue
@@ -25,7 +27,9 @@ private[spark] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
result: Any,
- accumUpdates: Map[Long, Any])
+ accumUpdates: Map[Long, Any],
+ taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index 1721f78f48..beb21a76fe 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -72,6 +72,7 @@ private[spark] class ResultTask[T, U](
override def run(attemptId: Long): U = {
val context = new TaskContext(stageId, partition, attemptId)
+ metrics = Some(context.taskMetrics)
try {
func(context, rdd.iterator(split, context))
} finally {
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 59ee3c0a09..36d087a4d0 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -13,6 +13,7 @@ import com.ning.compress.lzf.LZFInputStream
import com.ning.compress.lzf.LZFOutputStream
import spark._
+import executor.ShuffleWriteMetrics
import spark.storage._
import util.{TimeStampedHashMap, MetadataCleaner}
@@ -119,6 +120,7 @@ private[spark] class ShuffleMapTask(
val numOutputSplits = dep.partitioner.numPartitions
val taskContext = new TaskContext(stageId, partition, attemptId)
+ metrics = Some(taskContext.taskMetrics)
try {
// Partition the map output.
val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
@@ -130,14 +132,20 @@ private[spark] class ShuffleMapTask(
val compressedSizes = new Array[Byte](numOutputSplits)
+ var totalBytes = 0l
+
val blockManager = SparkEnv.get.blockManager
for (i <- 0 until numOutputSplits) {
val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
// Get a Scala iterator from Java map
val iter: Iterator[(Any, Any)] = buckets(i).iterator
val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
+ totalBytes += size
compressedSizes(i) = MapOutputTracker.compressSize(size)
}
+ val shuffleMetrics = new ShuffleWriteMetrics
+ shuffleMetrics.shuffleBytesWritten = totalBytes
+ metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
return new MapStatus(blockManager.blockManagerId, compressedSizes)
} finally {
diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala
new file mode 100644
index 0000000000..a65140b145
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/SparkListener.scala
@@ -0,0 +1,146 @@
+package spark.scheduler
+
+import spark.scheduler.cluster.TaskInfo
+import spark.util.Distribution
+import spark.{Utils, Logging}
+import spark.executor.TaskMetrics
+
+trait SparkListener {
+ /**
+ * called when a stage is completed, with information on the completed stage
+ */
+ def onStageCompleted(stageCompleted: StageCompleted)
+}
+
+sealed trait SparkListenerEvents
+
+case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
+
+
+/**
+ * Simple SparkListener that logs a few summary statistics when each stage completes
+ */
+class StatsReportListener extends SparkListener with Logging {
+ def onStageCompleted(stageCompleted: StageCompleted) {
+ import spark.scheduler.StatsReportListener._
+ implicit val sc = stageCompleted
+ this.logInfo("Finished stage: " + stageCompleted.stageInfo)
+ showMillisDistribution("task runtime:", (info, _) => Some(info.duration))
+
+ //shuffle write
+ showBytesDistribution("shuffle bytes written:",(_,metric) => metric.shuffleWriteMetrics.map{_.shuffleBytesWritten})
+
+ //fetch & io
+ showMillisDistribution("fetch wait time:",(_, metric) => metric.shuffleReadMetrics.map{_.fetchWaitTime})
+ showBytesDistribution("remote bytes read:", (_, metric) => metric.shuffleReadMetrics.map{_.remoteBytesRead})
+ showBytesDistribution("task result size:", (_, metric) => Some(metric.resultSize))
+
+ //runtime breakdown
+ val runtimePcts = stageCompleted.stageInfo.taskInfos.map{
+ case (info, metrics) => RuntimePercentage(info.duration, metrics)
+ }
+ showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%")
+ showDistribution("fetch wait time pct: ", Distribution(runtimePcts.flatMap{_.fetchPct.map{_ * 100}}), "%2.0f %%")
+ showDistribution("other time pct: ", Distribution(runtimePcts.map{_.other * 100}), "%2.0f %%")
+ }
+
+}
+
+object StatsReportListener extends Logging {
+
+ //for profiling, the extremes are more interesting
+ val percentiles = Array[Int](0,5,10,25,50,75,90,95,100)
+ val probabilities = percentiles.map{_ / 100.0}
+ val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
+
+ def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = {
+ Distribution(stage.stageInfo.taskInfos.flatMap{
+ case ((info,metric)) => getMetric(info, metric)})
+ }
+
+ //is there some way to setup the types that I can get rid of this completely?
+ def extractLongDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Long]): Option[Distribution] = {
+ extractDoubleDistribution(stage, (info, metric) => getMetric(info,metric).map{_.toDouble})
+ }
+
+ def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) {
+ val stats = d.statCounter
+ logInfo(heading + stats)
+ val quantiles = d.getQuantiles(probabilities).map{formatNumber}
+ logInfo(percentilesHeader)
+ logInfo("\t" + quantiles.mkString("\t"))
+ }
+
+ def showDistribution(heading: String, dOpt: Option[Distribution], formatNumber: Double => String) {
+ dOpt.foreach { d => showDistribution(heading, d, formatNumber)}
+ }
+
+ def showDistribution(heading: String, dOpt: Option[Distribution], format:String) {
+ def f(d:Double) = format.format(d)
+ showDistribution(heading, dOpt, f _)
+ }
+
+ def showDistribution(heading:String, format: String, getMetric: (TaskInfo,TaskMetrics) => Option[Double])
+ (implicit stage: StageCompleted) {
+ showDistribution(heading, extractDoubleDistribution(stage, getMetric), format)
+ }
+
+ def showBytesDistribution(heading:String, getMetric: (TaskInfo,TaskMetrics) => Option[Long])
+ (implicit stage: StageCompleted) {
+ showBytesDistribution(heading, extractLongDistribution(stage, getMetric))
+ }
+
+ def showBytesDistribution(heading: String, dOpt: Option[Distribution]) {
+ dOpt.foreach{dist => showBytesDistribution(heading, dist)}
+ }
+
+ def showBytesDistribution(heading: String, dist: Distribution) {
+ showDistribution(heading, dist, (d => Utils.memoryBytesToString(d.toLong)): Double => String)
+ }
+
+ def showMillisDistribution(heading: String, dOpt: Option[Distribution]) {
+ showDistribution(heading, dOpt, (d => StatsReportListener.millisToString(d.toLong)): Double => String)
+ }
+
+ def showMillisDistribution(heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long])
+ (implicit stage: StageCompleted) {
+ showMillisDistribution(heading, extractLongDistribution(stage, getMetric))
+ }
+
+
+
+ val seconds = 1000L
+ val minutes = seconds * 60
+ val hours = minutes * 60
+
+ /**
+ * reformat a time interval in milliseconds to a prettier format for output
+ */
+ def millisToString(ms: Long) = {
+ val (size, units) =
+ if (ms > hours) {
+ (ms.toDouble / hours, "hours")
+ } else if (ms > minutes) {
+ (ms.toDouble / minutes, "min")
+ } else if (ms > seconds) {
+ (ms.toDouble / seconds, "s")
+ } else {
+ (ms.toDouble, "ms")
+ }
+ "%.1f %s".format(size, units)
+ }
+}
+
+
+
+case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double)
+object RuntimePercentage {
+ def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
+ val denom = totalTime.toDouble
+ val fetchTime = metrics.shuffleReadMetrics.map{_.fetchWaitTime}
+ val fetch = fetchTime.map{_ / denom}
+ val exec = (metrics.executorRunTime - fetchTime.getOrElse(0l)) / denom
+ val other = 1.0 - (exec + fetch.getOrElse(0d))
+ RuntimePercentage(exec, fetch, other)
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/StageInfo.scala b/core/src/main/scala/spark/scheduler/StageInfo.scala
new file mode 100644
index 0000000000..8d83ff10c4
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/StageInfo.scala
@@ -0,0 +1,12 @@
+package spark.scheduler
+
+import spark.scheduler.cluster.TaskInfo
+import scala.collection._
+import spark.executor.TaskMetrics
+
+case class StageInfo(
+ val stage: Stage,
+ val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]()
+) {
+ override def toString = stage.rdd.toString
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala
index ef987fdeb6..a6462c6968 100644
--- a/core/src/main/scala/spark/scheduler/Task.scala
+++ b/core/src/main/scala/spark/scheduler/Task.scala
@@ -1,12 +1,12 @@
package spark.scheduler
-import scala.collection.mutable.HashMap
-import spark.serializer.{SerializerInstance, Serializer}
+import spark.serializer.SerializerInstance
import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import spark.util.ByteBufferInputStream
import scala.collection.mutable.HashMap
+import spark.executor.TaskMetrics
/**
* A task to execute on a worker node.
@@ -16,6 +16,9 @@ private[spark] abstract class Task[T](val stageId: Int) extends Serializable {
def preferredLocations: Seq[String] = Nil
var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler.
+
+ var metrics: Option[TaskMetrics] = None
+
}
/**
diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala
index 9a54d0e854..6de0aa7adf 100644
--- a/core/src/main/scala/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/spark/scheduler/TaskResult.scala
@@ -3,13 +3,14 @@ package spark.scheduler
import java.io._
import scala.collection.mutable.Map
+import spark.executor.TaskMetrics
// Task result. Also contains updates to accumulator variables.
// TODO: Use of distributed cache to return result is a hack to get around
// what seems to be a bug with messages over 60KB in libprocess; fix it
private[spark]
-class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Externalizable {
- def this() = this(null.asInstanceOf[T], null)
+class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics) extends Externalizable {
+ def this() = this(null.asInstanceOf[T], null, null)
override def writeExternal(out: ObjectOutput) {
out.writeObject(value)
@@ -18,6 +19,7 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Exte
out.writeLong(key)
out.writeObject(value)
}
+ out.writeObject(metrics)
}
override def readExternal(in: ObjectInput) {
@@ -31,5 +33,6 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Exte
accumUpdates(in.readLong()) = in.readObject()
}
}
+ metrics = in.readObject().asInstanceOf[TaskMetrics]
}
}
diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
index 9fcef86e46..771518dddf 100644
--- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
+++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
@@ -1,15 +1,18 @@
package spark.scheduler
+import spark.scheduler.cluster.TaskInfo
import scala.collection.mutable.Map
import spark.TaskEndReason
+import spark.executor.TaskMetrics
/**
* Interface for getting events back from the TaskScheduler.
*/
private[spark] trait TaskSchedulerListener {
// A task has finished or failed.
- def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit
+ def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any],
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
// A node was lost from the cluster.
def executorLost(execId: String): Unit
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index d9c2f9517b..26fdef101b 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -140,6 +140,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// Mark each slave as alive and remember its hostname
for (o <- offers) {
executorIdToHost(o.executorId) = o.hostname
+ if (!executorsByHost.contains(o.hostname)) {
+ executorsByHost(o.hostname) = new HashSet()
+ }
}
// Build a list of tasks to assign to each slave
val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
@@ -159,9 +162,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
taskSetTaskIds(manager.taskSet.id) += tid
taskIdToExecutorId(tid) = execId
activeExecutorIds += execId
- if (!executorsByHost.contains(host)) {
- executorsByHost(host) = new HashSet()
- }
executorsByHost(host) += execId
availableCpus(i) -= 1
launchedTask = true
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
index 0f975ce1eb..dfe3c5a85b 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
@@ -9,7 +9,8 @@ class TaskInfo(
val index: Int,
val launchTime: Long,
val executorId: String,
- val host: String) {
+ val host: String,
+ val preferred: Boolean) {
var finishTime: Long = 0
var failed = false
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index 3dabdd76b1..c9f2c48804 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -208,7 +208,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
taskSet.id, index, taskId, execId, host, prefStr))
// Do various bookkeeping
copiesRunning(index) += 1
- val info = new TaskInfo(taskId, index, time, execId, host)
+ val info = new TaskInfo(taskId, index, time, execId, host, preferred)
taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index)
if (preferred) {
@@ -259,7 +259,8 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
tid, info.duration, tasksFinished, numTasks))
// Deserialize task result and pass it to the scheduler
val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
- sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
+ result.metrics.resultSize = serializedData.limit()
+ sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
// Mark finished and stop if we've finished all the tasks
finished(index) = true
if (tasksFinished == numTasks) {
@@ -290,7 +291,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
reason match {
case fetchFailed: FetchFailed =>
logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.listener.taskEnded(tasks(index), fetchFailed, null, null)
+ sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
finished(index) = true
tasksFinished += 1
sched.taskSetFinished(this)
@@ -378,7 +379,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
- sched.listener.taskEnded(tasks(index), Resubmitted, null, null)
+ sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
}
}
}
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 482d1cc853..9e1bde3fbe 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -1,14 +1,13 @@
package spark.scheduler.local
import java.io.File
-import java.net.URLClassLoader
-import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.HashMap
import spark._
-import executor.ExecutorURLClassLoader
+import spark.executor.ExecutorURLClassLoader
import spark.scheduler._
+import spark.scheduler.cluster.TaskInfo
/**
* A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
@@ -54,6 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
logInfo("Running " + task)
+ val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local", true)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
try {
@@ -67,8 +67,10 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes")
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
updateDependencies(taskFiles, taskJars) // Download any files added with addFile
+ val deserStart = System.currentTimeMillis()
val deserializedTask = ser.deserialize[Task[_]](
taskBytes, Thread.currentThread.getContextClassLoader)
+ val deserTime = System.currentTimeMillis() - deserStart
// Run it
val result: Any = deserializedTask.run(attemptId)
@@ -77,14 +79,19 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
// executor does. This is useful to catch serialization errors early
// on in development (so when users move their local Spark programs
// to the cluster, they don't get surprised by serialization errors).
- val resultToReturn = ser.deserialize[Any](ser.serialize(result))
+ val serResult = ser.serialize(result)
+ deserializedTask.metrics.get.resultSize = serResult.limit()
+ val resultToReturn = ser.deserialize[Any](serResult)
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
logInfo("Finished " + task)
+ info.markSuccessful()
+ deserializedTask.metrics.get.executorRunTime = info.duration.toInt //close enough
+ deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
// If the threadpool has not already been shutdown, notify DAGScheduler
if (!Thread.currentThread().isInterrupted)
- listener.taskEnded(task, Success, resultToReturn, accumUpdates)
+ listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, deserializedTask.metrics.getOrElse(null))
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
@@ -95,7 +102,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
} else {
// TODO: Do something nicer here to return all the way to the user
if (!Thread.currentThread().isInterrupted)
- listener.taskEnded(task, new ExceptionFailure(t), null, null)
+ listener.taskEnded(task, new ExceptionFailure(t), null, null, info, null)
}
}
}
diff --git a/core/src/main/scala/spark/serializer/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala
index 50b086125a..aca86ab6f0 100644
--- a/core/src/main/scala/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/spark/serializer/Serializer.scala
@@ -72,40 +72,18 @@ trait DeserializationStream {
* Read the elements of this stream through an iterator. This can only be called once, as
* reading each element will consume data from the input source.
*/
- def asIterator: Iterator[Any] = new Iterator[Any] {
- var gotNext = false
- var finished = false
- var nextValue: Any = null
-
- private def getNext() {
+ def asIterator: Iterator[Any] = new spark.util.NextIterator[Any] {
+ override protected def getNext() = {
try {
- nextValue = readObject[Any]()
+ readObject[Any]()
} catch {
case eof: EOFException =>
finished = true
}
- gotNext = true
}
- override def hasNext: Boolean = {
- if (!gotNext) {
- getNext()
- }
- if (finished) {
- close()
- }
- !finished
- }
-
- override def next(): Any = {
- if (!gotNext) {
- getNext()
- }
- if (finished) {
- throw new NoSuchElementException("End of stream")
- }
- gotNext = false
- nextValue
+ override protected def close() {
+ DeserializationStream.this.close()
}
}
}
diff --git a/core/src/main/scala/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/spark/storage/BlockFetchTracker.scala
new file mode 100644
index 0000000000..993aece1f7
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockFetchTracker.scala
@@ -0,0 +1,10 @@
+package spark.storage
+
+private[spark] trait BlockFetchTracker {
+ def totalBlocks : Int
+ def numLocalBlocks: Int
+ def numRemoteBlocks: Int
+ def remoteFetchTime : Long
+ def fetchWaitTime: Long
+ def remoteBytesRead : Long
+}
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 2462721fb8..210061e972 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -40,21 +40,36 @@ class BlockManager(
class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
var pending: Boolean = true
var size: Long = -1L
+ var failed: Boolean = false
- /** Wait for this BlockInfo to be marked as ready (i.e. block is finished writing) */
- def waitForReady() {
+ /**
+ * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
+ * Return true if the block is available, false otherwise.
+ */
+ def waitForReady(): Boolean = {
if (pending) {
synchronized {
while (pending) this.wait()
}
}
+ !failed
}
/** Mark this BlockInfo as ready (i.e. block is finished writing) */
def markReady(sizeInBytes: Long) {
- pending = false
- size = sizeInBytes
synchronized {
+ pending = false
+ failed = false
+ size = sizeInBytes
+ this.notifyAll()
+ }
+ }
+
+ /** Mark this BlockInfo as ready but failed */
+ def markFailure() {
+ synchronized {
+ failed = true
+ pending = false
this.notifyAll()
}
}
@@ -88,7 +103,7 @@ class BlockManager(
val host = System.getProperty("spark.hostname", Utils.localHostName())
- val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
+ val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
// Pending reregistration action being executed asynchronously or null if none
@@ -277,7 +292,14 @@ class BlockManager(
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
- info.waitForReady() // In case the block is still being put() by another thread
+
+ // In the another thread is writing the block, wait for it to become ready.
+ if (!info.waitForReady()) {
+ // If we get here, the block write failed.
+ logWarning("Block " + blockId + " was marked as failure.")
+ return None
+ }
+
val level = info.level
logDebug("Level for block " + blockId + " is " + level)
@@ -362,7 +384,14 @@ class BlockManager(
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
- info.waitForReady() // In case the block is still being put() by another thread
+
+ // In the another thread is writing the block, wait for it to become ready.
+ if (!info.waitForReady()) {
+ // If we get here, the block write failed.
+ logWarning("Block " + blockId + " was marked as failure.")
+ return None
+ }
+
val level = info.level
logDebug("Level for block " + blockId + " is " + level)
@@ -423,12 +452,11 @@ class BlockManager(
val data = BlockManagerWorker.syncGetBlock(
GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port))
if (data != null) {
- logDebug("Data is not null: " + data)
return Some(dataDeserialize(blockId, data))
}
- logDebug("Data is null")
+ logDebug("The value of block " + blockId + " is null")
}
- logDebug("Data not found")
+ logDebug("Block " + blockId + " not found")
return None
}
@@ -446,152 +474,8 @@ class BlockManager(
* so that we can control the maxMegabytesInFlight for the fetch.
*/
def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])])
- : Iterator[(String, Option[Iterator[Any]])] = {
-
- if (blocksByAddress == null) {
- throw new IllegalArgumentException("BlocksByAddress is null")
- }
- val totalBlocks = blocksByAddress.map(_._2.size).sum
- logDebug("Getting " + totalBlocks + " blocks")
- var startTime = System.currentTimeMillis
- val localBlockIds = new ArrayBuffer[String]()
- val remoteBlockIds = new HashSet[String]()
-
- // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
- // the block (since we want all deserializaton to happen in the calling thread); can also
- // represent a fetch failure if size == -1.
- class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
- def failed: Boolean = size == -1
- }
-
- // A queue to hold our results.
- val results = new LinkedBlockingQueue[FetchResult]
-
- // A request to fetch one or more blocks, complete with their sizes
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
- val size = blocks.map(_._2).sum
- }
-
- // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
- // the number of bytes in flight is limited to maxBytesInFlight
- val fetchRequests = new Queue[FetchRequest]
-
- // Current bytes in flight from our requests
- var bytesInFlight = 0L
-
- def sendRequest(req: FetchRequest) {
- logDebug("Sending request for %d blocks (%s) from %s".format(
- req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
- val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
- val blockMessageArray = new BlockMessageArray(req.blocks.map {
- case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
- })
- bytesInFlight += req.size
- val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
- val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
- future.onSuccess {
- case Some(message) => {
- val bufferMessage = message.asInstanceOf[BufferMessage]
- val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
- for (blockMessage <- blockMessageArray) {
- if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
- throw new SparkException(
- "Unexpected message " + blockMessage.getType + " received from " + cmId)
- }
- val blockId = blockMessage.getId
- results.put(new FetchResult(
- blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
- logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
- }
- }
- case None => {
- logError("Could not get block(s) from " + cmId)
- for ((blockId, size) <- req.blocks) {
- results.put(new FetchResult(blockId, -1, null))
- }
- }
- }
- }
-
- // Partition local and remote blocks. Remote blocks are further split into FetchRequests of size
- // at most maxBytesInFlight in order to limit the amount of data in flight.
- val remoteRequests = new ArrayBuffer[FetchRequest]
- for ((address, blockInfos) <- blocksByAddress) {
- if (address == blockManagerId) {
- localBlockIds ++= blockInfos.map(_._1)
- } else {
- remoteBlockIds ++= blockInfos.map(_._1)
- // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
- // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
- // nodes, rather than blocking on reading output from one node.
- val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
- logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
- val iterator = blockInfos.iterator
- var curRequestSize = 0L
- var curBlocks = new ArrayBuffer[(String, Long)]
- while (iterator.hasNext) {
- val (blockId, size) = iterator.next()
- curBlocks += ((blockId, size))
- curRequestSize += size
- if (curRequestSize >= minRequestSize) {
- // Add this FetchRequest
- remoteRequests += new FetchRequest(address, curBlocks)
- curRequestSize = 0
- curBlocks = new ArrayBuffer[(String, Long)]
- }
- }
- // Add in the final request
- if (!curBlocks.isEmpty) {
- remoteRequests += new FetchRequest(address, curBlocks)
- }
- }
- }
- // Add the remote requests into our queue in a random order
- fetchRequests ++= Utils.randomize(remoteRequests)
-
- // Send out initial requests for blocks, up to our maxBytesInFlight
- while (!fetchRequests.isEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
- sendRequest(fetchRequests.dequeue())
- }
-
- val numGets = remoteBlockIds.size - fetchRequests.size
- logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
-
- // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
- // these all at once because they will just memory-map some files, so they won't consume
- // any memory that might exceed our maxBytesInFlight
- startTime = System.currentTimeMillis
- for (id <- localBlockIds) {
- getLocal(id) match {
- case Some(iter) => {
- results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
- logDebug("Got local block " + id)
- }
- case None => {
- throw new BlockException(id, "Could not get block " + id + " from local machine")
- }
- }
- }
- logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
-
- // Return an iterator that will read fetched blocks off the queue as they arrive.
- return new Iterator[(String, Option[Iterator[Any]])] {
- var resultsGotten = 0
-
- def hasNext: Boolean = resultsGotten < totalBlocks
-
- def next(): (String, Option[Iterator[Any]]) = {
- resultsGotten += 1
- val result = results.take()
- bytesInFlight -= result.size
- while (!fetchRequests.isEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
- sendRequest(fetchRequests.dequeue())
- }
- (result.blockId, if (result.failed) None else Some(result.deserialize()))
- }
- }
+ : BlockFetcherIterator = {
+ return new BlockFetcherIterator(this, blocksByAddress)
}
def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
@@ -618,9 +502,8 @@ class BlockManager(
}
val oldBlock = blockInfo.get(blockId).orNull
- if (oldBlock != null) {
+ if (oldBlock != null && oldBlock.waitForReady()) {
logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
- oldBlock.waitForReady()
return oldBlock.size
}
@@ -648,31 +531,45 @@ class BlockManager(
logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
- if (level.useMemory) {
- // Save it just to memory first, even if it also has useDisk set to true; we will later
- // drop it to disk if the memory store can't hold it.
- val res = memoryStore.putValues(blockId, values, level, true)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case Left(newIterator) => valuesAfterPut = newIterator
- }
- } else {
- // Save directly to disk.
- val askForBytes = level.replication > 1 // Don't get back the bytes unless we replicate them
- val res = diskStore.putValues(blockId, values, level, askForBytes)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case _ =>
+ try {
+ if (level.useMemory) {
+ // Save it just to memory first, even if it also has useDisk set to true; we will later
+ // drop it to disk if the memory store can't hold it.
+ val res = memoryStore.putValues(blockId, values, level, true)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case Left(newIterator) => valuesAfterPut = newIterator
+ }
+ } else {
+ // Save directly to disk.
+ // Don't get back the bytes unless we replicate them.
+ val askForBytes = level.replication > 1
+ val res = diskStore.putValues(blockId, values, level, askForBytes)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case _ =>
+ }
}
- }
- // Now that the block is in either the memory or disk store, let other threads read it,
- // and tell the master about it.
- myInfo.markReady(size)
- if (tellMaster) {
- reportBlockStatus(blockId, myInfo)
+ // Now that the block is in either the memory or disk store, let other threads read it,
+ // and tell the master about it.
+ myInfo.markReady(size)
+ if (tellMaster) {
+ reportBlockStatus(blockId, myInfo)
+ }
+ } catch {
+ // If we failed at putting the block to memory/disk, notify other possible readers
+ // that it has failed, and then remove it from the block info map.
+ case e: Exception => {
+ // Note that the remove must happen before markFailure otherwise another thread
+ // could've inserted a new BlockInfo before we remove it.
+ blockInfo.remove(blockId)
+ myInfo.markFailure()
+ logWarning("Putting block " + blockId + " failed", e)
+ throw e
+ }
}
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
@@ -742,28 +639,38 @@ class BlockManager(
logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
- if (level.useMemory) {
- // Store it only in memory at first, even if useDisk is also set to true
- bytes.rewind()
- memoryStore.putBytes(blockId, bytes, level)
- } else {
- bytes.rewind()
- diskStore.putBytes(blockId, bytes, level)
- }
+ try {
+ if (level.useMemory) {
+ // Store it only in memory at first, even if useDisk is also set to true
+ bytes.rewind()
+ memoryStore.putBytes(blockId, bytes, level)
+ } else {
+ bytes.rewind()
+ diskStore.putBytes(blockId, bytes, level)
+ }
- // Now that the block is in either the memory or disk store, let other threads read it,
- // and tell the master about it.
- myInfo.markReady(bytes.limit)
- if (tellMaster) {
- reportBlockStatus(blockId, myInfo)
+ // Now that the block is in either the memory or disk store, let other threads read it,
+ // and tell the master about it.
+ myInfo.markReady(bytes.limit)
+ if (tellMaster) {
+ reportBlockStatus(blockId, myInfo)
+ }
+ } catch {
+ // If we failed at putting the block to memory/disk, notify other possible readers
+ // that it has failed, and then remove it from the block info map.
+ case e: Exception => {
+ // Note that the remove must happen before markFailure otherwise another thread
+ // could've inserted a new BlockInfo before we remove it.
+ blockInfo.remove(blockId)
+ myInfo.markFailure()
+ logWarning("Putting block " + blockId + " failed", e)
+ throw e
+ }
}
}
// If replication had started, then wait for it to finish
if (level.replication > 1) {
- if (replicationFuture == null) {
- throw new Exception("Unexpected")
- }
Await.ready(replicationFuture, Duration.Inf)
}
@@ -946,7 +853,7 @@ class BlockManager(
heartBeatTask.cancel()
}
connectionManager.stop()
- master.actorSystem.stop(slaveActor)
+ actorSystem.stop(slaveActor)
blockInfo.clear()
memoryStore.clear()
diskStore.clear()
@@ -986,3 +893,176 @@ object BlockManager extends Logging {
}
}
}
+
+class BlockFetcherIterator(
+ private val blockManager: BlockManager,
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]
+) extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker {
+
+ import blockManager._
+
+ private var _remoteBytesRead = 0l
+ private var _remoteFetchTime = 0l
+ private var _fetchWaitTime = 0l
+
+ if (blocksByAddress == null) {
+ throw new IllegalArgumentException("BlocksByAddress is null")
+ }
+ val totalBlocks = blocksByAddress.map(_._2.size).sum
+ logDebug("Getting " + totalBlocks + " blocks")
+ var startTime = System.currentTimeMillis
+ val localBlockIds = new ArrayBuffer[String]()
+ val remoteBlockIds = new HashSet[String]()
+
+ // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
+ // the block (since we want all deserializaton to happen in the calling thread); can also
+ // represent a fetch failure if size == -1.
+ class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
+ def failed: Boolean = size == -1
+ }
+
+ // A queue to hold our results.
+ val results = new LinkedBlockingQueue[FetchResult]
+
+ // A request to fetch one or more blocks, complete with their sizes
+ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
+ val size = blocks.map(_._2).sum
+ }
+
+ // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ // the number of bytes in flight is limited to maxBytesInFlight
+ val fetchRequests = new Queue[FetchRequest]
+
+ // Current bytes in flight from our requests
+ var bytesInFlight = 0L
+
+ def sendRequest(req: FetchRequest) {
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
+ val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
+ val blockMessageArray = new BlockMessageArray(req.blocks.map {
+ case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
+ })
+ bytesInFlight += req.size
+ val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
+ val fetchStart = System.currentTimeMillis()
+ val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
+ future.onSuccess {
+ case Some(message) => {
+ val fetchDone = System.currentTimeMillis()
+ _remoteFetchTime += fetchDone - fetchStart
+ val bufferMessage = message.asInstanceOf[BufferMessage]
+ val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
+ for (blockMessage <- blockMessageArray) {
+ if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
+ throw new SparkException(
+ "Unexpected message " + blockMessage.getType + " received from " + cmId)
+ }
+ val blockId = blockMessage.getId
+ results.put(new FetchResult(
+ blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
+ _remoteBytesRead += req.size
+ logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ }
+ }
+ case None => {
+ logError("Could not get block(s) from " + cmId)
+ for ((blockId, size) <- req.blocks) {
+ results.put(new FetchResult(blockId, -1, null))
+ }
+ }
+ }
+ }
+
+ // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+ // at most maxBytesInFlight in order to limit the amount of data in flight.
+ val remoteRequests = new ArrayBuffer[FetchRequest]
+ for ((address, blockInfos) <- blocksByAddress) {
+ if (address == blockManagerId) {
+ localBlockIds ++= blockInfos.map(_._1)
+ } else {
+ remoteBlockIds ++= blockInfos.map(_._1)
+ // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
+ // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+ // nodes, rather than blocking on reading output from one node.
+ val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
+ logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
+ val iterator = blockInfos.iterator
+ var curRequestSize = 0L
+ var curBlocks = new ArrayBuffer[(String, Long)]
+ while (iterator.hasNext) {
+ val (blockId, size) = iterator.next()
+ curBlocks += ((blockId, size))
+ curRequestSize += size
+ if (curRequestSize >= minRequestSize) {
+ // Add this FetchRequest
+ remoteRequests += new FetchRequest(address, curBlocks)
+ curRequestSize = 0
+ curBlocks = new ArrayBuffer[(String, Long)]
+ }
+ }
+ // Add in the final request
+ if (!curBlocks.isEmpty) {
+ remoteRequests += new FetchRequest(address, curBlocks)
+ }
+ }
+ }
+ // Add the remote requests into our queue in a random order
+ fetchRequests ++= Utils.randomize(remoteRequests)
+
+ // Send out initial requests for blocks, up to our maxBytesInFlight
+ while (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+
+ val numGets = remoteBlockIds.size - fetchRequests.size
+ logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
+
+ // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
+ // these all at once because they will just memory-map some files, so they won't consume
+ // any memory that might exceed our maxBytesInFlight
+ startTime = System.currentTimeMillis
+ for (id <- localBlockIds) {
+ getLocal(id) match {
+ case Some(iter) => {
+ results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
+ logDebug("Got local block " + id)
+ }
+ case None => {
+ throw new BlockException(id, "Could not get block " + id + " from local machine")
+ }
+ }
+ }
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+
+ //an iterator that will read fetched blocks off the queue as they arrive.
+ var resultsGotten = 0
+
+ def hasNext: Boolean = resultsGotten < totalBlocks
+
+ def next(): (String, Option[Iterator[Any]]) = {
+ resultsGotten += 1
+ val startFetchWait = System.currentTimeMillis()
+ val result = results.take()
+ val stopFetchWait = System.currentTimeMillis()
+ _fetchWaitTime += (stopFetchWait - startFetchWait)
+ bytesInFlight -= result.size
+ while (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+ (result.blockId, if (result.failed) None else Some(result.deserialize()))
+ }
+
+
+ //methods to profile the block fetching
+ def numLocalBlocks = localBlockIds.size
+ def numRemoteBlocks = remoteBlockIds.size
+
+ def remoteFetchTime = _remoteFetchTime
+ def fetchWaitTime = _fetchWaitTime
+
+ def remoteBytesRead = _remoteBytesRead
+
+}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
index 7389bee150..036fdc3480 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -15,32 +15,14 @@ import akka.util.duration._
import spark.{Logging, SparkException, Utils}
-private[spark] class BlockManagerMaster(
- val actorSystem: ActorSystem,
- isDriver: Boolean,
- isLocal: Boolean,
- driverIp: String,
- driverPort: Int)
- extends Logging {
+private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging {
val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
- val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager"
+ val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster"
val timeout = 10.seconds
- var driverActor: ActorRef = {
- if (isDriver) {
- val driverActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)),
- name = DRIVER_AKKA_ACTOR_NAME)
- logInfo("Registered BlockManagerMaster Actor")
- driverActor
- } else {
- val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, DRIVER_AKKA_ACTOR_NAME)
- logInfo("Connecting to BlockManagerMaster: " + url)
- actorSystem.actorFor(url)
- }
- }
/** Remove a dead executor from the driver actor. This is only called on the driver side. */
def removeExecutor(execId: String) {
@@ -59,7 +41,7 @@ private[spark] class BlockManagerMaster(
/** Register the BlockManager's id with the driver. */
def registerBlockManager(
- blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
+ blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
logInfo("Trying to register BlockManager")
tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor))
logInfo("Registered BlockManager")
diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala
index 1494f90103..cff48d9909 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala
@@ -49,16 +49,16 @@ class UpdateBlockInfo(
blockManagerId.writeExternal(out)
out.writeUTF(blockId)
storageLevel.writeExternal(out)
- out.writeInt(memSize.toInt)
- out.writeInt(diskSize.toInt)
+ out.writeLong(memSize)
+ out.writeLong(diskSize)
}
override def readExternal(in: ObjectInput) {
blockManagerId = BlockManagerId(in)
blockId = in.readUTF()
storageLevel = StorageLevel(in)
- memSize = in.readInt()
- diskSize = in.readInt()
+ memSize = in.readLong()
+ diskSize = in.readLong()
}
}
diff --git a/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala b/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala
new file mode 100644
index 0000000000..f6c28dce52
--- /dev/null
+++ b/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala
@@ -0,0 +1,12 @@
+package spark.storage
+
+private[spark] trait DelegateBlockFetchTracker extends BlockFetchTracker {
+ var delegate : BlockFetchTracker = _
+ def setDelegate(d: BlockFetchTracker) {delegate = d}
+ def totalBlocks = delegate.totalBlocks
+ def numLocalBlocks = delegate.numLocalBlocks
+ def numRemoteBlocks = delegate.numRemoteBlocks
+ def remoteFetchTime = delegate.remoteFetchTime
+ def fetchWaitTime = delegate.fetchWaitTime
+ def remoteBytesRead = delegate.remoteBytesRead
+}
diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala
index ae88ff0bb1..949588476c 100644
--- a/core/src/main/scala/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/spark/storage/MemoryStore.scala
@@ -32,8 +32,8 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ bytes.rewind()
if (level.deserialized) {
- bytes.rewind()
val values = blockManager.dataDeserialize(blockId, bytes)
val elements = new ArrayBuffer[Any]
elements ++= values
@@ -58,7 +58,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
} else {
val bytes = blockManager.dataSerialize(blockId, values.iterator)
tryToPut(blockId, bytes, bytes.limit, false)
- PutResult(bytes.limit(), Right(bytes))
+ PutResult(bytes.limit(), Right(bytes.duplicate()))
}
}
diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala
index a70d1c8e78..5c406e68cb 100644
--- a/core/src/main/scala/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/spark/storage/ThreadingTest.scala
@@ -75,9 +75,8 @@ private[spark] object ThreadingTest {
System.setProperty("spark.kryoserializer.buffer.mb", "1")
val actorSystem = ActorSystem("test")
val serializer = new KryoSerializer
- val driverIp: String = System.getProperty("spark.driver.host", "localhost")
- val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
- val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, driverIp, driverPort)
+ val blockManagerMaster = new BlockManagerMaster(
+ actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))
val blockManager = new BlockManager(
"<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024)
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala
index 30aec5a663..3e805b7831 100644
--- a/core/src/main/scala/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/spark/util/AkkaUtils.scala
@@ -31,20 +31,22 @@ private[spark] object AkkaUtils {
val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt
val akkaTimeout = System.getProperty("spark.akka.timeout", "20").toInt
val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt
+ val lifecycleEvents = System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean
val akkaConf = ConfigFactory.parseString("""
akka.daemonic = on
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
akka.stdout-loglevel = "ERROR"
akka.actor.provider = "akka.remote.RemoteActorRefProvider"
akka.remote.transport = "akka.remote.netty.NettyRemoteTransport"
- akka.remote.log-remote-lifecycle-events = on
akka.remote.netty.hostname = "%s"
akka.remote.netty.port = %d
akka.remote.netty.connection-timeout = %ds
akka.remote.netty.message-frame-size = %d MiB
akka.remote.netty.execution-pool-size = %d
akka.actor.default-dispatcher.throughput = %d
- """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize))
+ akka.remote.log-remote-lifecycle-events = %s
+ """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize,
+ if (lifecycleEvents) "on" else "off"))
val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader)
diff --git a/core/src/main/scala/spark/util/CompletionIterator.scala b/core/src/main/scala/spark/util/CompletionIterator.scala
new file mode 100644
index 0000000000..8139183780
--- /dev/null
+++ b/core/src/main/scala/spark/util/CompletionIterator.scala
@@ -0,0 +1,25 @@
+package spark.util
+
+/**
+ * Wrapper around an iterator which calls a completion method after it successfully iterates through all the elements
+ */
+abstract class CompletionIterator[+A, +I <: Iterator[A]](sub: I) extends Iterator[A]{
+ def next = sub.next
+ def hasNext = {
+ val r = sub.hasNext
+ if (!r) {
+ completion
+ }
+ r
+ }
+
+ def completion()
+}
+
+object CompletionIterator {
+ def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A,I] = {
+ new CompletionIterator[A,I](sub) {
+ def completion() = completionFunction
+ }
+ }
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/util/Distribution.scala b/core/src/main/scala/spark/util/Distribution.scala
new file mode 100644
index 0000000000..24738b4307
--- /dev/null
+++ b/core/src/main/scala/spark/util/Distribution.scala
@@ -0,0 +1,65 @@
+package spark.util
+
+import java.io.PrintStream
+
+/**
+ * Util for getting some stats from a small sample of numeric values, with some handy summary functions.
+ *
+ * Entirely in memory, not intended as a good way to compute stats over large data sets.
+ *
+ * Assumes you are giving it a non-empty set of data
+ */
+class Distribution(val data: Array[Double], val startIdx: Int, val endIdx: Int) {
+ require(startIdx < endIdx)
+ def this(data: Traversable[Double]) = this(data.toArray, 0, data.size)
+ java.util.Arrays.sort(data, startIdx, endIdx)
+ val length = endIdx - startIdx
+
+ val defaultProbabilities = Array(0,0.25,0.5,0.75,1.0)
+
+ /**
+ * Get the value of the distribution at the given probabilities. Probabilities should be
+ * given from 0 to 1
+ * @param probabilities
+ */
+ def getQuantiles(probabilities: Traversable[Double] = defaultProbabilities) = {
+ probabilities.toIndexedSeq.map{p:Double => data(closestIndex(p))}
+ }
+
+ private def closestIndex(p: Double) = {
+ math.min((p * length).toInt + startIdx, endIdx - 1)
+ }
+
+ def showQuantiles(out: PrintStream = System.out) = {
+ out.println("min\t25%\t50%\t75%\tmax")
+ getQuantiles(defaultProbabilities).foreach{q => out.print(q + "\t")}
+ out.println
+ }
+
+ def statCounter = StatCounter(data.slice(startIdx, endIdx))
+
+ /**
+ * print a summary of this distribution to the given PrintStream.
+ * @param out
+ */
+ def summary(out: PrintStream = System.out) {
+ out.println(statCounter)
+ showQuantiles(out)
+ }
+}
+
+object Distribution {
+
+ def apply(data: Traversable[Double]): Option[Distribution] = {
+ if (data.size > 0)
+ Some(new Distribution(data))
+ else
+ None
+ }
+
+ def showQuantiles(out: PrintStream = System.out, quantiles: Traversable[Double]) {
+ out.println("min\t25%\t50%\t75%\tmax")
+ quantiles.foreach{q => out.print(q + "\t")}
+ out.println
+ }
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/util/NextIterator.scala b/core/src/main/scala/spark/util/NextIterator.scala
new file mode 100644
index 0000000000..48b5018ddd
--- /dev/null
+++ b/core/src/main/scala/spark/util/NextIterator.scala
@@ -0,0 +1,71 @@
+package spark.util
+
+/** Provides a basic/boilerplate Iterator implementation. */
+private[spark] abstract class NextIterator[U] extends Iterator[U] {
+
+ private var gotNext = false
+ private var nextValue: U = _
+ private var closed = false
+ protected var finished = false
+
+ /**
+ * Method for subclasses to implement to provide the next element.
+ *
+ * If no next element is available, the subclass should set `finished`
+ * to `true` and may return any value (it will be ignored).
+ *
+ * This convention is required because `null` may be a valid value,
+ * and using `Option` seems like it might create unnecessary Some/None
+ * instances, given some iterators might be called in a tight loop.
+ *
+ * @return U, or set 'finished' when done
+ */
+ protected def getNext(): U
+
+ /**
+ * Method for subclasses to implement when all elements have been successfully
+ * iterated, and the iteration is done.
+ *
+ * <b>Note:</b> `NextIterator` cannot guarantee that `close` will be
+ * called because it has no control over what happens when an exception
+ * happens in the user code that is calling hasNext/next.
+ *
+ * Ideally you should have another try/catch, as in HadoopRDD, that
+ * ensures any resources are closed should iteration fail.
+ */
+ protected def close()
+
+ /**
+ * Calls the subclass-defined close method, but only once.
+ *
+ * Usually calling `close` multiple times should be fine, but historically
+ * there have been issues with some InputFormats throwing exceptions.
+ */
+ def closeIfNeeded() {
+ if (!closed) {
+ close()
+ closed = true
+ }
+ }
+
+ override def hasNext: Boolean = {
+ if (!finished) {
+ if (!gotNext) {
+ nextValue = getNext()
+ if (finished) {
+ closeIfNeeded()
+ }
+ gotNext = true
+ }
+ }
+ !finished
+ }
+
+ override def next(): U = {
+ if (!hasNext) {
+ throw new NoSuchElementException("End of stream")
+ }
+ gotNext = false
+ nextValue
+ }
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/util/TimedIterator.scala b/core/src/main/scala/spark/util/TimedIterator.scala
new file mode 100644
index 0000000000..539b01f4ce
--- /dev/null
+++ b/core/src/main/scala/spark/util/TimedIterator.scala
@@ -0,0 +1,32 @@
+package spark.util
+
+/**
+ * A utility for tracking the total time an iterator takes to iterate through its elements.
+ *
+ * In general, this should only be used if you expect it to take a considerable amount of time
+ * (eg. milliseconds) to get each element -- otherwise, the timing won't be very accurate,
+ * and you are probably just adding more overhead
+ */
+class TimedIterator[+A](val sub: Iterator[A]) extends Iterator[A] {
+ private var netMillis = 0l
+ private var nElems = 0
+ def hasNext = {
+ val start = System.currentTimeMillis()
+ val r = sub.hasNext
+ val end = System.currentTimeMillis()
+ netMillis += (end - start)
+ r
+ }
+ def next = {
+ val start = System.currentTimeMillis()
+ val r = sub.next
+ val end = System.currentTimeMillis()
+ netMillis += (end - start)
+ nElems += 1
+ r
+ }
+
+ def getNetMillis = netMillis
+ def getAverageTimePerItem = netMillis / nElems.toDouble
+
+}