aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Wendell <pwendell@gmail.com>2013-10-14 22:25:47 -0700
committerPatrick Wendell <pwendell@gmail.com>2013-10-14 22:25:47 -0700
commite33b1839e27249e232a2126cec67a38109e03243 (patch)
treec25eb6810167f7a0a19fd210712ca4c209bfb48a
parent3b11f43e36e2aca2346db7542c52fcbbeee70da2 (diff)
parent9cd8786e4a6aebe29dabe802cc177e4338e140e6 (diff)
downloadspark-e33b1839e27249e232a2126cec67a38109e03243.tar.gz
spark-e33b1839e27249e232a2126cec67a38109e03243.tar.bz2
spark-e33b1839e27249e232a2126cec67a38109e03243.zip
Merge pull request #29 from rxin/kill
Job killing Moving https://github.com/mesos/spark/pull/935 here The high level idea is to have an "interrupted" field in TaskContext, and a task should check that flag to determine if its execution should continue. For convenience, I provide an InterruptibleIterator which wraps around a normal iterator but checks for the interrupted flag. I also provide an InterruptibleRDD that wraps around an existing RDD. As part of this pull request, I added an AsyncRDDActions class that provides a number of RDD actions that return a FutureJob (extending scala.concurrent.Future). The FutureJob can be used to kill the job execution, or waits until the job finishes. This is NOT ready for merging yet. Remaining TODOs: 1. Add unit tests 2. Add job killing functionality for local scheduler (current job killing functionality only works in cluster scheduler) ------------- Update on Oct 10, 2013: This is ready! Related future work: - Figure out how to handle the job triggered by RangePartitioner (this one is tough; might become future work) - Java API - Python API
-rw-r--r--core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/CacheManager.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/FutureAction.scala250
-rw-r--r--core/src/main/scala/org/apache/spark/InterruptibleIterator.scala30
-rw-r--r--core/src/main/scala/org/apache/spark/ShuffleFetcher.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala37
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/TaskEndReason.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala118
-rw-r--r--core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala122
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala64
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala (renamed from core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala)12
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala79
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala95
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala118
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala62
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Pool.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala46
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala63
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala37
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala81
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala190
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala13
-rw-r--r--core/src/test/scala/org/apache/spark/CacheManagerSuite.scala9
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/JavaAPISuite.java2
-rw-r--r--core/src/test/scala/org/apache/spark/JobCancellationSuite.scala177
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala176
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala14
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala28
50 files changed, 1528 insertions, 515 deletions
diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
index f8af6b0fbe..d9ed572da6 100644
--- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
@@ -28,7 +28,11 @@ import org.apache.spark.util.CompletionIterator
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
- override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer)
+ override def fetch[T](
+ shuffleId: Int,
+ reduceId: Int,
+ context: TaskContext,
+ serializer: Serializer)
: Iterator[T] =
{
@@ -73,7 +77,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val itr = blockFetcherItr.flatMap(unpackBlock)
- CompletionIterator[T, Iterator[T]](itr, {
+ val completionIter = CompletionIterator[T, Iterator[T]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
@@ -82,7 +86,9 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
- metrics.shuffleReadMetrics = Some(shuffleMetrics)
+ context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics)
})
+
+ new InterruptibleIterator[T](context, completionIter)
}
}
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index 221bb68c61..519ecde50a 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -38,7 +38,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
blockManager.get(key) match {
case Some(values) =>
// Partition is already materialized, so just return its values
- return values.asInstanceOf[Iterator[T]]
+ return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
case None =>
// Mark the split as loading (unless someone else marks it first)
@@ -56,7 +56,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
// downside of the current code is that threads wait serially if this does happen.
blockManager.get(key) match {
case Some(values) =>
- return values.asInstanceOf[Iterator[T]]
+ return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
case None =>
logInfo("Whoever was loading %s failed; we'll try it ourselves".format(key))
loading.add(key)
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
new file mode 100644
index 0000000000..1ad9240cfa
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -0,0 +1,250 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.concurrent._
+import scala.concurrent.duration.Duration
+import scala.util.Try
+
+import org.apache.spark.scheduler.{JobSucceeded, JobWaiter}
+import org.apache.spark.scheduler.JobFailed
+import org.apache.spark.rdd.RDD
+
+
+/**
+ * A future for the result of an action. This is an extension of the Scala Future interface to
+ * support cancellation.
+ */
+trait FutureAction[T] extends Future[T] {
+ // Note that we redefine methods of the Future trait here explicitly so we can specify a different
+ // documentation (with reference to the word "action").
+
+ /**
+ * Cancels the execution of this action.
+ */
+ def cancel()
+
+ /**
+ * Blocks until this action completes.
+ * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf
+ * for unbounded waiting, or a finite positive duration
+ * @return this FutureAction
+ */
+ override def ready(atMost: Duration)(implicit permit: CanAwait): FutureAction.this.type
+
+ /**
+ * Awaits and returns the result (of type T) of this action.
+ * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf
+ * for unbounded waiting, or a finite positive duration
+ * @throws Exception exception during action execution
+ * @return the result value if the action is completed within the specific maximum wait time
+ */
+ @throws(classOf[Exception])
+ override def result(atMost: Duration)(implicit permit: CanAwait): T
+
+ /**
+ * When this action is completed, either through an exception, or a value, applies the provided
+ * function.
+ */
+ def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext)
+
+ /**
+ * Returns whether the action has already been completed with a value or an exception.
+ */
+ override def isCompleted: Boolean
+
+ /**
+ * The value of this Future.
+ *
+ * If the future is not completed the returned value will be None. If the future is completed
+ * the value will be Some(Success(t)) if it contains a valid result, or Some(Failure(error)) if
+ * it contains an exception.
+ */
+ override def value: Option[Try[T]]
+
+ /**
+ * Blocks and returns the result of this job.
+ */
+ @throws(classOf[Exception])
+ def get(): T = Await.result(this, Duration.Inf)
+}
+
+
+/**
+ * The future holding the result of an action that triggers a single job. Examples include
+ * count, collect, reduce.
+ */
+class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
+ extends FutureAction[T] {
+
+ override def cancel() {
+ jobWaiter.cancel()
+ }
+
+ override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
+ if (!atMost.isFinite()) {
+ awaitResult()
+ } else {
+ val finishTime = System.currentTimeMillis() + atMost.toMillis
+ while (!isCompleted) {
+ val time = System.currentTimeMillis()
+ if (time >= finishTime) {
+ throw new TimeoutException
+ } else {
+ jobWaiter.wait(finishTime - time)
+ }
+ }
+ }
+ this
+ }
+
+ @throws(classOf[Exception])
+ override def result(atMost: Duration)(implicit permit: CanAwait): T = {
+ ready(atMost)(permit)
+ awaitResult() match {
+ case scala.util.Success(res) => res
+ case scala.util.Failure(e) => throw e
+ }
+ }
+
+ override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) {
+ executor.execute(new Runnable {
+ override def run() {
+ func(awaitResult())
+ }
+ })
+ }
+
+ override def isCompleted: Boolean = jobWaiter.jobFinished
+
+ override def value: Option[Try[T]] = {
+ if (jobWaiter.jobFinished) {
+ Some(awaitResult())
+ } else {
+ None
+ }
+ }
+
+ private def awaitResult(): Try[T] = {
+ jobWaiter.awaitResult() match {
+ case JobSucceeded => scala.util.Success(resultFunc)
+ case JobFailed(e: Exception, _) => scala.util.Failure(e)
+ }
+ }
+}
+
+
+/**
+ * A FutureAction for actions that could trigger multiple Spark jobs. Examples include take,
+ * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the
+ * action thread if it is being blocked by a job.
+ */
+class ComplexFutureAction[T] extends FutureAction[T] {
+
+ // Pointer to the thread that is executing the action. It is set when the action is run.
+ @volatile private var thread: Thread = _
+
+ // A flag indicating whether the future has been cancelled. This is used in case the future
+ // is cancelled before the action was even run (and thus we have no thread to interrupt).
+ @volatile private var _cancelled: Boolean = false
+
+ // A promise used to signal the future.
+ private val p = promise[T]()
+
+ override def cancel(): Unit = this.synchronized {
+ _cancelled = true
+ if (thread != null) {
+ thread.interrupt()
+ }
+ }
+
+ /**
+ * Executes some action enclosed in the closure. To properly enable cancellation, the closure
+ * should use runJob implementation in this promise. See takeAsync for example.
+ */
+ def run(func: => T)(implicit executor: ExecutionContext): this.type = {
+ scala.concurrent.future {
+ thread = Thread.currentThread
+ try {
+ p.success(func)
+ } catch {
+ case e: Exception => p.failure(e)
+ } finally {
+ thread = null
+ }
+ }
+ this
+ }
+
+ /**
+ * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext
+ * to enable cancellation.
+ */
+ def runJob[T, U, R](
+ rdd: RDD[T],
+ processPartition: Iterator[T] => U,
+ partitions: Seq[Int],
+ resultHandler: (Int, U) => Unit,
+ resultFunc: => R) {
+ // If the action hasn't been cancelled yet, submit the job. The check and the submitJob
+ // command need to be in an atomic block.
+ val job = this.synchronized {
+ if (!cancelled) {
+ rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc)
+ } else {
+ throw new SparkException("Action has been cancelled")
+ }
+ }
+
+ // Wait for the job to complete. If the action is cancelled (with an interrupt),
+ // cancel the job and stop the execution. This is not in a synchronized block because
+ // Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
+ try {
+ Await.ready(job, Duration.Inf)
+ } catch {
+ case e: InterruptedException =>
+ job.cancel()
+ throw new SparkException("Action has been cancelled")
+ }
+ }
+
+ /**
+ * Returns whether the promise has been cancelled.
+ */
+ def cancelled: Boolean = _cancelled
+
+ @throws(classOf[InterruptedException])
+ @throws(classOf[scala.concurrent.TimeoutException])
+ override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = {
+ p.future.ready(atMost)(permit)
+ this
+ }
+
+ @throws(classOf[Exception])
+ override def result(atMost: Duration)(implicit permit: CanAwait): T = {
+ p.future.result(atMost)(permit)
+ }
+
+ override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit = {
+ p.future.onComplete(func)(executor)
+ }
+
+ override def isCompleted: Boolean = p.isCompleted
+
+ override def value: Option[Try[T]] = p.future.value
+}
diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
new file mode 100644
index 0000000000..56e0b8d2c0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * An iterator that wraps around an existing iterator to provide task killing functionality.
+ * It works by checking the interrupted flag in TaskContext.
+ */
+class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
+ extends Iterator[T] {
+
+ def hasNext: Boolean = !context.interrupted && delegate.hasNext
+
+ def next(): T = delegate.next()
+}
diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
index 307c383a89..a85aa50a9b 100644
--- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
@@ -27,7 +27,10 @@ private[spark] abstract class ShuffleFetcher {
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
- def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
+ def fetch[T](
+ shuffleId: Int,
+ reduceId: Int,
+ context: TaskContext,
serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T]
/** Stop the fetcher */
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index b7a8179786..f674cb397f 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -755,10 +755,11 @@ class SparkContext(
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
val callSite = Utils.formatSparkCallSite
+ val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler,
- localProperties.get)
+ val result = dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
+ resultHandler, localProperties.get)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
@@ -848,6 +849,36 @@ class SparkContext(
}
/**
+ * Submit a job for execution and return a FutureJob holding the result.
+ */
+ def submitJob[T, U, R](
+ rdd: RDD[T],
+ processPartition: Iterator[T] => U,
+ partitions: Seq[Int],
+ resultHandler: (Int, U) => Unit,
+ resultFunc: => R): SimpleFutureAction[R] =
+ {
+ val cleanF = clean(processPartition)
+ val callSite = Utils.formatSparkCallSite
+ val waiter = dagScheduler.submitJob(
+ rdd,
+ (context: TaskContext, iter: Iterator[T]) => cleanF(iter),
+ partitions,
+ callSite,
+ allowLocal = false,
+ resultHandler,
+ null)
+ new SimpleFutureAction(waiter, resultFunc)
+ }
+
+ /**
+ * Cancel all jobs that have been scheduled or are running.
+ */
+ def cancelAllJobs() {
+ dagScheduler.cancelAllJobs()
+ }
+
+ /**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
*/
@@ -930,6 +961,8 @@ object SparkContext {
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
+ implicit def rddToAsyncRDDActions[T: ClassManifest](rdd: RDD[T]) = new AsyncRDDActions(rdd)
+
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index c2c358c7ad..51584d686d 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -17,21 +17,30 @@
package org.apache.spark
-import executor.TaskMetrics
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.executor.TaskMetrics
+
class TaskContext(
- val stageId: Int,
- val splitId: Int,
+ private[spark] val stageId: Int,
+ val partitionId: Int,
val attemptId: Long,
val runningLocally: Boolean = false,
- val taskMetrics: TaskMetrics = TaskMetrics.empty()
+ @volatile var interrupted: Boolean = false,
+ private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty()
) extends Serializable {
- @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
+ @deprecated("use partitionId", "0.8.1")
+ def splitId = partitionId
+
+ // List of callback functions to execute when the task completes.
+ @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
- // Add a callback function to be executed on task completion. An example use
- // is for HadoopRDD to register a callback to close the input stream.
+ /**
+ * Add a callback function to be executed on task completion. An example use
+ * is for HadoopRDD to register a callback to close the input stream.
+ * @param f Callback function.
+ */
def addOnCompleteCallback(f: () => Unit) {
onCompleteCallbacks += f
}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 8466c2a004..c1e5e04b31 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -52,4 +52,6 @@ private[spark] case class ExceptionFailure(
*/
private[spark] case object TaskResultLost extends TaskEndReason
+private[spark] case object TaskKilled extends TaskEndReason
+
private[spark] case class OtherFailure(message: String) extends TaskEndReason
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index eff0c0f274..20323ea038 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -36,7 +36,8 @@ import org.apache.spark.util.Utils
private[spark] class Executor(
executorId: String,
slaveHostname: String,
- properties: Seq[(String, String)])
+ properties: Seq[(String, String)],
+ isLocal: Boolean = false)
extends Logging
{
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
@@ -101,18 +102,47 @@ private[spark] class Executor(
val executorSource = new ExecutorSource(this, executorId)
// Initialize Spark environment (using system properties read above)
- val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
- SparkEnv.set(env)
- env.metricsSystem.registerSource(executorSource)
+ private val env = {
+ if (!isLocal) {
+ val _env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0,
+ isDriver = false, isLocal = false)
+ SparkEnv.set(_env)
+ _env.metricsSystem.registerSource(executorSource)
+ _env
+ } else {
+ SparkEnv.get
+ }
+ }
- private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
+ // Akka's message frame size. If task result is bigger than this, we use the block manager
+ // to send the result back.
+ private val akkaFrameSize = {
+ env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
+ }
// Start worker thread pool
val threadPool = new ThreadPoolExecutor(
- 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
+ 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable], Utils.daemonThreadFactory)
+
+ // Maintains the list of running tasks.
+ private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
- threadPool.execute(new TaskRunner(context, taskId, serializedTask))
+ val tr = new TaskRunner(context, taskId, serializedTask)
+ runningTasks.put(taskId, tr)
+ threadPool.execute(tr)
+ }
+
+ def killTask(taskId: Long) {
+ val tr = runningTasks.get(taskId)
+ if (tr != null) {
+ tr.kill()
+ // We remove the task also in the finally block in TaskRunner.run.
+ // The reason we need to remove it here is because killTask might be called before the task
+ // is even launched, and never reaching that finally block. ConcurrentHashMap's remove is
+ // idempotent.
+ runningTasks.remove(taskId)
+ }
}
/** Get the Yarn approved local directories. */
@@ -124,49 +154,80 @@ private[spark] class Executor(
.getOrElse(Option(System.getenv("LOCAL_DIRS"))
.getOrElse(""))
- if (localDirs.isEmpty()) {
+ if (localDirs.isEmpty) {
throw new Exception("Yarn Local dirs can't be empty")
}
- return localDirs
+ localDirs
}
- class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
+ class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
extends Runnable {
+ @volatile private var killed = false
+ @volatile private var task: Task[Any] = _
+
+ def kill() {
+ logInfo("Executor is trying to kill task " + taskId)
+ killed = true
+ if (task != null) {
+ task.kill()
+ }
+ }
+
override def run() {
val startTime = System.currentTimeMillis()
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId)
- context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
+ execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var attemptedTask: Option[Task[Any]] = None
var taskStart: Long = 0
- def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum
- val startGCTime = getTotalGCTime
+ def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
+ val startGCTime = gcTime
try {
SparkEnv.set(env)
Accumulators.clear()
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
- val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+ task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+
+ // If this task has been killed before we deserialized it, let's quit now. Otherwise,
+ // continue executing the task.
+ if (killed) {
+ logInfo("Executor killed task " + taskId)
+ execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
+ return
+ }
+
attemptedTask = Some(task)
- logInfo("Its epoch is " + task.epoch)
+ logDebug("Task " + taskId +"'s epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch)
+
+ // Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
val value = task.run(taskId.toInt)
val taskFinish = System.currentTimeMillis()
+
+ // If the task has been killed, let's fail it.
+ if (task.killed) {
+ logInfo("Executor killed task " + taskId)
+ execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
+ return
+ }
+
for (m <- task.metrics) {
- m.hostname = Utils.localHostName
+ m.hostname = Utils.localHostName()
m.executorDeserializeTime = (taskStart - startTime).toInt
m.executorRunTime = (taskFinish - taskStart).toInt
- m.jvmGCTime = getTotalGCTime - startGCTime
+ m.jvmGCTime = gcTime - startGCTime
}
- //TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c
- // we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could
- // just change the relevants bytes in the byte buffer
+ // TODO I'd also like to track the time it takes to serialize the task results, but that is
+ // huge headache, b/c we need to serialize the task metrics first. If TaskMetrics had a
+ // custom serialized format, we could just change the relevants bytes in the byte buffer
val accumUpdates = Accumulators.values
+
val directResult = new DirectTaskResult(value, accumUpdates, task.metrics.getOrElse(null))
val serializedDirectResult = ser.serialize(directResult)
logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit)
@@ -182,12 +243,13 @@ private[spark] class Executor(
serializedDirectResult
}
}
- context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
+
+ execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
logInfo("Finished task ID " + taskId)
} catch {
case ffe: FetchFailedException => {
val reason = ffe.toTaskEndReason
- context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
+ execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}
case t: Throwable => {
@@ -195,10 +257,10 @@ private[spark] class Executor(
val metrics = attemptedTask.flatMap(t => t.metrics)
for (m <- metrics) {
m.executorRunTime = serviceTime
- m.jvmGCTime = getTotalGCTime - startGCTime
+ m.jvmGCTime = gcTime - startGCTime
}
val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
- context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
+ execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
// TODO: Should we exit the whole executor here? On the one hand, the failed task may
// have left some weird state around depending on when the exception was thrown, but on
@@ -206,6 +268,8 @@ private[spark] class Executor(
logError("Exception in task ID " + taskId, t)
//System.exit(1)
}
+ } finally {
+ runningTasks.remove(taskId)
}
}
}
@@ -215,7 +279,7 @@ private[spark] class Executor(
* created by the interpreter to the search path
*/
private def createClassLoader(): ExecutorURLClassLoader = {
- var loader = this.getClass.getClassLoader
+ val loader = this.getClass.getClassLoader
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
@@ -237,7 +301,7 @@ private[spark] class Executor(
val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
- return constructor.newInstance(classUri, parent)
+ constructor.newInstance(classUri, parent)
} catch {
case _: ClassNotFoundException =>
logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!")
@@ -245,7 +309,7 @@ private[spark] class Executor(
null
}
} else {
- return parent
+ parent
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index da62091980..b56d8c9912 100644
--- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -18,14 +18,18 @@
package org.apache.spark.executor
import java.nio.ByteBuffer
-import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver}
-import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _}
-import org.apache.spark.TaskState.TaskState
+
import com.google.protobuf.ByteString
-import org.apache.spark.{Logging}
+
+import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver}
+import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
+
+import org.apache.spark.Logging
import org.apache.spark.TaskState
+import org.apache.spark.TaskState.TaskState
import org.apache.spark.util.Utils
+
private[spark] class MesosExecutorBackend
extends MesosExecutor
with ExecutorBackend
@@ -71,7 +75,11 @@ private[spark] class MesosExecutorBackend
}
override def killTask(d: ExecutorDriver, t: TaskID) {
- logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)")
+ if (executor == null) {
+ logError("Received KillTask but executor was null")
+ } else {
+ executor.killTask(t.getValue.toLong)
+ }
}
override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {}
diff --git a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala
index 7839023868..db0bea0472 100644
--- a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala
@@ -63,12 +63,20 @@ private[spark] class StandaloneExecutorBackend(
case LaunchTask(taskDesc) =>
logInfo("Got assigned task " + taskDesc.taskId)
if (executor == null) {
- logError("Received launchTask but executor was null")
+ logError("Received LaunchTask command but executor was null")
System.exit(1)
} else {
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
}
+ case KillTask(taskId, _) =>
+ if (executor == null) {
+ logError("Received KillTask command but executor was null")
+ System.exit(1)
+ } else {
+ executor.killTask(taskId)
+ }
+
case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
logError("Driver terminated or disconnected! Shutting down.")
System.exit(1)
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
new file mode 100644
index 0000000000..faaf837be0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
+
+/**
+ * A set of asynchronous RDD actions available through an implicit conversion.
+ * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
+ */
+class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with Logging {
+
+ /**
+ * Returns a future for counting the number of elements in the RDD.
+ */
+ def countAsync(): FutureAction[Long] = {
+ val totalCount = new AtomicLong
+ self.context.submitJob(
+ self,
+ (iter: Iterator[T]) => {
+ var result = 0L
+ while (iter.hasNext) {
+ result += 1L
+ iter.next()
+ }
+ result
+ },
+ Range(0, self.partitions.size),
+ (index: Int, data: Long) => totalCount.addAndGet(data),
+ totalCount.get())
+ }
+
+ /**
+ * Returns a future for retrieving all elements of this RDD.
+ */
+ def collectAsync(): FutureAction[Seq[T]] = {
+ val results = new Array[Array[T]](self.partitions.size)
+ self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size),
+ (index, data) => results(index) = data, results.flatten.toSeq)
+ }
+
+ /**
+ * Returns a future for retrieving the first num elements of the RDD.
+ */
+ def takeAsync(num: Int): FutureAction[Seq[T]] = {
+ val f = new ComplexFutureAction[Seq[T]]
+
+ f.run {
+ val results = new ArrayBuffer[T](num)
+ val totalParts = self.partitions.length
+ var partsScanned = 0
+ while (results.size < num && partsScanned < totalParts) {
+ // The number of partitions to try in this iteration. It is ok for this number to be
+ // greater than totalParts because we actually cap it at totalParts in runJob.
+ var numPartsToTry = 1
+ if (partsScanned > 0) {
+ // If we didn't find any rows after the first iteration, just try all partitions next.
+ // Otherwise, interpolate the number of partitions we need to try, but overestimate it
+ // by 50%.
+ if (results.size == 0) {
+ numPartsToTry = totalParts - 1
+ } else {
+ numPartsToTry = (1.5 * num * partsScanned / results.size).toInt
+ }
+ }
+ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
+
+ val left = num - results.size
+ val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+
+ val buf = new Array[Array[T]](p.size)
+ f.runJob(self,
+ (it: Iterator[T]) => it.take(left).toArray,
+ p,
+ (index: Int, data: Array[T]) => buf(index) = data,
+ Unit)
+
+ buf.foreach(results ++= _.take(num - results.size))
+ partsScanned += numPartsToTry
+ }
+ results.toSeq
+ }
+
+ f
+ }
+
+ /**
+ * Applies a function f to all elements of this RDD.
+ */
+ def foreachAsync(f: T => Unit): FutureAction[Unit] = {
+ self.context.submitJob[T, Unit, Unit](self, _.foreach(f), Range(0, self.partitions.size),
+ (index, data) => Unit, Unit)
+ }
+
+ /**
+ * Applies a function f to each partition of this RDD.
+ */
+ def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = {
+ self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size),
+ (index, data) => Unit, Unit)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 3311757189..ccaaecb85b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -85,7 +85,7 @@ private[spark] object CheckpointRDD extends Logging {
val outputDir = new Path(path)
val fs = outputDir.getFileSystem(env.hadoop.newConfiguration())
- val finalOutputName = splitIdToFile(ctx.splitId)
+ val finalOutputName = splitIdToFile(ctx.partitionId)
val finalOutputPath = new Path(outputDir, finalOutputName)
val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index d237797aa6..911a002884 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -21,7 +21,7 @@ import java.io.{ObjectOutputStream, IOException}
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.util.AppendOnlyMap
@@ -125,12 +125,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
- fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
+ fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach {
kv => getSeq(kv._1)(depNum) += kv._2
}
}
}
- map.iterator
+ new InterruptibleIterator(context, map.iterator)
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 2d394abfd9..fad042c7ae 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -27,8 +27,7 @@ import org.apache.hadoop.mapred.RecordReader
import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
-import org.apache.spark.{Logging, Partition, SerializableWritable, SparkContext, SparkEnv,
- TaskContext}
+import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.util.NextIterator
import org.apache.hadoop.conf.{Configuration, Configurable}
@@ -39,7 +38,7 @@ import org.apache.hadoop.conf.{Configuration, Configurable}
*/
private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit)
extends Partition {
-
+
val inputSplit = new SerializableWritable[InputSplit](s)
override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt
@@ -144,38 +143,41 @@ class HadoopRDD[K, V](
array
}
- override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] {
- val split = theSplit.asInstanceOf[HadoopPartition]
- logInfo("Input split: " + split.inputSplit)
- var reader: RecordReader[K, V] = null
-
- val jobConf = getJobConf()
- val inputFormat = getInputFormat(jobConf)
- reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
-
- // Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback{ () => closeIfNeeded() }
-
- val key: K = reader.createKey()
- val value: V = reader.createValue()
-
- override def getNext() = {
- try {
- finished = !reader.next(key, value)
- } catch {
- case eof: EOFException =>
- finished = true
+ override def compute(theSplit: Partition, context: TaskContext) = {
+ val iter = new NextIterator[(K, V)] {
+ val split = theSplit.asInstanceOf[HadoopPartition]
+ logInfo("Input split: " + split.inputSplit)
+ var reader: RecordReader[K, V] = null
+
+ val jobConf = getJobConf()
+ val inputFormat = getInputFormat(jobConf)
+ reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback{ () => closeIfNeeded() }
+
+ val key: K = reader.createKey()
+ val value: V = reader.createValue()
+
+ override def getNext() = {
+ try {
+ finished = !reader.next(key, value)
+ } catch {
+ case eof: EOFException =>
+ finished = true
+ }
+ (key, value)
}
- (key, value)
- }
- override def close() {
- try {
- reader.close()
- } catch {
- case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ override def close() {
+ try {
+ reader.close()
+ } catch {
+ case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ }
}
}
+ new InterruptibleIterator[(K, V)](context, iter)
}
override def getPreferredLocations(split: Partition): Seq[String] = {
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
index 3ed8339010..aea08ff81b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
@@ -21,14 +21,14 @@ import org.apache.spark.{Partition, TaskContext}
/**
- * A variant of the MapPartitionsRDD that passes the partition index into the
- * closure. This can be used to generate or collect partition specific
- * information such as the number of tuples in a partition.
+ * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the
+ * TaskContext, the closure can either get access to the interruptible flag or get the index
+ * of the partition in the RDD.
*/
private[spark]
-class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest](
+class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
- f: (Int, Iterator[T]) => Iterator[U],
+ f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean
) extends RDD[U](prev) {
@@ -37,5 +37,5 @@ class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest](
override val partitioner = if (preservesPartitioning) prev.partitioner else None
override def compute(split: Partition, context: TaskContext) =
- f(split.index, firstParent[T].iterator(split, context))
+ f(context, firstParent[T].iterator(split, context))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 7b3a89f7e0..2662d48c84 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -24,7 +24,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
-import org.apache.spark.{Dependency, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
private[spark]
@@ -71,49 +71,52 @@ class NewHadoopRDD[K, V](
result
}
- override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
- val split = theSplit.asInstanceOf[NewHadoopPartition]
- logInfo("Input split: " + split.serializableHadoopSplit)
- val conf = confBroadcast.value.value
- val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
- val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
- val format = inputFormatClass.newInstance
- if (format.isInstanceOf[Configurable]) {
- format.asInstanceOf[Configurable].setConf(conf)
- }
- val reader = format.createRecordReader(
- split.serializableHadoopSplit.value, hadoopAttemptContext)
- reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
-
- // Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback(() => close())
-
- var havePair = false
- var finished = false
-
- override def hasNext: Boolean = {
- if (!finished && !havePair) {
- finished = !reader.nextKeyValue
- havePair = !finished
+ override def compute(theSplit: Partition, context: TaskContext) = {
+ val iter = new Iterator[(K, V)] {
+ val split = theSplit.asInstanceOf[NewHadoopPartition]
+ logInfo("Input split: " + split.serializableHadoopSplit)
+ val conf = confBroadcast.value.value
+ val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
+ val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
+ val format = inputFormatClass.newInstance
+ if (format.isInstanceOf[Configurable]) {
+ format.asInstanceOf[Configurable].setConf(conf)
+ }
+ val reader = format.createRecordReader(
+ split.serializableHadoopSplit.value, hadoopAttemptContext)
+ reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback(() => close())
+
+ var havePair = false
+ var finished = false
+
+ override def hasNext: Boolean = {
+ if (!finished && !havePair) {
+ finished = !reader.nextKeyValue
+ havePair = !finished
+ }
+ !finished
}
- !finished
- }
- override def next: (K, V) = {
- if (!hasNext) {
- throw new java.util.NoSuchElementException("End of stream")
+ override def next(): (K, V) = {
+ if (!hasNext) {
+ throw new java.util.NoSuchElementException("End of stream")
+ }
+ havePair = false
+ (reader.getCurrentKey, reader.getCurrentValue)
}
- havePair = false
- return (reader.getCurrentKey, reader.getCurrentValue)
- }
- private def close() {
- try {
- reader.close()
- } catch {
- case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ private def close() {
+ try {
+ reader.close()
+ } catch {
+ case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ }
}
}
+ new InterruptibleIterator(context, iter)
}
override def getPreferredLocations(split: Partition): Seq[String] = {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index a47c512275..93b78e1232 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -84,18 +84,24 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
}
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
- self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
+ self.mapPartitionsWithContext((context, iter) => {
+ new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
+ }, preservesPartitioning = true)
} else if (mapSideCombine) {
val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
.setSerializer(serializerClass)
- partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true)
+ partitioned.mapPartitionsWithContext((context, iter) => {
+ new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter))
+ }, preservesPartitioning = true)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
- values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
+ values.mapPartitionsWithContext((context, iter) => {
+ new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
+ }, preservesPartitioning = true)
}
}
@@ -564,7 +570,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
- val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber)
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.partitionId, attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outputFormatClass.newInstance
val committer = format.getOutputCommitter(hadoopContext)
@@ -663,7 +669,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
- writer.setup(context.stageId, context.splitId, attemptNumber)
+ writer.setup(context.stageId, context.partitionId, attemptNumber)
writer.open()
var count = 0
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index 6dbd4309aa..cd96250389 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -94,8 +94,9 @@ private[spark] class ParallelCollectionRDD[T: ClassManifest](
slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
}
- override def compute(s: Partition, context: TaskContext) =
- s.asInstanceOf[ParallelCollectionPartition[T]].iterator
+ override def compute(s: Partition, context: TaskContext) = {
+ new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
+ }
override def getPreferredLocations(s: Partition): Seq[String] = {
locationPrefs.getOrElse(s.index, Nil)
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 1893627ee2..0355618e43 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -418,26 +418,39 @@ abstract class RDD[T: ClassManifest](
command: Seq[String],
env: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
- printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
+ printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = {
new PipedRDD(this, command, env,
if (printPipeContext ne null) sc.clean(printPipeContext) else null,
if (printRDDElement ne null) sc.clean(printRDDElement) else null)
+ }
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] =
+ def mapPartitions[U: ClassManifest](
+ f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
+ }
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
def mapPartitionsWithIndex[U: ClassManifest](
- f: (Int, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] =
- new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
+ f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
+ val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
+ new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
+ }
+
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD. This is a variant of
+ * mapPartitions that also passes the TaskContext into the closure.
+ */
+ def mapPartitionsWithContext[U: ClassManifest](
+ f: (TaskContext, Iterator[T]) => Iterator[U],
+ preservesPartitioning: Boolean = false): RDD[U] = {
+ new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
+ }
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
@@ -445,22 +458,23 @@ abstract class RDD[T: ClassManifest](
*/
@deprecated("use mapPartitionsWithIndex", "0.7.0")
def mapPartitionsWithSplit[U: ClassManifest](
- f: (Int, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] =
- new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
+ f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
+ mapPartitionsWithIndex(f, preservesPartitioning)
+ }
/**
* Maps f over this RDD, where f takes an additional parameter of type A. This
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
- (f:(T, A) => U): RDD[U] = {
- def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(index)
- iter.map(t => f(t, a))
- }
- new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
+ def mapWith[A: ClassManifest, U: ClassManifest]
+ (constructA: Int => A, preservesPartitioning: Boolean = false)
+ (f: (T, A) => U): RDD[U] = {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
+ val a = constructA(context.partitionId)
+ iter.map(t => f(t, a))
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
@@ -468,13 +482,14 @@ abstract class RDD[T: ClassManifest](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
- (f:(T, A) => Seq[U]): RDD[U] = {
- def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(index)
- iter.flatMap(t => f(t, a))
- }
- new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
+ def flatMapWith[A: ClassManifest, U: ClassManifest]
+ (constructA: Int => A, preservesPartitioning: Boolean = false)
+ (f: (T, A) => Seq[U]): RDD[U] = {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
+ val a = constructA(context.partitionId)
+ iter.flatMap(t => f(t, a))
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
@@ -482,13 +497,12 @@ abstract class RDD[T: ClassManifest](
* This additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def foreachWith[A: ClassManifest](constructA: Int => A)
- (f:(T, A) => Unit) {
- def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(index)
- iter.map(t => {f(t, a); t})
- }
- (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {})
+ def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
+ val a = constructA(context.partitionId)
+ iter.map(t => {f(t, a); t})
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
}
/**
@@ -496,13 +510,12 @@ abstract class RDD[T: ClassManifest](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def filterWith[A: ClassManifest](constructA: Int => A)
- (p:(T, A) => Boolean): RDD[T] = {
- def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(index)
- iter.filter(t => p(t, a))
- }
- new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)
+ def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
+ val a = constructA(context.partitionId)
+ iter.filter(t => p(t, a))
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
}
/**
@@ -541,16 +554,14 @@ abstract class RDD[T: ClassManifest](
* Applies a function f to all elements of this RDD.
*/
def foreach(f: T => Unit) {
- val cleanF = sc.clean(f)
- sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
+ sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f))
}
/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
- val cleanF = sc.clean(f)
- sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
+ sc.runJob(this, (iter: Iterator[T]) => f(iter))
}
/**
@@ -675,6 +686,8 @@ abstract class RDD[T: ClassManifest](
*/
def count(): Long = {
sc.runJob(this, (iter: Iterator[T]) => {
+ // Use a while loop to count the number of elements rather than iter.size because
+ // iter.size uses a for loop, which is slightly slower in current version of Scala.
var result = 0L
while (iter.hasNext) {
result += 1L
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index 9537152335..a5d751a7bd 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -56,7 +56,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest](
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics,
+ SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context,
SparkEnv.get.serializerManager.get(serializerClass))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 8c1a29dfff..7af4d803e7 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -108,7 +108,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
}
case ShuffleCoGroupSplitDep(shuffleId) => {
val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
- context.taskMetrics, serializer)
+ context, serializer)
iter.foreach(op)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 5c51852985..7fb614402b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -41,11 +41,11 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
* locations to run each task on, based on the current cache status, and passes these to the
* low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being
* lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are
- * not caused by shuffie file loss are handled by the TaskScheduler, which will retry each task
+ * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task
* a small number of times before cancelling the whole stage.
*
* THREADING: This class runs all its logic in a single thread executing the run() method, to which
- * events are submitted using a synchonized queue (eventQueue). The public API methods, such as
+ * events are submitted using a synchronized queue (eventQueue). The public API methods, such as
* runJob, taskEnded and executorLost, post events asynchronously to this queue. All other methods
* should be private.
*/
@@ -88,7 +88,8 @@ class DAGScheduler(
eventQueue.put(ExecutorGained(execId, host))
}
- // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
+ // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
+ // cancellation of the job itself.
override def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
}
@@ -104,13 +105,15 @@ class DAGScheduler(
private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
- val nextJobId = new AtomicInteger(0)
+ private[scheduler] val nextJobId = new AtomicInteger(0)
- val nextStageId = new AtomicInteger(0)
+ def numTotalJobs: Int = nextJobId.get()
- val stageIdToStage = new TimeStampedHashMap[Int, Stage]
+ private val nextStageId = new AtomicInteger(0)
- val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
+ private val stageIdToStage = new TimeStampedHashMap[Int, Stage]
+
+ private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
@@ -127,6 +130,7 @@ class DAGScheduler(
// stray messages to detect.
val failedEpoch = new HashMap[String, Long]
+ // stage id to the active job
val idToActiveJob = new HashMap[Int, ActiveJob]
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
@@ -261,32 +265,41 @@ class DAGScheduler(
}
/**
- * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
- * JobWaiter whose getResult() method will return the result of the job when it is complete.
- *
- * The job is assumed to have at least one partition; zero partition jobs should be handled
- * without a JobSubmitted event.
+ * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object
+ * can be used to block until the the job finishes executing or can be used to cancel the job.
*/
- private[scheduler] def prepareJob[T, U: ClassManifest](
- finalRdd: RDD[T],
+ def submitJob[T, U](
+ rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
resultHandler: (Int, U) => Unit,
- properties: Properties = null)
- : (JobSubmitted, JobWaiter[U]) =
+ properties: Properties = null): JobWaiter[U] =
{
+ val jobId = nextJobId.getAndIncrement()
+ if (partitions.size == 0) {
+ return new JobWaiter[U](this, jobId, 0, resultHandler)
+ }
+
+ // Check to make sure we are not launching a task on a partition that does not exist.
+ val maxPartitions = rdd.partitions.length
+ partitions.find(p => p >= maxPartitions).foreach { p =>
+ throw new IllegalArgumentException(
+ "Attempting to access a non-existent partition: " + p + ". " +
+ "Total number of partitions: " + maxPartitions)
+ }
+
assert(partitions.size > 0)
- val waiter = new JobWaiter(partitions.size, resultHandler)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter,
- properties)
- (toSubmit, waiter)
+ val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
+ eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite,
+ waiter, properties))
+ waiter
}
def runJob[T, U: ClassManifest](
- finalRdd: RDD[T],
+ rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
@@ -294,21 +307,7 @@ class DAGScheduler(
resultHandler: (Int, U) => Unit,
properties: Properties = null)
{
- if (partitions.size == 0) {
- return
- }
-
- // Check to make sure we are not launching a task on a partition that does not exist.
- val maxPartitions = finalRdd.partitions.length
- partitions.find(p => p >= maxPartitions).foreach { p =>
- throw new IllegalArgumentException(
- "Attempting to access a non-existent partition: " + p + ". " +
- "Total number of partitions: " + maxPartitions)
- }
-
- val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob(
- finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
- eventQueue.put(toSubmit)
+ val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
waiter.awaitResult() match {
case JobSucceeded => {}
case JobFailed(exception: Exception, _) =>
@@ -329,19 +328,35 @@ class DAGScheduler(
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
- eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, callSite, listener, properties))
+ val jobId = nextJobId.getAndIncrement()
+ eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite,
+ listener, properties))
listener.awaitResult() // Will throw an exception if the job fails
}
/**
+ * Cancel a job that is running or waiting in the queue.
+ */
+ def cancelJob(jobId: Int) {
+ logInfo("Asked to cancel job " + jobId)
+ eventQueue.put(JobCancelled(jobId))
+ }
+
+ /**
+ * Cancel all jobs that are running or waiting in the queue.
+ */
+ def cancelAllJobs() {
+ eventQueue.put(AllJobsCancelled)
+ }
+
+ /**
* Process one event retrieved from the event queue.
* Returns true if we should stop the event loop.
*/
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
- case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) =>
- val jobId = nextJobId.getAndIncrement()
- val finalStage = newStage(finalRDD, None, jobId, Some(callSite))
+ case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
+ val finalStage = newStage(rdd, None, jobId, Some(callSite))
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
@@ -360,6 +375,18 @@ class DAGScheduler(
submitStage(finalStage)
}
+ case JobCancelled(jobId) =>
+ // Cancel a job: find all the running stages that are linked to this job, and cancel them.
+ running.filter(_.jobId == jobId).foreach { stage =>
+ taskSched.cancelTasks(stage.id)
+ }
+
+ case AllJobsCancelled =>
+ // Cancel all running jobs.
+ running.foreach { stage =>
+ taskSched.cancelTasks(stage.id)
+ }
+
case ExecutorGained(execId, host) =>
handleExecutorGained(execId, host)
@@ -578,6 +605,11 @@ class DAGScheduler(
*/
private def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
+
+ if (!stageIdToStage.contains(task.stageId)) {
+ // Skip all the actions if the stage has been cancelled.
+ return
+ }
val stage = stageIdToStage(task.stageId)
def markStageAsFinished(stage: Stage) = {
@@ -626,7 +658,7 @@ class DAGScheduler(
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
} else {
- stage.addOutputLoc(smt.partition, status)
+ stage.addOutputLoc(smt.partitionId, status)
}
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
markStageAsFinished(stage)
@@ -752,14 +784,14 @@ class DAGScheduler(
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
- * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
+ * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
private def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
failedStage.completionTime = Some(System.currentTimeMillis())
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
- val error = new SparkException("Job failed: " + reason)
+ val error = new SparkException("Job aborted: " + reason)
job.listener.jobFailed(error)
listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage))))
idToActiveJob -= resultStage.jobId
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 10ff1b4376..ee89bfb38d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -31,9 +31,10 @@ import org.apache.spark.executor.TaskMetrics
* submitted) but there is a single "logic" thread that reads these events and takes decisions.
* This greatly simplifies synchronization.
*/
-private[spark] sealed trait DAGSchedulerEvent
+private[scheduler] sealed trait DAGSchedulerEvent
-private[spark] case class JobSubmitted(
+private[scheduler] case class JobSubmitted(
+ jobId: Int,
finalRDD: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
@@ -43,9 +44,14 @@ private[spark] case class JobSubmitted(
properties: Properties = null)
extends DAGSchedulerEvent
-private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent
-private[spark] case class CompletionEvent(
+private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
+
+private[scheduler]
+case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+
+private[scheduler] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
result: Any,
@@ -54,10 +60,12 @@ private[spark] case class CompletionEvent(
taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
-private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
+private[scheduler]
+case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
-private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
+private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
-private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
+private[scheduler]
+case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
-private[spark] case object StopDAGScheduler extends DAGSchedulerEvent
+private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
index 151514896f..7b5c0e29ad 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
@@ -40,7 +40,7 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: Spar
})
metricRegistry.register(MetricRegistry.name("job", "allJobs"), new Gauge[Int] {
- override def getValue: Int = dagScheduler.nextJobId.get()
+ override def getValue: Int = dagScheduler.numTotalJobs
})
metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
index 200d881799..58f238d8cf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
@@ -17,48 +17,58 @@
package org.apache.spark.scheduler
-import scala.collection.mutable.ArrayBuffer
-
/**
* An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
* results to the given handler function.
*/
-private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit)
+private[spark] class JobWaiter[T](
+ dagScheduler: DAGScheduler,
+ jobId: Int,
+ totalTasks: Int,
+ resultHandler: (Int, T) => Unit)
extends JobListener {
private var finishedTasks = 0
- private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
- private var jobResult: JobResult = null // If the job is finished, this will be its result
+ // Is the job as a whole finished (succeeded or failed)?
+ private var _jobFinished = totalTasks == 0
- override def taskSucceeded(index: Int, result: Any) {
- synchronized {
- if (jobFinished) {
- throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
- }
- resultHandler(index, result.asInstanceOf[T])
- finishedTasks += 1
- if (finishedTasks == totalTasks) {
- jobFinished = true
- jobResult = JobSucceeded
- this.notifyAll()
- }
- }
+ def jobFinished = _jobFinished
+
+ // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero
+ // partition RDDs), we set the jobResult directly to JobSucceeded.
+ private var jobResult: JobResult = if (jobFinished) JobSucceeded else null
+
+ /**
+ * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled
+ * asynchronously. After the low level scheduler cancels all the tasks belonging to this job, it
+ * will fail this job with a SparkException.
+ */
+ def cancel() {
+ dagScheduler.cancelJob(jobId)
}
- override def jobFailed(exception: Exception) {
- synchronized {
- if (jobFinished) {
- throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter")
- }
- jobFinished = true
- jobResult = JobFailed(exception, None)
+ override def taskSucceeded(index: Int, result: Any): Unit = synchronized {
+ if (_jobFinished) {
+ throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
+ }
+ resultHandler(index, result.asInstanceOf[T])
+ finishedTasks += 1
+ if (finishedTasks == totalTasks) {
+ _jobFinished = true
+ jobResult = JobSucceeded
this.notifyAll()
}
}
+ override def jobFailed(exception: Exception): Unit = synchronized {
+ _jobFinished = true
+ jobResult = JobFailed(exception, None)
+ this.notifyAll()
+ }
+
def awaitResult(): JobResult = synchronized {
- while (!jobFinished) {
+ while (!_jobFinished) {
this.wait()
}
return jobResult
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index 9eb8d48501..596f9adde9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -43,7 +43,10 @@ private[spark] class Pool(
var runningTasks = 0
var priority = 0
- var stageId = 0
+
+ // A pool's stage id is used to break the tie in scheduling.
+ var stageId = -1
+
var name = poolName
var parent: Pool = null
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 6dd422bbf6..310ec62ca8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -38,17 +38,17 @@ private[spark] object ResultTask {
synchronized {
val old = serializedInfoCache.get(stageId).orNull
if (old != null) {
- return old
+ old
} else {
val out = new ByteArrayOutputStream
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objOut = ser.serializeStream(new GZIPOutputStream(out))
objOut.writeObject(rdd)
objOut.writeObject(func)
objOut.close()
val bytes = out.toByteArray
serializedInfoCache.put(stageId, bytes)
- return bytes
+ bytes
}
}
}
@@ -56,11 +56,11 @@ private[spark] object ResultTask {
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
- return (rdd, func)
+ (rdd, func)
}
def clearCache() {
@@ -71,29 +71,37 @@ private[spark] object ResultTask {
}
+/**
+ * A task that sends back the output to the driver application.
+ *
+ * See [[org.apache.spark.scheduler.Task]] for more information.
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param rdd input to func
+ * @param func a function to apply on a partition of the RDD
+ * @param _partitionId index of the number in the RDD
+ * @param locs preferred task execution locations for locality scheduling
+ * @param outputId index of the task in this job (a job can launch tasks on only a subset of the
+ * input RDD's partitions).
+ */
private[spark] class ResultTask[T, U](
stageId: Int,
var rdd: RDD[T],
var func: (TaskContext, Iterator[T]) => U,
- var partition: Int,
+ _partitionId: Int,
@transient locs: Seq[TaskLocation],
var outputId: Int)
- extends Task[U](stageId) with Externalizable {
+ extends Task[U](stageId, _partitionId) with Externalizable {
def this() = this(0, null, null, 0, null, 0)
- var split = if (rdd == null) {
- null
- } else {
- rdd.partitions(partition)
- }
+ var split = if (rdd == null) null else rdd.partitions(partitionId)
@transient private val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
}
- override def run(attemptId: Long): U = {
- val context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
+ override def runTask(context: TaskContext): U = {
metrics = Some(context.taskMetrics)
try {
func(context, rdd.iterator(split, context))
@@ -104,17 +112,17 @@ private[spark] class ResultTask[T, U](
override def preferredLocations: Seq[TaskLocation] = preferredLocs
- override def toString = "ResultTask(" + stageId + ", " + partition + ")"
+ override def toString = "ResultTask(" + stageId + ", " + partitionId + ")"
override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
- split = rdd.partitions(partition)
+ split = rdd.partitions(partitionId)
out.writeInt(stageId)
val bytes = ResultTask.serializeInfo(
stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
out.writeInt(bytes.length)
out.write(bytes)
- out.writeInt(partition)
+ out.writeInt(partitionId)
out.writeInt(outputId)
out.writeLong(epoch)
out.writeObject(split)
@@ -129,7 +137,7 @@ private[spark] class ResultTask[T, U](
val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
rdd = rdd_.asInstanceOf[RDD[T]]
func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
- partition = in.readInt()
+ partitionId = in.readInt()
outputId = in.readInt()
epoch = in.readLong()
split = in.readObject().asInstanceOf[Partition]
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
index 4e25086ec9..356fe56bf3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
@@ -30,7 +30,10 @@ import scala.xml.XML
* addTaskSetManager: build the leaf nodes(TaskSetManagers)
*/
private[spark] trait SchedulableBuilder {
+ def rootPool: Pool
+
def buildPools()
+
def addTaskSetManager(manager: Schedulable, properties: Properties)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 3b9d5679fb..802791797a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -53,7 +53,7 @@ private[spark] object ShuffleMapTask {
objOut.close()
val bytes = out.toByteArray
serializedInfoCache.put(stageId, bytes)
- return bytes
+ bytes
}
}
}
@@ -66,7 +66,7 @@ private[spark] object ShuffleMapTask {
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
- return (rdd, dep)
+ (rdd, dep)
}
}
@@ -75,7 +75,7 @@ private[spark] object ShuffleMapTask {
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val objIn = new ObjectInputStream(in)
val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
- return (HashMap(set.toSeq: _*))
+ HashMap(set.toSeq: _*)
}
def clearCache() {
@@ -85,13 +85,25 @@ private[spark] object ShuffleMapTask {
}
}
+/**
+ * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
+ * specified in the ShuffleDependency).
+ *
+ * See [[org.apache.spark.scheduler.Task]] for more information.
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param rdd the final RDD in this stage
+ * @param dep the ShuffleDependency
+ * @param _partitionId index of the number in the RDD
+ * @param locs preferred task execution locations for locality scheduling
+ */
private[spark] class ShuffleMapTask(
stageId: Int,
var rdd: RDD[_],
var dep: ShuffleDependency[_,_],
- var partition: Int,
+ _partitionId: Int,
@transient private var locs: Seq[TaskLocation])
- extends Task[MapStatus](stageId)
+ extends Task[MapStatus](stageId, _partitionId)
with Externalizable
with Logging {
@@ -101,16 +113,16 @@ private[spark] class ShuffleMapTask(
if (locs == null) Nil else locs.toSet.toSeq
}
- var split = if (rdd == null) null else rdd.partitions(partition)
+ var split = if (rdd == null) null else rdd.partitions(partitionId)
override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
- split = rdd.partitions(partition)
+ split = rdd.partitions(partitionId)
out.writeInt(stageId)
val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
out.writeInt(bytes.length)
out.write(bytes)
- out.writeInt(partition)
+ out.writeInt(partitionId)
out.writeLong(epoch)
out.writeObject(split)
}
@@ -124,16 +136,14 @@ private[spark] class ShuffleMapTask(
val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
rdd = rdd_
dep = dep_
- partition = in.readInt()
+ partitionId = in.readInt()
epoch = in.readLong()
split = in.readObject().asInstanceOf[Partition]
}
- override def run(attemptId: Long): MapStatus = {
+ override def runTask(context: TaskContext): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
-
- val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false)
- metrics = Some(taskContext.taskMetrics)
+ metrics = Some(context.taskMetrics)
val blockManager = SparkEnv.get.blockManager
var shuffle: ShuffleBlocks = null
@@ -143,10 +153,10 @@ private[spark] class ShuffleMapTask(
// Obtain all the block writers for shuffle blocks.
val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
- buckets = shuffle.acquireWriters(partition)
+ buckets = shuffle.acquireWriters(partitionId)
// Write the map output to its associated buckets.
- for (elem <- rdd.iterator(split, taskContext)) {
+ for (elem <- rdd.iterator(split, context)) {
val pair = elem.asInstanceOf[Product2[Any, Any]]
val bucketId = dep.partitioner.getPartition(pair._1)
buckets.writers(bucketId).write(pair)
@@ -167,7 +177,7 @@ private[spark] class ShuffleMapTask(
shuffleMetrics.shuffleBytesWritten = totalBytes
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
- return new MapStatus(blockManager.blockManagerId, compressedSizes)
+ new MapStatus(blockManager.blockManagerId, compressedSizes)
} catch { case e: Exception =>
// If there is an exception from running the task, revert the partial writes
// and throw the exception upstream to Spark.
@@ -181,11 +191,11 @@ private[spark] class ShuffleMapTask(
shuffle.releaseWriters(buckets)
}
// Execute the callbacks on task completion.
- taskContext.executeOnCompleteCallbacks()
+ context.executeOnCompleteCallbacks()
}
}
override def preferredLocations: Seq[TaskLocation] = preferredLocs
- override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
+ override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 62b521ad45..466baf9913 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -54,7 +54,7 @@ trait SparkListener {
/**
* Called when a task starts
*/
- def onTaskStart(taskEnd: SparkListenerTaskStart) { }
+ def onTaskStart(taskStart: SparkListenerTaskStart) { }
/**
* Called when a task ends
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 598d91752a..1fe0d0e4e2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -17,25 +17,74 @@
package org.apache.spark.scheduler
-import org.apache.spark.serializer.SerializerInstance
import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-import org.apache.spark.util.ByteBufferInputStream
+
import scala.collection.mutable.HashMap
+
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+
+import org.apache.spark.TaskContext
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.util.ByteBufferInputStream
+
/**
- * A task to execute on a worker node.
+ * A unit of execution. We have two kinds of Task's in Spark:
+ * - [[org.apache.spark.scheduler.ShuffleMapTask]]
+ * - [[org.apache.spark.scheduler.ResultTask]]
+ *
+ * A Spark job consists of one or more stages. The very last stage in a job consists of multiple
+ * ResultTask's, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task
+ * and sends the task output back to the driver application. A ShuffleMapTask executes the task
+ * and divides the task output to multiple buckets (based on the task's partitioner).
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param partitionId index of the number in the RDD
*/
-private[spark] abstract class Task[T](val stageId: Int) extends Serializable {
- def run(attemptId: Long): T
+private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
+
+ def run(attemptId: Long): T = {
+ context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
+ if (_killed) {
+ kill()
+ }
+ runTask(context)
+ }
+
+ def runTask(context: TaskContext): T
+
def preferredLocations: Seq[TaskLocation] = Nil
- var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler.
+ // Map output tracker epoch. Will be set by TaskScheduler.
+ var epoch: Long = -1
var metrics: Option[TaskMetrics] = None
+ // Task context, to be initialized in run().
+ @transient protected var context: TaskContext = _
+
+ // A flag to indicate whether the task is killed. This is used in case context is not yet
+ // initialized when kill() is invoked.
+ @volatile @transient private var _killed = false
+
+ /**
+ * Whether the task has been killed.
+ */
+ def killed: Boolean = _killed
+
+ /**
+ * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark
+ * code and user code to properly handle the flag. This function should be idempotent so it can
+ * be called multiple times.
+ */
+ def kill() {
+ _killed = true
+ if (context != null) {
+ context.interrupted = true
+ }
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 7c2a9f03d7..6a51efe8d6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -45,6 +45,9 @@ private[spark] trait TaskScheduler {
// Submit a sequence of tasks to run.
def submitTasks(taskSet: TaskSet): Unit
+ // Cancel a stage.
+ def cancelTasks(stageId: Int)
+
// Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
def setListener(listener: TaskSchedulerListener): Unit
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
index c3ad325156..03bf760837 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
@@ -31,5 +31,9 @@ private[spark] class TaskSet(
val properties: Properties) {
val id: String = stageId + "." + attempt
+ def kill() {
+ tasks.foreach(_.kill())
+ }
+
override def toString: String = "TaskSet " + id
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
index 1a844b7e7e..7a72ff0474 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -17,7 +17,6 @@
package org.apache.spark.scheduler.cluster
-import java.lang.{Boolean => JBoolean}
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicLong
import java.util.{TimerTask, Timer}
@@ -79,12 +78,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
private val executorIdToHost = new HashMap[String, String]
- // JAR server, if any JARs were added by the user to the SparkContext
- var jarServer: HttpServer = null
-
- // URIs of JARs to pass to executor
- var jarUris: String = ""
-
// Listener object to pass upcalls into
var listener: TaskSchedulerListener = null
@@ -171,8 +164,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
backend.reviveOffers()
}
- def taskSetFinished(manager: TaskSetManager) {
- this.synchronized {
+ override def cancelTasks(stageId: Int): Unit = synchronized {
+ logInfo("Cancelling stage " + stageId)
+ activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
+ // There are two possible cases here:
+ // 1. The task set manager has been created and some tasks have been scheduled.
+ // In this case, send a kill signal to the executors to kill the task and then abort
+ // the stage.
+ // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+ // simply abort the stage.
+ val taskIds = taskSetTaskIds(tsm.taskSet.id)
+ if (taskIds.size > 0) {
+ taskIds.foreach { tid =>
+ val execId = taskIdToExecutorId(tid)
+ backend.killTask(tid, execId)
+ }
+ }
+ tsm.error("Stage %d was cancelled".format(stageId))
+ }
+ }
+
+ def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
+ // Check to see if the given task set has been removed. This is possible in the case of
+ // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has
+ // more than one running tasks).
+ if (activeTaskSets.contains(manager.taskSet.id)) {
activeTaskSets -= manager.taskSet.id
manager.parent.removeSchedulable(manager)
logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
@@ -334,9 +350,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (backend != null) {
backend.stop()
}
- if (jarServer != null) {
- jarServer.stop()
- }
if (taskResultGetter != null) {
taskResultGetter.stop()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 936167c13f..e1366e0627 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -17,18 +17,16 @@
package org.apache.spark.scheduler.cluster
-import java.nio.ByteBuffer
-import java.util.{Arrays, NoSuchElementException}
+import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.math.max
import scala.math.min
-import scala.Some
import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
- SparkException, Success, TaskEndReason, TaskResultLost, TaskState}
+ Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler._
import org.apache.spark.util.{SystemClock, Clock}
@@ -463,49 +461,52 @@ private[spark] class ClusterTaskSetManager(
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
reason.foreach {
- _ match {
- case fetchFailed: FetchFailed =>
- logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
- successful(index) = true
- tasksSuccessful += 1
- sched.taskSetFinished(this)
- removeAllRunningTasks()
- return
-
- case ef: ExceptionFailure =>
- sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
- val key = ef.description
- val now = clock.getTime()
- val (printFull, dupCount) = {
- if (recentExceptions.contains(key)) {
- val (dupCount, printTime) = recentExceptions(key)
- if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
- recentExceptions(key) = (0, now)
- (true, 0)
- } else {
- recentExceptions(key) = (dupCount + 1, printTime)
- (false, dupCount + 1)
- }
- } else {
+ case fetchFailed: FetchFailed =>
+ logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+ sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+ successful(index) = true
+ tasksSuccessful += 1
+ sched.taskSetFinished(this)
+ removeAllRunningTasks()
+ return
+
+ case TaskKilled =>
+ logInfo("Task %d was killed.".format(tid))
+ sched.listener.taskEnded(tasks(index), reason.get, null, null, info, null)
+ return
+
+ case ef: ExceptionFailure =>
+ sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
+ val key = ef.description
+ val now = clock.getTime()
+ val (printFull, dupCount) = {
+ if (recentExceptions.contains(key)) {
+ val (dupCount, printTime) = recentExceptions(key)
+ if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
recentExceptions(key) = (0, now)
(true, 0)
+ } else {
+ recentExceptions(key) = (dupCount + 1, printTime)
+ (false, dupCount + 1)
}
- }
- if (printFull) {
- val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
- logInfo("Loss was due to %s\n%s\n%s".format(
- ef.className, ef.description, locs.mkString("\n")))
} else {
- logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ recentExceptions(key) = (0, now)
+ (true, 0)
}
+ }
+ if (printFull) {
+ val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logInfo("Loss was due to %s\n%s\n%s".format(
+ ef.className, ef.description, locs.mkString("\n")))
+ } else {
+ logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ }
- case TaskResultLost =>
- logInfo("Lost result for TID %s on host %s".format(tid, info.host))
- sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+ case TaskResultLost =>
+ logInfo("Lost result for TID %s on host %s".format(tid, info.host))
+ sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
- case _ => {}
- }
+ case _ => {}
}
// On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
index d57eb3276f..5367218faa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
@@ -17,7 +17,7 @@
package org.apache.spark.scheduler.cluster
-import org.apache.spark.{SparkContext}
+import org.apache.spark.SparkContext
/**
* A backend interface for cluster scheduling systems that allows plugging in different ones under
@@ -30,8 +30,8 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int
+ def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException
+
// Memory used by each executor (in megabytes)
protected val executorMemory: Int = SparkContext.executorMemoryRequested
-
- // TODO: Probably want to add a killTask too
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala
index c0b836bf1a..12b2fd01c0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala
@@ -31,6 +31,8 @@ private[spark] object StandaloneClusterMessages {
// Driver to executors
case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage
+ case class KillTask(taskId: Long, executor: String) extends StandaloneClusterMessage
+
case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
extends StandaloneClusterMessage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
index f3aeea43d5..08ee2182a2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -91,6 +91,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
case ReviveOffers =>
makeOffers()
+ case KillTask(taskId, executorId) =>
+ executorActor(executorId) ! KillTask(taskId, executorId)
+
case StopDriver =>
sender ! true
context.stop(self)
@@ -180,6 +183,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
driverActor ! ReviveOffers
}
+ override def killTask(taskId: Long, executorId: String) {
+ driverActor ! KillTask(taskId, executorId)
+ }
+
override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism"))
.map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
index 4d1bb1c639..b445260d1b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
@@ -17,23 +17,19 @@
package org.apache.spark.scheduler.local
-import java.io.File
-import java.lang.management.ManagementFactory
-import java.util.concurrent.atomic.AtomicInteger
import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicInteger
-import scala.collection.JavaConversions._
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+
+import akka.actor._
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.executor.ExecutorURLClassLoader
+import org.apache.spark.executor.{Executor, ExecutorBackend}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
-import akka.actor._
-import org.apache.spark.util.Utils
+
/**
* A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
@@ -41,43 +37,50 @@ import org.apache.spark.util.Utils
* testing fault recovery.
*/
-private[spark]
+private[local]
case class LocalReviveOffers()
-private[spark]
+private[local]
case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
+private[local]
+case class KillTask(taskId: Long)
+
private[spark]
-class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
+class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int)
+ extends Actor with Logging {
+
+ val executor = new Executor("localhost", "localhost", Seq.empty, isLocal = true)
def receive = {
case LocalReviveOffers =>
launchTask(localScheduler.resourceOffer(freeCores))
+
case LocalStatusUpdate(taskId, state, serializeData) =>
- freeCores += 1
- localScheduler.statusUpdate(taskId, state, serializeData)
- launchTask(localScheduler.resourceOffer(freeCores))
+ if (TaskState.isFinished(state)) {
+ freeCores += 1
+ launchTask(localScheduler.resourceOffer(freeCores))
+ }
+
+ case KillTask(taskId) =>
+ executor.killTask(taskId)
}
- def launchTask(tasks : Seq[TaskDescription]) {
+ private def launchTask(tasks: Seq[TaskDescription]) {
for (task <- tasks) {
freeCores -= 1
- localScheduler.threadPool.submit(new Runnable {
- def run() {
- localScheduler.runTask(task.taskId, task.serializedTask)
- }
- })
+ executor.launchTask(localScheduler, task.taskId, task.serializedTask)
}
}
}
private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
extends TaskScheduler
+ with ExecutorBackend
with Logging {
- var attemptId = new AtomicInteger(0)
- var threadPool = Utils.newDaemonFixedThreadPool(threads)
val env = SparkEnv.get
+ val attemptId = new AtomicInteger
var listener: TaskSchedulerListener = null
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
@@ -85,8 +88,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
- val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)
-
var schedulableBuilder: SchedulableBuilder = null
var rootPool: Pool = null
val schedulingMode: SchedulingMode = SchedulingMode.withName(
@@ -127,6 +128,26 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
}
+ override def cancelTasks(stageId: Int): Unit = synchronized {
+ logInfo("Cancelling stage " + stageId)
+ logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId))
+ activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
+ // There are two possible cases here:
+ // 1. The task set manager has been created and some tasks have been scheduled.
+ // In this case, send a kill signal to the executors to kill the task and then abort
+ // the stage.
+ // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+ // simply abort the stage.
+ val taskIds = taskSetTaskIds(tsm.taskSet.id)
+ if (taskIds.size > 0) {
+ taskIds.foreach { tid =>
+ localActor ! KillTask(tid)
+ }
+ }
+ tsm.error("Stage %d was cancelled".format(stageId))
+ }
+ }
+
def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
synchronized {
var freeCpuCores = freeCores
@@ -166,107 +187,32 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
}
- def runTask(taskId: Long, bytes: ByteBuffer) {
- logInfo("Running " + taskId)
- val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
- // Set the Spark execution environment for the worker thread
- SparkEnv.set(env)
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val objectSer = SparkEnv.get.serializer.newInstance()
- var attemptedTask: Option[Task[_]] = None
- val start = System.currentTimeMillis()
- var taskStart: Long = 0
- def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum
- val startGCTime = getTotalGCTime
-
- try {
- Accumulators.clear()
- Thread.currentThread().setContextClassLoader(classLoader)
-
- // Serialize and deserialize the task so that accumulators are changed to thread-local ones;
- // this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
- val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
- updateDependencies(taskFiles, taskJars) // Download any files added with addFile
- val deserializedTask = ser.deserialize[Task[_]](
- taskBytes, Thread.currentThread.getContextClassLoader)
- attemptedTask = Some(deserializedTask)
- val deserTime = System.currentTimeMillis() - start
- taskStart = System.currentTimeMillis()
-
- // Run it
- val result: Any = deserializedTask.run(taskId)
-
- // Serialize and deserialize the result to emulate what the Mesos
- // executor does. This is useful to catch serialization errors early
- // on in development (so when users move their local Spark programs
- // to the cluster, they don't get surprised by serialization errors).
- val serResult = objectSer.serialize(result)
- deserializedTask.metrics.get.resultSize = serResult.limit()
- val resultToReturn = objectSer.deserialize[Any](serResult)
- val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
- ser.serialize(Accumulators.values))
- val serviceTime = System.currentTimeMillis() - taskStart
- logInfo("Finished " + taskId)
- deserializedTask.metrics.get.executorRunTime = serviceTime.toInt
- deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime
- deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
- val taskResult = new DirectTaskResult(
- result, accumUpdates, deserializedTask.metrics.getOrElse(null))
- val serializedResult = ser.serialize(taskResult)
- localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
- } catch {
- case t: Throwable => {
- val serviceTime = System.currentTimeMillis() - taskStart
- val metrics = attemptedTask.flatMap(t => t.metrics)
- for (m <- metrics) {
- m.executorRunTime = serviceTime.toInt
- m.jvmGCTime = getTotalGCTime - startGCTime
- }
- val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
- localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
- }
- }
- }
-
- /**
- * Download any missing dependencies if we receive a new set of files and JARs from the
- * SparkContext. Also adds any new JARs we fetched to the class loader.
- */
- private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
- synchronized {
- // Fetch missing dependencies
- for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
- currentFiles(name) = timestamp
- }
-
- for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
- currentJars(name) = timestamp
- // Add it to our class loader
- val localName = name.split("/").last
- val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
- if (!classLoader.getURLs.contains(url)) {
- logInfo("Adding " + url + " to class loader")
- classLoader.addURL(url)
+ override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
+ if (TaskState.isFinished(state)) {
+ synchronized {
+ taskIdToTaskSetId.get(taskId) match {
+ case Some(taskSetId) =>
+ val taskSetManager = activeTaskSets(taskSetId)
+ taskSetTaskIds(taskSetId) -= taskId
+
+ state match {
+ case TaskState.FINISHED =>
+ taskSetManager.taskEnded(taskId, state, serializedData)
+ case TaskState.FAILED =>
+ taskSetManager.taskFailed(taskId, state, serializedData)
+ case TaskState.KILLED =>
+ taskSetManager.error("Task %d was killed".format(taskId))
+ case _ => {}
+ }
+ case None =>
+ logInfo("Ignoring update from TID " + taskId + " because its task set is gone")
}
}
+ localActor ! LocalStatusUpdate(taskId, state, serializedData)
}
}
- def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
- synchronized {
- val taskSetId = taskIdToTaskSetId(taskId)
- val taskSetManager = activeTaskSets(taskSetId)
- taskSetTaskIds(taskSetId) -= taskId
- taskSetManager.statusUpdate(taskId, state, serializedData)
- }
- }
-
- override def stop() {
- threadPool.shutdownNow()
+ override def stop() {
}
override def defaultParallelism() = threads
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
index c2e2399ccb..f72e77d40f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
@@ -132,17 +132,6 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
return None
}
- def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- SparkEnv.set(env)
- state match {
- case TaskState.FINISHED =>
- taskEnded(tid, state, serializedData)
- case TaskState.FAILED =>
- taskFailed(tid, state, serializedData)
- case _ => {}
- }
- }
-
def taskStarted(task: Task[_], info: TaskInfo) {
sched.listener.taskStarted(task, info)
}
@@ -195,5 +184,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
}
override def error(message: String) {
+ sched.listener.taskSetFailed(taskSet, message)
+ sched.taskSetFinished(this)
}
}
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index ced036c58d..ea936e815b 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -58,7 +58,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
whenExecuting(blockManager) {
- val context = new TaskContext(0, 0, 0, runningLocally = false, null)
+ val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
+ taskMetrics = null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
@@ -70,7 +71,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
whenExecuting(blockManager) {
- val context = new TaskContext(0, 0, 0, runningLocally = false, null)
+ val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
+ taskMetrics = null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
@@ -83,7 +85,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
whenExecuting(blockManager) {
- val context = new TaskContext(0, 0, 0, runningLocally = true, null)
+ val context = new TaskContext(0, 0, 0, runningLocally = true, interrupted = false,
+ taskMetrics = null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index 7ca5f16202..f26c44d3e7 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -62,8 +62,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testCheckpointing(_.sample(false, 0.5, 0))
testCheckpointing(_.glom())
testCheckpointing(_.mapPartitions(_.map(_.toString)))
- testCheckpointing(r => new MapPartitionsWithIndexRDD(r,
- (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false ))
+ testCheckpointing(r => new MapPartitionsWithContextRDD(r,
+ (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false ))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
testCheckpointing(_.pipe(Seq("cat")))
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 591c1d498d..7b0bb89ab2 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContext(0, 0, 0, false, null);
+ TaskContext context = new TaskContext(0, 0, 0, false, false, null);
Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
}
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
new file mode 100644
index 0000000000..a192651491
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.util.concurrent.Semaphore
+
+import scala.concurrent.future
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.matchers.ShouldMatchers
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.scheduler.{SparkListenerTaskStart, SparkListener}
+
+
+/**
+ * Test suite for cancelling running jobs. We run the cancellation tasks for single job action
+ * (e.g. count) as well as multi-job action (e.g. take). We test the local and cluster schedulers
+ * in both FIFO and fair scheduling modes.
+ */
+class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
+ with LocalSparkContext {
+
+ override def afterEach() {
+ super.afterEach()
+ resetSparkContext()
+ System.clearProperty("spark.scheduler.mode")
+ }
+
+ test("local mode, FIFO scheduler") {
+ System.setProperty("spark.scheduler.mode", "FIFO")
+ sc = new SparkContext("local[2]", "test")
+ testCount()
+ testTake()
+ // Make sure we can still launch tasks.
+ assert(sc.parallelize(1 to 10, 2).count === 10)
+ }
+
+ test("local mode, fair scheduler") {
+ System.setProperty("spark.scheduler.mode", "FAIR")
+ val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
+ System.setProperty("spark.scheduler.allocation.file", xmlPath)
+ sc = new SparkContext("local[2]", "test")
+ testCount()
+ testTake()
+ // Make sure we can still launch tasks.
+ assert(sc.parallelize(1 to 10, 2).count === 10)
+ }
+
+ test("cluster mode, FIFO scheduler") {
+ System.setProperty("spark.scheduler.mode", "FIFO")
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ testCount()
+ testTake()
+ // Make sure we can still launch tasks.
+ assert(sc.parallelize(1 to 10, 2).count === 10)
+ }
+
+ test("cluster mode, fair scheduler") {
+ System.setProperty("spark.scheduler.mode", "FAIR")
+ val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
+ System.setProperty("spark.scheduler.allocation.file", xmlPath)
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ testCount()
+ testTake()
+ // Make sure we can still launch tasks.
+ assert(sc.parallelize(1 to 10, 2).count === 10)
+ }
+
+ test("two jobs sharing the same stage") {
+ // sem1: make sure cancel is issued after some tasks are launched
+ // sem2: make sure the first stage is not finished until cancel is issued
+ val sem1 = new Semaphore(0)
+ val sem2 = new Semaphore(0)
+
+ sc = new SparkContext("local[2]", "test")
+ sc.dagScheduler.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ sem1.release()
+ }
+ })
+
+ // Create two actions that would share the some stages.
+ val rdd = sc.parallelize(1 to 10, 2).map { i =>
+ sem2.acquire()
+ (i, i)
+ }.reduceByKey(_+_)
+ val f1 = rdd.collectAsync()
+ val f2 = rdd.countAsync()
+
+ // Kill one of the action.
+ future {
+ sem1.acquire()
+ f1.cancel()
+ sem2.release(10)
+ }
+
+ // Expect both to fail now.
+ // TODO: update this test when we change Spark so cancelling f1 wouldn't affect f2.
+ intercept[SparkException] { f1.get() }
+ intercept[SparkException] { f2.get() }
+ }
+
+ def testCount() {
+ // Cancel before launching any tasks
+ {
+ val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync()
+ future { f.cancel() }
+ val e = intercept[SparkException] { f.get() }
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+ }
+
+ // Cancel after some tasks have been launched
+ {
+ // Add a listener to release the semaphore once any tasks are launched.
+ val sem = new Semaphore(0)
+ sc.dagScheduler.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ sem.release()
+ }
+ })
+
+ val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync()
+ future {
+ // Wait until some tasks were launched before we cancel the job.
+ sem.acquire()
+ f.cancel()
+ }
+ val e = intercept[SparkException] { f.get() }
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+ }
+ }
+
+ def testTake() {
+ // Cancel before launching any tasks
+ {
+ val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000)
+ future { f.cancel() }
+ val e = intercept[SparkException] { f.get() }
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+ }
+
+ // Cancel after some tasks have been launched
+ {
+ // Add a listener to release the semaphore once any tasks are launched.
+ val sem = new Semaphore(0)
+ sc.dagScheduler.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ sem.release()
+ }
+ })
+ val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000)
+ future {
+ sem.acquire()
+ f.cancel()
+ }
+ val e = intercept[SparkException] { f.get() }
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
new file mode 100644
index 0000000000..da032b17d9
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.util.concurrent.Semaphore
+
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.{SparkContext, SparkException, LocalSparkContext}
+
+
+class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with Timeouts {
+
+ @transient private var sc: SparkContext = _
+
+ override def beforeAll() {
+ sc = new SparkContext("local[2]", "test")
+ }
+
+ override def afterAll() {
+ LocalSparkContext.stop(sc)
+ sc = null
+ }
+
+ lazy val zeroPartRdd = new EmptyRDD[Int](sc)
+
+ test("countAsync") {
+ assert(zeroPartRdd.countAsync().get() === 0)
+ assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000)
+ }
+
+ test("collectAsync") {
+ assert(zeroPartRdd.collectAsync().get() === Seq.empty)
+
+ val collected = sc.parallelize(1 to 1000, 3).collectAsync().get()
+ assert(collected === (1 to 1000))
+ }
+
+ test("foreachAsync") {
+ zeroPartRdd.foreachAsync(i => Unit).get()
+
+ val accum = sc.accumulator(0)
+ sc.parallelize(1 to 1000, 3).foreachAsync { i =>
+ accum += 1
+ }.get()
+ assert(accum.value === 1000)
+ }
+
+ test("foreachPartitionAsync") {
+ zeroPartRdd.foreachPartitionAsync(iter => Unit).get()
+
+ val accum = sc.accumulator(0)
+ sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter =>
+ accum += 1
+ }.get()
+ assert(accum.value === 9)
+ }
+
+ test("takeAsync") {
+ def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) {
+ val expected = input.take(num)
+ val saw = rdd.takeAsync(num).get()
+ assert(saw == expected, "incorrect result for rdd with %d partitions (expected %s, saw %s)"
+ .format(rdd.partitions.size, expected, saw))
+ }
+ val input = Range(1, 1000)
+
+ var rdd = sc.parallelize(input, 1)
+ for (num <- Seq(0, 1, 999, 1000)) {
+ testTake(rdd, input, num)
+ }
+
+ rdd = sc.parallelize(input, 2)
+ for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
+ testTake(rdd, input, num)
+ }
+
+ rdd = sc.parallelize(input, 100)
+ for (num <- Seq(0, 1, 500, 501, 999, 1000)) {
+ testTake(rdd, input, num)
+ }
+
+ rdd = sc.parallelize(input, 1000)
+ for (num <- Seq(0, 1, 3, 999, 1000)) {
+ testTake(rdd, input, num)
+ }
+ }
+
+ /**
+ * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
+ * of a successful job execution.
+ */
+ test("async success handling") {
+ val f = sc.parallelize(1 to 10, 2).countAsync()
+
+ // Use a semaphore to make sure onSuccess and onComplete's success path will be called.
+ // If not, the test will hang.
+ val sem = new Semaphore(0)
+
+ f.onComplete {
+ case scala.util.Success(res) =>
+ sem.release()
+ case scala.util.Failure(e) =>
+ info("Should not have reached this code path (onComplete matching Failure)")
+ throw new Exception("Task should succeed")
+ }
+ f.onSuccess { case a: Any =>
+ sem.release()
+ }
+ f.onFailure { case t =>
+ info("Should not have reached this code path (onFailure)")
+ throw new Exception("Task should succeed")
+ }
+ assert(f.get() === 10)
+
+ failAfter(10 seconds) {
+ sem.acquire(2)
+ }
+ }
+
+ /**
+ * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
+ * of a failed job execution.
+ */
+ test("async failure handling") {
+ val f = sc.parallelize(1 to 10, 2).map { i =>
+ throw new Exception("intentional"); i
+ }.countAsync()
+
+ // Use a semaphore to make sure onFailure and onComplete's failure path will be called.
+ // If not, the test will hang.
+ val sem = new Semaphore(0)
+
+ f.onComplete {
+ case scala.util.Success(res) =>
+ info("Should not have reached this code path (onComplete matching Success)")
+ throw new Exception("Task should fail")
+ case scala.util.Failure(e) =>
+ sem.release()
+ }
+ f.onSuccess { case a: Any =>
+ info("Should not have reached this code path (onSuccess)")
+ throw new Exception("Task should fail")
+ }
+ f.onFailure { case t =>
+ sem.release()
+ }
+ intercept[SparkException] {
+ f.get()
+ }
+
+ failAfter(10 seconds) {
+ sem.acquire(2)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 31f97fc139..57d3382ed0 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -106,7 +106,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
}
}
visit(sums)
- assert(deps.size === 2) // ShuffledRDD, ParallelCollection
+ assert(deps.size === 2) // ShuffledRDD, ParallelCollection.
}
test("join") {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 3952ee9264..838179c6b5 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -24,15 +24,14 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.LocalSparkContext
import org.apache.spark.MapOutputTracker
-import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.Partition
import org.apache.spark.TaskContext
import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency}
import org.apache.spark.{FetchFailed, Success, TaskEndReason}
-import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
-
+import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
/**
* Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
@@ -60,6 +59,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
taskSets += taskSet
}
+ override def cancelTasks(stageId: Int) {}
override def setListener(listener: TaskSchedulerListener) = {}
override def defaultParallelism() = 2
}
@@ -181,7 +181,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
func: (TaskContext, Iterator[_]) => _ = jobComputeFunc,
allowLocal: Boolean = false,
listener: JobListener = listener) {
- runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener))
+ val jobId = scheduler.nextJobId.getAndIncrement()
+ runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, null, listener))
}
/** Sends TaskSetFailed to the scheduler. */
@@ -215,7 +216,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
override def getPreferredLocations(split: Partition) = Nil
override def toString = "DAGSchedulerSuite Local RDD"
}
- runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener))
+ val jobId = scheduler.nextJobId.getAndIncrement()
+ runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener))
assert(results === Map(0 -> 42))
}
@@ -242,7 +244,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
test("trivial job failure") {
submit(makeRdd(1, Nil), Array(0))
failed(taskSets(0), "some failure")
- assert(failure.getMessage === "Job failed: some failure")
+ assert(failure.getMessage === "Job aborted: some failure")
}
test("run trivial shuffle") {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
index 2f12aaed18..0f01515179 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
@@ -17,10 +17,11 @@
package org.apache.spark.scheduler.cluster
+import org.apache.spark.TaskContext
import org.apache.spark.scheduler.{TaskLocation, Task}
-class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) {
- override def run(attemptId: Long): Int = 0
+class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
+ override def runTask(context: TaskContext): Int = 0
override def preferredLocations: Seq[TaskLocation] = prefLocs
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
index af76c843e8..1e676c1719 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
@@ -17,17 +17,15 @@
package org.apache.spark.scheduler.local
-import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
-
-import org.apache.spark._
-import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster._
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.{ConcurrentMap, HashMap}
import java.util.concurrent.Semaphore
import java.util.concurrent.CountDownLatch
-import java.util.Properties
+
+import scala.collection.mutable.HashMap
+
+import org.scalatest.{BeforeAndAfterEach, FunSuite}
+
+import org.apache.spark._
+
class Lock() {
var finished = false
@@ -63,7 +61,12 @@ object TaskThreadInfo {
* 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue,
* thus it will be scheduled later when cluster has free cpu cores.
*/
-class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
+class LocalSchedulerSuite extends FunSuite with LocalSparkContext with BeforeAndAfterEach {
+
+ override def afterEach() {
+ super.afterEach()
+ System.clearProperty("spark.scheduler.mode")
+ }
def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
@@ -148,12 +151,13 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
}
test("Local fair scheduler end-to-end test") {
- sc = new SparkContext("local[8]", "LocalSchedulerSuite")
- val sem = new Semaphore(0)
System.setProperty("spark.scheduler.mode", "FAIR")
val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
System.setProperty("spark.scheduler.allocation.file", xmlPath)
+ sc = new SparkContext("local[8]", "LocalSchedulerSuite")
+ val sem = new Semaphore(0)
+
createThread(10,"1",sc,sem)
TaskThreadInfo.threadToStarted(10).await()
createThread(20,"2",sc,sem)