aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org')
-rw-r--r--core/src/main/scala/org/apache/spark/FutureAction.scala164
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala48
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala48
4 files changed, 112 insertions, 156 deletions
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 48792a9581..2a8220ff40 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -20,13 +20,15 @@ package org.apache.spark
import java.util.Collections
import java.util.concurrent.TimeUnit
+import scala.concurrent._
+import scala.concurrent.duration.Duration
+import scala.util.Try
+
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaFutureAction
import org.apache.spark.rdd.RDD
-import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter}
+import org.apache.spark.scheduler.JobWaiter
-import scala.concurrent._
-import scala.concurrent.duration.Duration
-import scala.util.{Failure, Try}
/**
* A future for the result of an action to support cancellation. This is an extension of the
@@ -105,6 +107,7 @@ trait FutureAction[T] extends Future[T] {
* A [[FutureAction]] holding the result of an action that triggers a single job. Examples include
* count, collect, reduce.
*/
+@DeveloperApi
class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
extends FutureAction[T] {
@@ -116,142 +119,96 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
}
override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
- if (!atMost.isFinite()) {
- awaitResult()
- } else jobWaiter.synchronized {
- val finishTime = System.currentTimeMillis() + atMost.toMillis
- while (!isCompleted) {
- val time = System.currentTimeMillis()
- if (time >= finishTime) {
- throw new TimeoutException
- } else {
- jobWaiter.wait(finishTime - time)
- }
- }
- }
+ jobWaiter.completionFuture.ready(atMost)
this
}
@throws(classOf[Exception])
override def result(atMost: Duration)(implicit permit: CanAwait): T = {
- ready(atMost)(permit)
- awaitResult() match {
- case scala.util.Success(res) => res
- case scala.util.Failure(e) => throw e
- }
+ jobWaiter.completionFuture.ready(atMost)
+ assert(value.isDefined, "Future has not completed properly")
+ value.get.get
}
override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) {
- executor.execute(new Runnable {
- override def run() {
- func(awaitResult())
- }
- })
+ jobWaiter.completionFuture onComplete {_ => func(value.get)}
}
override def isCompleted: Boolean = jobWaiter.jobFinished
override def isCancelled: Boolean = _cancelled
- override def value: Option[Try[T]] = {
- if (jobWaiter.jobFinished) {
- Some(awaitResult())
- } else {
- None
- }
- }
-
- private def awaitResult(): Try[T] = {
- jobWaiter.awaitResult() match {
- case JobSucceeded => scala.util.Success(resultFunc)
- case JobFailed(e: Exception) => scala.util.Failure(e)
- }
- }
+ override def value: Option[Try[T]] =
+ jobWaiter.completionFuture.value.map {res => res.map(_ => resultFunc)}
def jobIds: Seq[Int] = Seq(jobWaiter.jobId)
}
/**
+ * Handle via which a "run" function passed to a [[ComplexFutureAction]]
+ * can submit jobs for execution.
+ */
+@DeveloperApi
+trait JobSubmitter {
+ /**
+ * Submit a job for execution and return a FutureAction holding the result.
+ * This is a wrapper around the same functionality provided by SparkContext
+ * to enable cancellation.
+ */
+ def submitJob[T, U, R](
+ rdd: RDD[T],
+ processPartition: Iterator[T] => U,
+ partitions: Seq[Int],
+ resultHandler: (Int, U) => Unit,
+ resultFunc: => R): FutureAction[R]
+}
+
+
+/**
* A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take,
- * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the
- * action thread if it is being blocked by a job.
+ * takeSample. Cancellation works by setting the cancelled flag to true and cancelling any pending
+ * jobs.
*/
-class ComplexFutureAction[T] extends FutureAction[T] {
+@DeveloperApi
+class ComplexFutureAction[T](run : JobSubmitter => Future[T])
+ extends FutureAction[T] { self =>
- // Pointer to the thread that is executing the action. It is set when the action is run.
- @volatile private var thread: Thread = _
+ @volatile private var _cancelled = false
- // A flag indicating whether the future has been cancelled. This is used in case the future
- // is cancelled before the action was even run (and thus we have no thread to interrupt).
- @volatile private var _cancelled: Boolean = false
-
- @volatile private var jobs: Seq[Int] = Nil
+ @volatile private var subActions: List[FutureAction[_]] = Nil
// A promise used to signal the future.
- private val p = promise[T]()
+ private val p = Promise[T]().tryCompleteWith(run(jobSubmitter))
- override def cancel(): Unit = this.synchronized {
+ override def cancel(): Unit = synchronized {
_cancelled = true
- if (thread != null) {
- thread.interrupt()
- }
- }
-
- /**
- * Executes some action enclosed in the closure. To properly enable cancellation, the closure
- * should use runJob implementation in this promise. See takeAsync for example.
- */
- def run(func: => T)(implicit executor: ExecutionContext): this.type = {
- scala.concurrent.future {
- thread = Thread.currentThread
- try {
- p.success(func)
- } catch {
- case e: Exception => p.failure(e)
- } finally {
- // This lock guarantees when calling `thread.interrupt()` in `cancel`,
- // thread won't be set to null.
- ComplexFutureAction.this.synchronized {
- thread = null
- }
- }
- }
- this
+ p.tryFailure(new SparkException("Action has been cancelled"))
+ subActions.foreach(_.cancel())
}
- /**
- * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext
- * to enable cancellation.
- */
- def runJob[T, U, R](
+ private def jobSubmitter = new JobSubmitter {
+ def submitJob[T, U, R](
rdd: RDD[T],
processPartition: Iterator[T] => U,
partitions: Seq[Int],
resultHandler: (Int, U) => Unit,
- resultFunc: => R) {
- // If the action hasn't been cancelled yet, submit the job. The check and the submitJob
- // command need to be in an atomic block.
- val job = this.synchronized {
+ resultFunc: => R): FutureAction[R] = self.synchronized {
+ // If the action hasn't been cancelled yet, submit the job. The check and the submitJob
+ // command need to be in an atomic block.
if (!isCancelled) {
- rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc)
+ val job = rdd.context.submitJob(
+ rdd,
+ processPartition,
+ partitions,
+ resultHandler,
+ resultFunc)
+ subActions = job :: subActions
+ job
} else {
throw new SparkException("Action has been cancelled")
}
}
-
- this.jobs = jobs ++ job.jobIds
-
- // Wait for the job to complete. If the action is cancelled (with an interrupt),
- // cancel the job and stop the execution. This is not in a synchronized block because
- // Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
- try {
- Await.ready(job, Duration.Inf)
- } catch {
- case e: InterruptedException =>
- job.cancel()
- throw new SparkException("Action has been cancelled")
- }
}
override def isCancelled: Boolean = _cancelled
@@ -276,10 +233,11 @@ class ComplexFutureAction[T] extends FutureAction[T] {
override def value: Option[Try[T]] = p.future.value
- def jobIds: Seq[Int] = jobs
+ def jobIds: Seq[Int] = subActions.flatMap(_.jobIds)
}
+
private[spark]
class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S => T)
extends JavaFutureAction[T] {
@@ -303,7 +261,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S
Await.ready(futureAction, timeout)
futureAction.value.get match {
case scala.util.Success(value) => converter(value)
- case Failure(exception) =>
+ case scala.util.Failure(exception) =>
if (isCancelled) {
throw new CancellationException("Job cancelled").initCause(exception)
} else {
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index d5e853613b..14f541f937 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -19,13 +19,12 @@ package org.apache.spark.rdd
import java.util.concurrent.atomic.AtomicLong
-import org.apache.spark.util.ThreadUtils
-
import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.ExecutionContext
+import scala.concurrent.{Future, ExecutionContext}
import scala.reflect.ClassTag
-import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
+import org.apache.spark.{JobSubmitter, ComplexFutureAction, FutureAction, Logging}
+import org.apache.spark.util.ThreadUtils
/**
* A set of asynchronous RDD actions available through an implicit conversion.
@@ -65,17 +64,23 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
* Returns a future for retrieving the first num elements of the RDD.
*/
def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope {
- val f = new ComplexFutureAction[Seq[T]]
val callSite = self.context.getCallSite
-
- f.run {
- // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which
- // is a cached thread pool.
- val results = new ArrayBuffer[T](num)
- val totalParts = self.partitions.length
- var partsScanned = 0
- self.context.setCallSite(callSite)
- while (results.size < num && partsScanned < totalParts) {
+ val localProperties = self.context.getLocalProperties
+ // Cached thread pool to handle aggregation of subtasks.
+ implicit val executionContext = AsyncRDDActions.futureExecutionContext
+ val results = new ArrayBuffer[T](num)
+ val totalParts = self.partitions.length
+
+ /*
+ Recursively triggers jobs to scan partitions until either the requested
+ number of elements are retrieved, or the partitions to scan are exhausted.
+ This implementation is non-blocking, asynchronously handling the
+ results of each job and triggering the next job using callbacks on futures.
+ */
+ def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] =
+ if (results.size >= num || partsScanned >= totalParts) {
+ Future.successful(results.toSeq)
+ } else {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
@@ -97,19 +102,20 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val buf = new Array[Array[T]](p.size)
- f.runJob(self,
+ self.context.setCallSite(callSite)
+ self.context.setLocalProperties(localProperties)
+ val job = jobSubmitter.submitJob(self,
(it: Iterator[T]) => it.take(left).toArray,
p,
(index: Int, data: Array[T]) => buf(index) = data,
Unit)
-
- buf.foreach(results ++= _.take(num - results.size))
- partsScanned += numPartsToTry
+ job.flatMap {_ =>
+ buf.foreach(results ++= _.take(num - results.size))
+ continue(partsScanned + numPartsToTry)
+ }
}
- results.toSeq
- }(AsyncRDDActions.futureExecutionContext)
- f
+ new ComplexFutureAction[Seq[T]](continue(0)(_))
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 5582720bbc..8d0e0c8624 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.Map
import scala.collection.mutable.{HashMap, HashSet, Stack}
+import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.existentials
import scala.language.postfixOps
@@ -610,11 +611,12 @@ class DAGScheduler(
properties: Properties): Unit = {
val start = System.nanoTime
val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
- waiter.awaitResult() match {
- case JobSucceeded =>
+ Await.ready(waiter.completionFuture, atMost = Duration.Inf)
+ waiter.completionFuture.value.get match {
+ case scala.util.Success(_) =>
logInfo("Job %d finished: %s, took %f s".format
(waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
- case JobFailed(exception: Exception) =>
+ case scala.util.Failure(exception) =>
logInfo("Job %d failed: %s, took %f s".format
(waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
// SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
index 382b09422a..4326135186 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
@@ -17,6 +17,10 @@
package org.apache.spark.scheduler
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.concurrent.{Future, Promise}
+
/**
* An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
* results to the given handler function.
@@ -28,17 +32,15 @@ private[spark] class JobWaiter[T](
resultHandler: (Int, T) => Unit)
extends JobListener {
- private var finishedTasks = 0
-
- // Is the job as a whole finished (succeeded or failed)?
- @volatile
- private var _jobFinished = totalTasks == 0
-
- def jobFinished: Boolean = _jobFinished
-
+ private val finishedTasks = new AtomicInteger(0)
// If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero
// partition RDDs), we set the jobResult directly to JobSucceeded.
- private var jobResult: JobResult = if (jobFinished) JobSucceeded else null
+ private val jobPromise: Promise[Unit] =
+ if (totalTasks == 0) Promise.successful(()) else Promise()
+
+ def jobFinished: Boolean = jobPromise.isCompleted
+
+ def completionFuture: Future[Unit] = jobPromise.future
/**
* Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled
@@ -49,29 +51,17 @@ private[spark] class JobWaiter[T](
dagScheduler.cancelJob(jobId)
}
- override def taskSucceeded(index: Int, result: Any): Unit = synchronized {
- if (_jobFinished) {
- throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
+ override def taskSucceeded(index: Int, result: Any): Unit = {
+ // resultHandler call must be synchronized in case resultHandler itself is not thread safe.
+ synchronized {
+ resultHandler(index, result.asInstanceOf[T])
}
- resultHandler(index, result.asInstanceOf[T])
- finishedTasks += 1
- if (finishedTasks == totalTasks) {
- _jobFinished = true
- jobResult = JobSucceeded
- this.notifyAll()
+ if (finishedTasks.incrementAndGet() == totalTasks) {
+ jobPromise.success(())
}
}
- override def jobFailed(exception: Exception): Unit = synchronized {
- _jobFinished = true
- jobResult = JobFailed(exception)
- this.notifyAll()
- }
+ override def jobFailed(exception: Exception): Unit =
+ jobPromise.failure(exception)
- def awaitResult(): JobResult = synchronized {
- while (!_jobFinished) {
- this.wait()
- }
- return jobResult
- }
}