aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/TaskEndReason.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Pool.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala55
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala162
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala124
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala27
-rw-r--r--core/src/test/scala/org/apache/spark/DistributedSuite.scala1
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala15
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala58
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala113
-rw-r--r--docs/running-on-yarn.md1
-rw-r--r--ec2/README2
-rwxr-xr-xmake-distribution.sh2
-rw-r--r--project/SparkBuild.scala3
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala2
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala6
25 files changed, 510 insertions, 157 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index e67390cfd1..17c6f9c955 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -147,7 +147,7 @@ class SparkContext(
}
// Create and start the scheduler
- private var taskScheduler: TaskScheduler = {
+ private[spark] var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 03bf268863..8466c2a004 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -46,6 +46,10 @@ private[spark] case class ExceptionFailure(
metrics: Option[TaskMetrics])
extends TaskEndReason
-private[spark] case class OtherFailure(message: String) extends TaskEndReason
+/**
+ * The task finished successfully, but the result was lost from the executor's block manager before
+ * it was fetched.
+ */
+private[spark] case object TaskResultLost extends TaskEndReason
-private[spark] case class TaskResultTooBigFailure() extends TaskEndReason
+private[spark] case class OtherFailure(message: String) extends TaskEndReason
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 99a4a95e82..b4153f3533 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -17,7 +17,7 @@
package org.apache.spark.executor
-import java.io.{File}
+import java.io.File
import java.lang.management.ManagementFactory
import java.nio.ByteBuffer
import java.util.concurrent._
@@ -27,11 +27,11 @@ import scala.collection.mutable.HashMap
import org.apache.spark.scheduler._
import org.apache.spark._
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-
/**
- * The Mesos executor for Spark.
+ * Spark executor used with Mesos and the standalone scheduler.
*/
private[spark] class Executor(
executorId: String,
@@ -167,12 +167,20 @@ private[spark] class Executor(
// we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could
// just change the relevants bytes in the byte buffer
val accumUpdates = Accumulators.values
- val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null))
- val serializedResult = ser.serialize(result)
- logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit)
- if (serializedResult.limit >= (akkaFrameSize - 1024)) {
- context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(TaskResultTooBigFailure()))
- return
+ val directResult = new DirectTaskResult(value, accumUpdates, task.metrics.getOrElse(null))
+ val serializedDirectResult = ser.serialize(directResult)
+ logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit)
+ val serializedResult = {
+ if (serializedDirectResult.limit >= akkaFrameSize - 1024) {
+ logInfo("Storing result for " + taskId + " in local BlockManager")
+ val blockId = "taskresult_" + taskId
+ env.blockManager.putBytes(
+ blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
+ ser.serialize(new IndirectTaskResult[Any](blockId))
+ } else {
+ logInfo("Sending result for " + taskId + " directly to driver")
+ serializedDirectResult
+ }
}
context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
logInfo("Finished task ID " + taskId)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 693d8a7c5d..e79b67579e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -553,7 +553,7 @@ class DAGScheduler(
SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head)
} catch {
case e: NotSerializableException =>
- abortStage(stage, e.toString)
+ abortStage(stage, "Task not serializable: " + e.toString)
running -= stage
return
}
@@ -705,6 +705,9 @@ class DAGScheduler(
case ExceptionFailure(className, description, stackTrace, metrics) =>
// Do nothing here, left up to the TaskScheduler to decide how to handle user failures
+ case TaskResultLost =>
+ // Do nothing here; the TaskScheduler handles these failures and resubmits the task.
+
case other =>
// Unrecognized failure - abort all jobs depending on this stage
abortStage(stageIdToStage(task.stageId), task + " failed: " + other)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index c9a66b3a75..9eb8d48501 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -45,7 +45,7 @@ private[spark] class Pool(
var priority = 0
var stageId = 0
var name = poolName
- var parent:Schedulable = null
+ var parent: Pool = null
var taskSetSchedulingAlgorithm: SchedulingAlgorithm = {
schedulingMode match {
@@ -101,14 +101,14 @@ private[spark] class Pool(
return sortedTaskSetQueue
}
- override def increaseRunningTasks(taskNum: Int) {
+ def increaseRunningTasks(taskNum: Int) {
runningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
- override def decreaseRunningTasks(taskNum: Int) {
+ def decreaseRunningTasks(taskNum: Int) {
runningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
index 857adaef5a..1c7ea2dccc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
@@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer
* there are two type of Schedulable entities(Pools and TaskSetManagers)
*/
private[spark] trait Schedulable {
- var parent: Schedulable
+ var parent: Pool
// child queues
def schedulableQueue: ArrayBuffer[Schedulable]
def schedulingMode: SchedulingMode
@@ -36,8 +36,6 @@ private[spark] trait Schedulable {
def stageId: Int
def name: String
- def increaseRunningTasks(taskNum: Int): Unit
- def decreaseRunningTasks(taskNum: Int): Unit
def addSchedulable(schedulable: Schedulable): Unit
def removeSchedulable(schedulable: Schedulable): Unit
def getSchedulableByName(name: String): Schedulable
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index 5c7e5bb977..db3954a9d3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -26,12 +26,17 @@ import java.nio.ByteBuffer
import org.apache.spark.util.Utils
// Task result. Also contains updates to accumulator variables.
-// TODO: Use of distributed cache to return result is a hack to get around
-// what seems to be a bug with messages over 60KB in libprocess; fix it
+private[spark] sealed trait TaskResult[T]
+
+/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
+private[spark]
+case class IndirectTaskResult[T](val blockId: String) extends TaskResult[T] with Serializable
+
+/** A TaskResult that contains the task's return value and accumulator updates. */
private[spark]
-class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics)
- extends Externalizable
-{
+class DirectTaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics)
+ extends TaskResult[T] with Externalizable {
+
def this() = this(null.asInstanceOf[T], null, null)
override def writeExternal(out: ObjectOutput) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index f192b0b7a4..90f6bcefac 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -44,7 +44,5 @@ private[spark] trait TaskSetManager extends Schedulable {
maxLocality: TaskLocality.TaskLocality)
: Option[TaskDescription]
- def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer)
-
def error(message: String)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
index a6dee604b7..1a844b7e7e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -18,6 +18,9 @@
package org.apache.spark.scheduler.cluster
import java.lang.{Boolean => JBoolean}
+import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicLong
+import java.util.{TimerTask, Timer}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@@ -27,9 +30,6 @@ import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
-import java.nio.ByteBuffer
-import java.util.concurrent.atomic.AtomicLong
-import java.util.{TimerTask, Timer}
/**
* The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call
@@ -55,7 +55,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// Threshold above which we warn user initial TaskSet may be starved
val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
- val activeTaskSets = new HashMap[String, TaskSetManager]
+ // ClusterTaskSetManagers are not thread safe, so any access to one should be synchronized
+ // on this class.
+ val activeTaskSets = new HashMap[String, ClusterTaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String]
val taskIdToExecutorId = new HashMap[Long, String]
@@ -65,7 +67,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
@volatile private var hasLaunchedTask = false
private val starvationTimer = new Timer(true)
- // Incrementing Mesos task IDs
+ // Incrementing task IDs
val nextTaskId = new AtomicLong(0)
// Which executor IDs we have executors on
@@ -96,6 +98,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val schedulingMode: SchedulingMode = SchedulingMode.withName(
System.getProperty("spark.scheduler.mode", "FIFO"))
+ // This is a var so that we can reset it for testing purposes.
+ private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
+
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
}
@@ -234,7 +239,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- var taskSetToUpdate: Option[TaskSetManager] = None
var failedExecutor: Option[String] = None
var taskFailed = false
synchronized {
@@ -249,9 +253,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
taskIdToTaskSetId.get(tid) match {
case Some(taskSetId) =>
- if (activeTaskSets.contains(taskSetId)) {
- taskSetToUpdate = Some(activeTaskSets(taskSetId))
- }
if (TaskState.isFinished(state)) {
taskIdToTaskSetId.remove(tid)
if (taskSetTaskIds.contains(taskSetId)) {
@@ -262,6 +263,15 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (state == TaskState.FAILED) {
taskFailed = true
}
+ activeTaskSets.get(taskSetId).foreach { taskSet =>
+ if (state == TaskState.FINISHED) {
+ taskSet.removeRunningTask(tid)
+ taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
+ } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
+ taskSet.removeRunningTask(tid)
+ taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
+ }
+ }
case None =>
logInfo("Ignoring update from TID " + tid + " because its task set is gone")
}
@@ -269,10 +279,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
case e: Exception => logError("Exception in statusUpdate", e)
}
}
- // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock
- if (taskSetToUpdate != None) {
- taskSetToUpdate.get.statusUpdate(tid, state, serializedData)
- }
+ // Update the DAGScheduler without holding a lock on this, since that can deadlock
if (failedExecutor != None) {
listener.executorLost(failedExecutor.get)
backend.reviveOffers()
@@ -283,6 +290,25 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
+ def handleSuccessfulTask(
+ taskSetManager: ClusterTaskSetManager,
+ tid: Long,
+ taskResult: DirectTaskResult[_]) = synchronized {
+ taskSetManager.handleSuccessfulTask(tid, taskResult)
+ }
+
+ def handleFailedTask(
+ taskSetManager: ClusterTaskSetManager,
+ tid: Long,
+ taskState: TaskState,
+ reason: Option[TaskEndReason]) = synchronized {
+ taskSetManager.handleFailedTask(tid, taskState, reason)
+ if (taskState == TaskState.FINISHED) {
+ // The task finished successfully but the result was lost, so we should revive offers.
+ backend.reviveOffers()
+ }
+ }
+
def error(message: String) {
synchronized {
if (activeTaskSets.size > 0) {
@@ -311,6 +337,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (jarServer != null) {
jarServer.stop()
}
+ if (taskResultGetter != null) {
+ taskResultGetter.stop()
+ }
// sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.
// TODO: Do something better !
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 411e49b021..194ab55102 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -28,7 +28,7 @@ import scala.math.min
import scala.Some
import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
- SparkException, Success, TaskEndReason, TaskResultTooBigFailure, TaskState}
+ SparkException, Success, TaskEndReason, TaskResultLost, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler._
import org.apache.spark.util.{SystemClock, Clock}
@@ -68,18 +68,20 @@ private[spark] class ClusterTaskSetManager(
val tasks = taskSet.tasks
val numTasks = tasks.length
val copiesRunning = new Array[Int](numTasks)
- val finished = new Array[Boolean](numTasks)
+ val successful = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
- var tasksFinished = 0
+ var tasksSuccessful = 0
var weight = 1
var minShare = 0
- var runningTasks = 0
var priority = taskSet.priority
var stageId = taskSet.stageId
var name = "TaskSet_"+taskSet.stageId.toString
- var parent: Schedulable = null
+ var parent: Pool = null
+
+ var runningTasks = 0
+ private val runningTasksSet = new HashSet[Long]
// Set of pending tasks for each executor. These collections are actually
// treated as stacks, in which new tasks are added to the end of the
@@ -220,7 +222,7 @@ private[spark] class ClusterTaskSetManager(
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
- if (copiesRunning(index) == 0 && !finished(index)) {
+ if (copiesRunning(index) == 0 && !successful(index)) {
return Some(index)
}
}
@@ -240,7 +242,7 @@ private[spark] class ClusterTaskSetManager(
private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
: Option[(Int, TaskLocality.Value)] =
{
- speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
+ speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
if (!speculatableTasks.isEmpty) {
// Check for process-local or preference-less tasks; note that tasks can be process-local
@@ -341,7 +343,7 @@ private[spark] class ClusterTaskSetManager(
maxLocality: TaskLocality.TaskLocality)
: Option[TaskDescription] =
{
- if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
+ if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
val curTime = clock.getTime()
var allowedLocality = getAllowedLocalityLevel(curTime)
@@ -372,7 +374,7 @@ private[spark] class ClusterTaskSetManager(
val serializedTask = Task.serializeWithDependencies(
task, sched.sc.addedFiles, sched.sc.addedJars, ser)
val timeTaken = clock.getTime() - startTime
- increaseRunningTasks(1)
+ addRunningTask(taskId)
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index)
@@ -414,94 +416,61 @@ private[spark] class ClusterTaskSetManager(
index
}
- /** Called by cluster scheduler when one of our tasks changes state */
- override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- SparkEnv.set(env)
- state match {
- case TaskState.FINISHED =>
- taskFinished(tid, state, serializedData)
- case TaskState.LOST =>
- taskLost(tid, state, serializedData)
- case TaskState.FAILED =>
- taskLost(tid, state, serializedData)
- case TaskState.KILLED =>
- taskLost(tid, state, serializedData)
- case _ =>
- }
- }
-
- def taskStarted(task: Task[_], info: TaskInfo) {
+ private def taskStarted(task: Task[_], info: TaskInfo) {
sched.listener.taskStarted(task, info)
}
- def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ /**
+ * Marks the task as successful and notifies the listener that a task has ended.
+ */
+ def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
val info = taskInfos(tid)
- if (info.failed) {
- // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
- // or even from Mesos itself when acks get delayed.
- return
- }
val index = info.index
info.markSuccessful()
- decreaseRunningTasks(1)
- if (!finished(index)) {
- tasksFinished += 1
+ removeRunningTask(tid)
+ if (!successful(index)) {
logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
- tid, info.duration, info.host, tasksFinished, numTasks))
- // Deserialize task result and pass it to the scheduler
- try {
- val result = ser.deserialize[TaskResult[_]](serializedData)
- result.metrics.resultSize = serializedData.limit()
- sched.listener.taskEnded(
- tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
- } catch {
- case cnf: ClassNotFoundException =>
- val loader = Thread.currentThread().getContextClassLoader
- throw new SparkException("ClassNotFound with classloader: " + loader, cnf)
- case ex => throw ex
- }
- // Mark finished and stop if we've finished all the tasks
- finished(index) = true
- if (tasksFinished == numTasks) {
+ tid, info.duration, info.host, tasksSuccessful, numTasks))
+ sched.listener.taskEnded(
+ tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
+
+ // Mark successful and stop if all the tasks have succeeded.
+ tasksSuccessful += 1
+ successful(index) = true
+ if (tasksSuccessful == numTasks) {
sched.taskSetFinished(this)
}
} else {
- logInfo("Ignoring task-finished event for TID " + tid +
- " because task " + index + " is already finished")
+ logInfo("Ignorning task-finished event for TID " + tid + " because task " +
+ index + " has already completed successfully")
}
}
- def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ /**
+ * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the listener.
+ */
+ def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
val info = taskInfos(tid)
if (info.failed) {
- // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
- // or even from Mesos itself when acks get delayed.
return
}
+ removeRunningTask(tid)
val index = info.index
info.markFailed()
- decreaseRunningTasks(1)
- if (!finished(index)) {
+ if (!successful(index)) {
logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
copiesRunning(index) -= 1
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
- if (serializedData != null && serializedData.limit() > 0) {
- val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
- reason match {
+ reason.foreach {
+ _ match {
case fetchFailed: FetchFailed =>
logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
- finished(index) = true
- tasksFinished += 1
+ successful(index) = true
+ tasksSuccessful += 1
sched.taskSetFinished(this)
- decreaseRunningTasks(runningTasks)
- return
-
- case taskResultTooBig: TaskResultTooBigFailure =>
- logInfo("Loss was due to task %s result exceeding Akka frame size; aborting job".format(
- tid))
- abort("Task %s result exceeded Akka frame size".format(tid))
+ removeAllRunningTasks()
return
case ef: ExceptionFailure =>
@@ -531,13 +500,16 @@ private[spark] class ClusterTaskSetManager(
logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
}
+ case TaskResultLost =>
+ logInfo("Lost result for TID %s on host %s".format(tid, info.host))
+ sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+
case _ => {}
}
}
// On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
- // Count failed attempts only on FAILED and LOST state (not on KILLED)
- if (state == TaskState.FAILED || state == TaskState.LOST) {
+ if (state != TaskState.KILLED) {
numFailures(index) += 1
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format(
@@ -561,22 +533,36 @@ private[spark] class ClusterTaskSetManager(
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
sched.listener.taskSetFailed(taskSet, message)
- decreaseRunningTasks(runningTasks)
+ removeAllRunningTasks()
sched.taskSetFinished(this)
}
- override def increaseRunningTasks(taskNum: Int) {
- runningTasks += taskNum
- if (parent != null) {
- parent.increaseRunningTasks(taskNum)
+ /** If the given task ID is not in the set of running tasks, adds it.
+ *
+ * Used to keep track of the number of running tasks, for enforcing scheduling policies.
+ */
+ def addRunningTask(tid: Long) {
+ if (runningTasksSet.add(tid) && parent != null) {
+ parent.increaseRunningTasks(1)
+ }
+ runningTasks = runningTasksSet.size
+ }
+
+ /** If the given task ID is in the set of running tasks, removes it. */
+ def removeRunningTask(tid: Long) {
+ if (runningTasksSet.remove(tid) && parent != null) {
+ parent.decreaseRunningTasks(1)
}
+ runningTasks = runningTasksSet.size
}
- override def decreaseRunningTasks(taskNum: Int) {
- runningTasks -= taskNum
+ private def removeAllRunningTasks() {
+ val numRunningTasks = runningTasksSet.size
+ runningTasksSet.clear()
if (parent != null) {
- parent.decreaseRunningTasks(taskNum)
+ parent.decreaseRunningTasks(numRunningTasks)
}
+ runningTasks = 0
}
override def getSchedulableByName(name: String): Schedulable = {
@@ -612,10 +598,10 @@ private[spark] class ClusterTaskSetManager(
if (tasks(0).isInstanceOf[ShuffleMapTask]) {
for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
- if (finished(index)) {
- finished(index) = false
+ if (successful(index)) {
+ successful(index) = false
copiesRunning(index) -= 1
- tasksFinished -= 1
+ tasksSuccessful -= 1
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
@@ -625,7 +611,7 @@ private[spark] class ClusterTaskSetManager(
}
// Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
- taskLost(tid, TaskState.KILLED, null)
+ handleFailedTask(tid, TaskState.KILLED, None)
}
}
@@ -638,13 +624,13 @@ private[spark] class ClusterTaskSetManager(
*/
override def checkSpeculatableTasks(): Boolean = {
// Can't speculate if we only have one task, or if all tasks have finished.
- if (numTasks == 1 || tasksFinished == numTasks) {
+ if (numTasks == 1 || tasksSuccessful == numTasks) {
return false
}
var foundTasks = false
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
- if (tasksFinished >= minFinishedForSpeculation) {
+ if (tasksSuccessful >= minFinishedForSpeculation) {
val time = clock.getTime()
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
Arrays.sort(durations)
@@ -655,7 +641,7 @@ private[spark] class ClusterTaskSetManager(
logDebug("Task length threshold for speculation: " + threshold)
for ((tid, info) <- taskInfos) {
val index = info.index
- if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
+ if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
!speculatableTasks.contains(index)) {
logInfo(
"Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
@@ -669,7 +655,7 @@ private[spark] class ClusterTaskSetManager(
}
override def hasPendingTasks(): Boolean = {
- numTasks > 0 && tasksFinished < numTasks
+ numTasks > 0 && tasksSuccessful < numTasks
}
private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
new file mode 100644
index 0000000000..feec8ecfe4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import java.nio.ByteBuffer
+import java.util.concurrent.{LinkedBlockingDeque, ThreadFactory, ThreadPoolExecutor, TimeUnit}
+
+import org.apache.spark._
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
+import org.apache.spark.serializer.SerializerInstance
+
+/**
+ * Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
+ */
+private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
+ extends Logging {
+ private val MIN_THREADS = System.getProperty("spark.resultGetter.minThreads", "4").toInt
+ private val MAX_THREADS = System.getProperty("spark.resultGetter.maxThreads", "4").toInt
+ private val getTaskResultExecutor = new ThreadPoolExecutor(
+ MIN_THREADS,
+ MAX_THREADS,
+ 0L,
+ TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable],
+ new ResultResolverThreadFactory)
+
+ class ResultResolverThreadFactory extends ThreadFactory {
+ private var counter = 0
+ private var PREFIX = "Result resolver thread"
+
+ override def newThread(r: Runnable): Thread = {
+ val thread = new Thread(r, "%s-%s".format(PREFIX, counter))
+ counter += 1
+ thread.setDaemon(true)
+ return thread
+ }
+ }
+
+ protected val serializer = new ThreadLocal[SerializerInstance] {
+ override def initialValue(): SerializerInstance = {
+ return sparkEnv.closureSerializer.newInstance()
+ }
+ }
+
+ def enqueueSuccessfulTask(
+ taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) {
+ getTaskResultExecutor.execute(new Runnable {
+ override def run() {
+ try {
+ val result = serializer.get().deserialize[TaskResult[_]](serializedData) match {
+ case directResult: DirectTaskResult[_] => directResult
+ case IndirectTaskResult(blockId) =>
+ logDebug("Fetching indirect task result for TID %s".format(tid))
+ val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
+ if (!serializedTaskResult.isDefined) {
+ /* We won't be able to get the task result if the machine that ran the task failed
+ * between when the task ended and when we tried to fetch the result, or if the
+ * block manager had to flush the result. */
+ scheduler.handleFailedTask(
+ taskSetManager, tid, TaskState.FINISHED, Some(TaskResultLost))
+ return
+ }
+ val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
+ serializedTaskResult.get)
+ sparkEnv.blockManager.master.removeBlock(blockId)
+ deserializedResult
+ }
+ result.metrics.resultSize = serializedData.limit()
+ scheduler.handleSuccessfulTask(taskSetManager, tid, result)
+ } catch {
+ case cnf: ClassNotFoundException =>
+ val loader = Thread.currentThread.getContextClassLoader
+ taskSetManager.abort("ClassNotFound with classloader: " + loader)
+ case ex =>
+ taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex))
+ }
+ }
+ })
+ }
+
+ def enqueueFailedTask(taskSetManager: ClusterTaskSetManager, tid: Long, taskState: TaskState,
+ serializedData: ByteBuffer) {
+ var reason: Option[TaskEndReason] = None
+ getTaskResultExecutor.execute(new Runnable {
+ override def run() {
+ try {
+ if (serializedData != null && serializedData.limit() > 0) {
+ reason = Some(serializer.get().deserialize[TaskEndReason](
+ serializedData, getClass.getClassLoader))
+ }
+ } catch {
+ case cnd: ClassNotFoundException =>
+ // Log an error but keep going here -- the task failed, so not catastropic if we can't
+ // deserialize the reason.
+ val loader = Thread.currentThread.getContextClassLoader
+ logError(
+ "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
+ case ex => {}
+ }
+ scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
+ }
+ })
+ }
+
+ def stop() {
+ getTaskResultExecutor.shutdownNow()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
index e29438f4ed..4d1bb1c639 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
@@ -91,7 +91,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
var rootPool: Pool = null
val schedulingMode: SchedulingMode = SchedulingMode.withName(
System.getProperty("spark.scheduler.mode", "FIFO"))
- val activeTaskSets = new HashMap[String, TaskSetManager]
+ val activeTaskSets = new HashMap[String, LocalTaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
@@ -210,7 +210,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
deserializedTask.metrics.get.executorRunTime = serviceTime.toInt
deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime
deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
- val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null))
+ val taskResult = new DirectTaskResult(
+ result, accumUpdates, deserializedTask.metrics.getOrElse(null))
val serializedResult = ser.serialize(taskResult)
localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
index a2fda4c124..c2e2399ccb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
@@ -21,16 +21,16 @@ import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
-import org.apache.spark.{ExceptionFailure, Logging, SparkEnv, Success, TaskState}
+import org.apache.spark.{ExceptionFailure, Logging, SparkEnv, SparkException, Success, TaskState}
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.{Schedulable, Task, TaskDescription, TaskInfo, TaskLocality,
- TaskResult, TaskSet, TaskSetManager}
+import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Pool, Schedulable, Task,
+ TaskDescription, TaskInfo, TaskLocality, TaskResult, TaskSet, TaskSetManager}
private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet)
extends TaskSetManager with Logging {
- var parent: Schedulable = null
+ var parent: Pool = null
var weight: Int = 1
var minShare: Int = 0
var runningTasks: Int = 0
@@ -49,14 +49,14 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
val numFailures = new Array[Int](numTasks)
val MAX_TASK_FAILURES = sched.maxFailures
- override def increaseRunningTasks(taskNum: Int): Unit = {
+ def increaseRunningTasks(taskNum: Int): Unit = {
runningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
- override def decreaseRunningTasks(taskNum: Int): Unit = {
+ def decreaseRunningTasks(taskNum: Int): Unit = {
runningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
@@ -132,7 +132,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
return None
}
- override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
SparkEnv.set(env)
state match {
case TaskState.FINISHED =>
@@ -152,7 +152,12 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
val index = info.index
val task = taskSet.tasks(index)
info.markSuccessful()
- val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
+ val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) match {
+ case directResult: DirectTaskResult[_] => directResult
+ case IndirectTaskResult(blockId) => {
+ throw new SparkException("Expect only DirectTaskResults when using LocalScheduler")
+ }
+ }
result.metrics.resultSize = serializedData.limit()
sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
numFinished += 1
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 13b98a51a1..a5e792d896 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -484,7 +484,7 @@ private[spark] class BlockManager(
for (loc <- locations) {
logDebug("Getting remote block " + blockId + " from " + loc)
val data = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
+ GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
if (data != null) {
return Some(dataDeserialize(blockId, data))
}
@@ -495,6 +495,31 @@ private[spark] class BlockManager(
}
/**
+ * Get block from remote block managers as serialized bytes.
+ */
+ def getRemoteBytes(blockId: String): Option[ByteBuffer] = {
+ // TODO: As with getLocalBytes, this is very similar to getRemote and perhaps should be
+ // refactored.
+ if (blockId == null) {
+ throw new IllegalArgumentException("Block Id is null")
+ }
+ logDebug("Getting remote block " + blockId + " as bytes")
+
+ val locations = master.getLocations(blockId)
+ for (loc <- locations) {
+ logDebug("Getting remote block " + blockId + " from " + loc)
+ val data = BlockManagerWorker.syncGetBlock(
+ GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
+ if (data != null) {
+ return Some(data)
+ }
+ logDebug("The value of block " + blockId + " is null")
+ }
+ logDebug("Block " + blockId + " not found")
+ return None
+ }
+
+ /**
* Get a block from the block manager (either local or remote).
*/
def get(blockId: String): Option[Iterator[Any]] = {
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index c719a54a61..a31988a729 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -332,6 +332,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
}
exception.getMessage should endWith("result exceeded Akka frame size")
}
+
}
object DistributedSuite {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 9ed591e494..2f933246b0 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -32,8 +32,6 @@ import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency}
import org.apache.spark.{FetchFailed, Success, TaskEndReason}
import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster}
-import org.apache.spark.scheduler.Pool
-import org.apache.spark.scheduler.SchedulingMode
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
/**
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
index 1b50ce06b3..95d3553d91 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
@@ -43,16 +43,16 @@ class FakeTaskSetManager(
stageId = initStageId
name = "TaskSet_"+stageId
override val numTasks = initNumTasks
- tasksFinished = 0
+ tasksSuccessful = 0
- override def increaseRunningTasks(taskNum: Int) {
+ def increaseRunningTasks(taskNum: Int) {
runningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
- override def decreaseRunningTasks(taskNum: Int) {
+ def decreaseRunningTasks(taskNum: Int) {
runningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
@@ -79,7 +79,7 @@ class FakeTaskSetManager(
maxLocality: TaskLocality.TaskLocality)
: Option[TaskDescription] =
{
- if (tasksFinished + runningTasks < numTasks) {
+ if (tasksSuccessful + runningTasks < numTasks) {
increaseRunningTasks(1)
return Some(new TaskDescription(0, execId, "task 0:0", 0, null))
}
@@ -92,8 +92,8 @@ class FakeTaskSetManager(
def taskFinished() {
decreaseRunningTasks(1)
- tasksFinished +=1
- if (tasksFinished == numTasks) {
+ tasksSuccessful +=1
+ if (tasksSuccessful == numTasks) {
parent.removeSchedulable(this)
}
}
@@ -114,7 +114,8 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
val taskSetQueue = rootPool.getSortedTaskSetQueue()
/* Just for Test*/
for (manager <- taskSetQueue) {
- logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks))
+ logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(
+ manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks))
}
for (taskSet <- taskSetQueue) {
taskSet.resourceOffer("execId_1", "hostname_1", 1, TaskLocality.ANY) match {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
index ff70a2cdf0..80d0c5a5e9 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
@@ -40,6 +40,7 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /*
val startedTasks = new ArrayBuffer[Long]
val endedTasks = new mutable.HashMap[Long, TaskEndReason]
val finishedManagers = new ArrayBuffer[TaskSetManager]
+ val taskSetsFailed = new ArrayBuffer[String]
val executors = new mutable.HashMap[String, String] ++ liveExecutors
@@ -63,7 +64,9 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /*
def executorLost(execId: String) {}
- def taskSetFailed(taskSet: TaskSet, reason: String) {}
+ def taskSetFailed(taskSet: TaskSet, reason: String) {
+ taskSetsFailed += taskSet.id
+ }
}
def removeExecutor(execId: String): Unit = executors -= execId
@@ -101,7 +104,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
assert(manager.resourceOffer("exec1", "host1", 2, PROCESS_LOCAL) === None)
// Tell it the task has finished
- manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0))
+ manager.handleSuccessfulTask(0, createTaskResult(0))
assert(sched.endedTasks(0) === Success)
assert(sched.finishedManagers.contains(manager))
}
@@ -125,14 +128,14 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None)
// Finish the first two tasks
- manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0))
- manager.statusUpdate(1, TaskState.FINISHED, createTaskResult(1))
+ manager.handleSuccessfulTask(0, createTaskResult(0))
+ manager.handleSuccessfulTask(1, createTaskResult(1))
assert(sched.endedTasks(0) === Success)
assert(sched.endedTasks(1) === Success)
assert(!sched.finishedManagers.contains(manager))
// Finish the last task
- manager.statusUpdate(2, TaskState.FINISHED, createTaskResult(2))
+ manager.handleSuccessfulTask(2, createTaskResult(2))
assert(sched.endedTasks(2) === Success)
assert(sched.finishedManagers.contains(manager))
}
@@ -253,6 +256,47 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None)
}
+ test("task result lost") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
+ val taskSet = createTaskSet(1)
+ val clock = new FakeClock
+ val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
+
+ // Tell it the task has finished but the result was lost.
+ manager.handleFailedTask(0, TaskState.FINISHED, Some(TaskResultLost))
+ assert(sched.endedTasks(0) === TaskResultLost)
+
+ // Re-offer the host -- now we should get task 0 again.
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
+ }
+
+ test("repeated failures lead to task set abortion") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
+ val taskSet = createTaskSet(1)
+ val clock = new FakeClock
+ val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+
+ // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted
+ // after the last failure.
+ (0 until manager.MAX_TASK_FAILURES).foreach { index =>
+ val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY)
+ assert(offerResult != None,
+ "Expect resource offer on iteration %s to return a task".format(index))
+ assert(offerResult.get.index === 0)
+ manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost))
+ if (index < manager.MAX_TASK_FAILURES) {
+ assert(!sched.taskSetsFailed.contains(taskSet.id))
+ } else {
+ assert(sched.taskSetsFailed.contains(taskSet.id))
+ }
+ }
+ }
+
+
/**
* Utility method to create a TaskSet, potentially setting a particular sequence of preferred
* locations for each task (given as varargs) if this sequence is not empty.
@@ -267,7 +311,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
new TaskSet(tasks, 0, 0, 0, null)
}
- def createTaskResult(id: Int): ByteBuffer = {
- ByteBuffer.wrap(Utils.serialize(new TaskResult[Int](id, mutable.Map.empty, new TaskMetrics)))
+ def createTaskResult(id: Int): DirectTaskResult[Int] = {
+ new DirectTaskResult[Int](id, mutable.Map.empty, new TaskMetrics)
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
new file mode 100644
index 0000000000..119ba30090
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import java.nio.ByteBuffer
+
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv}
+import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
+
+/**
+ * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter.
+ *
+ * Used to test the case where a BlockManager evicts the task result (or dies) before the
+ * TaskResult is retrieved.
+ */
+class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
+ extends TaskResultGetter(sparkEnv, scheduler) {
+ var removedResult = false
+
+ override def enqueueSuccessfulTask(
+ taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) {
+ if (!removedResult) {
+ // Only remove the result once, since we'd like to test the case where the task eventually
+ // succeeds.
+ serializer.get().deserialize[TaskResult[_]](serializedData) match {
+ case IndirectTaskResult(blockId) =>
+ sparkEnv.blockManager.master.removeBlock(blockId)
+ case directResult: DirectTaskResult[_] =>
+ taskSetManager.abort("Internal error: expect only indirect results")
+ }
+ serializedData.rewind()
+ removedResult = true
+ }
+ super.enqueueSuccessfulTask(taskSetManager, tid, serializedData)
+ }
+}
+
+/**
+ * Tests related to handling task results (both direct and indirect).
+ */
+class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndAfterAll
+ with LocalSparkContext {
+
+ override def beforeAll {
+ // Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small
+ // as we can make it) so the tests don't take too long.
+ System.setProperty("spark.akka.frameSize", "1")
+ }
+
+ before {
+ // Use local-cluster mode because results are returned differently when running with the
+ // LocalScheduler.
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+ }
+
+ override def afterAll {
+ System.clearProperty("spark.akka.frameSize")
+ }
+
+ test("handling results smaller than Akka frame size") {
+ val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
+ assert(result === 2)
+ }
+
+ test("handling results larger than Akka frame size") {
+ val akkaFrameSize =
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+ val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
+ assert(result === 1.to(akkaFrameSize).toArray)
+
+ val RESULT_BLOCK_ID = "taskresult_0"
+ assert(sc.env.blockManager.master.getLocations(RESULT_BLOCK_ID).size === 0,
+ "Expect result to be removed from the block manager.")
+ }
+
+ test("task retried if result missing from block manager") {
+ // If this test hangs, it's probably because no resource offers were made after the task
+ // failed.
+ val scheduler: ClusterScheduler = sc.taskScheduler match {
+ case clusterScheduler: ClusterScheduler =>
+ clusterScheduler
+ case _ =>
+ assert(false, "Expect local cluster to use ClusterScheduler")
+ throw new ClassCastException
+ }
+ scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler)
+ val akkaFrameSize =
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+ val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
+ assert(result === 1.to(akkaFrameSize).toArray)
+
+ // Make sure two tasks were run (one failed one, and a second retried one).
+ assert(scheduler.nextTaskId.get() === 2)
+ }
+}
+
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index c611db0af4..30128ec45d 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -50,6 +50,7 @@ The command to launch the YARN Client is as follows:
--master-memory <MEMORY_FOR_MASTER> \
--worker-memory <MEMORY_PER_WORKER> \
--worker-cores <CORES_PER_WORKER> \
+ --name <application_name> \
--queue <queue_name>
For example:
diff --git a/ec2/README b/ec2/README
index 0add81312c..433da37b4c 100644
--- a/ec2/README
+++ b/ec2/README
@@ -1,4 +1,4 @@
This folder contains a script, spark-ec2, for launching Spark clusters on
Amazon EC2. Usage instructions are available online at:
-http://spark-project.org/docs/latest/ec2-scripts.html
+http://spark.incubator.apache.org/docs/latest/ec2-scripts.html
diff --git a/make-distribution.sh b/make-distribution.sh
index bffb19843c..32bbdb90a5 100755
--- a/make-distribution.sh
+++ b/make-distribution.sh
@@ -95,7 +95,7 @@ cp $FWDIR/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/jars/"
# Copy other things
mkdir "$DISTDIR"/conf
-cp "$FWDIR/conf/*.template" "$DISTDIR"/conf
+cp "$FWDIR"/conf/*.template "$DISTDIR"/conf
cp -r "$FWDIR/bin" "$DISTDIR"
cp -r "$FWDIR/python" "$DISTDIR"
cp "$FWDIR/spark-class" "$DISTDIR"
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 67d03f987f..19d3aa23ad 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -97,6 +97,9 @@ object SparkBuild extends Build {
// Only allow one test at a time, even across projects, since they run in the same JVM
concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
+ // also check the local Maven repository ~/.m2
+ resolvers ++= Seq(Resolver.file("Local Maven Repo", file(Path.userHome + "/.m2/repository"))),
+
// For Sonatype publishing
resolvers ++= Seq("sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots",
"sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/"),
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 3362010106..076dd3c9b0 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -106,7 +106,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
logInfo("Setting up application submission context for ASM")
val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
appContext.setApplicationId(appId)
- appContext.setApplicationName("Spark")
+ appContext.setApplicationName(args.appName)
return appContext
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index cd651904d2..c56dbd99ba 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -32,6 +32,7 @@ class ClientArguments(val args: Array[String]) {
var numWorkers = 2
var amQueue = System.getProperty("QUEUE", "default")
var amMemory: Int = 512
+ var appName: String = "Spark"
// TODO
var inputFormatInfo: List[InputFormatInfo] = null
@@ -78,6 +79,10 @@ class ClientArguments(val args: Array[String]) {
amQueue = value
args = tail
+ case ("--name") :: value :: tail =>
+ appName = value
+ args = tail
+
case Nil =>
if (userJar == null || userClass == null) {
printUsageAndExit(1)
@@ -108,6 +113,7 @@ class ClientArguments(val args: Array[String]) {
" --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
" --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
" --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
+ " --name NAME The name of your application (Default: Spark)\n" +
" --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')"
)
System.exit(exitCode)