aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala40
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala2
3 files changed, 34 insertions, 16 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 5673fbf2c8..a1f0fd05f6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -947,7 +947,13 @@ class DAGScheduler(
// serializable. If tasks are not serializable, a SparkListenerStageCompleted event
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
- outputCommitCoordinator.stageStart(stage.id)
+ stage match {
+ case s: ShuffleMapStage =>
+ outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1)
+ case s: ResultStage =>
+ outputCommitCoordinator.stageStart(
+ stage = s.id, maxPartitionId = s.rdd.partitions.length - 1)
+ }
val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
stage match {
case s: ShuffleMapStage =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
index add0dedc03..4d14667817 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
@@ -47,6 +47,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
private type PartitionId = Int
private type TaskAttemptNumber = Int
+ private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1
+
/**
* Map from active stages's id => partition id => task attempt with exclusive lock on committing
* output for that partition.
@@ -56,9 +58,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
*
* Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance.
*/
- private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map()
- private type CommittersByStageMap =
- mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptNumber]]
+ private val authorizedCommittersByStage = mutable.Map[StageId, Array[TaskAttemptNumber]]()
/**
* Returns whether the OutputCommitCoordinator's internal data structures are all empty.
@@ -95,9 +95,21 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
}
}
- // Called by DAGScheduler
- private[scheduler] def stageStart(stage: StageId): Unit = synchronized {
- authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptNumber]()
+ /**
+ * Called by the DAGScheduler when a stage starts.
+ *
+ * @param stage the stage id.
+ * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e.
+ * the maximum possible value of `context.partitionId`).
+ */
+ private[scheduler] def stageStart(
+ stage: StageId,
+ maxPartitionId: Int): Unit = {
+ val arr = new Array[TaskAttemptNumber](maxPartitionId + 1)
+ java.util.Arrays.fill(arr, NO_AUTHORIZED_COMMITTER)
+ synchronized {
+ authorizedCommittersByStage(stage) = arr
+ }
}
// Called by DAGScheduler
@@ -122,10 +134,10 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " +
s"attempt: $attemptNumber")
case otherReason =>
- if (authorizedCommitters.get(partition).exists(_ == attemptNumber)) {
+ if (authorizedCommitters(partition) == attemptNumber) {
logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " +
s"partition=$partition) failed; clearing lock")
- authorizedCommitters.remove(partition)
+ authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER
}
}
}
@@ -145,16 +157,16 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
attemptNumber: TaskAttemptNumber): Boolean = synchronized {
authorizedCommittersByStage.get(stage) match {
case Some(authorizedCommitters) =>
- authorizedCommitters.get(partition) match {
- case Some(existingCommitter) =>
- logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " +
- s"partition=$partition; existingCommitter = $existingCommitter")
- false
- case None =>
+ authorizedCommitters(partition) match {
+ case NO_AUTHORIZED_COMMITTER =>
logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " +
s"partition=$partition")
authorizedCommitters(partition) = attemptNumber
true
+ case existingCommitter =>
+ logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " +
+ s"partition=$partition; existingCommitter = $existingCommitter")
+ false
}
case None =>
logDebug(s"Stage $stage has completed, so not allowing attempt number $attemptNumber of" +
diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
index 48456a9cd6..7345508bfe 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
@@ -171,7 +171,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
val partition: Int = 2
val authorizedCommitter: Int = 3
val nonAuthorizedCommitter: Int = 100
- outputCommitCoordinator.stageStart(stage)
+ outputCommitCoordinator.stageStart(stage, maxPartitionId = 2)
assert(outputCommitCoordinator.canCommit(stage, partition, authorizedCommitter))
assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter))