diff options
author | Joseph E. Gonzalez <joseph.e.gonzalez@gmail.com> | 2013-10-15 16:15:19 -0700 |
---|---|---|
committer | Joseph E. Gonzalez <joseph.e.gonzalez@gmail.com> | 2013-10-15 16:15:19 -0700 |
commit | 1b22eef7449352d6097ae8025e6adbe767398344 (patch) | |
tree | fb667b566f4305143bdd9b22b97cf928c828c461 /core | |
parent | 194bb03d1637f731535f964e1d1661d218380162 (diff) | |
parent | 3249e0e90dd9a7b422f561c42407b6a2b3feab17 (diff) | |
download | spark-1b22eef7449352d6097ae8025e6adbe767398344.tar.gz spark-1b22eef7449352d6097ae8025e6adbe767398344.tar.bz2 spark-1b22eef7449352d6097ae8025e6adbe767398344.zip |
Merge branch 'master' of https://github.com/apache/incubator-spark into indexedrdd_graphx
Diffstat (limited to 'core')
50 files changed, 1529 insertions, 516 deletions
diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala index f8af6b0fbe..d9ed572da6 100644 --- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala @@ -28,7 +28,11 @@ import org.apache.spark.util.CompletionIterator private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { - override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) + override def fetch[T]( + shuffleId: Int, + reduceId: Int, + context: TaskContext, + serializer: Serializer) : Iterator[T] = { @@ -73,7 +77,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer) val itr = blockFetcherItr.flatMap(unpackBlock) - CompletionIterator[T, Iterator[T]](itr, { + val completionIter = CompletionIterator[T, Iterator[T]](itr, { val shuffleMetrics = new ShuffleReadMetrics shuffleMetrics.shuffleFinishTime = System.currentTimeMillis shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime @@ -82,7 +86,9 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks - metrics.shuffleReadMetrics = Some(shuffleMetrics) + context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics) }) + + new InterruptibleIterator[T](context, completionIter) } } diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 221bb68c61..519ecde50a 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -38,7 +38,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { blockManager.get(key) match { case Some(values) => // Partition is already materialized, so just return its values - return values.asInstanceOf[Iterator[T]] + return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) case None => // Mark the split as loading (unless someone else marks it first) @@ -56,7 +56,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { // downside of the current code is that threads wait serially if this does happen. blockManager.get(key) match { case Some(values) => - return values.asInstanceOf[Iterator[T]] + return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) case None => logInfo("Whoever was loading %s failed; we'll try it ourselves".format(key)) loading.add(key) diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala new file mode 100644 index 0000000000..1ad9240cfa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent._ +import scala.concurrent.duration.Duration +import scala.util.Try + +import org.apache.spark.scheduler.{JobSucceeded, JobWaiter} +import org.apache.spark.scheduler.JobFailed +import org.apache.spark.rdd.RDD + + +/** + * A future for the result of an action. This is an extension of the Scala Future interface to + * support cancellation. + */ +trait FutureAction[T] extends Future[T] { + // Note that we redefine methods of the Future trait here explicitly so we can specify a different + // documentation (with reference to the word "action"). + + /** + * Cancels the execution of this action. + */ + def cancel() + + /** + * Blocks until this action completes. + * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf + * for unbounded waiting, or a finite positive duration + * @return this FutureAction + */ + override def ready(atMost: Duration)(implicit permit: CanAwait): FutureAction.this.type + + /** + * Awaits and returns the result (of type T) of this action. + * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf + * for unbounded waiting, or a finite positive duration + * @throws Exception exception during action execution + * @return the result value if the action is completed within the specific maximum wait time + */ + @throws(classOf[Exception]) + override def result(atMost: Duration)(implicit permit: CanAwait): T + + /** + * When this action is completed, either through an exception, or a value, applies the provided + * function. + */ + def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) + + /** + * Returns whether the action has already been completed with a value or an exception. + */ + override def isCompleted: Boolean + + /** + * The value of this Future. + * + * If the future is not completed the returned value will be None. If the future is completed + * the value will be Some(Success(t)) if it contains a valid result, or Some(Failure(error)) if + * it contains an exception. + */ + override def value: Option[Try[T]] + + /** + * Blocks and returns the result of this job. + */ + @throws(classOf[Exception]) + def get(): T = Await.result(this, Duration.Inf) +} + + +/** + * The future holding the result of an action that triggers a single job. Examples include + * count, collect, reduce. + */ +class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) + extends FutureAction[T] { + + override def cancel() { + jobWaiter.cancel() + } + + override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = { + if (!atMost.isFinite()) { + awaitResult() + } else { + val finishTime = System.currentTimeMillis() + atMost.toMillis + while (!isCompleted) { + val time = System.currentTimeMillis() + if (time >= finishTime) { + throw new TimeoutException + } else { + jobWaiter.wait(finishTime - time) + } + } + } + this + } + + @throws(classOf[Exception]) + override def result(atMost: Duration)(implicit permit: CanAwait): T = { + ready(atMost)(permit) + awaitResult() match { + case scala.util.Success(res) => res + case scala.util.Failure(e) => throw e + } + } + + override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) { + executor.execute(new Runnable { + override def run() { + func(awaitResult()) + } + }) + } + + override def isCompleted: Boolean = jobWaiter.jobFinished + + override def value: Option[Try[T]] = { + if (jobWaiter.jobFinished) { + Some(awaitResult()) + } else { + None + } + } + + private def awaitResult(): Try[T] = { + jobWaiter.awaitResult() match { + case JobSucceeded => scala.util.Success(resultFunc) + case JobFailed(e: Exception, _) => scala.util.Failure(e) + } + } +} + + +/** + * A FutureAction for actions that could trigger multiple Spark jobs. Examples include take, + * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the + * action thread if it is being blocked by a job. + */ +class ComplexFutureAction[T] extends FutureAction[T] { + + // Pointer to the thread that is executing the action. It is set when the action is run. + @volatile private var thread: Thread = _ + + // A flag indicating whether the future has been cancelled. This is used in case the future + // is cancelled before the action was even run (and thus we have no thread to interrupt). + @volatile private var _cancelled: Boolean = false + + // A promise used to signal the future. + private val p = promise[T]() + + override def cancel(): Unit = this.synchronized { + _cancelled = true + if (thread != null) { + thread.interrupt() + } + } + + /** + * Executes some action enclosed in the closure. To properly enable cancellation, the closure + * should use runJob implementation in this promise. See takeAsync for example. + */ + def run(func: => T)(implicit executor: ExecutionContext): this.type = { + scala.concurrent.future { + thread = Thread.currentThread + try { + p.success(func) + } catch { + case e: Exception => p.failure(e) + } finally { + thread = null + } + } + this + } + + /** + * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext + * to enable cancellation. + */ + def runJob[T, U, R]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + partitions: Seq[Int], + resultHandler: (Int, U) => Unit, + resultFunc: => R) { + // If the action hasn't been cancelled yet, submit the job. The check and the submitJob + // command need to be in an atomic block. + val job = this.synchronized { + if (!cancelled) { + rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc) + } else { + throw new SparkException("Action has been cancelled") + } + } + + // Wait for the job to complete. If the action is cancelled (with an interrupt), + // cancel the job and stop the execution. This is not in a synchronized block because + // Await.ready eventually waits on the monitor in FutureJob.jobWaiter. + try { + Await.ready(job, Duration.Inf) + } catch { + case e: InterruptedException => + job.cancel() + throw new SparkException("Action has been cancelled") + } + } + + /** + * Returns whether the promise has been cancelled. + */ + def cancelled: Boolean = _cancelled + + @throws(classOf[InterruptedException]) + @throws(classOf[scala.concurrent.TimeoutException]) + override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = { + p.future.ready(atMost)(permit) + this + } + + @throws(classOf[Exception]) + override def result(atMost: Duration)(implicit permit: CanAwait): T = { + p.future.result(atMost)(permit) + } + + override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit = { + p.future.onComplete(func)(executor) + } + + override def isCompleted: Boolean = p.isCompleted + + override def value: Option[Try[T]] = p.future.value +} diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala new file mode 100644 index 0000000000..56e0b8d2c0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * An iterator that wraps around an existing iterator to provide task killing functionality. + * It works by checking the interrupted flag in TaskContext. + */ +class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T]) + extends Iterator[T] { + + def hasNext: Boolean = !context.interrupted && delegate.hasNext + + def next(): T = delegate.next() +} diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala index 307c383a89..a85aa50a9b 100644 --- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala @@ -27,7 +27,10 @@ private[spark] abstract class ShuffleFetcher { * Fetch the shuffle outputs for a given ShuffleDependency. * @return An iterator over the elements of the fetched shuffle outputs. */ - def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, + def fetch[T]( + shuffleId: Int, + reduceId: Int, + context: TaskContext, serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T] /** Stop the fetcher */ diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f3723a4f9d..7db6b6b8bc 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -755,10 +755,11 @@ class SparkContext( allowLocal: Boolean, resultHandler: (Int, U) => Unit) { val callSite = Utils.formatSparkCallSite + val cleanedFunc = clean(func) logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, - localProperties.get) + val result = dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, + resultHandler, localProperties.get) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() result @@ -848,6 +849,36 @@ class SparkContext( } /** + * Submit a job for execution and return a FutureJob holding the result. + */ + def submitJob[T, U, R]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + partitions: Seq[Int], + resultHandler: (Int, U) => Unit, + resultFunc: => R): SimpleFutureAction[R] = + { + val cleanF = clean(processPartition) + val callSite = Utils.formatSparkCallSite + val waiter = dagScheduler.submitJob( + rdd, + (context: TaskContext, iter: Iterator[T]) => cleanF(iter), + partitions, + callSite, + allowLocal = false, + resultHandler, + null) + new SimpleFutureAction(waiter, resultFunc) + } + + /** + * Cancel all jobs that have been scheduled or are running. + */ + def cancelAllJobs() { + dagScheduler.cancelAllJobs() + } + + /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) */ @@ -930,6 +961,8 @@ object SparkContext { implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) = rdd.pairRDDFunctions + implicit def rddToAsyncRDDActions[T: ClassManifest](rdd: RDD[T]) = new AsyncRDDActions(rdd) + implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest]( rdd: RDD[(K, V)]) = new SequenceFileRDDFunctions(rdd) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index c2c358c7ad..51584d686d 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -17,21 +17,30 @@ package org.apache.spark -import executor.TaskMetrics import scala.collection.mutable.ArrayBuffer +import org.apache.spark.executor.TaskMetrics + class TaskContext( - val stageId: Int, - val splitId: Int, + private[spark] val stageId: Int, + val partitionId: Int, val attemptId: Long, val runningLocally: Boolean = false, - val taskMetrics: TaskMetrics = TaskMetrics.empty() + @volatile var interrupted: Boolean = false, + private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty() ) extends Serializable { - @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit] + @deprecated("use partitionId", "0.8.1") + def splitId = partitionId + + // List of callback functions to execute when the task completes. + @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit] - // Add a callback function to be executed on task completion. An example use - // is for HadoopRDD to register a callback to close the input stream. + /** + * Add a callback function to be executed on task completion. An example use + * is for HadoopRDD to register a callback to close the input stream. + * @param f Callback function. + */ def addOnCompleteCallback(f: () => Unit) { onCompleteCallbacks += f } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 8466c2a004..c1e5e04b31 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -52,4 +52,6 @@ private[spark] case class ExceptionFailure( */ private[spark] case object TaskResultLost extends TaskEndReason +private[spark] case object TaskKilled extends TaskEndReason + private[spark] case class OtherFailure(message: String) extends TaskEndReason diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index eff0c0f274..20323ea038 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -36,7 +36,8 @@ import org.apache.spark.util.Utils private[spark] class Executor( executorId: String, slaveHostname: String, - properties: Seq[(String, String)]) + properties: Seq[(String, String)], + isLocal: Boolean = false) extends Logging { // Application dependencies (added through SparkContext) that we've fetched so far on this node. @@ -101,18 +102,47 @@ private[spark] class Executor( val executorSource = new ExecutorSource(this, executorId) // Initialize Spark environment (using system properties read above) - val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false) - SparkEnv.set(env) - env.metricsSystem.registerSource(executorSource) + private val env = { + if (!isLocal) { + val _env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, + isDriver = false, isLocal = false) + SparkEnv.set(_env) + _env.metricsSystem.registerSource(executorSource) + _env + } else { + SparkEnv.get + } + } - private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size") + // Akka's message frame size. If task result is bigger than this, we use the block manager + // to send the result back. + private val akkaFrameSize = { + env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size") + } // Start worker thread pool val threadPool = new ThreadPoolExecutor( - 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) + 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable], Utils.daemonThreadFactory) + + // Maintains the list of running tasks. + private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) { - threadPool.execute(new TaskRunner(context, taskId, serializedTask)) + val tr = new TaskRunner(context, taskId, serializedTask) + runningTasks.put(taskId, tr) + threadPool.execute(tr) + } + + def killTask(taskId: Long) { + val tr = runningTasks.get(taskId) + if (tr != null) { + tr.kill() + // We remove the task also in the finally block in TaskRunner.run. + // The reason we need to remove it here is because killTask might be called before the task + // is even launched, and never reaching that finally block. ConcurrentHashMap's remove is + // idempotent. + runningTasks.remove(taskId) + } } /** Get the Yarn approved local directories. */ @@ -124,49 +154,80 @@ private[spark] class Executor( .getOrElse(Option(System.getenv("LOCAL_DIRS")) .getOrElse("")) - if (localDirs.isEmpty()) { + if (localDirs.isEmpty) { throw new Exception("Yarn Local dirs can't be empty") } - return localDirs + localDirs } - class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) + class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) extends Runnable { + @volatile private var killed = false + @volatile private var task: Task[Any] = _ + + def kill() { + logInfo("Executor is trying to kill task " + taskId) + killed = true + if (task != null) { + task.kill() + } + } + override def run() { val startTime = System.currentTimeMillis() SparkEnv.set(env) Thread.currentThread.setContextClassLoader(replClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo("Running task ID " + taskId) - context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) + execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) var attemptedTask: Option[Task[Any]] = None var taskStart: Long = 0 - def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum - val startGCTime = getTotalGCTime + def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum + val startGCTime = gcTime try { SparkEnv.set(env) Accumulators.clear() val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) updateDependencies(taskFiles, taskJars) - val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + + // If this task has been killed before we deserialized it, let's quit now. Otherwise, + // continue executing the task. + if (killed) { + logInfo("Executor killed task " + taskId) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + return + } + attemptedTask = Some(task) - logInfo("Its epoch is " + task.epoch) + logDebug("Task " + taskId +"'s epoch is " + task.epoch) env.mapOutputTracker.updateEpoch(task.epoch) + + // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() val value = task.run(taskId.toInt) val taskFinish = System.currentTimeMillis() + + // If the task has been killed, let's fail it. + if (task.killed) { + logInfo("Executor killed task " + taskId) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + return + } + for (m <- task.metrics) { - m.hostname = Utils.localHostName + m.hostname = Utils.localHostName() m.executorDeserializeTime = (taskStart - startTime).toInt m.executorRunTime = (taskFinish - taskStart).toInt - m.jvmGCTime = getTotalGCTime - startGCTime + m.jvmGCTime = gcTime - startGCTime } - //TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c - // we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could - // just change the relevants bytes in the byte buffer + // TODO I'd also like to track the time it takes to serialize the task results, but that is + // huge headache, b/c we need to serialize the task metrics first. If TaskMetrics had a + // custom serialized format, we could just change the relevants bytes in the byte buffer val accumUpdates = Accumulators.values + val directResult = new DirectTaskResult(value, accumUpdates, task.metrics.getOrElse(null)) val serializedDirectResult = ser.serialize(directResult) logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit) @@ -182,12 +243,13 @@ private[spark] class Executor( serializedDirectResult } } - context.statusUpdate(taskId, TaskState.FINISHED, serializedResult) + + execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) logInfo("Finished task ID " + taskId) } catch { case ffe: FetchFailedException => { val reason = ffe.toTaskEndReason - context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) } case t: Throwable => { @@ -195,10 +257,10 @@ private[spark] class Executor( val metrics = attemptedTask.flatMap(t => t.metrics) for (m <- metrics) { m.executorRunTime = serviceTime - m.jvmGCTime = getTotalGCTime - startGCTime + m.jvmGCTime = gcTime - startGCTime } val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics) - context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) // TODO: Should we exit the whole executor here? On the one hand, the failed task may // have left some weird state around depending on when the exception was thrown, but on @@ -206,6 +268,8 @@ private[spark] class Executor( logError("Exception in task ID " + taskId, t) //System.exit(1) } + } finally { + runningTasks.remove(taskId) } } } @@ -215,7 +279,7 @@ private[spark] class Executor( * created by the interpreter to the search path */ private def createClassLoader(): ExecutorURLClassLoader = { - var loader = this.getClass.getClassLoader + val loader = this.getClass.getClassLoader // For each of the jars in the jarSet, add them to the class loader. // We assume each of the files has already been fetched. @@ -237,7 +301,7 @@ private[spark] class Executor( val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader]) - return constructor.newInstance(classUri, parent) + constructor.newInstance(classUri, parent) } catch { case _: ClassNotFoundException => logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") @@ -245,7 +309,7 @@ private[spark] class Executor( null } } else { - return parent + parent } } diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index da62091980..b56d8c9912 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -18,14 +18,18 @@ package org.apache.spark.executor import java.nio.ByteBuffer -import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver} -import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _} -import org.apache.spark.TaskState.TaskState + import com.google.protobuf.ByteString -import org.apache.spark.{Logging} + +import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver} +import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} + +import org.apache.spark.Logging import org.apache.spark.TaskState +import org.apache.spark.TaskState.TaskState import org.apache.spark.util.Utils + private[spark] class MesosExecutorBackend extends MesosExecutor with ExecutorBackend @@ -71,7 +75,11 @@ private[spark] class MesosExecutorBackend } override def killTask(d: ExecutorDriver, t: TaskID) { - logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)") + if (executor == null) { + logError("Received KillTask but executor was null") + } else { + executor.killTask(t.getValue.toLong) + } } override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {} diff --git a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala index 7839023868..db0bea0472 100644 --- a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala @@ -63,12 +63,20 @@ private[spark] class StandaloneExecutorBackend( case LaunchTask(taskDesc) => logInfo("Got assigned task " + taskDesc.taskId) if (executor == null) { - logError("Received launchTask but executor was null") + logError("Received LaunchTask command but executor was null") System.exit(1) } else { executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask) } + case KillTask(taskId, _) => + if (executor == null) { + logError("Received KillTask command but executor was null") + System.exit(1) + } else { + executor.killTask(taskId) + } + case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => logError("Driver terminated or disconnected! Shutting down.") System.exit(1) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala new file mode 100644 index 0000000000..faaf837be0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.ExecutionContext.Implicits.global + +import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} + +/** + * A set of asynchronous RDD actions available through an implicit conversion. + * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. + */ +class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with Logging { + + /** + * Returns a future for counting the number of elements in the RDD. + */ + def countAsync(): FutureAction[Long] = { + val totalCount = new AtomicLong + self.context.submitJob( + self, + (iter: Iterator[T]) => { + var result = 0L + while (iter.hasNext) { + result += 1L + iter.next() + } + result + }, + Range(0, self.partitions.size), + (index: Int, data: Long) => totalCount.addAndGet(data), + totalCount.get()) + } + + /** + * Returns a future for retrieving all elements of this RDD. + */ + def collectAsync(): FutureAction[Seq[T]] = { + val results = new Array[Array[T]](self.partitions.size) + self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), + (index, data) => results(index) = data, results.flatten.toSeq) + } + + /** + * Returns a future for retrieving the first num elements of the RDD. + */ + def takeAsync(num: Int): FutureAction[Seq[T]] = { + val f = new ComplexFutureAction[Seq[T]] + + f.run { + val results = new ArrayBuffer[T](num) + val totalParts = self.partitions.length + var partsScanned = 0 + while (results.size < num && partsScanned < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. + var numPartsToTry = 1 + if (partsScanned > 0) { + // If we didn't find any rows after the first iteration, just try all partitions next. + // Otherwise, interpolate the number of partitions we need to try, but overestimate it + // by 50%. + if (results.size == 0) { + numPartsToTry = totalParts - 1 + } else { + numPartsToTry = (1.5 * num * partsScanned / results.size).toInt + } + } + numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions + + val left = num - results.size + val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + + val buf = new Array[Array[T]](p.size) + f.runJob(self, + (it: Iterator[T]) => it.take(left).toArray, + p, + (index: Int, data: Array[T]) => buf(index) = data, + Unit) + + buf.foreach(results ++= _.take(num - results.size)) + partsScanned += numPartsToTry + } + results.toSeq + } + + f + } + + /** + * Applies a function f to all elements of this RDD. + */ + def foreachAsync(f: T => Unit): FutureAction[Unit] = { + self.context.submitJob[T, Unit, Unit](self, _.foreach(f), Range(0, self.partitions.size), + (index, data) => Unit, Unit) + } + + /** + * Applies a function f to each partition of this RDD. + */ + def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { + self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size), + (index, data) => Unit, Unit) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 3311757189..ccaaecb85b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -85,7 +85,7 @@ private[spark] object CheckpointRDD extends Logging { val outputDir = new Path(path) val fs = outputDir.getFileSystem(env.hadoop.newConfiguration()) - val finalOutputName = splitIdToFile(ctx.splitId) + val finalOutputName = splitIdToFile(ctx.partitionId) val finalOutputPath = new Path(outputDir, finalOutputName) val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index d237797aa6..911a002884 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -21,7 +21,7 @@ import java.io.{ObjectOutputStream, IOException} import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext} +import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} import org.apache.spark.util.AppendOnlyMap @@ -125,12 +125,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: case ShuffleCoGroupSplitDep(shuffleId) => { // Read map outputs of shuffle val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach { + fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach { kv => getSeq(kv._1)(depNum) += kv._2 } } } - map.iterator + new InterruptibleIterator(context, map.iterator) } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 2d394abfd9..fad042c7ae 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -27,8 +27,7 @@ import org.apache.hadoop.mapred.RecordReader import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.util.ReflectionUtils -import org.apache.spark.{Logging, Partition, SerializableWritable, SparkContext, SparkEnv, - TaskContext} +import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.util.NextIterator import org.apache.hadoop.conf.{Configuration, Configurable} @@ -39,7 +38,7 @@ import org.apache.hadoop.conf.{Configuration, Configurable} */ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit) extends Partition { - + val inputSplit = new SerializableWritable[InputSplit](s) override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt @@ -144,38 +143,41 @@ class HadoopRDD[K, V]( array } - override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] { - val split = theSplit.asInstanceOf[HadoopPartition] - logInfo("Input split: " + split.inputSplit) - var reader: RecordReader[K, V] = null - - val jobConf = getJobConf() - val inputFormat = getInputFormat(jobConf) - reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) - - // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback{ () => closeIfNeeded() } - - val key: K = reader.createKey() - val value: V = reader.createValue() - - override def getNext() = { - try { - finished = !reader.next(key, value) - } catch { - case eof: EOFException => - finished = true + override def compute(theSplit: Partition, context: TaskContext) = { + val iter = new NextIterator[(K, V)] { + val split = theSplit.asInstanceOf[HadoopPartition] + logInfo("Input split: " + split.inputSplit) + var reader: RecordReader[K, V] = null + + val jobConf = getJobConf() + val inputFormat = getInputFormat(jobConf) + reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + + // Register an on-task-completion callback to close the input stream. + context.addOnCompleteCallback{ () => closeIfNeeded() } + + val key: K = reader.createKey() + val value: V = reader.createValue() + + override def getNext() = { + try { + finished = !reader.next(key, value) + } catch { + case eof: EOFException => + finished = true + } + (key, value) } - (key, value) - } - override def close() { - try { - reader.close() - } catch { - case e: Exception => logWarning("Exception in RecordReader.close()", e) + override def close() { + try { + reader.close() + } catch { + case e: Exception => logWarning("Exception in RecordReader.close()", e) + } } } + new InterruptibleIterator[(K, V)](context, iter) } override def getPreferredLocations(split: Partition): Seq[String] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala index 3ed8339010..aea08ff81b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala @@ -21,14 +21,14 @@ import org.apache.spark.{Partition, TaskContext} /** - * A variant of the MapPartitionsRDD that passes the partition index into the - * closure. This can be used to generate or collect partition specific - * information such as the number of tuples in a partition. + * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the + * TaskContext, the closure can either get access to the interruptible flag or get the index + * of the partition in the RDD. */ private[spark] -class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest]( +class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], - f: (Int, Iterator[T]) => Iterator[U], + f: (TaskContext, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean ) extends RDD[U](prev) { @@ -37,5 +37,5 @@ class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest]( override val partitioner = if (preservesPartitioning) prev.partitioner else None override def compute(split: Partition, context: TaskContext) = - f(split.index, firstParent[T].iterator(split, context)) + f(context, firstParent[T].iterator(split, context)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7b3a89f7e0..2662d48c84 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ -import org.apache.spark.{Dependency, Logging, Partition, SerializableWritable, SparkContext, TaskContext} +import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext} private[spark] @@ -71,49 +71,52 @@ class NewHadoopRDD[K, V]( result } - override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] { - val split = theSplit.asInstanceOf[NewHadoopPartition] - logInfo("Input split: " + split.serializableHadoopSplit) - val conf = confBroadcast.value.value - val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) - val format = inputFormatClass.newInstance - if (format.isInstanceOf[Configurable]) { - format.asInstanceOf[Configurable].setConf(conf) - } - val reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - - // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => close()) - - var havePair = false - var finished = false - - override def hasNext: Boolean = { - if (!finished && !havePair) { - finished = !reader.nextKeyValue - havePair = !finished + override def compute(theSplit: Partition, context: TaskContext) = { + val iter = new Iterator[(K, V)] { + val split = theSplit.asInstanceOf[NewHadoopPartition] + logInfo("Input split: " + split.serializableHadoopSplit) + val conf = confBroadcast.value.value + val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val format = inputFormatClass.newInstance + if (format.isInstanceOf[Configurable]) { + format.asInstanceOf[Configurable].setConf(conf) + } + val reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + + // Register an on-task-completion callback to close the input stream. + context.addOnCompleteCallback(() => close()) + + var havePair = false + var finished = false + + override def hasNext: Boolean = { + if (!finished && !havePair) { + finished = !reader.nextKeyValue + havePair = !finished + } + !finished } - !finished - } - override def next: (K, V) = { - if (!hasNext) { - throw new java.util.NoSuchElementException("End of stream") + override def next(): (K, V) = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + havePair = false + (reader.getCurrentKey, reader.getCurrentValue) } - havePair = false - return (reader.getCurrentKey, reader.getCurrentValue) - } - private def close() { - try { - reader.close() - } catch { - case e: Exception => logWarning("Exception in RecordReader.close()", e) + private def close() { + try { + reader.close() + } catch { + case e: Exception => logWarning("Exception in RecordReader.close()", e) + } } } + new InterruptibleIterator(context, iter) } override def getPreferredLocations(split: Partition): Seq[String] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 569d74ae7a..7fadbcf4ec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -84,18 +84,24 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) } val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (self.partitioner == Some(partitioner)) { - self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) + self.mapPartitionsWithContext((context, iter) => { + new InterruptibleIterator(context, aggregator.combineValuesByKey(iter)) + }, preservesPartitioning = true) } else if (mapSideCombine) { val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner) .setSerializer(serializerClass) - partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true) + partitioned.mapPartitionsWithContext((context, iter) => { + new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter)) + }, preservesPartitioning = true) } else { // Don't apply map-side combiner. // A sanity check to make sure mergeCombiners is not defined. assert(mergeCombiners == null) val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass) - values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) + values.mapPartitionsWithContext((context, iter) => { + new InterruptibleIterator(context, aggregator.combineValuesByKey(iter)) + }, preservesPartitioning = true) } } @@ -575,7 +581,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" <split #> <attempt # = spark task #> */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber) + val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.partitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = outputFormatClass.newInstance val committer = format.getOutputCommitter(hadoopContext) @@ -674,7 +680,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt - writer.setup(context.stageId, context.splitId, attemptNumber) + writer.setup(context.stageId, context.partitionId, attemptNumber) writer.open() var count = 0 diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 6dbd4309aa..cd96250389 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -94,8 +94,9 @@ private[spark] class ParallelCollectionRDD[T: ClassManifest]( slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray } - override def compute(s: Partition, context: TaskContext) = - s.asInstanceOf[ParallelCollectionPartition[T]].iterator + override def compute(s: Partition, context: TaskContext) = { + new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator) + } override def getPreferredLocations(s: Partition): Seq[String] = { locationPrefs.getOrElse(s.index, Nil) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 16eaffb0fc..d14b4c60c7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -418,26 +418,39 @@ abstract class RDD[T: ClassManifest]( command: Seq[String], env: Map[String, String] = Map(), printPipeContext: (String => Unit) => Unit = null, - printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = + printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = { new PipedRDD(this, command, env, if (printPipeContext ne null) sc.clean(printPipeContext) else null, if (printRDDElement ne null) sc.clean(printRDDElement) else null) + } /** * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = + def mapPartitions[U: ClassManifest]( + f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning) + } /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. */ def mapPartitionsWithIndex[U: ClassManifest]( - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) + f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { + val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter) + new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning) + } + + /** + * Return a new RDD by applying a function to each partition of this RDD. This is a variant of + * mapPartitions that also passes the TaskContext into the closure. + */ + def mapPartitionsWithContext[U: ClassManifest]( + f: (TaskContext, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = { + new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning) + } /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index @@ -445,22 +458,23 @@ abstract class RDD[T: ClassManifest]( */ @deprecated("use mapPartitionsWithIndex", "0.7.0") def mapPartitionsWithSplit[U: ClassManifest]( - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) + f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { + mapPartitionsWithIndex(f, preservesPartitioning) + } /** * Maps f over this RDD, where f takes an additional parameter of type A. This * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false) - (f:(T, A) => U): RDD[U] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { - val a = constructA(index) - iter.map(t => f(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + def mapWith[A: ClassManifest, U: ClassManifest] + (constructA: Int => A, preservesPartitioning: Boolean = false) + (f: (T, A) => U): RDD[U] = { + def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { + val a = constructA(context.partitionId) + iter.map(t => f(t, a)) + } + new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) } /** @@ -468,13 +482,14 @@ abstract class RDD[T: ClassManifest]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false) - (f:(T, A) => Seq[U]): RDD[U] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { - val a = constructA(index) - iter.flatMap(t => f(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + def flatMapWith[A: ClassManifest, U: ClassManifest] + (constructA: Int => A, preservesPartitioning: Boolean = false) + (f: (T, A) => Seq[U]): RDD[U] = { + def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { + val a = constructA(context.partitionId) + iter.flatMap(t => f(t, a)) + } + new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) } /** @@ -482,13 +497,12 @@ abstract class RDD[T: ClassManifest]( * This additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def foreachWith[A: ClassManifest](constructA: Int => A) - (f:(T, A) => Unit) { - def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { - val a = constructA(index) - iter.map(t => {f(t, a); t}) - } - (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {}) + def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) { + def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { + val a = constructA(context.partitionId) + iter.map(t => {f(t, a); t}) + } + new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {}) } /** @@ -496,13 +510,12 @@ abstract class RDD[T: ClassManifest]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def filterWith[A: ClassManifest](constructA: Int => A) - (p:(T, A) => Boolean): RDD[T] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { - val a = constructA(index) - iter.filter(t => p(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true) + def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = { + def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { + val a = constructA(context.partitionId) + iter.filter(t => p(t, a)) + } + new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true) } /** @@ -541,16 +554,14 @@ abstract class RDD[T: ClassManifest]( * Applies a function f to all elements of this RDD. */ def foreach(f: T => Unit) { - val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) + sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f)) } /** * Applies a function f to each partition of this RDD. */ def foreachPartition(f: Iterator[T] => Unit) { - val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) + sc.runJob(this, (iter: Iterator[T]) => f(iter)) } /** @@ -675,6 +686,8 @@ abstract class RDD[T: ClassManifest]( */ def count(): Long = { sc.runJob(this, (iter: Iterator[T]) => { + // Use a while loop to count the number of elements rather than iter.size because + // iter.size uses a for loop, which is slightly slower in current version of Scala. var result = 0L while (iter.hasNext) { result += 1L diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 9537152335..a5d751a7bd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -56,7 +56,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest]( override def compute(split: Partition, context: TaskContext): Iterator[P] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId - SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics, + SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, SparkEnv.get.serializerManager.get(serializerClass)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 8c1a29dfff..7af4d803e7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -108,7 +108,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM } case ShuffleCoGroupSplitDep(shuffleId) => { val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index, - context.taskMetrics, serializer) + context, serializer) iter.foreach(op) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5c51852985..7fb614402b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -41,11 +41,11 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH * locations to run each task on, based on the current cache status, and passes these to the * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are - * not caused by shuffie file loss are handled by the TaskScheduler, which will retry each task + * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * * THREADING: This class runs all its logic in a single thread executing the run() method, to which - * events are submitted using a synchonized queue (eventQueue). The public API methods, such as + * events are submitted using a synchronized queue (eventQueue). The public API methods, such as * runJob, taskEnded and executorLost, post events asynchronously to this queue. All other methods * should be private. */ @@ -88,7 +88,8 @@ class DAGScheduler( eventQueue.put(ExecutorGained(execId, host)) } - // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. + // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or + // cancellation of the job itself. override def taskSetFailed(taskSet: TaskSet, reason: String) { eventQueue.put(TaskSetFailed(taskSet, reason)) } @@ -104,13 +105,15 @@ class DAGScheduler( private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent] - val nextJobId = new AtomicInteger(0) + private[scheduler] val nextJobId = new AtomicInteger(0) - val nextStageId = new AtomicInteger(0) + def numTotalJobs: Int = nextJobId.get() - val stageIdToStage = new TimeStampedHashMap[Int, Stage] + private val nextStageId = new AtomicInteger(0) - val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] + private val stageIdToStage = new TimeStampedHashMap[Int, Stage] + + private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] @@ -127,6 +130,7 @@ class DAGScheduler( // stray messages to detect. val failedEpoch = new HashMap[String, Long] + // stage id to the active job val idToActiveJob = new HashMap[Int, ActiveJob] val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done @@ -261,32 +265,41 @@ class DAGScheduler( } /** - * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a - * JobWaiter whose getResult() method will return the result of the job when it is complete. - * - * The job is assumed to have at least one partition; zero partition jobs should be handled - * without a JobSubmitted event. + * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object + * can be used to block until the the job finishes executing or can be used to cancel the job. */ - private[scheduler] def prepareJob[T, U: ClassManifest]( - finalRdd: RDD[T], + def submitJob[T, U]( + rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: String, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null) - : (JobSubmitted, JobWaiter[U]) = + properties: Properties = null): JobWaiter[U] = { + val jobId = nextJobId.getAndIncrement() + if (partitions.size == 0) { + return new JobWaiter[U](this, jobId, 0, resultHandler) + } + + // Check to make sure we are not launching a task on a partition that does not exist. + val maxPartitions = rdd.partitions.length + partitions.find(p => p >= maxPartitions).foreach { p => + throw new IllegalArgumentException( + "Attempting to access a non-existent partition: " + p + ". " + + "Total number of partitions: " + maxPartitions) + } + assert(partitions.size > 0) - val waiter = new JobWaiter(partitions.size, resultHandler) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, - properties) - (toSubmit, waiter) + val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) + eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, + waiter, properties)) + waiter } def runJob[T, U: ClassManifest]( - finalRdd: RDD[T], + rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: String, @@ -294,21 +307,7 @@ class DAGScheduler( resultHandler: (Int, U) => Unit, properties: Properties = null) { - if (partitions.size == 0) { - return - } - - // Check to make sure we are not launching a task on a partition that does not exist. - val maxPartitions = finalRdd.partitions.length - partitions.find(p => p >= maxPartitions).foreach { p => - throw new IllegalArgumentException( - "Attempting to access a non-existent partition: " + p + ". " + - "Total number of partitions: " + maxPartitions) - } - - val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob( - finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties) - eventQueue.put(toSubmit) + val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { case JobSucceeded => {} case JobFailed(exception: Exception, _) => @@ -329,19 +328,35 @@ class DAGScheduler( val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray - eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, callSite, listener, properties)) + val jobId = nextJobId.getAndIncrement() + eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, + listener, properties)) listener.awaitResult() // Will throw an exception if the job fails } /** + * Cancel a job that is running or waiting in the queue. + */ + def cancelJob(jobId: Int) { + logInfo("Asked to cancel job " + jobId) + eventQueue.put(JobCancelled(jobId)) + } + + /** + * Cancel all jobs that are running or waiting in the queue. + */ + def cancelAllJobs() { + eventQueue.put(AllJobsCancelled) + } + + /** * Process one event retrieved from the event queue. * Returns true if we should stop the event loop. */ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { event match { - case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) => - val jobId = nextJobId.getAndIncrement() - val finalStage = newStage(finalRDD, None, jobId, Some(callSite)) + case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => + val finalStage = newStage(rdd, None, jobId, Some(callSite)) val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length + @@ -360,6 +375,18 @@ class DAGScheduler( submitStage(finalStage) } + case JobCancelled(jobId) => + // Cancel a job: find all the running stages that are linked to this job, and cancel them. + running.filter(_.jobId == jobId).foreach { stage => + taskSched.cancelTasks(stage.id) + } + + case AllJobsCancelled => + // Cancel all running jobs. + running.foreach { stage => + taskSched.cancelTasks(stage.id) + } + case ExecutorGained(execId, host) => handleExecutorGained(execId, host) @@ -578,6 +605,11 @@ class DAGScheduler( */ private def handleTaskCompletion(event: CompletionEvent) { val task = event.task + + if (!stageIdToStage.contains(task.stageId)) { + // Skip all the actions if the stage has been cancelled. + return + } val stage = stageIdToStage(task.stageId) def markStageAsFinished(stage: Stage) = { @@ -626,7 +658,7 @@ class DAGScheduler( if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) } else { - stage.addOutputLoc(smt.partition, status) + stage.addOutputLoc(smt.partitionId, status) } if (running.contains(stage) && pendingTasks(stage).isEmpty) { markStageAsFinished(stage) @@ -752,14 +784,14 @@ class DAGScheduler( /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set - * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. + * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ private def abortStage(failedStage: Stage, reason: String) { val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq failedStage.completionTime = Some(System.currentTimeMillis()) for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) - val error = new SparkException("Job failed: " + reason) + val error = new SparkException("Job aborted: " + reason) job.listener.jobFailed(error) listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage)))) idToActiveJob -= resultStage.jobId diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 10ff1b4376..ee89bfb38d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -31,9 +31,10 @@ import org.apache.spark.executor.TaskMetrics * submitted) but there is a single "logic" thread that reads these events and takes decisions. * This greatly simplifies synchronization. */ -private[spark] sealed trait DAGSchedulerEvent +private[scheduler] sealed trait DAGSchedulerEvent -private[spark] case class JobSubmitted( +private[scheduler] case class JobSubmitted( + jobId: Int, finalRDD: RDD[_], func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], @@ -43,9 +44,14 @@ private[spark] case class JobSubmitted( properties: Properties = null) extends DAGSchedulerEvent -private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent +private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent -private[spark] case class CompletionEvent( +private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent + +private[scheduler] +case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent + +private[scheduler] case class CompletionEvent( task: Task[_], reason: TaskEndReason, result: Any, @@ -54,10 +60,12 @@ private[spark] case class CompletionEvent( taskMetrics: TaskMetrics) extends DAGSchedulerEvent -private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent +private[scheduler] +case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent -private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent +private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent -private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent +private[scheduler] +case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent -private[spark] case object StopDAGScheduler extends DAGSchedulerEvent +private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 151514896f..7b5c0e29ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -40,7 +40,7 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: Spar }) metricRegistry.register(MetricRegistry.name("job", "allJobs"), new Gauge[Int] { - override def getValue: Int = dagScheduler.nextJobId.get() + override def getValue: Int = dagScheduler.numTotalJobs }) metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] { diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 200d881799..58f238d8cf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -17,48 +17,58 @@ package org.apache.spark.scheduler -import scala.collection.mutable.ArrayBuffer - /** * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their * results to the given handler function. */ -private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit) +private[spark] class JobWaiter[T]( + dagScheduler: DAGScheduler, + jobId: Int, + totalTasks: Int, + resultHandler: (Int, T) => Unit) extends JobListener { private var finishedTasks = 0 - private var jobFinished = false // Is the job as a whole finished (succeeded or failed)? - private var jobResult: JobResult = null // If the job is finished, this will be its result + // Is the job as a whole finished (succeeded or failed)? + private var _jobFinished = totalTasks == 0 - override def taskSucceeded(index: Int, result: Any) { - synchronized { - if (jobFinished) { - throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") - } - resultHandler(index, result.asInstanceOf[T]) - finishedTasks += 1 - if (finishedTasks == totalTasks) { - jobFinished = true - jobResult = JobSucceeded - this.notifyAll() - } - } + def jobFinished = _jobFinished + + // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero + // partition RDDs), we set the jobResult directly to JobSucceeded. + private var jobResult: JobResult = if (jobFinished) JobSucceeded else null + + /** + * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled + * asynchronously. After the low level scheduler cancels all the tasks belonging to this job, it + * will fail this job with a SparkException. + */ + def cancel() { + dagScheduler.cancelJob(jobId) } - override def jobFailed(exception: Exception) { - synchronized { - if (jobFinished) { - throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter") - } - jobFinished = true - jobResult = JobFailed(exception, None) + override def taskSucceeded(index: Int, result: Any): Unit = synchronized { + if (_jobFinished) { + throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") + } + resultHandler(index, result.asInstanceOf[T]) + finishedTasks += 1 + if (finishedTasks == totalTasks) { + _jobFinished = true + jobResult = JobSucceeded this.notifyAll() } } + override def jobFailed(exception: Exception): Unit = synchronized { + _jobFinished = true + jobResult = JobFailed(exception, None) + this.notifyAll() + } + def awaitResult(): JobResult = synchronized { - while (!jobFinished) { + while (!_jobFinished) { this.wait() } return jobResult diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 9eb8d48501..596f9adde9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -43,7 +43,10 @@ private[spark] class Pool( var runningTasks = 0 var priority = 0 - var stageId = 0 + + // A pool's stage id is used to break the tie in scheduling. + var stageId = -1 + var name = poolName var parent: Pool = null diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 6dd422bbf6..310ec62ca8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -38,17 +38,17 @@ private[spark] object ResultTask { synchronized { val old = serializedInfoCache.get(stageId).orNull if (old != null) { - return old + old } else { val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance + val ser = SparkEnv.get.closureSerializer.newInstance() val objOut = ser.serializeStream(new GZIPOutputStream(out)) objOut.writeObject(rdd) objOut.writeObject(func) objOut.close() val bytes = out.toByteArray serializedInfoCache.put(stageId, bytes) - return bytes + bytes } } } @@ -56,11 +56,11 @@ private[spark] object ResultTask { def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = { val loader = Thread.currentThread.getContextClassLoader val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance + val ser = SparkEnv.get.closureSerializer.newInstance() val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] - return (rdd, func) + (rdd, func) } def clearCache() { @@ -71,29 +71,37 @@ private[spark] object ResultTask { } +/** + * A task that sends back the output to the driver application. + * + * See [[org.apache.spark.scheduler.Task]] for more information. + * + * @param stageId id of the stage this task belongs to + * @param rdd input to func + * @param func a function to apply on a partition of the RDD + * @param _partitionId index of the number in the RDD + * @param locs preferred task execution locations for locality scheduling + * @param outputId index of the task in this job (a job can launch tasks on only a subset of the + * input RDD's partitions). + */ private[spark] class ResultTask[T, U]( stageId: Int, var rdd: RDD[T], var func: (TaskContext, Iterator[T]) => U, - var partition: Int, + _partitionId: Int, @transient locs: Seq[TaskLocation], var outputId: Int) - extends Task[U](stageId) with Externalizable { + extends Task[U](stageId, _partitionId) with Externalizable { def this() = this(0, null, null, 0, null, 0) - var split = if (rdd == null) { - null - } else { - rdd.partitions(partition) - } + var split = if (rdd == null) null else rdd.partitions(partitionId) @transient private val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } - override def run(attemptId: Long): U = { - val context = new TaskContext(stageId, partition, attemptId, runningLocally = false) + override def runTask(context: TaskContext): U = { metrics = Some(context.taskMetrics) try { func(context, rdd.iterator(split, context)) @@ -104,17 +112,17 @@ private[spark] class ResultTask[T, U]( override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ResultTask(" + stageId + ", " + partition + ")" + override def toString = "ResultTask(" + stageId + ", " + partitionId + ")" override def writeExternal(out: ObjectOutput) { RDDCheckpointData.synchronized { - split = rdd.partitions(partition) + split = rdd.partitions(partitionId) out.writeInt(stageId) val bytes = ResultTask.serializeInfo( stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) out.writeInt(bytes.length) out.write(bytes) - out.writeInt(partition) + out.writeInt(partitionId) out.writeInt(outputId) out.writeLong(epoch) out.writeObject(split) @@ -129,7 +137,7 @@ private[spark] class ResultTask[T, U]( val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) rdd = rdd_.asInstanceOf[RDD[T]] func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] - partition = in.readInt() + partitionId = in.readInt() outputId = in.readInt() epoch = in.readLong() split = in.readObject().asInstanceOf[Partition] diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 4e25086ec9..356fe56bf3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -30,7 +30,10 @@ import scala.xml.XML * addTaskSetManager: build the leaf nodes(TaskSetManagers) */ private[spark] trait SchedulableBuilder { + def rootPool: Pool + def buildPools() + def addTaskSetManager(manager: Schedulable, properties: Properties) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 3b9d5679fb..802791797a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -53,7 +53,7 @@ private[spark] object ShuffleMapTask { objOut.close() val bytes = out.toByteArray serializedInfoCache.put(stageId, bytes) - return bytes + bytes } } } @@ -66,7 +66,7 @@ private[spark] object ShuffleMapTask { val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] - return (rdd, dep) + (rdd, dep) } } @@ -75,7 +75,7 @@ private[spark] object ShuffleMapTask { val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val objIn = new ObjectInputStream(in) val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap - return (HashMap(set.toSeq: _*)) + HashMap(set.toSeq: _*) } def clearCache() { @@ -85,13 +85,25 @@ private[spark] object ShuffleMapTask { } } +/** + * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner + * specified in the ShuffleDependency). + * + * See [[org.apache.spark.scheduler.Task]] for more information. + * + * @param stageId id of the stage this task belongs to + * @param rdd the final RDD in this stage + * @param dep the ShuffleDependency + * @param _partitionId index of the number in the RDD + * @param locs preferred task execution locations for locality scheduling + */ private[spark] class ShuffleMapTask( stageId: Int, var rdd: RDD[_], var dep: ShuffleDependency[_,_], - var partition: Int, + _partitionId: Int, @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId) + extends Task[MapStatus](stageId, _partitionId) with Externalizable with Logging { @@ -101,16 +113,16 @@ private[spark] class ShuffleMapTask( if (locs == null) Nil else locs.toSet.toSeq } - var split = if (rdd == null) null else rdd.partitions(partition) + var split = if (rdd == null) null else rdd.partitions(partitionId) override def writeExternal(out: ObjectOutput) { RDDCheckpointData.synchronized { - split = rdd.partitions(partition) + split = rdd.partitions(partitionId) out.writeInt(stageId) val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) out.writeInt(bytes.length) out.write(bytes) - out.writeInt(partition) + out.writeInt(partitionId) out.writeLong(epoch) out.writeObject(split) } @@ -124,16 +136,14 @@ private[spark] class ShuffleMapTask( val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) rdd = rdd_ dep = dep_ - partition = in.readInt() + partitionId = in.readInt() epoch = in.readLong() split = in.readObject().asInstanceOf[Partition] } - override def run(attemptId: Long): MapStatus = { + override def runTask(context: TaskContext): MapStatus = { val numOutputSplits = dep.partitioner.numPartitions - - val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false) - metrics = Some(taskContext.taskMetrics) + metrics = Some(context.taskMetrics) val blockManager = SparkEnv.get.blockManager var shuffle: ShuffleBlocks = null @@ -143,10 +153,10 @@ private[spark] class ShuffleMapTask( // Obtain all the block writers for shuffle blocks. val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser) - buckets = shuffle.acquireWriters(partition) + buckets = shuffle.acquireWriters(partitionId) // Write the map output to its associated buckets. - for (elem <- rdd.iterator(split, taskContext)) { + for (elem <- rdd.iterator(split, context)) { val pair = elem.asInstanceOf[Product2[Any, Any]] val bucketId = dep.partitioner.getPartition(pair._1) buckets.writers(bucketId).write(pair) @@ -167,7 +177,7 @@ private[spark] class ShuffleMapTask( shuffleMetrics.shuffleBytesWritten = totalBytes metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) - return new MapStatus(blockManager.blockManagerId, compressedSizes) + new MapStatus(blockManager.blockManagerId, compressedSizes) } catch { case e: Exception => // If there is an exception from running the task, revert the partial writes // and throw the exception upstream to Spark. @@ -181,11 +191,11 @@ private[spark] class ShuffleMapTask( shuffle.releaseWriters(buckets) } // Execute the callbacks on task completion. - taskContext.executeOnCompleteCallbacks() + context.executeOnCompleteCallbacks() } } override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition) + override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 62b521ad45..466baf9913 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -54,7 +54,7 @@ trait SparkListener { /** * Called when a task starts */ - def onTaskStart(taskEnd: SparkListenerTaskStart) { } + def onTaskStart(taskStart: SparkListenerTaskStart) { } /** * Called when a task ends diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 598d91752a..1fe0d0e4e2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -17,25 +17,74 @@ package org.apache.spark.scheduler -import org.apache.spark.serializer.SerializerInstance import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import org.apache.spark.util.ByteBufferInputStream + import scala.collection.mutable.HashMap + +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream + +import org.apache.spark.TaskContext import org.apache.spark.executor.TaskMetrics +import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.util.ByteBufferInputStream + /** - * A task to execute on a worker node. + * A unit of execution. We have two kinds of Task's in Spark: + * - [[org.apache.spark.scheduler.ShuffleMapTask]] + * - [[org.apache.spark.scheduler.ResultTask]] + * + * A Spark job consists of one or more stages. The very last stage in a job consists of multiple + * ResultTask's, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task + * and sends the task output back to the driver application. A ShuffleMapTask executes the task + * and divides the task output to multiple buckets (based on the task's partitioner). + * + * @param stageId id of the stage this task belongs to + * @param partitionId index of the number in the RDD */ -private[spark] abstract class Task[T](val stageId: Int) extends Serializable { - def run(attemptId: Long): T +private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { + + def run(attemptId: Long): T = { + context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) + if (_killed) { + kill() + } + runTask(context) + } + + def runTask(context: TaskContext): T + def preferredLocations: Seq[TaskLocation] = Nil - var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler. + // Map output tracker epoch. Will be set by TaskScheduler. + var epoch: Long = -1 var metrics: Option[TaskMetrics] = None + // Task context, to be initialized in run(). + @transient protected var context: TaskContext = _ + + // A flag to indicate whether the task is killed. This is used in case context is not yet + // initialized when kill() is invoked. + @volatile @transient private var _killed = false + + /** + * Whether the task has been killed. + */ + def killed: Boolean = _killed + + /** + * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark + * code and user code to properly handle the flag. This function should be idempotent so it can + * be called multiple times. + */ + def kill() { + _killed = true + if (context != null) { + context.interrupted = true + } + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 7c2a9f03d7..6a51efe8d6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -45,6 +45,9 @@ private[spark] trait TaskScheduler { // Submit a sequence of tasks to run. def submitTasks(taskSet: TaskSet): Unit + // Cancel a stage. + def cancelTasks(stageId: Int) + // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called. def setListener(listener: TaskSchedulerListener): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index c3ad325156..03bf760837 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -31,5 +31,9 @@ private[spark] class TaskSet( val properties: Properties) { val id: String = stageId + "." + attempt + def kill() { + tasks.foreach(_.kill()) + } + override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index 1a844b7e7e..7a72ff0474 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -17,7 +17,6 @@ package org.apache.spark.scheduler.cluster -import java.lang.{Boolean => JBoolean} import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong import java.util.{TimerTask, Timer} @@ -79,12 +78,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext) private val executorIdToHost = new HashMap[String, String] - // JAR server, if any JARs were added by the user to the SparkContext - var jarServer: HttpServer = null - - // URIs of JARs to pass to executor - var jarUris: String = "" - // Listener object to pass upcalls into var listener: TaskSchedulerListener = null @@ -171,8 +164,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext) backend.reviveOffers() } - def taskSetFinished(manager: TaskSetManager) { - this.synchronized { + override def cancelTasks(stageId: Int): Unit = synchronized { + logInfo("Cancelling stage " + stageId) + activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the stage. + val taskIds = taskSetTaskIds(tsm.taskSet.id) + if (taskIds.size > 0) { + taskIds.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId) + } + } + tsm.error("Stage %d was cancelled".format(stageId)) + } + } + + def taskSetFinished(manager: TaskSetManager): Unit = synchronized { + // Check to see if the given task set has been removed. This is possible in the case of + // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has + // more than one running tasks). + if (activeTaskSets.contains(manager.taskSet.id)) { activeTaskSets -= manager.taskSet.id manager.parent.removeSchedulable(manager) logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) @@ -334,9 +350,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext) if (backend != null) { backend.stop() } - if (jarServer != null) { - jarServer.stop() - } if (taskResultGetter != null) { taskResultGetter.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index 936167c13f..7bd3499300 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -17,18 +17,16 @@ package org.apache.spark.scheduler.cluster -import java.nio.ByteBuffer -import java.util.{Arrays, NoSuchElementException} +import java.util.Arrays import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import scala.math.max import scala.math.min -import scala.Some import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv, - SparkException, Success, TaskEndReason, TaskResultLost, TaskState} + Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler._ import org.apache.spark.util.{SystemClock, Clock} @@ -458,54 +456,57 @@ private[spark] class ClusterTaskSetManager( val index = info.index info.markFailed() if (!successful(index)) { - logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) + logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) copiesRunning(index) -= 1 // Check if the problem is a map output fetch failure. In that case, this // task will never succeed on any node, so tell the scheduler about it. reason.foreach { - _ match { - case fetchFailed: FetchFailed => - logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) - successful(index) = true - tasksSuccessful += 1 - sched.taskSetFinished(this) - removeAllRunningTasks() - return - - case ef: ExceptionFailure => - sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) - val key = ef.description - val now = clock.getTime() - val (printFull, dupCount) = { - if (recentExceptions.contains(key)) { - val (dupCount, printTime) = recentExceptions(key) - if (now - printTime > EXCEPTION_PRINT_INTERVAL) { - recentExceptions(key) = (0, now) - (true, 0) - } else { - recentExceptions(key) = (dupCount + 1, printTime) - (false, dupCount + 1) - } - } else { + case fetchFailed: FetchFailed => + logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress) + sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) + successful(index) = true + tasksSuccessful += 1 + sched.taskSetFinished(this) + removeAllRunningTasks() + return + + case TaskKilled => + logWarning("Task %d was killed.".format(tid)) + sched.listener.taskEnded(tasks(index), reason.get, null, null, info, null) + return + + case ef: ExceptionFailure => + sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) + val key = ef.description + val now = clock.getTime() + val (printFull, dupCount) = { + if (recentExceptions.contains(key)) { + val (dupCount, printTime) = recentExceptions(key) + if (now - printTime > EXCEPTION_PRINT_INTERVAL) { recentExceptions(key) = (0, now) (true, 0) + } else { + recentExceptions(key) = (dupCount + 1, printTime) + (false, dupCount + 1) } - } - if (printFull) { - val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s\n%s".format( - ef.className, ef.description, locs.mkString("\n"))) } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + recentExceptions(key) = (0, now) + (true, 0) } + } + if (printFull) { + val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logWarning("Loss was due to %s\n%s\n%s".format( + ef.className, ef.description, locs.mkString("\n"))) + } else { + logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + } - case TaskResultLost => - logInfo("Lost result for TID %s on host %s".format(tid, info.host)) - sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null) + case TaskResultLost => + logWarning("Lost result for TID %s on host %s".format(tid, info.host)) + sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null) - case _ => {} - } + case _ => {} } // On non-fetch failures, re-enqueue the task as pending for a max number of retries addPendingTask(index) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala index d57eb3276f..5367218faa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.{SparkContext} +import org.apache.spark.SparkContext /** * A backend interface for cluster scheduling systems that allows plugging in different ones under @@ -30,8 +30,8 @@ private[spark] trait SchedulerBackend { def reviveOffers(): Unit def defaultParallelism(): Int + def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException + // Memory used by each executor (in megabytes) protected val executorMemory: Int = SparkContext.executorMemoryRequested - - // TODO: Probably want to add a killTask too } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala index c0b836bf1a..12b2fd01c0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -31,6 +31,8 @@ private[spark] object StandaloneClusterMessages { // Driver to executors case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage + case class KillTask(taskId: Long, executor: String) extends StandaloneClusterMessage + case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) extends StandaloneClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index f3aeea43d5..08ee2182a2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -91,6 +91,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor case ReviveOffers => makeOffers() + case KillTask(taskId, executorId) => + executorActor(executorId) ! KillTask(taskId, executorId) + case StopDriver => sender ! true context.stop(self) @@ -180,6 +183,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor driverActor ! ReviveOffers } + override def killTask(taskId: Long, executorId: String) { + driverActor ! KillTask(taskId, executorId) + } + override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism")) .map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index 4d1bb1c639..b445260d1b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -17,23 +17,19 @@ package org.apache.spark.scheduler.local -import java.io.File -import java.lang.management.ManagementFactory -import java.util.concurrent.atomic.AtomicInteger import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} + +import akka.actor._ import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.ExecutorURLClassLoader +import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode -import akka.actor._ -import org.apache.spark.util.Utils + /** * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally @@ -41,43 +37,50 @@ import org.apache.spark.util.Utils * testing fault recovery. */ -private[spark] +private[local] case class LocalReviveOffers() -private[spark] +private[local] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) +private[local] +case class KillTask(taskId: Long) + private[spark] -class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging { +class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int) + extends Actor with Logging { + + val executor = new Executor("localhost", "localhost", Seq.empty, isLocal = true) def receive = { case LocalReviveOffers => launchTask(localScheduler.resourceOffer(freeCores)) + case LocalStatusUpdate(taskId, state, serializeData) => - freeCores += 1 - localScheduler.statusUpdate(taskId, state, serializeData) - launchTask(localScheduler.resourceOffer(freeCores)) + if (TaskState.isFinished(state)) { + freeCores += 1 + launchTask(localScheduler.resourceOffer(freeCores)) + } + + case KillTask(taskId) => + executor.killTask(taskId) } - def launchTask(tasks : Seq[TaskDescription]) { + private def launchTask(tasks: Seq[TaskDescription]) { for (task <- tasks) { freeCores -= 1 - localScheduler.threadPool.submit(new Runnable { - def run() { - localScheduler.runTask(task.taskId, task.serializedTask) - } - }) + executor.launchTask(localScheduler, task.taskId, task.serializedTask) } } } private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) extends TaskScheduler + with ExecutorBackend with Logging { - var attemptId = new AtomicInteger(0) - var threadPool = Utils.newDaemonFixedThreadPool(threads) val env = SparkEnv.get + val attemptId = new AtomicInteger var listener: TaskSchedulerListener = null // Application dependencies (added through SparkContext) that we've fetched so far on this node. @@ -85,8 +88,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() val currentJars: HashMap[String, Long] = new HashMap[String, Long]() - val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader) - var schedulableBuilder: SchedulableBuilder = null var rootPool: Pool = null val schedulingMode: SchedulingMode = SchedulingMode.withName( @@ -127,6 +128,26 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } } + override def cancelTasks(stageId: Int): Unit = synchronized { + logInfo("Cancelling stage " + stageId) + logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId)) + activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the stage. + val taskIds = taskSetTaskIds(tsm.taskSet.id) + if (taskIds.size > 0) { + taskIds.foreach { tid => + localActor ! KillTask(tid) + } + } + tsm.error("Stage %d was cancelled".format(stageId)) + } + } + def resourceOffer(freeCores: Int): Seq[TaskDescription] = { synchronized { var freeCpuCores = freeCores @@ -166,107 +187,32 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } } - def runTask(taskId: Long, bytes: ByteBuffer) { - logInfo("Running " + taskId) - val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) - // Set the Spark execution environment for the worker thread - SparkEnv.set(env) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objectSer = SparkEnv.get.serializer.newInstance() - var attemptedTask: Option[Task[_]] = None - val start = System.currentTimeMillis() - var taskStart: Long = 0 - def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum - val startGCTime = getTotalGCTime - - try { - Accumulators.clear() - Thread.currentThread().setContextClassLoader(classLoader) - - // Serialize and deserialize the task so that accumulators are changed to thread-local ones; - // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) - updateDependencies(taskFiles, taskJars) // Download any files added with addFile - val deserializedTask = ser.deserialize[Task[_]]( - taskBytes, Thread.currentThread.getContextClassLoader) - attemptedTask = Some(deserializedTask) - val deserTime = System.currentTimeMillis() - start - taskStart = System.currentTimeMillis() - - // Run it - val result: Any = deserializedTask.run(taskId) - - // Serialize and deserialize the result to emulate what the Mesos - // executor does. This is useful to catch serialization errors early - // on in development (so when users move their local Spark programs - // to the cluster, they don't get surprised by serialization errors). - val serResult = objectSer.serialize(result) - deserializedTask.metrics.get.resultSize = serResult.limit() - val resultToReturn = objectSer.deserialize[Any](serResult) - val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( - ser.serialize(Accumulators.values)) - val serviceTime = System.currentTimeMillis() - taskStart - logInfo("Finished " + taskId) - deserializedTask.metrics.get.executorRunTime = serviceTime.toInt - deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime - deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt - val taskResult = new DirectTaskResult( - result, accumUpdates, deserializedTask.metrics.getOrElse(null)) - val serializedResult = ser.serialize(taskResult) - localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult) - } catch { - case t: Throwable => { - val serviceTime = System.currentTimeMillis() - taskStart - val metrics = attemptedTask.flatMap(t => t.metrics) - for (m <- metrics) { - m.executorRunTime = serviceTime.toInt - m.jvmGCTime = getTotalGCTime - startGCTime - } - val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics) - localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure)) - } - } - } - - /** - * Download any missing dependencies if we receive a new set of files and JARs from the - * SparkContext. Also adds any new JARs we fetched to the class loader. - */ - private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - synchronized { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentFiles(name) = timestamp - } - - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!classLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - classLoader.addURL(url) + override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { + if (TaskState.isFinished(state)) { + synchronized { + taskIdToTaskSetId.get(taskId) match { + case Some(taskSetId) => + val taskSetManager = activeTaskSets(taskSetId) + taskSetTaskIds(taskSetId) -= taskId + + state match { + case TaskState.FINISHED => + taskSetManager.taskEnded(taskId, state, serializedData) + case TaskState.FAILED => + taskSetManager.taskFailed(taskId, state, serializedData) + case TaskState.KILLED => + taskSetManager.error("Task %d was killed".format(taskId)) + case _ => {} + } + case None => + logInfo("Ignoring update from TID " + taskId + " because its task set is gone") } } + localActor ! LocalStatusUpdate(taskId, state, serializedData) } } - def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) { - synchronized { - val taskSetId = taskIdToTaskSetId(taskId) - val taskSetManager = activeTaskSets(taskSetId) - taskSetTaskIds(taskSetId) -= taskId - taskSetManager.statusUpdate(taskId, state, serializedData) - } - } - - override def stop() { - threadPool.shutdownNow() + override def stop() { } override def defaultParallelism() = threads diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala index c2e2399ccb..f72e77d40f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala @@ -132,17 +132,6 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas return None } - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - SparkEnv.set(env) - state match { - case TaskState.FINISHED => - taskEnded(tid, state, serializedData) - case TaskState.FAILED => - taskFailed(tid, state, serializedData) - case _ => {} - } - } - def taskStarted(task: Task[_], info: TaskInfo) { sched.listener.taskStarted(task, info) } @@ -195,5 +184,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas } override def error(message: String) { + sched.listener.taskSetFailed(taskSet, message) + sched.taskSetFinished(this) } } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index ced036c58d..ea936e815b 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -58,7 +58,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, runningLocally = false, null) + val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false, + taskMetrics = null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } @@ -70,7 +71,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, runningLocally = false, null) + val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false, + taskMetrics = null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -83,7 +85,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, runningLocally = true, null) + val context = new TaskContext(0, 0, 0, runningLocally = true, interrupted = false, + taskMetrics = null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 7ca5f16202..f26c44d3e7 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -62,8 +62,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { testCheckpointing(_.sample(false, 0.5, 0)) testCheckpointing(_.glom()) testCheckpointing(_.mapPartitions(_.map(_.toString))) - testCheckpointing(r => new MapPartitionsWithIndexRDD(r, - (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false )) + testCheckpointing(r => new MapPartitionsWithContextRDD(r, + (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false )) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) testCheckpointing(_.pipe(Seq("cat"))) diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 591c1d498d..7b0bb89ab2 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable { @Test public void iterator() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0, false, null); + TaskContext context = new TaskContext(0, 0, 0, false, false, null); Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala new file mode 100644 index 0000000000..a192651491 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.concurrent.Semaphore + +import scala.concurrent.future +import scala.concurrent.ExecutionContext.Implicits.global + +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.matchers.ShouldMatchers + +import org.apache.spark.SparkContext._ +import org.apache.spark.scheduler.{SparkListenerTaskStart, SparkListener} + + +/** + * Test suite for cancelling running jobs. We run the cancellation tasks for single job action + * (e.g. count) as well as multi-job action (e.g. take). We test the local and cluster schedulers + * in both FIFO and fair scheduling modes. + */ +class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAfter + with LocalSparkContext { + + override def afterEach() { + super.afterEach() + resetSparkContext() + System.clearProperty("spark.scheduler.mode") + } + + test("local mode, FIFO scheduler") { + System.setProperty("spark.scheduler.mode", "FIFO") + sc = new SparkContext("local[2]", "test") + testCount() + testTake() + // Make sure we can still launch tasks. + assert(sc.parallelize(1 to 10, 2).count === 10) + } + + test("local mode, fair scheduler") { + System.setProperty("spark.scheduler.mode", "FAIR") + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.scheduler.allocation.file", xmlPath) + sc = new SparkContext("local[2]", "test") + testCount() + testTake() + // Make sure we can still launch tasks. + assert(sc.parallelize(1 to 10, 2).count === 10) + } + + test("cluster mode, FIFO scheduler") { + System.setProperty("spark.scheduler.mode", "FIFO") + sc = new SparkContext("local-cluster[2,1,512]", "test") + testCount() + testTake() + // Make sure we can still launch tasks. + assert(sc.parallelize(1 to 10, 2).count === 10) + } + + test("cluster mode, fair scheduler") { + System.setProperty("spark.scheduler.mode", "FAIR") + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.scheduler.allocation.file", xmlPath) + sc = new SparkContext("local-cluster[2,1,512]", "test") + testCount() + testTake() + // Make sure we can still launch tasks. + assert(sc.parallelize(1 to 10, 2).count === 10) + } + + test("two jobs sharing the same stage") { + // sem1: make sure cancel is issued after some tasks are launched + // sem2: make sure the first stage is not finished until cancel is issued + val sem1 = new Semaphore(0) + val sem2 = new Semaphore(0) + + sc = new SparkContext("local[2]", "test") + sc.dagScheduler.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem1.release() + } + }) + + // Create two actions that would share the some stages. + val rdd = sc.parallelize(1 to 10, 2).map { i => + sem2.acquire() + (i, i) + }.reduceByKey(_+_) + val f1 = rdd.collectAsync() + val f2 = rdd.countAsync() + + // Kill one of the action. + future { + sem1.acquire() + f1.cancel() + sem2.release(10) + } + + // Expect both to fail now. + // TODO: update this test when we change Spark so cancelling f1 wouldn't affect f2. + intercept[SparkException] { f1.get() } + intercept[SparkException] { f2.get() } + } + + def testCount() { + // Cancel before launching any tasks + { + val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync() + future { f.cancel() } + val e = intercept[SparkException] { f.get() } + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) + } + + // Cancel after some tasks have been launched + { + // Add a listener to release the semaphore once any tasks are launched. + val sem = new Semaphore(0) + sc.dagScheduler.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + + val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync() + future { + // Wait until some tasks were launched before we cancel the job. + sem.acquire() + f.cancel() + } + val e = intercept[SparkException] { f.get() } + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) + } + } + + def testTake() { + // Cancel before launching any tasks + { + val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000) + future { f.cancel() } + val e = intercept[SparkException] { f.get() } + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) + } + + // Cancel after some tasks have been launched + { + // Add a listener to release the semaphore once any tasks are launched. + val sem = new Semaphore(0) + sc.dagScheduler.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000) + future { + sem.acquire() + f.cancel() + } + val e = intercept[SparkException] { f.get() } + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala new file mode 100644 index 0000000000..da032b17d9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.util.concurrent.Semaphore + +import scala.concurrent.ExecutionContext.Implicits.global + +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkContext._ +import org.apache.spark.{SparkContext, SparkException, LocalSparkContext} + + +class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with Timeouts { + + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local[2]", "test") + } + + override def afterAll() { + LocalSparkContext.stop(sc) + sc = null + } + + lazy val zeroPartRdd = new EmptyRDD[Int](sc) + + test("countAsync") { + assert(zeroPartRdd.countAsync().get() === 0) + assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000) + } + + test("collectAsync") { + assert(zeroPartRdd.collectAsync().get() === Seq.empty) + + val collected = sc.parallelize(1 to 1000, 3).collectAsync().get() + assert(collected === (1 to 1000)) + } + + test("foreachAsync") { + zeroPartRdd.foreachAsync(i => Unit).get() + + val accum = sc.accumulator(0) + sc.parallelize(1 to 1000, 3).foreachAsync { i => + accum += 1 + }.get() + assert(accum.value === 1000) + } + + test("foreachPartitionAsync") { + zeroPartRdd.foreachPartitionAsync(iter => Unit).get() + + val accum = sc.accumulator(0) + sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter => + accum += 1 + }.get() + assert(accum.value === 9) + } + + test("takeAsync") { + def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) { + val expected = input.take(num) + val saw = rdd.takeAsync(num).get() + assert(saw == expected, "incorrect result for rdd with %d partitions (expected %s, saw %s)" + .format(rdd.partitions.size, expected, saw)) + } + val input = Range(1, 1000) + + var rdd = sc.parallelize(input, 1) + for (num <- Seq(0, 1, 999, 1000)) { + testTake(rdd, input, num) + } + + rdd = sc.parallelize(input, 2) + for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) { + testTake(rdd, input, num) + } + + rdd = sc.parallelize(input, 100) + for (num <- Seq(0, 1, 500, 501, 999, 1000)) { + testTake(rdd, input, num) + } + + rdd = sc.parallelize(input, 1000) + for (num <- Seq(0, 1, 3, 999, 1000)) { + testTake(rdd, input, num) + } + } + + /** + * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case + * of a successful job execution. + */ + test("async success handling") { + val f = sc.parallelize(1 to 10, 2).countAsync() + + // Use a semaphore to make sure onSuccess and onComplete's success path will be called. + // If not, the test will hang. + val sem = new Semaphore(0) + + f.onComplete { + case scala.util.Success(res) => + sem.release() + case scala.util.Failure(e) => + info("Should not have reached this code path (onComplete matching Failure)") + throw new Exception("Task should succeed") + } + f.onSuccess { case a: Any => + sem.release() + } + f.onFailure { case t => + info("Should not have reached this code path (onFailure)") + throw new Exception("Task should succeed") + } + assert(f.get() === 10) + + failAfter(10 seconds) { + sem.acquire(2) + } + } + + /** + * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case + * of a failed job execution. + */ + test("async failure handling") { + val f = sc.parallelize(1 to 10, 2).map { i => + throw new Exception("intentional"); i + }.countAsync() + + // Use a semaphore to make sure onFailure and onComplete's failure path will be called. + // If not, the test will hang. + val sem = new Semaphore(0) + + f.onComplete { + case scala.util.Success(res) => + info("Should not have reached this code path (onComplete matching Success)") + throw new Exception("Task should fail") + case scala.util.Failure(e) => + sem.release() + } + f.onSuccess { case a: Any => + info("Should not have reached this code path (onSuccess)") + throw new Exception("Task should fail") + } + f.onFailure { case t => + sem.release() + } + intercept[SparkException] { + f.get() + } + + failAfter(10 seconds) { + sem.acquire(2) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 31f97fc139..57d3382ed0 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -106,7 +106,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } } visit(sums) - assert(deps.size === 2) // ShuffledRDD, ParallelCollection + assert(deps.size === 2) // ShuffledRDD, ParallelCollection. } test("join") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 3952ee9264..838179c6b5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -24,15 +24,14 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.LocalSparkContext import org.apache.spark.MapOutputTracker -import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext import org.apache.spark.Partition import org.apache.spark.TaskContext import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency} import org.apache.spark.{FetchFailed, Success, TaskEndReason} -import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} - +import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} /** * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler @@ -60,6 +59,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) taskSets += taskSet } + override def cancelTasks(stageId: Int) {} override def setListener(listener: TaskSchedulerListener) = {} override def defaultParallelism() = 2 } @@ -181,7 +181,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, allowLocal: Boolean = false, listener: JobListener = listener) { - runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener)) + val jobId = scheduler.nextJobId.getAndIncrement() + runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, null, listener)) } /** Sends TaskSetFailed to the scheduler. */ @@ -215,7 +216,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont override def getPreferredLocations(split: Partition) = Nil override def toString = "DAGSchedulerSuite Local RDD" } - runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener)) + val jobId = scheduler.nextJobId.getAndIncrement() + runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener)) assert(results === Map(0 -> 42)) } @@ -242,7 +244,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont test("trivial job failure") { submit(makeRdd(1, Nil), Array(0)) failed(taskSets(0), "some failure") - assert(failure.getMessage === "Job failed: some failure") + assert(failure.getMessage === "Job aborted: some failure") } test("run trivial shuffle") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala index 2f12aaed18..0f01515179 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala @@ -17,10 +17,11 @@ package org.apache.spark.scheduler.cluster +import org.apache.spark.TaskContext import org.apache.spark.scheduler.{TaskLocation, Task} -class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) { - override def run(attemptId: Long): Int = 0 +class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) { + override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs } diff --git a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala index af76c843e8..1e676c1719 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.scheduler.local -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - -import org.apache.spark._ -import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.{ConcurrentMap, HashMap} import java.util.concurrent.Semaphore import java.util.concurrent.CountDownLatch -import java.util.Properties + +import scala.collection.mutable.HashMap + +import org.scalatest.{BeforeAndAfterEach, FunSuite} + +import org.apache.spark._ + class Lock() { var finished = false @@ -63,7 +61,12 @@ object TaskThreadInfo { * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue, * thus it will be scheduled later when cluster has free cpu cores. */ -class LocalSchedulerSuite extends FunSuite with LocalSparkContext { +class LocalSchedulerSuite extends FunSuite with LocalSparkContext with BeforeAndAfterEach { + + override def afterEach() { + super.afterEach() + System.clearProperty("spark.scheduler.mode") + } def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) { @@ -148,12 +151,13 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext { } test("Local fair scheduler end-to-end test") { - sc = new SparkContext("local[8]", "LocalSchedulerSuite") - val sem = new Semaphore(0) System.setProperty("spark.scheduler.mode", "FAIR") val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() System.setProperty("spark.scheduler.allocation.file", xmlPath) + sc = new SparkContext("local[8]", "LocalSchedulerSuite") + val sem = new Semaphore(0) + createThread(10,"1",sc,sem) TaskThreadInfo.threadToStarted(10).await() createThread(20,"2",sc,sem) |