aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala43
-rw-r--r--core/src/main/scala/org/apache/spark/TaskEndReason.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala172
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala25
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala213
12 files changed, 549 insertions, 23 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 53fce6b0de..24a316e40e 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -249,7 +249,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER)
// Create the Spark execution environment (cache, map output tracker, etc)
- private[spark] val env = SparkEnv.createDriverEnv(conf, isLocal, listenerBus)
+
+ // This function allows components created by SparkEnv to be mocked in unit tests:
+ private[spark] def createSparkEnv(
+ conf: SparkConf,
+ isLocal: Boolean,
+ listenerBus: LiveListenerBus): SparkEnv = {
+ SparkEnv.createDriverEnv(conf, isLocal, listenerBus)
+ }
+
+ private[spark] val env = createSparkEnv(conf, isLocal, listenerBus)
SparkEnv.set(env)
// Used to store a URL for each static file/jar together with the file's local timestamp
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index b63bea5b10..2a0c7e756d 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -34,7 +34,8 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.network.nio.NioBlockTransferService
-import org.apache.spark.scheduler.LiveListenerBus
+import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus}
+import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorActor
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
@@ -67,6 +68,7 @@ class SparkEnv (
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
val shuffleMemoryManager: ShuffleMemoryManager,
+ val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {
private[spark] var isStopped = false
@@ -88,6 +90,7 @@ class SparkEnv (
blockManager.stop()
blockManager.master.stop()
metricsSystem.stop()
+ outputCommitCoordinator.stop()
actorSystem.shutdown()
// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
// down, but let's call it anyway in case it gets fixed in a later release
@@ -169,7 +172,8 @@ object SparkEnv extends Logging {
private[spark] def createDriverEnv(
conf: SparkConf,
isLocal: Boolean,
- listenerBus: LiveListenerBus): SparkEnv = {
+ listenerBus: LiveListenerBus,
+ mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!")
assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!")
val hostname = conf.get("spark.driver.host")
@@ -181,7 +185,8 @@ object SparkEnv extends Logging {
port,
isDriver = true,
isLocal = isLocal,
- listenerBus = listenerBus
+ listenerBus = listenerBus,
+ mockOutputCommitCoordinator = mockOutputCommitCoordinator
)
}
@@ -220,7 +225,8 @@ object SparkEnv extends Logging {
isDriver: Boolean,
isLocal: Boolean,
listenerBus: LiveListenerBus = null,
- numUsableCores: Int = 0): SparkEnv = {
+ numUsableCores: Int = 0,
+ mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
// Listener bus is only used on the driver
if (isDriver) {
@@ -368,6 +374,13 @@ object SparkEnv extends Logging {
"levels using the RDD.persist() method instead.")
}
+ val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse {
+ new OutputCommitCoordinator(conf)
+ }
+ val outputCommitCoordinatorActor = registerOrLookup("OutputCommitCoordinator",
+ new OutputCommitCoordinatorActor(outputCommitCoordinator))
+ outputCommitCoordinator.coordinatorActor = Some(outputCommitCoordinatorActor)
+
val envInstance = new SparkEnv(
executorId,
actorSystem,
@@ -384,6 +397,7 @@ object SparkEnv extends Logging {
sparkFilesDir,
metricsSystem,
shuffleMemoryManager,
+ outputCommitCoordinator,
conf)
// Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 4023759657..6eb4537d10 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.mapred._
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
+import org.apache.spark.executor.CommitDeniedException
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.rdd.HadoopRDD
@@ -105,24 +106,56 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
def commit() {
val taCtxt = getTaskContext()
val cmtr = getOutputCommitter()
- if (cmtr.needsTaskCommit(taCtxt)) {
+
+ // Called after we have decided to commit
+ def performCommit(): Unit = {
try {
cmtr.commitTask(taCtxt)
- logInfo (taID + ": Committed")
+ logInfo (s"$taID: Committed")
} catch {
- case e: IOException => {
+ case e: IOException =>
logError("Error committing the output of task: " + taID.value, e)
cmtr.abortTask(taCtxt)
throw e
+ }
+ }
+
+ // First, check whether the task's output has already been committed by some other attempt
+ if (cmtr.needsTaskCommit(taCtxt)) {
+ // The task output needs to be committed, but we don't know whether some other task attempt
+ // might be racing to commit the same output partition. Therefore, coordinate with the driver
+ // in order to determine whether this attempt can commit (see SPARK-4879).
+ val shouldCoordinateWithDriver: Boolean = {
+ val sparkConf = SparkEnv.get.conf
+ // We only need to coordinate with the driver if there are multiple concurrent task
+ // attempts, which should only occur if speculation is enabled
+ val speculationEnabled = sparkConf.getBoolean("spark.speculation", false)
+ // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs
+ sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled)
+ }
+ if (shouldCoordinateWithDriver) {
+ val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator
+ val canCommit = outputCommitCoordinator.canCommit(jobID, splitID, attemptID)
+ if (canCommit) {
+ performCommit()
+ } else {
+ val msg = s"$taID: Not committed because the driver did not authorize commit"
+ logInfo(msg)
+ // We need to abort the task so that the driver can reschedule new attempts, if necessary
+ cmtr.abortTask(taCtxt)
+ throw new CommitDeniedException(msg, jobID, splitID, attemptID)
}
+ } else {
+ // Speculation is disabled or a user has chosen to manually bypass the commit coordination
+ performCommit()
}
} else {
- logInfo ("No need to commit output of task: " + taID.value)
+ // Some other attempt committed the output, so we do nothing and signal success
+ logInfo(s"No need to commit output of task because needsTaskCommit=false: ${taID.value}")
}
}
def commitJob() {
- // always ? Or if cmtr.needsTaskCommit ?
val cmtr = getOutputCommitter()
cmtr.commitJob(getJobContext())
}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index af5fd8e0ac..29a5cd5fda 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -148,6 +148,20 @@ case object TaskKilled extends TaskFailedReason {
/**
* :: DeveloperApi ::
+ * Task requested the driver to commit, but was denied.
+ */
+@DeveloperApi
+case class TaskCommitDenied(
+ jobID: Int,
+ partitionID: Int,
+ attemptID: Int)
+ extends TaskFailedReason {
+ override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" +
+ s" for job: $jobID, partition: $partitionID, attempt: $attemptID"
+}
+
+/**
+ * :: DeveloperApi ::
* The task failed because the executor that it was running on was lost. This may happen because
* the task crashed the JVM.
*/
diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
new file mode 100644
index 0000000000..f7604a321f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import org.apache.spark.{TaskCommitDenied, TaskEndReason}
+
+/**
+ * Exception thrown when a task attempts to commit output to HDFS but is denied by the driver.
+ */
+class CommitDeniedException(
+ msg: String,
+ jobID: Int,
+ splitID: Int,
+ attemptID: Int)
+ extends Exception(msg) {
+
+ def toTaskEndReason: TaskEndReason = new TaskCommitDenied(jobID, splitID, attemptID)
+
+}
+
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 6b22dcd6f5..b684fb7049 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -253,6 +253,11 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
}
+ case cDE: CommitDeniedException => {
+ val reason = cDE.toTaskEndReason
+ execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
+ }
+
case t: Throwable => {
// Attempt to exit cleanly by informing the driver of our failure.
// If anything goes wrong (or this was a fatal exception), we will delegate to
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 1cfe986737..79035571ad 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -38,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._
-import org.apache.spark.util.{CallSite, EventLoop, SystemClock, Clock, Utils}
+import org.apache.spark.util._
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
/**
@@ -63,7 +63,7 @@ class DAGScheduler(
mapOutputTracker: MapOutputTrackerMaster,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv,
- clock: Clock = SystemClock)
+ clock: org.apache.spark.util.Clock = SystemClock)
extends Logging {
def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
@@ -126,6 +126,8 @@ class DAGScheduler(
private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
taskScheduler.setDAGScheduler(this)
+ private val outputCommitCoordinator = env.outputCommitCoordinator
+
// Called by TaskScheduler to report task's starting.
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventProcessLoop.post(BeginEvent(task, taskInfo))
@@ -808,6 +810,7 @@ class DAGScheduler(
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size))
+ outputCommitCoordinator.stageStart(stage.id)
listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
// TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
@@ -865,6 +868,7 @@ class DAGScheduler(
} else {
// Because we posted SparkListenerStageSubmitted earlier, we should post
// SparkListenerStageCompleted here in case there are no tasks to run.
+ outputCommitCoordinator.stageEnd(stage.id)
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@@ -909,6 +913,9 @@ class DAGScheduler(
val stageId = task.stageId
val taskType = Utils.getFormattedClassName(task)
+ outputCommitCoordinator.taskCompleted(stageId, task.partitionId,
+ event.taskInfo.attempt, event.reason)
+
// The success case is dealt with separately below, since we need to compute accumulator
// updates before posting.
if (event.reason != Success) {
@@ -921,6 +928,7 @@ class DAGScheduler(
// Skip all the actions if the stage has been cancelled.
return
}
+
val stage = stageIdToStage(task.stageId)
def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None) = {
@@ -1073,6 +1081,9 @@ class DAGScheduler(
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
}
+ case commitDenied: TaskCommitDenied =>
+ // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits
+
case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) =>
// Do nothing here, left up to the TaskScheduler to decide how to handle user failures
diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
new file mode 100644
index 0000000000..759df023a6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import scala.collection.mutable
+
+import akka.actor.{ActorRef, Actor}
+
+import org.apache.spark._
+import org.apache.spark.util.{AkkaUtils, ActorLogReceive}
+
+private sealed trait OutputCommitCoordinationMessage extends Serializable
+
+private case object StopCoordinator extends OutputCommitCoordinationMessage
+private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttempt: Long)
+
+/**
+ * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins"
+ * policy.
+ *
+ * OutputCommitCoordinator is instantiated in both the drivers and executors. On executors, it is
+ * configured with a reference to the driver's OutputCommitCoordinatorActor, so requests to commit
+ * output will be forwarded to the driver's OutputCommitCoordinator.
+ *
+ * This class was introduced in SPARK-4879; see that JIRA issue (and the associated pull requests)
+ * for an extensive design discussion.
+ */
+private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging {
+
+ // Initialized by SparkEnv
+ var coordinatorActor: Option[ActorRef] = None
+ private val timeout = AkkaUtils.askTimeout(conf)
+ private val maxAttempts = AkkaUtils.numRetries(conf)
+ private val retryInterval = AkkaUtils.retryWaitMs(conf)
+
+ private type StageId = Int
+ private type PartitionId = Long
+ private type TaskAttemptId = Long
+
+ /**
+ * Map from active stages's id => partition id => task attempt with exclusive lock on committing
+ * output for that partition.
+ *
+ * Entries are added to the top-level map when stages start and are removed they finish
+ * (either successfully or unsuccessfully).
+ *
+ * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance.
+ */
+ private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map()
+ private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]]
+
+ /**
+ * Called by tasks to ask whether they can commit their output to HDFS.
+ *
+ * If a task attempt has been authorized to commit, then all other attempts to commit the same
+ * task will be denied. If the authorized task attempt fails (e.g. due to its executor being
+ * lost), then a subsequent task attempt may be authorized to commit its output.
+ *
+ * @param stage the stage number
+ * @param partition the partition number
+ * @param attempt a unique identifier for this task attempt
+ * @return true if this task is authorized to commit, false otherwise
+ */
+ def canCommit(
+ stage: StageId,
+ partition: PartitionId,
+ attempt: TaskAttemptId): Boolean = {
+ val msg = AskPermissionToCommitOutput(stage, partition, attempt)
+ coordinatorActor match {
+ case Some(actor) =>
+ AkkaUtils.askWithReply[Boolean](msg, actor, maxAttempts, retryInterval, timeout)
+ case None =>
+ logError(
+ "canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?")
+ false
+ }
+ }
+
+ // Called by DAGScheduler
+ private[scheduler] def stageStart(stage: StageId): Unit = synchronized {
+ authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptId]()
+ }
+
+ // Called by DAGScheduler
+ private[scheduler] def stageEnd(stage: StageId): Unit = synchronized {
+ authorizedCommittersByStage.remove(stage)
+ }
+
+ // Called by DAGScheduler
+ private[scheduler] def taskCompleted(
+ stage: StageId,
+ partition: PartitionId,
+ attempt: TaskAttemptId,
+ reason: TaskEndReason): Unit = synchronized {
+ val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, {
+ logDebug(s"Ignoring task completion for completed stage")
+ return
+ })
+ reason match {
+ case Success =>
+ // The task output has been committed successfully
+ case denied: TaskCommitDenied =>
+ logInfo(
+ s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt")
+ case otherReason =>
+ logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" +
+ s" clearing lock")
+ authorizedCommitters.remove(partition)
+ }
+ }
+
+ def stop(): Unit = synchronized {
+ coordinatorActor.foreach(_ ! StopCoordinator)
+ coordinatorActor = None
+ authorizedCommittersByStage.clear()
+ }
+
+ // Marked private[scheduler] instead of private so this can be mocked in tests
+ private[scheduler] def handleAskPermissionToCommit(
+ stage: StageId,
+ partition: PartitionId,
+ attempt: TaskAttemptId): Boolean = synchronized {
+ authorizedCommittersByStage.get(stage) match {
+ case Some(authorizedCommitters) =>
+ authorizedCommitters.get(partition) match {
+ case Some(existingCommitter) =>
+ logDebug(s"Denying $attempt to commit for stage=$stage, partition=$partition; " +
+ s"existingCommitter = $existingCommitter")
+ false
+ case None =>
+ logDebug(s"Authorizing $attempt to commit for stage=$stage, partition=$partition")
+ authorizedCommitters(partition) = attempt
+ true
+ }
+ case None =>
+ logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit")
+ false
+ }
+ }
+}
+
+private[spark] object OutputCommitCoordinator {
+
+ // This actor is used only for RPC
+ class OutputCommitCoordinatorActor(outputCommitCoordinator: OutputCommitCoordinator)
+ extends Actor with ActorLogReceive with Logging {
+
+ override def receiveWithLogging = {
+ case AskPermissionToCommitOutput(stage, partition, taskAttempt) =>
+ sender ! outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)
+ case StopCoordinator =>
+ logInfo("OutputCommitCoordinator stopped!")
+ context.stop(self)
+ sender ! true
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 79f84e70df..54f8fcfc41 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -158,7 +158,7 @@ private[spark] class TaskSchedulerImpl(
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
- val manager = new TaskSetManager(this, taskSet, maxTaskFailures)
+ val manager = createTaskSetManager(taskSet, maxTaskFailures)
activeTaskSets(taskSet.id) = manager
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
@@ -180,6 +180,13 @@ private[spark] class TaskSchedulerImpl(
backend.reviveOffers()
}
+ // Label as private[scheduler] to allow tests to swap in different task set managers if necessary
+ private[scheduler] def createTaskSetManager(
+ taskSet: TaskSet,
+ maxTaskFailures: Int): TaskSetManager = {
+ new TaskSetManager(this, taskSet, maxTaskFailures)
+ }
+
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 55024ecd55..99a5f71177 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -292,7 +292,8 @@ private[spark] class TaskSetManager(
* an attempt running on this host, in case the host is slow. In addition, the task should meet
* the given locality constraint.
*/
- private def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
+ // Labeled as protected to allow tests to override providing speculative tasks if necessary
+ protected def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
: Option[(Int, TaskLocality.Value)] =
{
speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
@@ -708,7 +709,10 @@ private[spark] class TaskSetManager(
put(info.executorId, clock.getTime())
sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics)
addPendingTask(index)
- if (!isZombie && state != TaskState.KILLED) {
+ if (!isZombie && state != TaskState.KILLED && !reason.isInstanceOf[TaskCommitDenied]) {
+ // If a task failed because its attempt to commit was denied, do not count this failure
+ // towards failing the stage. This is intended to prevent spurious stage failures in cases
+ // where many speculative tasks are launched and denied to commit.
assert (null != failureReason)
numFailures(index) += 1
if (numFailures(index) >= maxTaskFailures) {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index eb116213f6..9d0c127369 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -208,7 +208,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
assert(taskSet.tasks.size >= results.size)
for ((result, i) <- results.zipWithIndex) {
if (i < taskSet.tasks.size) {
- runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, null, null))
+ runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null))
}
}
}
@@ -219,7 +219,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
for ((result, i) <- results.zipWithIndex) {
if (i < taskSet.tasks.size) {
runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2,
- Map[Long, Any]((accumId, 1)), null, null))
+ Map[Long, Any]((accumId, 1)), createFakeTaskInfo(), null))
}
}
}
@@ -476,7 +476,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
null,
Map[Long, Any](),
- null,
+ createFakeTaskInfo(),
null))
assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
assert(sparkListener.failedStages.contains(1))
@@ -487,7 +487,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"),
null,
Map[Long, Any](),
- null,
+ createFakeTaskInfo(),
null))
// The SparkListener should not receive redundant failure events.
assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
@@ -507,14 +507,14 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
assert(newEpoch > oldEpoch)
val taskSet = taskSets(0)
// should be ignored for being too old
- runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, null, null))
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null))
// should work because it's a non-failed host
- runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, null, null))
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null))
// should be ignored for being too old
- runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, null, null))
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null))
// should work because it's a new epoch
taskSet.tasks(1).epoch = newEpoch
- runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, null, null))
+ runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null))
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
complete(taskSets(1), Seq((Success, 42), (Success, 43)))
@@ -766,5 +766,14 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
assert(scheduler.shuffleToMapStage.isEmpty)
assert(scheduler.waitingStages.isEmpty)
}
+
+ // Nothing in this test should break if the task info's fields are null, but
+ // OutputCommitCoordinator requires the task info itself to not be null.
+ private def createFakeTaskInfo(): TaskInfo = {
+ val info = new TaskInfo(0, 0, 0, 0L, "", "", TaskLocality.ANY, false)
+ info.finishTime = 1 // to prevent spurious errors in JobProgressListener
+ info
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
new file mode 100644
index 0000000000..3cc860caa1
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.File
+import java.util.concurrent.TimeoutException
+
+import org.mockito.Matchers
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{BeforeAndAfter, FunSuite}
+
+import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, OutputCommitter}
+
+import org.apache.spark._
+import org.apache.spark.rdd.{RDD, FakeOutputCommitter}
+import org.apache.spark.util.Utils
+
+import scala.concurrent.Await
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+/**
+ * Unit tests for the output commit coordination functionality.
+ *
+ * The unit test makes both the original task and the speculated task
+ * attempt to commit, where committing is emulated by creating a
+ * directory. If both tasks create directories then the end result is
+ * a failure.
+ *
+ * Note that there are some aspects of this test that are less than ideal.
+ * In particular, the test mocks the speculation-dequeuing logic to always
+ * dequeue a task and consider it as speculated. Immediately after initially
+ * submitting the tasks and calling reviveOffers(), reviveOffers() is invoked
+ * again to pick up the speculated task. This may be hacking the original
+ * behavior in too much of an unrealistic fashion.
+ *
+ * Also, the validation is done by checking the number of files in a directory.
+ * Ideally, an accumulator would be used for this, where we could increment
+ * the accumulator in the output committer's commitTask() call. If the call to
+ * commitTask() was called twice erroneously then the test would ideally fail because
+ * the accumulator would be incremented twice.
+ *
+ * The problem with this test implementation is that when both a speculated task and
+ * its original counterpart complete, only one of the accumulator's increments is
+ * captured. This results in a paradox where if the OutputCommitCoordinator logic
+ * was not in SparkHadoopWriter, the tests would still pass because only one of the
+ * increments would be captured even though the commit in both tasks was executed
+ * erroneously.
+ */
+class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter {
+
+ var outputCommitCoordinator: OutputCommitCoordinator = null
+ var tempDir: File = null
+ var sc: SparkContext = null
+
+ before {
+ tempDir = Utils.createTempDir()
+ val conf = new SparkConf()
+ .setMaster("local[4]")
+ .setAppName(classOf[OutputCommitCoordinatorSuite].getSimpleName)
+ .set("spark.speculation", "true")
+ sc = new SparkContext(conf) {
+ override private[spark] def createSparkEnv(
+ conf: SparkConf,
+ isLocal: Boolean,
+ listenerBus: LiveListenerBus): SparkEnv = {
+ outputCommitCoordinator = spy(new OutputCommitCoordinator(conf))
+ // Use Mockito.spy() to maintain the default infrastructure everywhere else.
+ // This mocking allows us to control the coordinator responses in test cases.
+ SparkEnv.createDriverEnv(conf, isLocal, listenerBus, Some(outputCommitCoordinator))
+ }
+ }
+ // Use Mockito.spy() to maintain the default infrastructure everywhere else
+ val mockTaskScheduler = spy(sc.taskScheduler.asInstanceOf[TaskSchedulerImpl])
+
+ doAnswer(new Answer[Unit]() {
+ override def answer(invoke: InvocationOnMock): Unit = {
+ // Submit the tasks, then force the task scheduler to dequeue the
+ // speculated task
+ invoke.callRealMethod()
+ mockTaskScheduler.backend.reviveOffers()
+ }
+ }).when(mockTaskScheduler).submitTasks(Matchers.any())
+
+ doAnswer(new Answer[TaskSetManager]() {
+ override def answer(invoke: InvocationOnMock): TaskSetManager = {
+ val taskSet = invoke.getArguments()(0).asInstanceOf[TaskSet]
+ new TaskSetManager(mockTaskScheduler, taskSet, 4) {
+ var hasDequeuedSpeculatedTask = false
+ override def dequeueSpeculativeTask(
+ execId: String,
+ host: String,
+ locality: TaskLocality.Value): Option[(Int, TaskLocality.Value)] = {
+ if (!hasDequeuedSpeculatedTask) {
+ hasDequeuedSpeculatedTask = true
+ Some(0, TaskLocality.PROCESS_LOCAL)
+ } else {
+ None
+ }
+ }
+ }
+ }
+ }).when(mockTaskScheduler).createTaskSetManager(Matchers.any(), Matchers.any())
+
+ sc.taskScheduler = mockTaskScheduler
+ val dagSchedulerWithMockTaskScheduler = new DAGScheduler(sc, mockTaskScheduler)
+ sc.taskScheduler.setDAGScheduler(dagSchedulerWithMockTaskScheduler)
+ sc.dagScheduler = dagSchedulerWithMockTaskScheduler
+ }
+
+ after {
+ sc.stop()
+ tempDir.delete()
+ outputCommitCoordinator = null
+ }
+
+ test("Only one of two duplicate commit tasks should commit") {
+ val rdd = sc.parallelize(Seq(1), 1)
+ sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully _,
+ 0 until rdd.partitions.size, allowLocal = false)
+ assert(tempDir.list().size === 1)
+ }
+
+ test("If commit fails, if task is retried it should not be locked, and will succeed.") {
+ val rdd = sc.parallelize(Seq(1), 1)
+ sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).failFirstCommitAttempt _,
+ 0 until rdd.partitions.size, allowLocal = false)
+ assert(tempDir.list().size === 1)
+ }
+
+ test("Job should not complete if all commits are denied") {
+ // Create a mock OutputCommitCoordinator that denies all attempts to commit
+ doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit(
+ Matchers.any(), Matchers.any(), Matchers.any())
+ val rdd: RDD[Int] = sc.parallelize(Seq(1), 1)
+ def resultHandler(x: Int, y: Unit): Unit = {}
+ val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, Unit](rdd,
+ OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully,
+ 0 until rdd.partitions.size, resultHandler, 0)
+ // It's an error if the job completes successfully even though no committer was authorized,
+ // so throw an exception if the job was allowed to complete.
+ intercept[TimeoutException] {
+ Await.result(futureAction, 5 seconds)
+ }
+ assert(tempDir.list().size === 0)
+ }
+}
+
+/**
+ * Class with methods that can be passed to runJob to test commits with a mock committer.
+ */
+private case class OutputCommitFunctions(tempDirPath: String) {
+
+ // Mock output committer that simulates a successful commit (after commit is authorized)
+ private def successfulOutputCommitter = new FakeOutputCommitter {
+ override def commitTask(context: TaskAttemptContext): Unit = {
+ Utils.createDirectory(tempDirPath)
+ }
+ }
+
+ // Mock output committer that simulates a failed commit (after commit is authorized)
+ private def failingOutputCommitter = new FakeOutputCommitter {
+ override def commitTask(taskAttemptContext: TaskAttemptContext) {
+ throw new RuntimeException
+ }
+ }
+
+ def commitSuccessfully(iter: Iterator[Int]): Unit = {
+ val ctx = TaskContext.get()
+ runCommitWithProvidedCommitter(ctx, iter, successfulOutputCommitter)
+ }
+
+ def failFirstCommitAttempt(iter: Iterator[Int]): Unit = {
+ val ctx = TaskContext.get()
+ runCommitWithProvidedCommitter(ctx, iter,
+ if (ctx.attemptNumber == 0) failingOutputCommitter else successfulOutputCommitter)
+ }
+
+ private def runCommitWithProvidedCommitter(
+ ctx: TaskContext,
+ iter: Iterator[Int],
+ outputCommitter: OutputCommitter): Unit = {
+ def jobConf = new JobConf {
+ override def getOutputCommitter(): OutputCommitter = outputCommitter
+ }
+ val sparkHadoopWriter = new SparkHadoopWriter(jobConf) {
+ override def newTaskAttemptContext(
+ conf: JobConf,
+ attemptId: TaskAttemptID): TaskAttemptContext = {
+ mock(classOf[TaskAttemptContext])
+ }
+ }
+ sparkHadoopWriter.setup(ctx.stageId, ctx.partitionId, ctx.attemptNumber)
+ sparkHadoopWriter.commit()
+ }
+}