aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala5
4 files changed, 39 insertions, 20 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index ade372be09..995862ece5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -353,10 +353,12 @@ class DAGScheduler(
if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
- for (i <- 0 until locs.length) {
- stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing
+ (0 until locs.length).foreach { i =>
+ if (locs(i) ne null) {
+ // locs(i) will be null if missing
+ stage.addOutputLoc(i, locs(i))
+ }
}
- stage.numAvailableOutputs = locs.count(_ != null)
} else {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of partitions is unknown
@@ -894,7 +896,7 @@ class DAGScheduler(
submitStage(finalStage)
// If the whole stage has already finished, tell the listener and remove it
- if (!finalStage.outputLocs.contains(Nil)) {
+ if (finalStage.isAvailable) {
markMapStageJobAsFinished(job, mapOutputTracker.getStatistics(dependency))
}
@@ -931,24 +933,12 @@ class DAGScheduler(
stage.pendingPartitions.clear()
// First figure out the indexes of partition ids to compute.
- val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = {
- stage match {
- case stage: ShuffleMapStage =>
- val allPartitions = 0 until stage.numPartitions
- val filteredPartitions = allPartitions.filter { id => stage.outputLocs(id).isEmpty }
- (allPartitions, filteredPartitions)
- case stage: ResultStage =>
- val job = stage.resultOfJob.get
- val allPartitions = 0 until job.numPartitions
- val filteredPartitions = allPartitions.filter { id => !job.finished(id) }
- (allPartitions, filteredPartitions)
- }
- }
+ val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()
// Create internal accumulators if the stage has no accumulators initialized.
// Reset internal accumulators only if this stage is not partially submitted
// Otherwise, we may override existing accumulator values from some tasks
- if (stage.internalAccumulators.isEmpty || allPartitions == partitionsToCompute) {
+ if (stage.internalAccumulators.isEmpty || stage.numPartitions == partitionsToCompute.size) {
stage.resetInternalAccumulators()
}
@@ -1202,7 +1192,7 @@ class DAGScheduler(
clearCacheLocs()
- if (shuffleStage.outputLocs.contains(Nil)) {
+ if (!shuffleStage.isAvailable) {
// Some tasks had failed; let's resubmit this shuffleStage
// TODO: Lower-level scheduler should also deal with this
logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name +
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala
index c0451da1f0..c1d86af7e8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala
@@ -43,5 +43,10 @@ private[spark] class ResultStage(
*/
var resultOfJob: Option[ActiveJob] = None
+ override def findMissingPartitions(): Seq[Int] = {
+ val job = resultOfJob.get
+ (0 until job.numPartitions).filter(id => !job.finished(id))
+ }
+
override def toString: String = "ResultStage " + id
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
index 7d92960876..3832d99edd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
@@ -48,12 +48,33 @@ private[spark] class ShuffleMapStage(
/** Running map-stage jobs that were submitted to execute this stage independently (if any) */
var mapStageJobs: List[ActiveJob] = Nil
+ /**
+ * Number of partitions that have shuffle outputs.
+ * When this reaches [[numPartitions]], this map stage is ready.
+ * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`.
+ */
var numAvailableOutputs: Int = 0
+ /**
+ * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs.
+ * This should be the same as `outputLocs.contains(Nil)`.
+ */
def isAvailable: Boolean = numAvailableOutputs == numPartitions
+ /**
+ * List of [[MapStatus]] for each partition. The index of the array is the map partition id,
+ * and each value in the array is the list of possible [[MapStatus]] for a partition
+ * (a single task might run multiple times).
+ */
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
+ override def findMissingPartitions(): Seq[Int] = {
+ val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty)
+ assert(missing.size == numPartitions - numAvailableOutputs,
+ s"${missing.size} missing, expected ${numPartitions - numAvailableOutputs}")
+ missing
+ }
+
def addOutputLoc(partition: Int, status: MapStatus): Unit = {
val prevList = outputLocs(partition)
outputLocs(partition) = status :: prevList
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index a3829c319c..5ce4a48434 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -61,7 +61,7 @@ private[scheduler] abstract class Stage(
val callSite: CallSite)
extends Logging {
- val numPartitions = rdd.partitions.size
+ val numPartitions = rdd.partitions.length
/** Set of jobs that this stage belongs to. */
val jobIds = new HashSet[Int]
@@ -138,6 +138,9 @@ private[scheduler] abstract class Stage(
case stage: Stage => stage != null && stage.id == id
case _ => false
}
+
+ /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */
+ def findMissingPartitions(): Seq[Int]
}
private[scheduler] object Stage {