aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-10-21 15:33:13 -0700
committerKay Ousterhout <kayousterhout@gmail.com>2015-10-21 15:33:13 -0700
commit555b2086a1ee432067de77032f1e3c64735481f0 (patch)
tree32a568413dae918a88145d1b847489f850af3978 /core
parentf481090a71940f06602a73f5bbd004980dea026f (diff)
downloadspark-555b2086a1ee432067de77032f1e3c64735481f0.tar.gz
spark-555b2086a1ee432067de77032f1e3c64735481f0.tar.bz2
spark-555b2086a1ee432067de77032f1e3c64735481f0.zip
Minor cleanup of ShuffleMapStage.outputLocs code.
I was looking at this code and found the documentation to be insufficient. I added more documentation, and refactored some relevant code path slightly to improve encapsulation. There are more that I want to do, but I want to get these changes in before doing more work. My goal is to reduce exposing internal fields directly in ShuffleMapStage to improve encapsulation. After this change, DAGScheduler no longer directly writes outputLocs. There are still 3 places that reads outputLocs directly, but we can change those later. Author: Reynold Xin <rxin@databricks.com> Closes #9175 from rxin/stage-cleanup.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala5
4 files changed, 39 insertions, 20 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index ade372be09..995862ece5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -353,10 +353,12 @@ class DAGScheduler(
if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
- for (i <- 0 until locs.length) {
- stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing
+ (0 until locs.length).foreach { i =>
+ if (locs(i) ne null) {
+ // locs(i) will be null if missing
+ stage.addOutputLoc(i, locs(i))
+ }
}
- stage.numAvailableOutputs = locs.count(_ != null)
} else {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of partitions is unknown
@@ -894,7 +896,7 @@ class DAGScheduler(
submitStage(finalStage)
// If the whole stage has already finished, tell the listener and remove it
- if (!finalStage.outputLocs.contains(Nil)) {
+ if (finalStage.isAvailable) {
markMapStageJobAsFinished(job, mapOutputTracker.getStatistics(dependency))
}
@@ -931,24 +933,12 @@ class DAGScheduler(
stage.pendingPartitions.clear()
// First figure out the indexes of partition ids to compute.
- val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = {
- stage match {
- case stage: ShuffleMapStage =>
- val allPartitions = 0 until stage.numPartitions
- val filteredPartitions = allPartitions.filter { id => stage.outputLocs(id).isEmpty }
- (allPartitions, filteredPartitions)
- case stage: ResultStage =>
- val job = stage.resultOfJob.get
- val allPartitions = 0 until job.numPartitions
- val filteredPartitions = allPartitions.filter { id => !job.finished(id) }
- (allPartitions, filteredPartitions)
- }
- }
+ val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()
// Create internal accumulators if the stage has no accumulators initialized.
// Reset internal accumulators only if this stage is not partially submitted
// Otherwise, we may override existing accumulator values from some tasks
- if (stage.internalAccumulators.isEmpty || allPartitions == partitionsToCompute) {
+ if (stage.internalAccumulators.isEmpty || stage.numPartitions == partitionsToCompute.size) {
stage.resetInternalAccumulators()
}
@@ -1202,7 +1192,7 @@ class DAGScheduler(
clearCacheLocs()
- if (shuffleStage.outputLocs.contains(Nil)) {
+ if (!shuffleStage.isAvailable) {
// Some tasks had failed; let's resubmit this shuffleStage
// TODO: Lower-level scheduler should also deal with this
logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name +
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala
index c0451da1f0..c1d86af7e8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala
@@ -43,5 +43,10 @@ private[spark] class ResultStage(
*/
var resultOfJob: Option[ActiveJob] = None
+ override def findMissingPartitions(): Seq[Int] = {
+ val job = resultOfJob.get
+ (0 until job.numPartitions).filter(id => !job.finished(id))
+ }
+
override def toString: String = "ResultStage " + id
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
index 7d92960876..3832d99edd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
@@ -48,12 +48,33 @@ private[spark] class ShuffleMapStage(
/** Running map-stage jobs that were submitted to execute this stage independently (if any) */
var mapStageJobs: List[ActiveJob] = Nil
+ /**
+ * Number of partitions that have shuffle outputs.
+ * When this reaches [[numPartitions]], this map stage is ready.
+ * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`.
+ */
var numAvailableOutputs: Int = 0
+ /**
+ * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs.
+ * This should be the same as `outputLocs.contains(Nil)`.
+ */
def isAvailable: Boolean = numAvailableOutputs == numPartitions
+ /**
+ * List of [[MapStatus]] for each partition. The index of the array is the map partition id,
+ * and each value in the array is the list of possible [[MapStatus]] for a partition
+ * (a single task might run multiple times).
+ */
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
+ override def findMissingPartitions(): Seq[Int] = {
+ val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty)
+ assert(missing.size == numPartitions - numAvailableOutputs,
+ s"${missing.size} missing, expected ${numPartitions - numAvailableOutputs}")
+ missing
+ }
+
def addOutputLoc(partition: Int, status: MapStatus): Unit = {
val prevList = outputLocs(partition)
outputLocs(partition) = status :: prevList
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index a3829c319c..5ce4a48434 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -61,7 +61,7 @@ private[scheduler] abstract class Stage(
val callSite: CallSite)
extends Logging {
- val numPartitions = rdd.partitions.size
+ val numPartitions = rdd.partitions.length
/** Set of jobs that this stage belongs to. */
val jobIds = new HashSet[Int]
@@ -138,6 +138,9 @@ private[scheduler] abstract class Stage(
case stage: Stage => stage != null && stage.id == id
case _ => false
}
+
+ /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */
+ def findMissingPartitions(): Seq[Int]
}
private[scheduler] object Stage {