diff options
author | Reynold Xin <rxin@apache.org> | 2013-10-12 15:53:31 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2013-10-12 15:53:31 -0700 |
commit | 6b288b75d4c05f42ad3612813dc77ff824bb6203 (patch) | |
tree | 0f356878c3b729689b81d55c8479e1e7d7e1eca2 | |
parent | ab0940f0c258085bbf930d43be0b9034aad039cf (diff) | |
download | spark-6b288b75d4c05f42ad3612813dc77ff824bb6203.tar.gz spark-6b288b75d4c05f42ad3612813dc77ff824bb6203.tar.bz2 spark-6b288b75d4c05f42ad3612813dc77ff824bb6203.zip |
Job cancellation: address Matei's code review feedback.
17 files changed, 248 insertions, 216 deletions
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 85018cb046..1ad9240cfa 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -31,6 +31,8 @@ import org.apache.spark.rdd.RDD * support cancellation. */ trait FutureAction[T] extends Future[T] { + // Note that we redefine methods of the Future trait here explicitly so we can specify a different + // documentation (with reference to the word "action"). /** * Cancels the execution of this action. @@ -87,14 +89,14 @@ trait FutureAction[T] extends Future[T] { * The future holding the result of an action that triggers a single job. Examples include * count, collect, reduce. */ -class FutureJob[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) +class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) extends FutureAction[T] { override def cancel() { jobWaiter.cancel() } - override def ready(atMost: Duration)(implicit permit: CanAwait): FutureJob.this.type = { + override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = { if (!atMost.isFinite()) { awaitResult() } else { @@ -149,19 +151,20 @@ class FutureJob[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) /** * A FutureAction for actions that could trigger multiple Spark jobs. Examples include take, - * takeSample. - * - * This is implemented as a Scala Promise that can be cancelled. Note that the promise itself is - * also its own Future (i.e. this.future returns this). See the implementation of takeAsync for - * usage. + * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the + * action thread if it is being blocked by a job. */ -class CancellablePromise[T] extends FutureAction[T] with Promise[T] { - // Cancellation works by setting the cancelled flag to true and interrupt the action thread - // if it is in progress. Before executing the action, the execution thread needs to check the - // cancelled flag in case cancel() is called before the thread even starts to execute. Because - // this and the execution thread is synchronized on the same promise object (this), the actual - // cancellation/interrupt event can only be triggered when the execution thread is waiting for - // the result of a job. +class ComplexFutureAction[T] extends FutureAction[T] { + + // Pointer to the thread that is executing the action. It is set when the action is run. + @volatile private var thread: Thread = _ + + // A flag indicating whether the future has been cancelled. This is used in case the future + // is cancelled before the action was even run (and thus we have no thread to interrupt). + @volatile private var _cancelled: Boolean = false + + // A promise used to signal the future. + private val p = promise[T]() override def cancel(): Unit = this.synchronized { _cancelled = true @@ -174,15 +177,18 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] { * Executes some action enclosed in the closure. To properly enable cancellation, the closure * should use runJob implementation in this promise. See takeAsync for example. */ - def run(func: => T)(implicit executor: ExecutionContext): Unit = scala.concurrent.future { - thread = Thread.currentThread - try { - this.success(func) - } catch { - case e: Exception => this.failure(e) - } finally { - thread = null + def run(func: => T)(implicit executor: ExecutionContext): this.type = { + scala.concurrent.future { + thread = Thread.currentThread + try { + p.success(func) + } catch { + case e: Exception => p.failure(e) + } finally { + thread = null + } } + this } /** @@ -193,15 +199,15 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] { rdd: RDD[T], processPartition: Iterator[T] => U, partitions: Seq[Int], - partitionResultHandler: (Int, U) => Unit, + resultHandler: (Int, U) => Unit, resultFunc: => R) { // If the action hasn't been cancelled yet, submit the job. The check and the submitJob // command need to be in an atomic block. val job = this.synchronized { if (!cancelled) { - rdd.context.submitJob(rdd, processPartition, partitions, partitionResultHandler, resultFunc) + rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc) } else { - throw new SparkException("action has been cancelled") + throw new SparkException("Action has been cancelled") } } @@ -213,7 +219,7 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] { } catch { case e: InterruptedException => job.cancel() - throw new SparkException("action has been cancelled") + throw new SparkException("Action has been cancelled") } } @@ -222,28 +228,14 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] { */ def cancelled: Boolean = _cancelled - // Pointer to the thread that is executing the action. It is set when the action is run. - @volatile private var thread: Thread = _ - - // A flag indicating whether the future has been cancelled. This is used in case the future - // is cancelled before the action was even run (and thus we have no thread to interrupt). - @volatile private var _cancelled: Boolean = false - - // Internally, we delegate most functionality to this promise. - private val p = promise[T]() - - override def future: this.type = this - - override def tryComplete(result: Try[T]): Boolean = p.tryComplete(result) - - @scala.throws(classOf[InterruptedException]) - @scala.throws(classOf[scala.concurrent.TimeoutException]) + @throws(classOf[InterruptedException]) + @throws(classOf[scala.concurrent.TimeoutException]) override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = { p.future.ready(atMost)(permit) this } - @scala.throws(classOf[Exception]) + @throws(classOf[Exception]) override def result(atMost: Duration)(implicit permit: CanAwait): T = { p.future.result(atMost)(permit) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 52fc4dd869..96a2f1fed3 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -760,10 +760,11 @@ class SparkContext( allowLocal: Boolean, resultHandler: (Int, U) => Unit) { val callSite = Utils.formatSparkCallSite + val cleanedFunc = clean(func) logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, - localProperties.get) + val result = dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, + resultHandler, localProperties.get) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() result @@ -853,16 +854,14 @@ class SparkContext( } /** - * Submit a job for execution and return a FutureJob holding the result. Note that the - * processPartition closure will be "cleaned" so the caller doesn't have to clean the closure - * explicitly. + * Submit a job for execution and return a FutureJob holding the result. */ def submitJob[T, U, R]( rdd: RDD[T], processPartition: Iterator[T] => U, partitions: Seq[Int], - partitionResultHandler: (Int, U) => Unit, - resultFunc: => R): FutureJob[R] = + resultHandler: (Int, U) => Unit, + resultFunc: => R): SimpleFutureAction[R] = { val cleanF = clean(processPartition) val callSite = Utils.formatSparkCallSite @@ -872,9 +871,9 @@ class SparkContext( partitions, callSite, allowLocal = false, - partitionResultHandler, + resultHandler, null) - new FutureJob(waiter, resultFunc) + new SimpleFutureAction(waiter, resultFunc) } /** diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 86370d5d91..51584d686d 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -22,14 +22,17 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.executor.TaskMetrics class TaskContext( - val stageId: Int, - val splitId: Int, + private[spark] val stageId: Int, + val partitionId: Int, val attemptId: Long, val runningLocally: Boolean = false, @volatile var interrupted: Boolean = false, - val taskMetrics: TaskMetrics = TaskMetrics.empty() + private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty() ) extends Serializable { + @deprecated("use partitionId", "0.8.1") + def splitId = partitionId + // List of callback functions to execute when the task completes. @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit] diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index da62091980..b56d8c9912 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -18,14 +18,18 @@ package org.apache.spark.executor import java.nio.ByteBuffer -import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver} -import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _} -import org.apache.spark.TaskState.TaskState + import com.google.protobuf.ByteString -import org.apache.spark.{Logging} + +import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver} +import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} + +import org.apache.spark.Logging import org.apache.spark.TaskState +import org.apache.spark.TaskState.TaskState import org.apache.spark.util.Utils + private[spark] class MesosExecutorBackend extends MesosExecutor with ExecutorBackend @@ -71,7 +75,11 @@ private[spark] class MesosExecutorBackend } override def killTask(d: ExecutorDriver, t: TaskID) { - logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)") + if (executor == null) { + logError("Received KillTask but executor was null") + } else { + executor.killTask(t.getValue.toLong) + } } override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {} diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 1f24ee8cd3..faaf837be0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -22,7 +22,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext.Implicits.global -import org.apache.spark.{Logging, CancellablePromise, FutureAction} +import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} /** * A set of asynchronous RDD actions available through an implicit conversion. @@ -63,9 +63,9 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with * Returns a future for retrieving the first num elements of the RDD. */ def takeAsync(num: Int): FutureAction[Seq[T]] = { - val promise = new CancellablePromise[Seq[T]] + val f = new ComplexFutureAction[Seq[T]] - promise.run { + f.run { val results = new ArrayBuffer[T](num) val totalParts = self.partitions.length var partsScanned = 0 @@ -89,7 +89,7 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) val buf = new Array[Array[T]](p.size) - promise.runJob(self, + f.runJob(self, (it: Iterator[T]) => it.take(left).toArray, p, (index: Int, data: Array[T]) => buf(index) = data, @@ -101,7 +101,7 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with results.toSeq } - promise.future + f } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 3311757189..ccaaecb85b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -85,7 +85,7 @@ private[spark] object CheckpointRDD extends Logging { val outputDir = new Path(path) val fs = outputDir.getFileSystem(env.hadoop.newConfiguration()) - val finalOutputName = splitIdToFile(ctx.splitId) + val finalOutputName = splitIdToFile(ctx.partitionId) val finalOutputPath = new Path(outputDir, finalOutputName) val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId) diff --git a/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala b/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala deleted file mode 100644 index e731deb2e5..0000000000 --- a/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import org.apache.spark.{InterruptibleIterator, Partition, TaskContext} - - -/** - * Wraps around an existing RDD to make it interruptible (can be killed). - */ -private[spark] -class InterruptibleRDD[T: ClassManifest](prev: RDD[T]) extends RDD[T](prev) { - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override val partitioner = prev.partitioner - - override def compute(split: Partition, context: TaskContext) = { - new InterruptibleIterator(context, firstParent[T].iterator(split, context)) - } -} diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala index 3ed8339010..aea08ff81b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala @@ -21,14 +21,14 @@ import org.apache.spark.{Partition, TaskContext} /** - * A variant of the MapPartitionsRDD that passes the partition index into the - * closure. This can be used to generate or collect partition specific - * information such as the number of tuples in a partition. + * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the + * TaskContext, the closure can either get access to the interruptible flag or get the index + * of the partition in the RDD. */ private[spark] -class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest]( +class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], - f: (Int, Iterator[T]) => Iterator[U], + f: (TaskContext, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean ) extends RDD[U](prev) { @@ -37,5 +37,5 @@ class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest]( override val partitioner = if (preservesPartitioning) prev.partitioner else None override def compute(split: Partition, context: TaskContext) = - f(split.index, firstParent[T].iterator(split, context)) + f(context, firstParent[T].iterator(split, context)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index ee17794499..93b78e1232 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -84,21 +84,24 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) } val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (self.partitioner == Some(partitioner)) { - self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) - .interruptible() + self.mapPartitionsWithContext((context, iter) => { + new InterruptibleIterator(context, aggregator.combineValuesByKey(iter)) + }, preservesPartitioning = true) } else if (mapSideCombine) { val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner) .setSerializer(serializerClass) - partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true) - .interruptible() + partitioned.mapPartitionsWithContext((context, iter) => { + new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter)) + }, preservesPartitioning = true) } else { // Don't apply map-side combiner. // A sanity check to make sure mergeCombiners is not defined. assert(mergeCombiners == null) val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass) - values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) - .interruptible() + values.mapPartitionsWithContext((context, iter) => { + new InterruptibleIterator(context, aggregator.combineValuesByKey(iter)) + }, preservesPartitioning = true) } } @@ -567,7 +570,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" <split #> <attempt # = spark task #> */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber) + val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.partitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = outputFormatClass.newInstance val committer = format.getOutputCommitter(hadoopContext) @@ -666,7 +669,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt - writer.setup(context.stageId, context.splitId, attemptNumber) + writer.setup(context.stageId, context.partitionId, attemptNumber) writer.open() var count = 0 diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 4be506ba3a..0355618e43 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -418,26 +418,39 @@ abstract class RDD[T: ClassManifest]( command: Seq[String], env: Map[String, String] = Map(), printPipeContext: (String => Unit) => Unit = null, - printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = + printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = { new PipedRDD(this, command, env, if (printPipeContext ne null) sc.clean(printPipeContext) else null, if (printRDDElement ne null) sc.clean(printRDDElement) else null) + } /** * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = + def mapPartitions[U: ClassManifest]( + f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning) + } /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. */ def mapPartitionsWithIndex[U: ClassManifest]( - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) + f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { + val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter) + new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning) + } + + /** + * Return a new RDD by applying a function to each partition of this RDD. This is a variant of + * mapPartitions that also passes the TaskContext into the closure. + */ + def mapPartitionsWithContext[U: ClassManifest]( + f: (TaskContext, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = { + new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning) + } /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index @@ -445,22 +458,23 @@ abstract class RDD[T: ClassManifest]( */ @deprecated("use mapPartitionsWithIndex", "0.7.0") def mapPartitionsWithSplit[U: ClassManifest]( - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) + f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { + mapPartitionsWithIndex(f, preservesPartitioning) + } /** * Maps f over this RDD, where f takes an additional parameter of type A. This * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false) - (f:(T, A) => U): RDD[U] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { - val a = constructA(index) - iter.map(t => f(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + def mapWith[A: ClassManifest, U: ClassManifest] + (constructA: Int => A, preservesPartitioning: Boolean = false) + (f: (T, A) => U): RDD[U] = { + def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { + val a = constructA(context.partitionId) + iter.map(t => f(t, a)) + } + new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) } /** @@ -468,13 +482,14 @@ abstract class RDD[T: ClassManifest]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false) - (f:(T, A) => Seq[U]): RDD[U] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { - val a = constructA(index) - iter.flatMap(t => f(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + def flatMapWith[A: ClassManifest, U: ClassManifest] + (constructA: Int => A, preservesPartitioning: Boolean = false) + (f: (T, A) => Seq[U]): RDD[U] = { + def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { + val a = constructA(context.partitionId) + iter.flatMap(t => f(t, a)) + } + new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) } /** @@ -482,13 +497,12 @@ abstract class RDD[T: ClassManifest]( * This additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def foreachWith[A: ClassManifest](constructA: Int => A) - (f:(T, A) => Unit) { - def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { - val a = constructA(index) - iter.map(t => {f(t, a); t}) - } - (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {}) + def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) { + def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { + val a = constructA(context.partitionId) + iter.map(t => {f(t, a); t}) + } + new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {}) } /** @@ -496,13 +510,12 @@ abstract class RDD[T: ClassManifest]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def filterWith[A: ClassManifest](constructA: Int => A) - (p:(T, A) => Boolean): RDD[T] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { - val a = constructA(index) - iter.filter(t => p(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true) + def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = { + def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { + val a = constructA(context.partitionId) + iter.filter(t => p(t, a)) + } + new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true) } /** @@ -541,16 +554,14 @@ abstract class RDD[T: ClassManifest]( * Applies a function f to all elements of this RDD. */ def foreach(f: T => Unit) { - val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) + sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f)) } /** * Applies a function f to each partition of this RDD. */ def foreachPartition(f: Iterator[T] => Unit) { - val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) + sc.runJob(this, (iter: Iterator[T]) => f(iter)) } /** @@ -862,11 +873,6 @@ abstract class RDD[T: ClassManifest]( map(x => (f(x), x)) } - /** - * Creates an interruptible version of this RDD. - */ - def interruptible(): RDD[T] = new InterruptibleRDD(this) - /** A private method for tests, to look at the contents of each partition */ private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c5b28b8286..2a8fbe8d09 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -377,7 +377,7 @@ class DAGScheduler( case JobCancelled(jobId) => // Cancel a job: find all the running stages that are linked to this job, and cancel them. - running.find(_.jobId == jobId).foreach { stage => + running.filter(_.jobId == jobId).foreach { stage => taskSched.cancelTasks(stage.id) } @@ -658,7 +658,7 @@ class DAGScheduler( if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) } else { - stage.addOutputLoc(smt.partition, status) + stage.addOutputLoc(smt.partitionId, status) } if (running.contains(stage) && pendingTasks(stage).isEmpty) { markStageAsFinished(stage) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index c084059859..625c84f572 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -71,18 +71,31 @@ private[spark] object ResultTask { } +/** + * A task that sends back the output to the driver application. + * + * See [[org.apache.spark.scheduler.Task]] for more information. + * + * @param stageId id of the stage this task belongs to + * @param rdd input to func + * @param func a function to apply on a partition of the RDD + * @param _partitionId index of the number in the RDD + * @param locs preferred task execution locations for locality scheduling + * @param outputId index of the task in this job (a job can launch tasks on only a subset of the + * input RDD's partitions). + */ private[spark] class ResultTask[T, U]( stageId: Int, var rdd: RDD[T], var func: (TaskContext, Iterator[T]) => U, - _partition: Int, + _partitionId: Int, @transient locs: Seq[TaskLocation], var outputId: Int) - extends Task[U](stageId, _partition) with Externalizable { + extends Task[U](stageId, _partitionId) with Externalizable { def this() = this(0, null, null, 0, null, 0) - var split = if (rdd == null) null else rdd.partitions(partition) + var split = if (rdd == null) null else rdd.partitions(partitionId) @transient private val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq @@ -99,17 +112,17 @@ private[spark] class ResultTask[T, U]( override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ResultTask(" + stageId + ", " + partition + ")" + override def toString = "ResultTask(" + stageId + ", " + partitionId + ")" override def writeExternal(out: ObjectOutput) { RDDCheckpointData.synchronized { - split = rdd.partitions(partition) + split = rdd.partitions(partitionId) out.writeInt(stageId) val bytes = ResultTask.serializeInfo( stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) out.writeInt(bytes.length) out.write(bytes) - out.writeInt(partition) + out.writeInt(partitionId) out.writeInt(outputId) out.writeLong(epoch) out.writeObject(split) @@ -124,7 +137,7 @@ private[spark] class ResultTask[T, U]( val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) rdd = rdd_.asInstanceOf[RDD[T]] func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] - partition = in.readInt() + partitionId = in.readInt() outputId = in.readInt() epoch = in.readLong() split = in.readObject().asInstanceOf[Partition] diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 1904ee89c6..66c1eae703 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -85,13 +85,25 @@ private[spark] object ShuffleMapTask { } } +/** + * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner + * specified in the ShuffleDependency). + * + * See [[org.apache.spark.scheduler.Task]] for more information. + * + * @param stageId id of the stage this task belongs to + * @param rdd the final RDD in this stage + * @param dep the ShuffleDependency + * @param _partitionId index of the number in the RDD + * @param locs preferred task execution locations for locality scheduling + */ private[spark] class ShuffleMapTask( stageId: Int, var rdd: RDD[_], var dep: ShuffleDependency[_,_], - _partition: Int, + _partitionId: Int, @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, _partition) + extends Task[MapStatus](stageId, _partitionId) with Externalizable with Logging { @@ -101,16 +113,16 @@ private[spark] class ShuffleMapTask( if (locs == null) Nil else locs.toSet.toSeq } - var split = if (rdd == null) null else rdd.partitions(partition) + var split = if (rdd == null) null else rdd.partitions(partitionId) override def writeExternal(out: ObjectOutput) { RDDCheckpointData.synchronized { - split = rdd.partitions(partition) + split = rdd.partitions(partitionId) out.writeInt(stageId) val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) out.writeInt(bytes.length) out.write(bytes) - out.writeInt(partition) + out.writeInt(partitionId) out.writeLong(epoch) out.writeObject(split) } @@ -124,7 +136,7 @@ private[spark] class ShuffleMapTask( val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) rdd = rdd_ dep = dep_ - partition = in.readInt() + partitionId = in.readInt() epoch = in.readLong() split = in.readObject().asInstanceOf[Partition] } @@ -141,7 +153,7 @@ private[spark] class ShuffleMapTask( // Obtain all the block writers for shuffle blocks. val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser) - buckets = shuffle.acquireWriters(partition) + buckets = shuffle.acquireWriters(partitionId) // Write the map output to its associated buckets. for (elem <- rdd.iterator(split, context)) { @@ -185,5 +197,5 @@ private[spark] class ShuffleMapTask( override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition) + override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 2c65d8211f..1fe0d0e4e2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -31,12 +31,22 @@ import org.apache.spark.util.ByteBufferInputStream /** - * A task to execute on a worker node. + * A unit of execution. We have two kinds of Task's in Spark: + * - [[org.apache.spark.scheduler.ShuffleMapTask]] + * - [[org.apache.spark.scheduler.ResultTask]] + * + * A Spark job consists of one or more stages. The very last stage in a job consists of multiple + * ResultTask's, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task + * and sends the task output back to the driver application. A ShuffleMapTask executes the task + * and divides the task output to multiple buckets (based on the task's partitioner). + * + * @param stageId id of the stage this task belongs to + * @param partitionId index of the number in the RDD */ -private[spark] abstract class Task[T](val stageId: Int, var partition: Int) extends Serializable { +private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { def run(attemptId: Long): T = { - context = new TaskContext(stageId, partition, attemptId, runningLocally = false) + context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) if (_killed) { kill() } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index d9103aebb7..70c1acca17 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -62,8 +62,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { testCheckpointing(_.sample(false, 0.5, 0)) testCheckpointing(_.glom()) testCheckpointing(_.mapPartitions(_.map(_.toString))) - testCheckpointing(r => new MapPartitionsWithIndexRDD(r, - (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false )) + testCheckpointing(r => new MapPartitionsWithContextRDD(r, + (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false )) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) testCheckpointing(_.pipe(Seq("cat"))) diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 53c225391c..a192651491 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -39,6 +39,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf override def afterEach() { super.afterEach() + resetSparkContext() System.clearProperty("spark.scheduler.mode") } @@ -49,7 +50,6 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf testTake() // Make sure we can still launch tasks. assert(sc.parallelize(1 to 10, 2).count === 10) - resetSparkContext() } test("local mode, fair scheduler") { @@ -61,7 +61,6 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf testTake() // Make sure we can still launch tasks. assert(sc.parallelize(1 to 10, 2).count === 10) - resetSparkContext() } test("cluster mode, FIFO scheduler") { @@ -71,7 +70,6 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf testTake() // Make sure we can still launch tasks. assert(sc.parallelize(1 to 10, 2).count === 10) - resetSparkContext() } test("cluster mode, fair scheduler") { @@ -83,7 +81,40 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf testTake() // Make sure we can still launch tasks. assert(sc.parallelize(1 to 10, 2).count === 10) - resetSparkContext() + } + + test("two jobs sharing the same stage") { + // sem1: make sure cancel is issued after some tasks are launched + // sem2: make sure the first stage is not finished until cancel is issued + val sem1 = new Semaphore(0) + val sem2 = new Semaphore(0) + + sc = new SparkContext("local[2]", "test") + sc.dagScheduler.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem1.release() + } + }) + + // Create two actions that would share the some stages. + val rdd = sc.parallelize(1 to 10, 2).map { i => + sem2.acquire() + (i, i) + }.reduceByKey(_+_) + val f1 = rdd.collectAsync() + val f2 = rdd.countAsync() + + // Kill one of the action. + future { + sem1.acquire() + f1.cancel() + sem2.release(10) + } + + // Expect both to fail now. + // TODO: update this test when we change Spark so cancelling f1 wouldn't affect f2. + intercept[SparkException] { f1.get() } + intercept[SparkException] { f2.get() } } def testCount() { diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 3ef000da4a..da032b17d9 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -18,19 +18,18 @@ package org.apache.spark.rdd import java.util.concurrent.Semaphore -import java.util.concurrent.atomic.AtomicInteger -import scala.concurrent.future import scala.concurrent.ExecutionContext.Implicits.global import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext._ import org.apache.spark.{SparkContext, SparkException, LocalSparkContext} -import org.apache.spark.scheduler._ -class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll { +class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with Timeouts { @transient private var sc: SparkContext = _ @@ -114,29 +113,29 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll { test("async success handling") { val f = sc.parallelize(1 to 10, 2).countAsync() - // This semaphore is used to make sure our final assert waits until onComplete / onSuccess - // finishes execution. + // Use a semaphore to make sure onSuccess and onComplete's success path will be called. + // If not, the test will hang. val sem = new Semaphore(0) - AsyncRDDActionsSuite.asyncSuccessHappened.set(0) f.onComplete { case scala.util.Success(res) => - AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet() sem.release() case scala.util.Failure(e) => + info("Should not have reached this code path (onComplete matching Failure)") throw new Exception("Task should succeed") - sem.release() } f.onSuccess { case a: Any => - AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet() sem.release() } f.onFailure { case t => + info("Should not have reached this code path (onFailure)") throw new Exception("Task should succeed") } assert(f.get() === 10) - sem.acquire(2) - assert(AsyncRDDActionsSuite.asyncSuccessHappened.get() === 2) + + failAfter(10 seconds) { + sem.acquire(2) + } } /** @@ -148,38 +147,30 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll { throw new Exception("intentional"); i }.countAsync() - // This semaphore is used to make sure our final assert waits until onComplete / onFailure - // finishes execution. + // Use a semaphore to make sure onFailure and onComplete's failure path will be called. + // If not, the test will hang. val sem = new Semaphore(0) - AsyncRDDActionsSuite.asyncFailureHappend.set(0) f.onComplete { case scala.util.Success(res) => + info("Should not have reached this code path (onComplete matching Success)") throw new Exception("Task should fail") - sem.release() case scala.util.Failure(e) => - AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet() sem.release() } f.onSuccess { case a: Any => + info("Should not have reached this code path (onSuccess)") throw new Exception("Task should fail") } f.onFailure { case t => - AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet() sem.release() } intercept[SparkException] { f.get() } - sem.acquire(2) - assert(AsyncRDDActionsSuite.asyncFailureHappend.get() === 2) - } -} - -object AsyncRDDActionsSuite { - // Some counters used in the test cases above. - var asyncSuccessHappened = new AtomicInteger - var asyncFailureHappend = new AtomicInteger + failAfter(10 seconds) { + sem.acquire(2) + } + } } - |