aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-03-03 17:16:22 -0800
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-03-03 17:16:22 -0800
commit6cf4be4b3991d30b2902f895a7082361c736e88d (patch)
tree30c239317de138512c0f462111fa6344427bbe9c
parent6bfc7cad6bf58aa40975c0984b98debb33df080d (diff)
parent0bd1d00c2ab75c6133269611d498cac321d9d389 (diff)
downloadspark-6cf4be4b3991d30b2902f895a7082361c736e88d.tar.gz
spark-6cf4be4b3991d30b2902f895a7082361c736e88d.tar.bz2
spark-6cf4be4b3991d30b2902f895a7082361c736e88d.zip
Merge pull request #462 from squito/stageInfo
Track assorted metrics for each task, report summaries to user at stage completion
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala26
-rw-r--r--core/src/main/scala/spark/ShuffleFetcher.scala4
-rw-r--r--core/src/main/scala/spark/SparkContext.scala31
-rw-r--r--core/src/main/scala/spark/TaskContext.scala9
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala12
-rw-r--r--core/src/main/scala/spark/executor/TaskMetrics.scala71
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala3
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/SubtractedRDD.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala24
-rw-r--r--core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/ResultTask.scala1
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala8
-rw-r--r--core/src/main/scala/spark/scheduler/SparkListener.scala146
-rw-r--r--core/src/main/scala/spark/scheduler/StageInfo.scala12
-rw-r--r--core/src/main/scala/spark/scheduler/Task.scala7
-rw-r--r--core/src/main/scala/spark/scheduler/TaskResult.scala7
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala5
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala3
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala11
-rw-r--r--core/src/main/scala/spark/storage/BlockFetchTracker.scala10
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala321
-rw-r--r--core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala12
-rw-r--r--core/src/main/scala/spark/util/CompletionIterator.scala25
-rw-r--r--core/src/main/scala/spark/util/Distribution.scala65
-rw-r--r--core/src/main/scala/spark/util/TimedIterator.scala32
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java2
-rw-r--r--core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala10
-rw-r--r--core/src/test/scala/spark/util/DistributionSuite.scala25
30 files changed, 699 insertions, 202 deletions
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
index 86432d0127..53b0389c3a 100644
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -1,20 +1,22 @@
package spark
+import executor.{ShuffleReadMetrics, TaskMetrics}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
-import spark.storage.BlockManagerId
+import spark.storage.{DelegateBlockFetchTracker, BlockManagerId}
+import util.{CompletionIterator, TimedIterator}
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
- override def fetch[K, V](shuffleId: Int, reduceId: Int) = {
+ override def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) = {
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
-
+
val startTime = System.currentTimeMillis
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
-
+
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
for (((address, size), index) <- statuses.zipWithIndex) {
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
@@ -45,6 +47,20 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
}
}
}
- blockManager.getMultiple(blocksByAddress).flatMap(unpackBlock)
+
+ val blockFetcherItr = blockManager.getMultiple(blocksByAddress)
+ val itr = new TimedIterator(blockFetcherItr.flatMap(unpackBlock)) with DelegateBlockFetchTracker
+ itr.setDelegate(blockFetcherItr)
+ CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
+ val shuffleMetrics = new ShuffleReadMetrics
+ shuffleMetrics.shuffleReadMillis = itr.getNetMillis
+ shuffleMetrics.remoteFetchTime = itr.remoteFetchTime
+ shuffleMetrics.remoteFetchWaitTime = itr.remoteFetchWaitTime
+ shuffleMetrics.remoteBytesRead = itr.remoteBytesRead
+ shuffleMetrics.totalBlocksFetched = itr.totalBlocks
+ shuffleMetrics.localBlocksFetched = itr.numLocalBlocks
+ shuffleMetrics.remoteBlocksFetched = itr.numRemoteBlocks
+ metrics.shuffleReadMetrics = Some(shuffleMetrics)
+ })
}
}
diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala
index d9a94d4021..442e9f0269 100644
--- a/core/src/main/scala/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/spark/ShuffleFetcher.scala
@@ -1,11 +1,13 @@
package spark
+import executor.TaskMetrics
+
private[spark] abstract class ShuffleFetcher {
/**
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
- def fetch[K, V](shuffleId: Int, reduceId: Int) : Iterator[(K, V)]
+ def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) : Iterator[(K,V)]
/** Stop the fetcher */
def stop() {}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 7503b1a5ea..4957a54c1b 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -1,19 +1,15 @@
package spark
import java.io._
-import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger
-import java.net.{URI, URLClassLoader}
-import java.lang.ref.WeakReference
+import java.net.URI
import scala.collection.Map
import scala.collection.generic.Growable
-import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
-import akka.actor.Actor
-import akka.actor.Actor._
-import org.apache.hadoop.fs.{FileUtil, Path}
+import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.SequenceFileInputFormat
@@ -33,20 +29,19 @@ import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
-import org.apache.mesos.{Scheduler, MesosNativeLibrary}
+import org.apache.mesos.MesosNativeLibrary
-import spark.broadcast._
import spark.deploy.LocalSparkCluster
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
-import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
-import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler}
+import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
+import spark.scheduler._
import spark.scheduler.local.LocalScheduler
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import storage.BlockManagerUI
-import util.{MetadataCleaner, TimeStampedHashMap}
-import storage.{StorageStatus, StorageUtils, RDDInfo}
+import spark.storage.BlockManagerUI
+import spark.util.{MetadataCleaner, TimeStampedHashMap}
+import spark.storage.{StorageStatus, StorageUtils, RDDInfo}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -466,6 +461,10 @@ class SparkContext(
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
+ def addSparkListener(listener: SparkListener) {
+ dagScheduler.sparkListeners += listener
+ }
+
/**
* Return a map from the slave to the max memory available for caching and the remaining
* memory available for caching.
@@ -484,6 +483,10 @@ class SparkContext(
StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
}
+ def getStageInfo: Map[Stage,StageInfo] = {
+ dagScheduler.stageToInfos
+ }
+
/**
* Return information about blocks stored in all of the slaves
*/
diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala
index eab85f85a2..dd0609026a 100644
--- a/core/src/main/scala/spark/TaskContext.scala
+++ b/core/src/main/scala/spark/TaskContext.scala
@@ -1,9 +1,14 @@
package spark
+import executor.TaskMetrics
import scala.collection.mutable.ArrayBuffer
-
-class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable {
+class TaskContext(
+ val stageId: Int,
+ val splitId: Int,
+ val attemptId: Long,
+ val taskMetrics: TaskMetrics = TaskMetrics.empty()
+) extends Serializable {
@transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index 5de09030aa..4474ef4593 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -85,6 +85,7 @@ private[spark] class Executor extends Logging {
extends Runnable {
override def run() {
+ val startTime = System.currentTimeMillis()
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(urlClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
@@ -98,9 +99,18 @@ private[spark] class Executor extends Logging {
val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
logInfo("Its generation is " + task.generation)
env.mapOutputTracker.updateGeneration(task.generation)
+ val taskStart = System.currentTimeMillis()
val value = task.run(taskId.toInt)
+ val taskFinish = System.currentTimeMillis()
+ task.metrics.foreach{ m =>
+ m.executorDeserializeTime = (taskStart - startTime).toInt
+ m.executorRunTime = (taskFinish - taskStart).toInt
+ }
+ //TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c
+ // we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could
+ // just change the relevants bytes in the byte buffer
val accumUpdates = Accumulators.values
- val result = new TaskResult(value, accumUpdates)
+ val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null))
val serializedResult = ser.serialize(result)
logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit)
context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala
new file mode 100644
index 0000000000..800305cd6c
--- /dev/null
+++ b/core/src/main/scala/spark/executor/TaskMetrics.scala
@@ -0,0 +1,71 @@
+package spark.executor
+
+class TaskMetrics{
+ /**
+ * Time taken on the executor to deserialize this task
+ */
+ var executorDeserializeTime: Int = _
+ /**
+ * Time the executor spends actually running the task (including fetching shuffle data)
+ */
+ var executorRunTime:Int = _
+ /**
+ * The number of bytes this task transmitted back to the driver as the TaskResult
+ */
+ var resultSize: Long = _
+
+ /**
+ * If this task reads from shuffle output, metrics on getting shuffle data will be collected here
+ */
+ var shuffleReadMetrics: Option[ShuffleReadMetrics] = None
+
+ /**
+ * If this task writes to shuffle output, metrics on the written shuffle data will be collected here
+ */
+ var shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None
+
+}
+
+object TaskMetrics {
+ private[spark] def empty() : TaskMetrics = new TaskMetrics
+}
+
+
+class ShuffleReadMetrics {
+ /**
+ * Total number of blocks fetched in a shuffle (remote or local)
+ */
+ var totalBlocksFetched : Int = _
+ /**
+ * Number of remote blocks fetched in a shuffle
+ */
+ var remoteBlocksFetched: Int = _
+ /**
+ * Local blocks fetched in a shuffle
+ */
+ var localBlocksFetched: Int = _
+ /**
+ * Total time to read shuffle data
+ */
+ var shuffleReadMillis: Long = _
+ /**
+ * Total time that is spent blocked waiting for shuffle to fetch remote data
+ */
+ var remoteFetchWaitTime: Long = _
+ /**
+ * The total amount of time for all the shuffle fetches. This adds up time from overlapping
+ * shuffles, so can be longer than task time
+ */
+ var remoteFetchTime: Long = _
+ /**
+ * Total number of remote bytes read from a shuffle
+ */
+ var remoteBytesRead: Long = _
+}
+
+class ShuffleWriteMetrics {
+ /**
+ * Number of bytes written for a shuffle
+ */
+ var shuffleBytesWritten: Long = _
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 5200fb6b65..65b4621b87 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -102,7 +102,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
- for ((k, vs) <- fetcher.fetch[K, Seq[Any]](shuffleId, split.index)) {
+ val fetchItr = fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics)
+ for ((k, vs) <- fetchItr) {
getSeq(k)(depNum) ++= vs
}
}
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index c2f118305f..51f02409b6 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -28,6 +28,6 @@ class ShuffledRDD[K, V](
override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index)
+ SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics)
}
}
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index daf9cc993c..43ec90cac5 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -89,7 +89,7 @@ private[spark] class SubtractedRDD[T: ClassManifest](
for (k <- rdd.iterator(itsSplit, context))
op(k.asInstanceOf[T])
case ShuffleCoGroupSplitDep(shuffleId) =>
- for ((k, _) <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index))
+ for ((k, _) <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
op(k.asInstanceOf[T])
}
// the first dep is rdd1; add all keys to the set
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index bf0837c066..1bf5054f4d 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -1,20 +1,19 @@
package spark.scheduler
-import java.net.URI
+import cluster.TaskInfo
import java.util.concurrent.atomic.AtomicInteger
-import java.util.concurrent.Future
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import spark._
+import spark.executor.TaskMetrics
import spark.partial.ApproximateActionListener
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
import spark.storage.BlockManagerMaster
-import spark.storage.BlockManagerId
-import util.{MetadataCleaner, TimeStampedHashMap}
+import spark.util.{MetadataCleaner, TimeStampedHashMap}
/**
* A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
@@ -40,8 +39,10 @@ class DAGScheduler(
task: Task[_],
reason: TaskEndReason,
result: Any,
- accumUpdates: Map[Long, Any]) {
- eventQueue.put(CompletionEvent(task, reason, result, accumUpdates))
+ accumUpdates: Map[Long, Any],
+ taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics) {
+ eventQueue.put(CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
}
// Called by TaskScheduler when an executor fails.
@@ -73,6 +74,10 @@ class DAGScheduler(
val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
+ private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
+
+ private[spark] val sparkListeners = ArrayBuffer[SparkListener]()
+
var cacheLocs = new HashMap[Int, Array[List[String]]]
// For tracking failed nodes, we use the MapOutputTracker's generation number, which is
@@ -148,6 +153,7 @@ class DAGScheduler(
val id = nextStageId.getAndIncrement()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority)
idToStage(id) = stage
+ stageToInfos(stage) = StageInfo(stage)
stage
}
@@ -472,6 +478,8 @@ class DAGScheduler(
case _ => "Unkown"
}
logInfo("%s (%s) finished in %s s".format(stage, stage.origin, serviceTime))
+ val stageComp = StageCompleted(stageToInfos(stage))
+ sparkListeners.foreach{_.onStageCompleted(stageComp)}
running -= stage
}
event.reason match {
@@ -481,6 +489,7 @@ class DAGScheduler(
Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
}
pendingTasks(stage) -= task
+ stageToInfos(stage).taskInfos += event.taskInfo -> event.taskMetrics
task match {
case rt: ResultTask[_, _] =>
resultStageToJob.get(stage) match {
@@ -501,7 +510,6 @@ class DAGScheduler(
}
case smt: ShuffleMapTask =>
- val stage = idToStage(smt.stageId)
val status = event.result.asInstanceOf[MapStatus]
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
index b34fa78c07..ed0b9bf178 100644
--- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
@@ -1,8 +1,10 @@
package spark.scheduler
+import spark.scheduler.cluster.TaskInfo
import scala.collection.mutable.Map
import spark._
+import spark.executor.TaskMetrics
/**
* Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue
@@ -25,7 +27,9 @@ private[spark] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
result: Any,
- accumUpdates: Map[Long, Any])
+ accumUpdates: Map[Long, Any],
+ taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index 1721f78f48..beb21a76fe 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -72,6 +72,7 @@ private[spark] class ResultTask[T, U](
override def run(attemptId: Long): U = {
val context = new TaskContext(stageId, partition, attemptId)
+ metrics = Some(context.taskMetrics)
try {
func(context, rdd.iterator(split, context))
} finally {
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 59ee3c0a09..36d087a4d0 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -13,6 +13,7 @@ import com.ning.compress.lzf.LZFInputStream
import com.ning.compress.lzf.LZFOutputStream
import spark._
+import executor.ShuffleWriteMetrics
import spark.storage._
import util.{TimeStampedHashMap, MetadataCleaner}
@@ -119,6 +120,7 @@ private[spark] class ShuffleMapTask(
val numOutputSplits = dep.partitioner.numPartitions
val taskContext = new TaskContext(stageId, partition, attemptId)
+ metrics = Some(taskContext.taskMetrics)
try {
// Partition the map output.
val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
@@ -130,14 +132,20 @@ private[spark] class ShuffleMapTask(
val compressedSizes = new Array[Byte](numOutputSplits)
+ var totalBytes = 0l
+
val blockManager = SparkEnv.get.blockManager
for (i <- 0 until numOutputSplits) {
val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
// Get a Scala iterator from Java map
val iter: Iterator[(Any, Any)] = buckets(i).iterator
val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
+ totalBytes += size
compressedSizes(i) = MapOutputTracker.compressSize(size)
}
+ val shuffleMetrics = new ShuffleWriteMetrics
+ shuffleMetrics.shuffleBytesWritten = totalBytes
+ metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
return new MapStatus(blockManager.blockManagerId, compressedSizes)
} finally {
diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala
new file mode 100644
index 0000000000..21185227ab
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/SparkListener.scala
@@ -0,0 +1,146 @@
+package spark.scheduler
+
+import spark.scheduler.cluster.TaskInfo
+import spark.util.Distribution
+import spark.{Utils, Logging}
+import spark.executor.TaskMetrics
+
+trait SparkListener {
+ /**
+ * called when a stage is completed, with information on the completed stage
+ */
+ def onStageCompleted(stageCompleted: StageCompleted)
+}
+
+sealed trait SparkListenerEvents
+
+case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
+
+
+/**
+ * Simple SparkListener that logs a few summary statistics when each stage completes
+ */
+class StatsReportListener extends SparkListener with Logging {
+ def onStageCompleted(stageCompleted: StageCompleted) {
+ import spark.scheduler.StatsReportListener._
+ implicit val sc = stageCompleted
+ this.logInfo("Finished stage: " + stageCompleted.stageInfo)
+ showMillisDistribution("task runtime:", (info, _) => Some(info.duration))
+
+ //shuffle write
+ showBytesDistribution("shuffle bytes written:",(_,metric) => metric.shuffleWriteMetrics.map{_.shuffleBytesWritten})
+
+ //fetch & io
+ showMillisDistribution("fetch wait time:",(_, metric) => metric.shuffleReadMetrics.map{_.remoteFetchWaitTime})
+ showBytesDistribution("remote bytes read:", (_, metric) => metric.shuffleReadMetrics.map{_.remoteBytesRead})
+ showBytesDistribution("task result size:", (_, metric) => Some(metric.resultSize))
+
+ //runtime breakdown
+ val runtimePcts = stageCompleted.stageInfo.taskInfos.map{
+ case (info, metrics) => RuntimePercentage(info.duration, metrics)
+ }
+ showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%")
+ showDistribution("fetch wait time pct: ", Distribution(runtimePcts.flatMap{_.fetchPct.map{_ * 100}}), "%2.0f %%")
+ showDistribution("other time pct: ", Distribution(runtimePcts.map{_.other * 100}), "%2.0f %%")
+ }
+
+}
+
+object StatsReportListener extends Logging {
+
+ //for profiling, the extremes are more interesting
+ val percentiles = Array[Int](0,5,10,25,50,75,90,95,100)
+ val probabilities = percentiles.map{_ / 100.0}
+ val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
+
+ def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = {
+ Distribution(stage.stageInfo.taskInfos.flatMap{
+ case ((info,metric)) => getMetric(info, metric)})
+ }
+
+ //is there some way to setup the types that I can get rid of this completely?
+ def extractLongDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Long]): Option[Distribution] = {
+ extractDoubleDistribution(stage, (info, metric) => getMetric(info,metric).map{_.toDouble})
+ }
+
+ def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) {
+ val stats = d.statCounter
+ logInfo(heading + stats)
+ val quantiles = d.getQuantiles(probabilities).map{formatNumber}
+ logInfo(percentilesHeader)
+ logInfo("\t" + quantiles.mkString("\t"))
+ }
+
+ def showDistribution(heading: String, dOpt: Option[Distribution], formatNumber: Double => String) {
+ dOpt.foreach { d => showDistribution(heading, d, formatNumber)}
+ }
+
+ def showDistribution(heading: String, dOpt: Option[Distribution], format:String) {
+ def f(d:Double) = format.format(d)
+ showDistribution(heading, dOpt, f _)
+ }
+
+ def showDistribution(heading:String, format: String, getMetric: (TaskInfo,TaskMetrics) => Option[Double])
+ (implicit stage: StageCompleted) {
+ showDistribution(heading, extractDoubleDistribution(stage, getMetric), format)
+ }
+
+ def showBytesDistribution(heading:String, getMetric: (TaskInfo,TaskMetrics) => Option[Long])
+ (implicit stage: StageCompleted) {
+ showBytesDistribution(heading, extractLongDistribution(stage, getMetric))
+ }
+
+ def showBytesDistribution(heading: String, dOpt: Option[Distribution]) {
+ dOpt.foreach{dist => showBytesDistribution(heading, dist)}
+ }
+
+ def showBytesDistribution(heading: String, dist: Distribution) {
+ showDistribution(heading, dist, (d => Utils.memoryBytesToString(d.toLong)): Double => String)
+ }
+
+ def showMillisDistribution(heading: String, dOpt: Option[Distribution]) {
+ showDistribution(heading, dOpt, (d => StatsReportListener.millisToString(d.toLong)): Double => String)
+ }
+
+ def showMillisDistribution(heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long])
+ (implicit stage: StageCompleted) {
+ showMillisDistribution(heading, extractLongDistribution(stage, getMetric))
+ }
+
+
+
+ val seconds = 1000L
+ val minutes = seconds * 60
+ val hours = minutes * 60
+
+ /**
+ * reformat a time interval in milliseconds to a prettier format for output
+ */
+ def millisToString(ms: Long) = {
+ val (size, units) =
+ if (ms > hours) {
+ (ms.toDouble / hours, "hours")
+ } else if (ms > minutes) {
+ (ms.toDouble / minutes, "min")
+ } else if (ms > seconds) {
+ (ms.toDouble / seconds, "s")
+ } else {
+ (ms.toDouble, "ms")
+ }
+ "%.1f %s".format(size, units)
+ }
+}
+
+
+
+case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double)
+object RuntimePercentage {
+ def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
+ val denom = totalTime.toDouble
+ val fetchTime = metrics.shuffleReadMetrics.map{_.remoteFetchWaitTime}
+ val fetch = fetchTime.map{_ / denom}
+ val exec = (metrics.executorRunTime - fetchTime.getOrElse(0l)) / denom
+ val other = 1.0 - (exec + fetch.getOrElse(0d))
+ RuntimePercentage(exec, fetch, other)
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/StageInfo.scala b/core/src/main/scala/spark/scheduler/StageInfo.scala
new file mode 100644
index 0000000000..8d83ff10c4
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/StageInfo.scala
@@ -0,0 +1,12 @@
+package spark.scheduler
+
+import spark.scheduler.cluster.TaskInfo
+import scala.collection._
+import spark.executor.TaskMetrics
+
+case class StageInfo(
+ val stage: Stage,
+ val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]()
+) {
+ override def toString = stage.rdd.toString
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala
index ef987fdeb6..a6462c6968 100644
--- a/core/src/main/scala/spark/scheduler/Task.scala
+++ b/core/src/main/scala/spark/scheduler/Task.scala
@@ -1,12 +1,12 @@
package spark.scheduler
-import scala.collection.mutable.HashMap
-import spark.serializer.{SerializerInstance, Serializer}
+import spark.serializer.SerializerInstance
import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import spark.util.ByteBufferInputStream
import scala.collection.mutable.HashMap
+import spark.executor.TaskMetrics
/**
* A task to execute on a worker node.
@@ -16,6 +16,9 @@ private[spark] abstract class Task[T](val stageId: Int) extends Serializable {
def preferredLocations: Seq[String] = Nil
var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler.
+
+ var metrics: Option[TaskMetrics] = None
+
}
/**
diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala
index 9a54d0e854..6de0aa7adf 100644
--- a/core/src/main/scala/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/spark/scheduler/TaskResult.scala
@@ -3,13 +3,14 @@ package spark.scheduler
import java.io._
import scala.collection.mutable.Map
+import spark.executor.TaskMetrics
// Task result. Also contains updates to accumulator variables.
// TODO: Use of distributed cache to return result is a hack to get around
// what seems to be a bug with messages over 60KB in libprocess; fix it
private[spark]
-class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Externalizable {
- def this() = this(null.asInstanceOf[T], null)
+class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics) extends Externalizable {
+ def this() = this(null.asInstanceOf[T], null, null)
override def writeExternal(out: ObjectOutput) {
out.writeObject(value)
@@ -18,6 +19,7 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Exte
out.writeLong(key)
out.writeObject(value)
}
+ out.writeObject(metrics)
}
override def readExternal(in: ObjectInput) {
@@ -31,5 +33,6 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Exte
accumUpdates(in.readLong()) = in.readObject()
}
}
+ metrics = in.readObject().asInstanceOf[TaskMetrics]
}
}
diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
index 9fcef86e46..771518dddf 100644
--- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
+++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
@@ -1,15 +1,18 @@
package spark.scheduler
+import spark.scheduler.cluster.TaskInfo
import scala.collection.mutable.Map
import spark.TaskEndReason
+import spark.executor.TaskMetrics
/**
* Interface for getting events back from the TaskScheduler.
*/
private[spark] trait TaskSchedulerListener {
// A task has finished or failed.
- def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit
+ def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any],
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
// A node was lost from the cluster.
def executorLost(execId: String): Unit
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
index 0f975ce1eb..dfe3c5a85b 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
@@ -9,7 +9,8 @@ class TaskInfo(
val index: Int,
val launchTime: Long,
val executorId: String,
- val host: String) {
+ val host: String,
+ val preferred: Boolean) {
var finishTime: Long = 0
var failed = false
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index 3dabdd76b1..c9f2c48804 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -208,7 +208,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
taskSet.id, index, taskId, execId, host, prefStr))
// Do various bookkeeping
copiesRunning(index) += 1
- val info = new TaskInfo(taskId, index, time, execId, host)
+ val info = new TaskInfo(taskId, index, time, execId, host, preferred)
taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index)
if (preferred) {
@@ -259,7 +259,8 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
tid, info.duration, tasksFinished, numTasks))
// Deserialize task result and pass it to the scheduler
val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
- sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
+ result.metrics.resultSize = serializedData.limit()
+ sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
// Mark finished and stop if we've finished all the tasks
finished(index) = true
if (tasksFinished == numTasks) {
@@ -290,7 +291,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
reason match {
case fetchFailed: FetchFailed =>
logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.listener.taskEnded(tasks(index), fetchFailed, null, null)
+ sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
finished(index) = true
tasksFinished += 1
sched.taskSetFinished(this)
@@ -378,7 +379,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
- sched.listener.taskEnded(tasks(index), Resubmitted, null, null)
+ sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
}
}
}
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 482d1cc853..a76253ea14 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -1,14 +1,13 @@
package spark.scheduler.local
import java.io.File
-import java.net.URLClassLoader
-import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.HashMap
import spark._
-import executor.ExecutorURLClassLoader
+import spark.executor.ExecutorURLClassLoader
import spark.scheduler._
+import spark.scheduler.cluster.TaskInfo
/**
* A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
@@ -54,6 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
logInfo("Running " + task)
+ val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local", true)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
try {
@@ -81,10 +81,11 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
logInfo("Finished " + task)
+ info.markSuccessful()
// If the threadpool has not already been shutdown, notify DAGScheduler
if (!Thread.currentThread().isInterrupted)
- listener.taskEnded(task, Success, resultToReturn, accumUpdates)
+ listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, null)
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
@@ -95,7 +96,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
} else {
// TODO: Do something nicer here to return all the way to the user
if (!Thread.currentThread().isInterrupted)
- listener.taskEnded(task, new ExceptionFailure(t), null, null)
+ listener.taskEnded(task, new ExceptionFailure(t), null, null, info, null)
}
}
}
diff --git a/core/src/main/scala/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/spark/storage/BlockFetchTracker.scala
new file mode 100644
index 0000000000..ababb04305
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockFetchTracker.scala
@@ -0,0 +1,10 @@
+package spark.storage
+
+private[spark] trait BlockFetchTracker {
+ def totalBlocks : Int
+ def numLocalBlocks: Int
+ def numRemoteBlocks: Int
+ def remoteFetchTime : Long
+ def remoteFetchWaitTime: Long
+ def remoteBytesRead : Long
+}
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 2462721fb8..4964060b1c 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -446,152 +446,8 @@ class BlockManager(
* so that we can control the maxMegabytesInFlight for the fetch.
*/
def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])])
- : Iterator[(String, Option[Iterator[Any]])] = {
-
- if (blocksByAddress == null) {
- throw new IllegalArgumentException("BlocksByAddress is null")
- }
- val totalBlocks = blocksByAddress.map(_._2.size).sum
- logDebug("Getting " + totalBlocks + " blocks")
- var startTime = System.currentTimeMillis
- val localBlockIds = new ArrayBuffer[String]()
- val remoteBlockIds = new HashSet[String]()
-
- // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
- // the block (since we want all deserializaton to happen in the calling thread); can also
- // represent a fetch failure if size == -1.
- class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
- def failed: Boolean = size == -1
- }
-
- // A queue to hold our results.
- val results = new LinkedBlockingQueue[FetchResult]
-
- // A request to fetch one or more blocks, complete with their sizes
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
- val size = blocks.map(_._2).sum
- }
-
- // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
- // the number of bytes in flight is limited to maxBytesInFlight
- val fetchRequests = new Queue[FetchRequest]
-
- // Current bytes in flight from our requests
- var bytesInFlight = 0L
-
- def sendRequest(req: FetchRequest) {
- logDebug("Sending request for %d blocks (%s) from %s".format(
- req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
- val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
- val blockMessageArray = new BlockMessageArray(req.blocks.map {
- case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
- })
- bytesInFlight += req.size
- val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
- val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
- future.onSuccess {
- case Some(message) => {
- val bufferMessage = message.asInstanceOf[BufferMessage]
- val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
- for (blockMessage <- blockMessageArray) {
- if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
- throw new SparkException(
- "Unexpected message " + blockMessage.getType + " received from " + cmId)
- }
- val blockId = blockMessage.getId
- results.put(new FetchResult(
- blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
- logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
- }
- }
- case None => {
- logError("Could not get block(s) from " + cmId)
- for ((blockId, size) <- req.blocks) {
- results.put(new FetchResult(blockId, -1, null))
- }
- }
- }
- }
-
- // Partition local and remote blocks. Remote blocks are further split into FetchRequests of size
- // at most maxBytesInFlight in order to limit the amount of data in flight.
- val remoteRequests = new ArrayBuffer[FetchRequest]
- for ((address, blockInfos) <- blocksByAddress) {
- if (address == blockManagerId) {
- localBlockIds ++= blockInfos.map(_._1)
- } else {
- remoteBlockIds ++= blockInfos.map(_._1)
- // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
- // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
- // nodes, rather than blocking on reading output from one node.
- val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
- logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
- val iterator = blockInfos.iterator
- var curRequestSize = 0L
- var curBlocks = new ArrayBuffer[(String, Long)]
- while (iterator.hasNext) {
- val (blockId, size) = iterator.next()
- curBlocks += ((blockId, size))
- curRequestSize += size
- if (curRequestSize >= minRequestSize) {
- // Add this FetchRequest
- remoteRequests += new FetchRequest(address, curBlocks)
- curRequestSize = 0
- curBlocks = new ArrayBuffer[(String, Long)]
- }
- }
- // Add in the final request
- if (!curBlocks.isEmpty) {
- remoteRequests += new FetchRequest(address, curBlocks)
- }
- }
- }
- // Add the remote requests into our queue in a random order
- fetchRequests ++= Utils.randomize(remoteRequests)
-
- // Send out initial requests for blocks, up to our maxBytesInFlight
- while (!fetchRequests.isEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
- sendRequest(fetchRequests.dequeue())
- }
-
- val numGets = remoteBlockIds.size - fetchRequests.size
- logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
-
- // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
- // these all at once because they will just memory-map some files, so they won't consume
- // any memory that might exceed our maxBytesInFlight
- startTime = System.currentTimeMillis
- for (id <- localBlockIds) {
- getLocal(id) match {
- case Some(iter) => {
- results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
- logDebug("Got local block " + id)
- }
- case None => {
- throw new BlockException(id, "Could not get block " + id + " from local machine")
- }
- }
- }
- logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
-
- // Return an iterator that will read fetched blocks off the queue as they arrive.
- return new Iterator[(String, Option[Iterator[Any]])] {
- var resultsGotten = 0
-
- def hasNext: Boolean = resultsGotten < totalBlocks
-
- def next(): (String, Option[Iterator[Any]]) = {
- resultsGotten += 1
- val result = results.take()
- bytesInFlight -= result.size
- while (!fetchRequests.isEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
- sendRequest(fetchRequests.dequeue())
- }
- (result.blockId, if (result.failed) None else Some(result.deserialize()))
- }
- }
+ : BlockFetcherIterator = {
+ return new BlockFetcherIterator(this, blocksByAddress)
}
def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
@@ -986,3 +842,176 @@ object BlockManager extends Logging {
}
}
}
+
+class BlockFetcherIterator(
+ private val blockManager: BlockManager,
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]
+) extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker {
+
+ import blockManager._
+
+ private var _remoteBytesRead = 0l
+ private var _remoteFetchTime = 0l
+ private var _remoteFetchWaitTime = 0l
+
+ if (blocksByAddress == null) {
+ throw new IllegalArgumentException("BlocksByAddress is null")
+ }
+ val totalBlocks = blocksByAddress.map(_._2.size).sum
+ logDebug("Getting " + totalBlocks + " blocks")
+ var startTime = System.currentTimeMillis
+ val localBlockIds = new ArrayBuffer[String]()
+ val remoteBlockIds = new HashSet[String]()
+
+ // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
+ // the block (since we want all deserializaton to happen in the calling thread); can also
+ // represent a fetch failure if size == -1.
+ class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
+ def failed: Boolean = size == -1
+ }
+
+ // A queue to hold our results.
+ val results = new LinkedBlockingQueue[FetchResult]
+
+ // A request to fetch one or more blocks, complete with their sizes
+ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
+ val size = blocks.map(_._2).sum
+ }
+
+ // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ // the number of bytes in flight is limited to maxBytesInFlight
+ val fetchRequests = new Queue[FetchRequest]
+
+ // Current bytes in flight from our requests
+ var bytesInFlight = 0L
+
+ def sendRequest(req: FetchRequest) {
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
+ val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
+ val blockMessageArray = new BlockMessageArray(req.blocks.map {
+ case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
+ })
+ bytesInFlight += req.size
+ val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
+ val fetchStart = System.currentTimeMillis()
+ val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
+ future.onSuccess {
+ case Some(message) => {
+ val fetchDone = System.currentTimeMillis()
+ _remoteFetchTime += fetchDone - fetchStart
+ val bufferMessage = message.asInstanceOf[BufferMessage]
+ val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
+ for (blockMessage <- blockMessageArray) {
+ if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
+ throw new SparkException(
+ "Unexpected message " + blockMessage.getType + " received from " + cmId)
+ }
+ val blockId = blockMessage.getId
+ results.put(new FetchResult(
+ blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
+ _remoteBytesRead += req.size
+ logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ }
+ }
+ case None => {
+ logError("Could not get block(s) from " + cmId)
+ for ((blockId, size) <- req.blocks) {
+ results.put(new FetchResult(blockId, -1, null))
+ }
+ }
+ }
+ }
+
+ // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+ // at most maxBytesInFlight in order to limit the amount of data in flight.
+ val remoteRequests = new ArrayBuffer[FetchRequest]
+ for ((address, blockInfos) <- blocksByAddress) {
+ if (address == blockManagerId) {
+ localBlockIds ++= blockInfos.map(_._1)
+ } else {
+ remoteBlockIds ++= blockInfos.map(_._1)
+ // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
+ // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+ // nodes, rather than blocking on reading output from one node.
+ val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
+ logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
+ val iterator = blockInfos.iterator
+ var curRequestSize = 0L
+ var curBlocks = new ArrayBuffer[(String, Long)]
+ while (iterator.hasNext) {
+ val (blockId, size) = iterator.next()
+ curBlocks += ((blockId, size))
+ curRequestSize += size
+ if (curRequestSize >= minRequestSize) {
+ // Add this FetchRequest
+ remoteRequests += new FetchRequest(address, curBlocks)
+ curRequestSize = 0
+ curBlocks = new ArrayBuffer[(String, Long)]
+ }
+ }
+ // Add in the final request
+ if (!curBlocks.isEmpty) {
+ remoteRequests += new FetchRequest(address, curBlocks)
+ }
+ }
+ }
+ // Add the remote requests into our queue in a random order
+ fetchRequests ++= Utils.randomize(remoteRequests)
+
+ // Send out initial requests for blocks, up to our maxBytesInFlight
+ while (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+
+ val numGets = remoteBlockIds.size - fetchRequests.size
+ logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
+
+ // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
+ // these all at once because they will just memory-map some files, so they won't consume
+ // any memory that might exceed our maxBytesInFlight
+ startTime = System.currentTimeMillis
+ for (id <- localBlockIds) {
+ getLocal(id) match {
+ case Some(iter) => {
+ results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
+ logDebug("Got local block " + id)
+ }
+ case None => {
+ throw new BlockException(id, "Could not get block " + id + " from local machine")
+ }
+ }
+ }
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+
+ //an iterator that will read fetched blocks off the queue as they arrive.
+ var resultsGotten = 0
+
+ def hasNext: Boolean = resultsGotten < totalBlocks
+
+ def next(): (String, Option[Iterator[Any]]) = {
+ resultsGotten += 1
+ val startFetchWait = System.currentTimeMillis()
+ val result = results.take()
+ val stopFetchWait = System.currentTimeMillis()
+ _remoteFetchWaitTime += (stopFetchWait - startFetchWait)
+ bytesInFlight -= result.size
+ while (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+ (result.blockId, if (result.failed) None else Some(result.deserialize()))
+ }
+
+
+ //methods to profile the block fetching
+ def numLocalBlocks = localBlockIds.size
+ def numRemoteBlocks = remoteBlockIds.size
+
+ def remoteFetchTime = _remoteFetchTime
+ def remoteFetchWaitTime = _remoteFetchWaitTime
+
+ def remoteBytesRead = _remoteBytesRead
+
+}
diff --git a/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala b/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala
new file mode 100644
index 0000000000..5c491877ba
--- /dev/null
+++ b/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala
@@ -0,0 +1,12 @@
+package spark.storage
+
+private[spark] trait DelegateBlockFetchTracker extends BlockFetchTracker {
+ var delegate : BlockFetchTracker = _
+ def setDelegate(d: BlockFetchTracker) {delegate = d}
+ def totalBlocks = delegate.totalBlocks
+ def numLocalBlocks = delegate.numLocalBlocks
+ def numRemoteBlocks = delegate.numRemoteBlocks
+ def remoteFetchTime = delegate.remoteFetchTime
+ def remoteFetchWaitTime = delegate.remoteFetchWaitTime
+ def remoteBytesRead = delegate.remoteBytesRead
+}
diff --git a/core/src/main/scala/spark/util/CompletionIterator.scala b/core/src/main/scala/spark/util/CompletionIterator.scala
new file mode 100644
index 0000000000..8139183780
--- /dev/null
+++ b/core/src/main/scala/spark/util/CompletionIterator.scala
@@ -0,0 +1,25 @@
+package spark.util
+
+/**
+ * Wrapper around an iterator which calls a completion method after it successfully iterates through all the elements
+ */
+abstract class CompletionIterator[+A, +I <: Iterator[A]](sub: I) extends Iterator[A]{
+ def next = sub.next
+ def hasNext = {
+ val r = sub.hasNext
+ if (!r) {
+ completion
+ }
+ r
+ }
+
+ def completion()
+}
+
+object CompletionIterator {
+ def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A,I] = {
+ new CompletionIterator[A,I](sub) {
+ def completion() = completionFunction
+ }
+ }
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/util/Distribution.scala b/core/src/main/scala/spark/util/Distribution.scala
new file mode 100644
index 0000000000..24738b4307
--- /dev/null
+++ b/core/src/main/scala/spark/util/Distribution.scala
@@ -0,0 +1,65 @@
+package spark.util
+
+import java.io.PrintStream
+
+/**
+ * Util for getting some stats from a small sample of numeric values, with some handy summary functions.
+ *
+ * Entirely in memory, not intended as a good way to compute stats over large data sets.
+ *
+ * Assumes you are giving it a non-empty set of data
+ */
+class Distribution(val data: Array[Double], val startIdx: Int, val endIdx: Int) {
+ require(startIdx < endIdx)
+ def this(data: Traversable[Double]) = this(data.toArray, 0, data.size)
+ java.util.Arrays.sort(data, startIdx, endIdx)
+ val length = endIdx - startIdx
+
+ val defaultProbabilities = Array(0,0.25,0.5,0.75,1.0)
+
+ /**
+ * Get the value of the distribution at the given probabilities. Probabilities should be
+ * given from 0 to 1
+ * @param probabilities
+ */
+ def getQuantiles(probabilities: Traversable[Double] = defaultProbabilities) = {
+ probabilities.toIndexedSeq.map{p:Double => data(closestIndex(p))}
+ }
+
+ private def closestIndex(p: Double) = {
+ math.min((p * length).toInt + startIdx, endIdx - 1)
+ }
+
+ def showQuantiles(out: PrintStream = System.out) = {
+ out.println("min\t25%\t50%\t75%\tmax")
+ getQuantiles(defaultProbabilities).foreach{q => out.print(q + "\t")}
+ out.println
+ }
+
+ def statCounter = StatCounter(data.slice(startIdx, endIdx))
+
+ /**
+ * print a summary of this distribution to the given PrintStream.
+ * @param out
+ */
+ def summary(out: PrintStream = System.out) {
+ out.println(statCounter)
+ showQuantiles(out)
+ }
+}
+
+object Distribution {
+
+ def apply(data: Traversable[Double]): Option[Distribution] = {
+ if (data.size > 0)
+ Some(new Distribution(data))
+ else
+ None
+ }
+
+ def showQuantiles(out: PrintStream = System.out, quantiles: Traversable[Double]) {
+ out.println("min\t25%\t50%\t75%\tmax")
+ quantiles.foreach{q => out.print(q + "\t")}
+ out.println
+ }
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/util/TimedIterator.scala b/core/src/main/scala/spark/util/TimedIterator.scala
new file mode 100644
index 0000000000..539b01f4ce
--- /dev/null
+++ b/core/src/main/scala/spark/util/TimedIterator.scala
@@ -0,0 +1,32 @@
+package spark.util
+
+/**
+ * A utility for tracking the total time an iterator takes to iterate through its elements.
+ *
+ * In general, this should only be used if you expect it to take a considerable amount of time
+ * (eg. milliseconds) to get each element -- otherwise, the timing won't be very accurate,
+ * and you are probably just adding more overhead
+ */
+class TimedIterator[+A](val sub: Iterator[A]) extends Iterator[A] {
+ private var netMillis = 0l
+ private var nElems = 0
+ def hasNext = {
+ val start = System.currentTimeMillis()
+ val r = sub.hasNext
+ val end = System.currentTimeMillis()
+ netMillis += (end - start)
+ r
+ }
+ def next = {
+ val start = System.currentTimeMillis()
+ val r = sub.next
+ val end = System.currentTimeMillis()
+ netMillis += (end - start)
+ nElems += 1
+ r
+ }
+
+ def getNetMillis = netMillis
+ def getAverageTimePerItem = netMillis / nElems.toDouble
+
+}
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index 9ffe7c5f99..26e3ab72c0 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -423,7 +423,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContext(0, 0, 0);
+ TaskContext context = new TaskContext(0, 0, 0, null);
Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
}
diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
index 8de490eb86..b3e6ab4c0f 100644
--- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -265,7 +265,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
assert(taskSet.tasks.size >= results.size)
for ((result, i) <- results.zipWithIndex) {
if (i < taskSet.tasks.size) {
- runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()))
+ runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any](), null, null))
}
}
}
@@ -463,14 +463,14 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
val noAccum = Map[Long, Any]()
// We rely on the event queue being ordered and increasing the generation number by 1
// should be ignored for being too old
- runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
// should work because it's a non-failed host
- runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum))
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum, null, null))
// should be ignored for being too old
- runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
taskSet.tasks(1).generation = newGeneration
val secondStage = interceptStage(reduceRdd) {
- runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum))
+ runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null))
}
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
diff --git a/core/src/test/scala/spark/util/DistributionSuite.scala b/core/src/test/scala/spark/util/DistributionSuite.scala
new file mode 100644
index 0000000000..cc6249b1dd
--- /dev/null
+++ b/core/src/test/scala/spark/util/DistributionSuite.scala
@@ -0,0 +1,25 @@
+package spark.util
+
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+
+/**
+ *
+ */
+
+class DistributionSuite extends FunSuite with ShouldMatchers {
+ test("summary") {
+ val d = new Distribution((1 to 100).toArray.map{_.toDouble})
+ val stats = d.statCounter
+ stats.count should be (100)
+ stats.mean should be (50.5)
+ stats.sum should be (50 * 101)
+
+ val quantiles = d.getQuantiles()
+ quantiles(0) should be (1)
+ quantiles(1) should be (26)
+ quantiles(2) should be (51)
+ quantiles(3) should be (76)
+ quantiles(4) should be (100)
+ }
+}