aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2013-10-12 15:53:31 -0700
committerReynold Xin <rxin@apache.org>2013-10-12 15:53:31 -0700
commit6b288b75d4c05f42ad3612813dc77ff824bb6203 (patch)
tree0f356878c3b729689b81d55c8479e1e7d7e1eca2
parentab0940f0c258085bbf930d43be0b9034aad039cf (diff)
downloadspark-6b288b75d4c05f42ad3612813dc77ff824bb6203.tar.gz
spark-6b288b75d4c05f42ad3612813dc77ff824bb6203.tar.bz2
spark-6b288b75d4c05f42ad3612813dc77ff824bb6203.zip
Job cancellation: address Matei's code review feedback.
-rw-r--r--core/src/main/scala/org/apache/spark/FutureAction.scala78
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala36
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala (renamed from core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala)12
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala98
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala16
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/JobCancellationSuite.scala39
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala47
17 files changed, 248 insertions, 216 deletions
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 85018cb046..1ad9240cfa 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -31,6 +31,8 @@ import org.apache.spark.rdd.RDD
* support cancellation.
*/
trait FutureAction[T] extends Future[T] {
+ // Note that we redefine methods of the Future trait here explicitly so we can specify a different
+ // documentation (with reference to the word "action").
/**
* Cancels the execution of this action.
@@ -87,14 +89,14 @@ trait FutureAction[T] extends Future[T] {
* The future holding the result of an action that triggers a single job. Examples include
* count, collect, reduce.
*/
-class FutureJob[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
+class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
extends FutureAction[T] {
override def cancel() {
jobWaiter.cancel()
}
- override def ready(atMost: Duration)(implicit permit: CanAwait): FutureJob.this.type = {
+ override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
if (!atMost.isFinite()) {
awaitResult()
} else {
@@ -149,19 +151,20 @@ class FutureJob[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
/**
* A FutureAction for actions that could trigger multiple Spark jobs. Examples include take,
- * takeSample.
- *
- * This is implemented as a Scala Promise that can be cancelled. Note that the promise itself is
- * also its own Future (i.e. this.future returns this). See the implementation of takeAsync for
- * usage.
+ * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the
+ * action thread if it is being blocked by a job.
*/
-class CancellablePromise[T] extends FutureAction[T] with Promise[T] {
- // Cancellation works by setting the cancelled flag to true and interrupt the action thread
- // if it is in progress. Before executing the action, the execution thread needs to check the
- // cancelled flag in case cancel() is called before the thread even starts to execute. Because
- // this and the execution thread is synchronized on the same promise object (this), the actual
- // cancellation/interrupt event can only be triggered when the execution thread is waiting for
- // the result of a job.
+class ComplexFutureAction[T] extends FutureAction[T] {
+
+ // Pointer to the thread that is executing the action. It is set when the action is run.
+ @volatile private var thread: Thread = _
+
+ // A flag indicating whether the future has been cancelled. This is used in case the future
+ // is cancelled before the action was even run (and thus we have no thread to interrupt).
+ @volatile private var _cancelled: Boolean = false
+
+ // A promise used to signal the future.
+ private val p = promise[T]()
override def cancel(): Unit = this.synchronized {
_cancelled = true
@@ -174,15 +177,18 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] {
* Executes some action enclosed in the closure. To properly enable cancellation, the closure
* should use runJob implementation in this promise. See takeAsync for example.
*/
- def run(func: => T)(implicit executor: ExecutionContext): Unit = scala.concurrent.future {
- thread = Thread.currentThread
- try {
- this.success(func)
- } catch {
- case e: Exception => this.failure(e)
- } finally {
- thread = null
+ def run(func: => T)(implicit executor: ExecutionContext): this.type = {
+ scala.concurrent.future {
+ thread = Thread.currentThread
+ try {
+ p.success(func)
+ } catch {
+ case e: Exception => p.failure(e)
+ } finally {
+ thread = null
+ }
}
+ this
}
/**
@@ -193,15 +199,15 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] {
rdd: RDD[T],
processPartition: Iterator[T] => U,
partitions: Seq[Int],
- partitionResultHandler: (Int, U) => Unit,
+ resultHandler: (Int, U) => Unit,
resultFunc: => R) {
// If the action hasn't been cancelled yet, submit the job. The check and the submitJob
// command need to be in an atomic block.
val job = this.synchronized {
if (!cancelled) {
- rdd.context.submitJob(rdd, processPartition, partitions, partitionResultHandler, resultFunc)
+ rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc)
} else {
- throw new SparkException("action has been cancelled")
+ throw new SparkException("Action has been cancelled")
}
}
@@ -213,7 +219,7 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] {
} catch {
case e: InterruptedException =>
job.cancel()
- throw new SparkException("action has been cancelled")
+ throw new SparkException("Action has been cancelled")
}
}
@@ -222,28 +228,14 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] {
*/
def cancelled: Boolean = _cancelled
- // Pointer to the thread that is executing the action. It is set when the action is run.
- @volatile private var thread: Thread = _
-
- // A flag indicating whether the future has been cancelled. This is used in case the future
- // is cancelled before the action was even run (and thus we have no thread to interrupt).
- @volatile private var _cancelled: Boolean = false
-
- // Internally, we delegate most functionality to this promise.
- private val p = promise[T]()
-
- override def future: this.type = this
-
- override def tryComplete(result: Try[T]): Boolean = p.tryComplete(result)
-
- @scala.throws(classOf[InterruptedException])
- @scala.throws(classOf[scala.concurrent.TimeoutException])
+ @throws(classOf[InterruptedException])
+ @throws(classOf[scala.concurrent.TimeoutException])
override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = {
p.future.ready(atMost)(permit)
this
}
- @scala.throws(classOf[Exception])
+ @throws(classOf[Exception])
override def result(atMost: Duration)(implicit permit: CanAwait): T = {
p.future.result(atMost)(permit)
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 52fc4dd869..96a2f1fed3 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -760,10 +760,11 @@ class SparkContext(
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
val callSite = Utils.formatSparkCallSite
+ val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler,
- localProperties.get)
+ val result = dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
+ resultHandler, localProperties.get)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
@@ -853,16 +854,14 @@ class SparkContext(
}
/**
- * Submit a job for execution and return a FutureJob holding the result. Note that the
- * processPartition closure will be "cleaned" so the caller doesn't have to clean the closure
- * explicitly.
+ * Submit a job for execution and return a FutureJob holding the result.
*/
def submitJob[T, U, R](
rdd: RDD[T],
processPartition: Iterator[T] => U,
partitions: Seq[Int],
- partitionResultHandler: (Int, U) => Unit,
- resultFunc: => R): FutureJob[R] =
+ resultHandler: (Int, U) => Unit,
+ resultFunc: => R): SimpleFutureAction[R] =
{
val cleanF = clean(processPartition)
val callSite = Utils.formatSparkCallSite
@@ -872,9 +871,9 @@ class SparkContext(
partitions,
callSite,
allowLocal = false,
- partitionResultHandler,
+ resultHandler,
null)
- new FutureJob(waiter, resultFunc)
+ new SimpleFutureAction(waiter, resultFunc)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 86370d5d91..51584d686d 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -22,14 +22,17 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.executor.TaskMetrics
class TaskContext(
- val stageId: Int,
- val splitId: Int,
+ private[spark] val stageId: Int,
+ val partitionId: Int,
val attemptId: Long,
val runningLocally: Boolean = false,
@volatile var interrupted: Boolean = false,
- val taskMetrics: TaskMetrics = TaskMetrics.empty()
+ private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty()
) extends Serializable {
+ @deprecated("use partitionId", "0.8.1")
+ def splitId = partitionId
+
// List of callback functions to execute when the task completes.
@transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index da62091980..b56d8c9912 100644
--- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -18,14 +18,18 @@
package org.apache.spark.executor
import java.nio.ByteBuffer
-import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver}
-import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _}
-import org.apache.spark.TaskState.TaskState
+
import com.google.protobuf.ByteString
-import org.apache.spark.{Logging}
+
+import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver}
+import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
+
+import org.apache.spark.Logging
import org.apache.spark.TaskState
+import org.apache.spark.TaskState.TaskState
import org.apache.spark.util.Utils
+
private[spark] class MesosExecutorBackend
extends MesosExecutor
with ExecutorBackend
@@ -71,7 +75,11 @@ private[spark] class MesosExecutorBackend
}
override def killTask(d: ExecutorDriver, t: TaskID) {
- logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)")
+ if (executor == null) {
+ logError("Received KillTask but executor was null")
+ } else {
+ executor.killTask(t.getValue.toLong)
+ }
}
override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {}
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index 1f24ee8cd3..faaf837be0 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -22,7 +22,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext.Implicits.global
-import org.apache.spark.{Logging, CancellablePromise, FutureAction}
+import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
/**
* A set of asynchronous RDD actions available through an implicit conversion.
@@ -63,9 +63,9 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with
* Returns a future for retrieving the first num elements of the RDD.
*/
def takeAsync(num: Int): FutureAction[Seq[T]] = {
- val promise = new CancellablePromise[Seq[T]]
+ val f = new ComplexFutureAction[Seq[T]]
- promise.run {
+ f.run {
val results = new ArrayBuffer[T](num)
val totalParts = self.partitions.length
var partsScanned = 0
@@ -89,7 +89,7 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val buf = new Array[Array[T]](p.size)
- promise.runJob(self,
+ f.runJob(self,
(it: Iterator[T]) => it.take(left).toArray,
p,
(index: Int, data: Array[T]) => buf(index) = data,
@@ -101,7 +101,7 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with
results.toSeq
}
- promise.future
+ f
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 3311757189..ccaaecb85b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -85,7 +85,7 @@ private[spark] object CheckpointRDD extends Logging {
val outputDir = new Path(path)
val fs = outputDir.getFileSystem(env.hadoop.newConfiguration())
- val finalOutputName = splitIdToFile(ctx.splitId)
+ val finalOutputName = splitIdToFile(ctx.partitionId)
val finalOutputPath = new Path(outputDir, finalOutputName)
val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
diff --git a/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala b/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala
deleted file mode 100644
index e731deb2e5..0000000000
--- a/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import org.apache.spark.{InterruptibleIterator, Partition, TaskContext}
-
-
-/**
- * Wraps around an existing RDD to make it interruptible (can be killed).
- */
-private[spark]
-class InterruptibleRDD[T: ClassManifest](prev: RDD[T]) extends RDD[T](prev) {
-
- override def getPartitions: Array[Partition] = firstParent[T].partitions
-
- override val partitioner = prev.partitioner
-
- override def compute(split: Partition, context: TaskContext) = {
- new InterruptibleIterator(context, firstParent[T].iterator(split, context))
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
index 3ed8339010..aea08ff81b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
@@ -21,14 +21,14 @@ import org.apache.spark.{Partition, TaskContext}
/**
- * A variant of the MapPartitionsRDD that passes the partition index into the
- * closure. This can be used to generate or collect partition specific
- * information such as the number of tuples in a partition.
+ * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the
+ * TaskContext, the closure can either get access to the interruptible flag or get the index
+ * of the partition in the RDD.
*/
private[spark]
-class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest](
+class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
- f: (Int, Iterator[T]) => Iterator[U],
+ f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean
) extends RDD[U](prev) {
@@ -37,5 +37,5 @@ class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest](
override val partitioner = if (preservesPartitioning) prev.partitioner else None
override def compute(split: Partition, context: TaskContext) =
- f(split.index, firstParent[T].iterator(split, context))
+ f(context, firstParent[T].iterator(split, context))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index ee17794499..93b78e1232 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -84,21 +84,24 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
}
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
- self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
- .interruptible()
+ self.mapPartitionsWithContext((context, iter) => {
+ new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
+ }, preservesPartitioning = true)
} else if (mapSideCombine) {
val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
.setSerializer(serializerClass)
- partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true)
- .interruptible()
+ partitioned.mapPartitionsWithContext((context, iter) => {
+ new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter))
+ }, preservesPartitioning = true)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
- values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
- .interruptible()
+ values.mapPartitionsWithContext((context, iter) => {
+ new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
+ }, preservesPartitioning = true)
}
}
@@ -567,7 +570,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
- val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber)
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.partitionId, attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outputFormatClass.newInstance
val committer = format.getOutputCommitter(hadoopContext)
@@ -666,7 +669,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
- writer.setup(context.stageId, context.splitId, attemptNumber)
+ writer.setup(context.stageId, context.partitionId, attemptNumber)
writer.open()
var count = 0
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 4be506ba3a..0355618e43 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -418,26 +418,39 @@ abstract class RDD[T: ClassManifest](
command: Seq[String],
env: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
- printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
+ printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = {
new PipedRDD(this, command, env,
if (printPipeContext ne null) sc.clean(printPipeContext) else null,
if (printRDDElement ne null) sc.clean(printRDDElement) else null)
+ }
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] =
+ def mapPartitions[U: ClassManifest](
+ f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
+ }
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
def mapPartitionsWithIndex[U: ClassManifest](
- f: (Int, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] =
- new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
+ f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
+ val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
+ new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
+ }
+
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD. This is a variant of
+ * mapPartitions that also passes the TaskContext into the closure.
+ */
+ def mapPartitionsWithContext[U: ClassManifest](
+ f: (TaskContext, Iterator[T]) => Iterator[U],
+ preservesPartitioning: Boolean = false): RDD[U] = {
+ new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
+ }
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
@@ -445,22 +458,23 @@ abstract class RDD[T: ClassManifest](
*/
@deprecated("use mapPartitionsWithIndex", "0.7.0")
def mapPartitionsWithSplit[U: ClassManifest](
- f: (Int, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] =
- new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
+ f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
+ mapPartitionsWithIndex(f, preservesPartitioning)
+ }
/**
* Maps f over this RDD, where f takes an additional parameter of type A. This
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
- (f:(T, A) => U): RDD[U] = {
- def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(index)
- iter.map(t => f(t, a))
- }
- new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
+ def mapWith[A: ClassManifest, U: ClassManifest]
+ (constructA: Int => A, preservesPartitioning: Boolean = false)
+ (f: (T, A) => U): RDD[U] = {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
+ val a = constructA(context.partitionId)
+ iter.map(t => f(t, a))
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
@@ -468,13 +482,14 @@ abstract class RDD[T: ClassManifest](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
- (f:(T, A) => Seq[U]): RDD[U] = {
- def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(index)
- iter.flatMap(t => f(t, a))
- }
- new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
+ def flatMapWith[A: ClassManifest, U: ClassManifest]
+ (constructA: Int => A, preservesPartitioning: Boolean = false)
+ (f: (T, A) => Seq[U]): RDD[U] = {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
+ val a = constructA(context.partitionId)
+ iter.flatMap(t => f(t, a))
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
@@ -482,13 +497,12 @@ abstract class RDD[T: ClassManifest](
* This additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def foreachWith[A: ClassManifest](constructA: Int => A)
- (f:(T, A) => Unit) {
- def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(index)
- iter.map(t => {f(t, a); t})
- }
- (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {})
+ def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
+ val a = constructA(context.partitionId)
+ iter.map(t => {f(t, a); t})
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
}
/**
@@ -496,13 +510,12 @@ abstract class RDD[T: ClassManifest](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def filterWith[A: ClassManifest](constructA: Int => A)
- (p:(T, A) => Boolean): RDD[T] = {
- def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(index)
- iter.filter(t => p(t, a))
- }
- new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)
+ def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
+ val a = constructA(context.partitionId)
+ iter.filter(t => p(t, a))
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
}
/**
@@ -541,16 +554,14 @@ abstract class RDD[T: ClassManifest](
* Applies a function f to all elements of this RDD.
*/
def foreach(f: T => Unit) {
- val cleanF = sc.clean(f)
- sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
+ sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f))
}
/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
- val cleanF = sc.clean(f)
- sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
+ sc.runJob(this, (iter: Iterator[T]) => f(iter))
}
/**
@@ -862,11 +873,6 @@ abstract class RDD[T: ClassManifest](
map(x => (f(x), x))
}
- /**
- * Creates an interruptible version of this RDD.
- */
- def interruptible(): RDD[T] = new InterruptibleRDD(this)
-
/** A private method for tests, to look at the contents of each partition */
private[spark] def collectPartitions(): Array[Array[T]] = {
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index c5b28b8286..2a8fbe8d09 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -377,7 +377,7 @@ class DAGScheduler(
case JobCancelled(jobId) =>
// Cancel a job: find all the running stages that are linked to this job, and cancel them.
- running.find(_.jobId == jobId).foreach { stage =>
+ running.filter(_.jobId == jobId).foreach { stage =>
taskSched.cancelTasks(stage.id)
}
@@ -658,7 +658,7 @@ class DAGScheduler(
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
} else {
- stage.addOutputLoc(smt.partition, status)
+ stage.addOutputLoc(smt.partitionId, status)
}
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
markStageAsFinished(stage)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index c084059859..625c84f572 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -71,18 +71,31 @@ private[spark] object ResultTask {
}
+/**
+ * A task that sends back the output to the driver application.
+ *
+ * See [[org.apache.spark.scheduler.Task]] for more information.
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param rdd input to func
+ * @param func a function to apply on a partition of the RDD
+ * @param _partitionId index of the number in the RDD
+ * @param locs preferred task execution locations for locality scheduling
+ * @param outputId index of the task in this job (a job can launch tasks on only a subset of the
+ * input RDD's partitions).
+ */
private[spark] class ResultTask[T, U](
stageId: Int,
var rdd: RDD[T],
var func: (TaskContext, Iterator[T]) => U,
- _partition: Int,
+ _partitionId: Int,
@transient locs: Seq[TaskLocation],
var outputId: Int)
- extends Task[U](stageId, _partition) with Externalizable {
+ extends Task[U](stageId, _partitionId) with Externalizable {
def this() = this(0, null, null, 0, null, 0)
- var split = if (rdd == null) null else rdd.partitions(partition)
+ var split = if (rdd == null) null else rdd.partitions(partitionId)
@transient private val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
@@ -99,17 +112,17 @@ private[spark] class ResultTask[T, U](
override def preferredLocations: Seq[TaskLocation] = preferredLocs
- override def toString = "ResultTask(" + stageId + ", " + partition + ")"
+ override def toString = "ResultTask(" + stageId + ", " + partitionId + ")"
override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
- split = rdd.partitions(partition)
+ split = rdd.partitions(partitionId)
out.writeInt(stageId)
val bytes = ResultTask.serializeInfo(
stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
out.writeInt(bytes.length)
out.write(bytes)
- out.writeInt(partition)
+ out.writeInt(partitionId)
out.writeInt(outputId)
out.writeLong(epoch)
out.writeObject(split)
@@ -124,7 +137,7 @@ private[spark] class ResultTask[T, U](
val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
rdd = rdd_.asInstanceOf[RDD[T]]
func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
- partition = in.readInt()
+ partitionId = in.readInt()
outputId = in.readInt()
epoch = in.readLong()
split = in.readObject().asInstanceOf[Partition]
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 1904ee89c6..66c1eae703 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -85,13 +85,25 @@ private[spark] object ShuffleMapTask {
}
}
+/**
+ * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
+ * specified in the ShuffleDependency).
+ *
+ * See [[org.apache.spark.scheduler.Task]] for more information.
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param rdd the final RDD in this stage
+ * @param dep the ShuffleDependency
+ * @param _partitionId index of the number in the RDD
+ * @param locs preferred task execution locations for locality scheduling
+ */
private[spark] class ShuffleMapTask(
stageId: Int,
var rdd: RDD[_],
var dep: ShuffleDependency[_,_],
- _partition: Int,
+ _partitionId: Int,
@transient private var locs: Seq[TaskLocation])
- extends Task[MapStatus](stageId, _partition)
+ extends Task[MapStatus](stageId, _partitionId)
with Externalizable
with Logging {
@@ -101,16 +113,16 @@ private[spark] class ShuffleMapTask(
if (locs == null) Nil else locs.toSet.toSeq
}
- var split = if (rdd == null) null else rdd.partitions(partition)
+ var split = if (rdd == null) null else rdd.partitions(partitionId)
override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
- split = rdd.partitions(partition)
+ split = rdd.partitions(partitionId)
out.writeInt(stageId)
val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
out.writeInt(bytes.length)
out.write(bytes)
- out.writeInt(partition)
+ out.writeInt(partitionId)
out.writeLong(epoch)
out.writeObject(split)
}
@@ -124,7 +136,7 @@ private[spark] class ShuffleMapTask(
val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
rdd = rdd_
dep = dep_
- partition = in.readInt()
+ partitionId = in.readInt()
epoch = in.readLong()
split = in.readObject().asInstanceOf[Partition]
}
@@ -141,7 +153,7 @@ private[spark] class ShuffleMapTask(
// Obtain all the block writers for shuffle blocks.
val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
- buckets = shuffle.acquireWriters(partition)
+ buckets = shuffle.acquireWriters(partitionId)
// Write the map output to its associated buckets.
for (elem <- rdd.iterator(split, context)) {
@@ -185,5 +197,5 @@ private[spark] class ShuffleMapTask(
override def preferredLocations: Seq[TaskLocation] = preferredLocs
- override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
+ override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 2c65d8211f..1fe0d0e4e2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -31,12 +31,22 @@ import org.apache.spark.util.ByteBufferInputStream
/**
- * A task to execute on a worker node.
+ * A unit of execution. We have two kinds of Task's in Spark:
+ * - [[org.apache.spark.scheduler.ShuffleMapTask]]
+ * - [[org.apache.spark.scheduler.ResultTask]]
+ *
+ * A Spark job consists of one or more stages. The very last stage in a job consists of multiple
+ * ResultTask's, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task
+ * and sends the task output back to the driver application. A ShuffleMapTask executes the task
+ * and divides the task output to multiple buckets (based on the task's partitioner).
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param partitionId index of the number in the RDD
*/
-private[spark] abstract class Task[T](val stageId: Int, var partition: Int) extends Serializable {
+private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
def run(attemptId: Long): T = {
- context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
+ context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
if (_killed) {
kill()
}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index d9103aebb7..70c1acca17 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -62,8 +62,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testCheckpointing(_.sample(false, 0.5, 0))
testCheckpointing(_.glom())
testCheckpointing(_.mapPartitions(_.map(_.toString)))
- testCheckpointing(r => new MapPartitionsWithIndexRDD(r,
- (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false ))
+ testCheckpointing(r => new MapPartitionsWithContextRDD(r,
+ (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false ))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
testCheckpointing(_.pipe(Seq("cat")))
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index 53c225391c..a192651491 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -39,6 +39,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
override def afterEach() {
super.afterEach()
+ resetSparkContext()
System.clearProperty("spark.scheduler.mode")
}
@@ -49,7 +50,6 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
testTake()
// Make sure we can still launch tasks.
assert(sc.parallelize(1 to 10, 2).count === 10)
- resetSparkContext()
}
test("local mode, fair scheduler") {
@@ -61,7 +61,6 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
testTake()
// Make sure we can still launch tasks.
assert(sc.parallelize(1 to 10, 2).count === 10)
- resetSparkContext()
}
test("cluster mode, FIFO scheduler") {
@@ -71,7 +70,6 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
testTake()
// Make sure we can still launch tasks.
assert(sc.parallelize(1 to 10, 2).count === 10)
- resetSparkContext()
}
test("cluster mode, fair scheduler") {
@@ -83,7 +81,40 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
testTake()
// Make sure we can still launch tasks.
assert(sc.parallelize(1 to 10, 2).count === 10)
- resetSparkContext()
+ }
+
+ test("two jobs sharing the same stage") {
+ // sem1: make sure cancel is issued after some tasks are launched
+ // sem2: make sure the first stage is not finished until cancel is issued
+ val sem1 = new Semaphore(0)
+ val sem2 = new Semaphore(0)
+
+ sc = new SparkContext("local[2]", "test")
+ sc.dagScheduler.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ sem1.release()
+ }
+ })
+
+ // Create two actions that would share the some stages.
+ val rdd = sc.parallelize(1 to 10, 2).map { i =>
+ sem2.acquire()
+ (i, i)
+ }.reduceByKey(_+_)
+ val f1 = rdd.collectAsync()
+ val f2 = rdd.countAsync()
+
+ // Kill one of the action.
+ future {
+ sem1.acquire()
+ f1.cancel()
+ sem2.release(10)
+ }
+
+ // Expect both to fail now.
+ // TODO: update this test when we change Spark so cancelling f1 wouldn't affect f2.
+ intercept[SparkException] { f1.get() }
+ intercept[SparkException] { f2.get() }
}
def testCount() {
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index 3ef000da4a..da032b17d9 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -18,19 +18,18 @@
package org.apache.spark.rdd
import java.util.concurrent.Semaphore
-import java.util.concurrent.atomic.AtomicInteger
-import scala.concurrent.future
import scala.concurrent.ExecutionContext.Implicits.global
import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkContext._
import org.apache.spark.{SparkContext, SparkException, LocalSparkContext}
-import org.apache.spark.scheduler._
-class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll {
+class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with Timeouts {
@transient private var sc: SparkContext = _
@@ -114,29 +113,29 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll {
test("async success handling") {
val f = sc.parallelize(1 to 10, 2).countAsync()
- // This semaphore is used to make sure our final assert waits until onComplete / onSuccess
- // finishes execution.
+ // Use a semaphore to make sure onSuccess and onComplete's success path will be called.
+ // If not, the test will hang.
val sem = new Semaphore(0)
- AsyncRDDActionsSuite.asyncSuccessHappened.set(0)
f.onComplete {
case scala.util.Success(res) =>
- AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
sem.release()
case scala.util.Failure(e) =>
+ info("Should not have reached this code path (onComplete matching Failure)")
throw new Exception("Task should succeed")
- sem.release()
}
f.onSuccess { case a: Any =>
- AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
sem.release()
}
f.onFailure { case t =>
+ info("Should not have reached this code path (onFailure)")
throw new Exception("Task should succeed")
}
assert(f.get() === 10)
- sem.acquire(2)
- assert(AsyncRDDActionsSuite.asyncSuccessHappened.get() === 2)
+
+ failAfter(10 seconds) {
+ sem.acquire(2)
+ }
}
/**
@@ -148,38 +147,30 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll {
throw new Exception("intentional"); i
}.countAsync()
- // This semaphore is used to make sure our final assert waits until onComplete / onFailure
- // finishes execution.
+ // Use a semaphore to make sure onFailure and onComplete's failure path will be called.
+ // If not, the test will hang.
val sem = new Semaphore(0)
- AsyncRDDActionsSuite.asyncFailureHappend.set(0)
f.onComplete {
case scala.util.Success(res) =>
+ info("Should not have reached this code path (onComplete matching Success)")
throw new Exception("Task should fail")
- sem.release()
case scala.util.Failure(e) =>
- AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
sem.release()
}
f.onSuccess { case a: Any =>
+ info("Should not have reached this code path (onSuccess)")
throw new Exception("Task should fail")
}
f.onFailure { case t =>
- AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
sem.release()
}
intercept[SparkException] {
f.get()
}
- sem.acquire(2)
- assert(AsyncRDDActionsSuite.asyncFailureHappend.get() === 2)
- }
-}
-
-object AsyncRDDActionsSuite {
- // Some counters used in the test cases above.
- var asyncSuccessHappened = new AtomicInteger
- var asyncFailureHappend = new AtomicInteger
+ failAfter(10 seconds) {
+ sem.acquire(2)
+ }
+ }
}
-