diff options
5 files changed, 474 insertions, 133 deletions
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala new file mode 100644 index 0000000000..465cc1fa7d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent._ +import scala.concurrent.duration.Duration +import scala.util.Try + +import org.apache.spark.scheduler.{JobSucceeded, JobWaiter} +import org.apache.spark.scheduler.JobFailed + + +/** + * A future for the result of an action. This is an extension of the Scala Future interface to + * support cancellation. + */ +trait FutureAction[T] extends Future[T] { + + /** + * Cancels the execution of this action. + */ + def cancel() + + /** + * Blocks until this action completes. + * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf + * for unbounded waiting, or a finite positive duration + * @return this FutureAction + */ + override def ready(atMost: Duration)(implicit permit: CanAwait): FutureAction.this.type + + /** + * Await and return the result (of type T) of this action. + * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf + * for unbounded waiting, or a finite positive duration + * @throws Exception exception during action execution + * @return the result value if the action is completed within the specific maximum wait time + */ + @throws(classOf[Exception]) + override def result(atMost: Duration)(implicit permit: CanAwait): T + + /** + * When this action is completed, either through an exception, or a value, apply the provided + * function. + */ + def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) + + /** + * Returns whether the action has already been completed with a value or an exception. + */ + override def isCompleted: Boolean + + /** + * The value of this Future. + * + * If the future is not completed the returned value will be None. If the future is completed + * the value will be Some(Success(t)) if it contains a valid result, or Some(Failure(error)) if + * it contains an exception. + */ + override def value: Option[Try[T]] + + /** + * Block and return the result of this job. + */ + @throws(classOf[Exception]) + def get(): T = Await.result(this, Duration.Inf) +} + + +/** + * The future holding the result of an action that triggers a single job. Examples include + * count, collect, reduce. + */ +class FutureJob[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) + extends FutureAction[T] { + + override def cancel() { + jobWaiter.kill() + } + + override def ready(atMost: Duration)(implicit permit: CanAwait): FutureJob.this.type = { + if (!atMost.isFinite()) { + awaitResult() + } else { + val finishTime = System.currentTimeMillis() + atMost.toMillis + while (!isCompleted) { + val time = System.currentTimeMillis() + if (time >= finishTime) { + throw new TimeoutException + } else { + jobWaiter.wait(finishTime - time) + } + } + } + this + } + + @throws(classOf[Exception]) + override def result(atMost: Duration)(implicit permit: CanAwait): T = { + ready(atMost)(permit) + awaitResult() match { + case scala.util.Success(res) => res + case scala.util.Failure(e) => throw e + } + } + + override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) { + executor.execute(new Runnable { + override def run() { + func(awaitResult()) + } + }) + } + + override def isCompleted: Boolean = jobWaiter.jobFinished + + override def value: Option[Try[T]] = { + if (jobWaiter.jobFinished) { + Some(awaitResult()) + } else { + None + } + } + + private def awaitResult(): Try[T] = { + jobWaiter.awaitResult() match { + case JobSucceeded => scala.util.Success(resultFunc) + case JobFailed(e: Exception, _) => scala.util.Failure(e) + } + } +} + + +/** + * A FutureAction for actions that could trigger multiple Spark jobs. Examples include take, + * takeSample. + * + * This is implemented as a Scala Promise that can be cancelled. Note that the promise itself is + * also its own Future (i.e. this.future returns this). See the implementation of takeAsync for + * usage. + */ +class CancellablePromise[T] extends FutureAction[T] with Promise[T] { + // Cancellation works by setting the cancelled flag to true and interrupt the action thread + // if it is in progress. Before executing the action, the execution thread needs to check the + // cancelled flag in case cancel() is called before the thread even starts to execute. Because + // this and the execution thread is synchronized on the same promise object (this), the actual + // cancellation/interrupt event can only be triggered when the execution thread is waiting for + // the result of a job. + + override def cancel(): Unit = this.synchronized { + _cancelled = true + if (thread != null) { + thread.interrupt() + } + } + + /** + * Executes some action enclosed in the closure. This execution of func is wrapped in a + * synchronized block to guarantee that this promise can only be cancelled when the task is + * waiting for + */ + def run(func: => T)(implicit executor: ExecutionContext): Unit = scala.concurrent.future { + thread = Thread.currentThread + try { + this.success(this.synchronized { + if (cancelled) { + // This action has been cancelled before this thread even started running. + throw new InterruptedException + } + func + }) + } catch { + case e: Exception => this.failure(e) + } finally { + thread = null + } + } + + /** + * Returns whether the promise has been cancelled. + */ + def cancelled: Boolean = _cancelled + + // Pointer to the thread that is executing the action. It is set when the action is run. + @volatile private var thread: Thread = _ + + // A flag indicating whether the future has been cancelled. This is used in case the future + // is cancelled before the action was even run (and thus we have no thread to interrupt). + @volatile private var _cancelled: Boolean = false + + // Internally, we delegate most functionality to this promise. + private val p = promise[T]() + + override def future: this.type = this + + override def tryComplete(result: Try[T]): Boolean = p.tryComplete(result) + + @scala.throws(classOf[InterruptedException]) + @scala.throws(classOf[scala.concurrent.TimeoutException]) + override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = { + p.future.ready(atMost)(permit) + this + } + + @scala.throws(classOf[Exception]) + override def result(atMost: Duration)(implicit permit: CanAwait): T = { + p.future.result(atMost)(permit) + } + + override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit = { + p.future.onComplete(func)(executor) + } + + override def isCompleted: Boolean = p.isCompleted + + override def value: Option[Try[T]] = p.future.value +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3012453a45..5c2946db4e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -817,17 +817,23 @@ class SparkContext( result } + /** + * Submit a job for execution and return a FutureJob holding the result. Note that the + * processPartition closure will be "cleaned" so the caller doesn't have to clean the closure + * explicitly. + */ def submitJob[T, U, R]( rdd: RDD[T], processPartition: Iterator[T] => U, partitions: Seq[Int], partitionResultHandler: (Int, U) => Unit, - resultFunc: () => R): FutureJob[R] = + resultFunc: => R): FutureJob[R] = { + val cleanF = clean(processPartition) val callSite = Utils.formatSparkCallSite val waiter = dagScheduler.submitJob( rdd, - (context: TaskContext, iter: Iterator[T]) => processPartition(iter), + (context: TaskContext, iter: Iterator[T]) => cleanF(iter), partitions, callSite, allowLocal = false, diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 6b810f753e..6806b8730b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -17,22 +17,27 @@ package org.apache.spark.rdd +import java.util.concurrent.atomic.AtomicLong + import scala.collection.mutable.ArrayBuffer +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scala.concurrent.ExecutionContext.Implicits.global -import org.apache.spark.FutureJob +import org.apache.spark.{Logging, CancellablePromise, FutureAction} /** * A set of asynchronous RDD actions available through an implicit conversion. * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. */ -class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable { +class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with Logging { /** * Return a future for counting the number of elements in the RDD. */ - def countAsync(): FutureJob[Long] = { - var totalCount: java.lang.Long = 0L - self.context.submitJob[T, Long, Long]( + def countAsync(): FutureAction[Long] = { + val totalCount = new AtomicLong + self.context.submitJob( self, (iter: Iterator[T]) => { var result = 0L @@ -43,39 +48,85 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable { result }, Range(0, self.partitions.size), - (index, data) => totalCount += data, - () => totalCount) + (index: Int, data: Long) => totalCount.addAndGet(data), + totalCount.get()) } /** * Return a future for retrieving all elements of this RDD. */ - def collectAsync(): FutureJob[Seq[T]] = { + def collectAsync(): FutureAction[Seq[T]] = { val results = new ArrayBuffer[T] self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), - (index, data) => results ++= data, () => results) + (index, data) => results ++= data, results) } - def takeAsync(num: Int): FutureJob[Seq[T]] = { - // TODO: Implement this. - null + /** + * The async version of take that returns a FutureAction. + */ + def takeAsync(num: Int): FutureAction[Seq[T]] = { + val promise = new CancellablePromise[Seq[T]] + + promise.run { + val buf = new ArrayBuffer[T](num) + val totalParts = self.partitions.length + var partsScanned = 0 + while (buf.size < num && partsScanned < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. + var numPartsToTry = 1 + if (partsScanned > 0) { + // If we didn't find any rows after the first iteration, just try all partitions next. + // Otherwise, interpolate the number of partitions we need to try, but overestimate it + // by 50%. + if (buf.size == 0) { + numPartsToTry = totalParts - 1 + } else { + numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt + } + } + numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions + + val left = num - buf.size + val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + + val job = self.context.submitJob( + self, + (it: Iterator[T]) => it.take(left).toArray, + p, + (index: Int, data: Array[T]) => buf ++= data.take(num - buf.size), + Unit) + + // Wait for the job to complete. If the action is cancelled (with an interrupt), + // cancel the job and stop the execution. + try { + Await.result(job, Duration.Inf) + } catch { + case e: InterruptedException => + job.cancel() + throw e + } + partsScanned += numPartsToTry + } + buf.toSeq + } + + promise.future } /** * Applies a function f to all elements of this RDD. */ - def foreachAsync(f: T => Unit): FutureJob[Unit] = { - val cleanF = self.context.clean(f) - self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size), - (index, data) => Unit, () => Unit) + def foreachAsync(f: T => Unit): FutureAction[Unit] = { + self.context.submitJob[T, Unit, Unit](self, _.foreach(f), Range(0, self.partitions.size), + (index, data) => Unit, Unit) } /** * Applies a function f to each partition of this RDD. */ - def foreachPartitionAsync(f: Iterator[T] => Unit): FutureJob[Unit] = { - val cleanF = self.context.clean(f) - self.context.submitJob[T, Unit, Unit](self, cleanF, Range(0, self.partitions.size), - (index, data) => Unit, () => Unit) + def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { + self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size), + (index, data) => Unit, Unit) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index 3961466fdf..be0dabf4b9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -185,12 +185,14 @@ private[spark] class ClusterScheduler(val sc: SparkContext) def taskSetFinished(manager: TaskSetManager) { this.synchronized { - activeTaskSets -= manager.taskSet.id - manager.parent.removeSchedulable(manager) - logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds.remove(manager.taskSet.id) + if (activeTaskSets.contains(manager.taskSet.id)) { + activeTaskSets -= manager.taskSet.id + manager.parent.removeSchedulable(manager) + logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) + taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) + taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) + taskSetTaskIds.remove(manager.taskSet.id) + } } } diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 3a65b7dabc..0fd96ed3b1 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -20,135 +20,185 @@ package org.apache.spark.rdd import java.util.concurrent.Semaphore import java.util.concurrent.atomic.AtomicInteger +import scala.concurrent.Await +import scala.concurrent.future +import scala.concurrent.duration._ import scala.concurrent.ExecutionContext.Implicits.global -import org.scalatest.FunSuite +import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark.SparkContext._ -import org.apache.spark.{SparkException, SharedSparkContext} +import org.apache.spark.{SparkContext, SparkException, LocalSparkContext} -class AsyncRDDActionsSuite extends FunSuite with SharedSparkContext { +class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll { - lazy val zeroPartRdd = new EmptyRDD[Int](sc) - - test("countAsync") { - assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000) - } - - test("countAsync zero partition") { - assert(zeroPartRdd.countAsync().get() === 0) - } - - test("collectAsync") { - assert(sc.parallelize(1 to 1000, 3).collectAsync().get() === (1 to 1000)) - } + @transient private var sc: SparkContext = _ - test("collectAsync zero partition") { - assert(zeroPartRdd.collectAsync().get() === Seq.empty) + override def beforeAll() { + sc = new SparkContext("local-cluster[2,1,512]", "test") } - test("foreachAsync") { - AsyncRDDActionsSuite.foreachCounter = 0 - sc.parallelize(1 to 1000, 3).foreachAsync { i => - AsyncRDDActionsSuite.foreachCounter += 1 - }.get() - assert(AsyncRDDActionsSuite.foreachCounter === 1000) + override def afterAll() { + LocalSparkContext.stop(sc) + sc = null } - test("foreachAsync zero partition") { - zeroPartRdd.foreachAsync(i => Unit).get() - } - - test("foreachPartitionAsync") { - AsyncRDDActionsSuite.foreachPartitionCounter = 0 - sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter => - AsyncRDDActionsSuite.foreachPartitionCounter += 1 - }.get() - assert(AsyncRDDActionsSuite.foreachPartitionCounter === 9) - } - - test("foreachPartitionAsync zero partition") { - zeroPartRdd.foreachPartitionAsync(iter => Unit).get() - } + lazy val zeroPartRdd = new EmptyRDD[Int](sc) - /** - * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case - * of a successful job execution. - */ - test("async success handling") { - val f = sc.parallelize(1 to 10, 2).countAsync() + test("job cancellation") { + val f = sc.parallelize(1 to 1000, 2).map { i => Thread.sleep(1000); i }.countAsync() - // This semaphore is used to make sure our final assert waits until onComplete / onSuccess - // finishes execution. val sem = new Semaphore(0) - - AsyncRDDActionsSuite.asyncSuccessHappened = new AtomicInteger - f.onComplete { - case scala.util.Success(res) => - AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet() - sem.release() - case scala.util.Failure(e) => - throw new Exception("Task should succeed") - sem.release() - } - f.onSuccess { case a: Any => - AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet() - sem.release() - } - f.onFailure { case t => - throw new Exception("Task should succeed") + future { + //sem.acquire() + Thread.sleep(1000) + f.cancel() + println("killing previous job") } - assert(f.get() === 10) - sem.acquire(2) - assert(AsyncRDDActionsSuite.asyncSuccessHappened.get() === 2) - } - - /** - * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case - * of a failed job execution. - */ - test("async failure handling") { - val f = sc.parallelize(1 to 10, 2).map { i => - throw new Exception("intentional"); i - }.countAsync() - - // This semaphore is used to make sure our final assert waits until onComplete / onFailure - // finishes execution. - val sem = new Semaphore(0) - AsyncRDDActionsSuite.asyncFailureHappend = new AtomicInteger - f.onComplete { - case scala.util.Success(res) => - throw new Exception("Task should fail") - sem.release() - case scala.util.Failure(e) => - AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet() - sem.release() - } - f.onSuccess { case a: Any => - throw new Exception("Task should fail") - } - f.onFailure { case t => - AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet() - sem.release() - } intercept[SparkException] { - f.get() + println("lalalalalala") + println(f.get()) + println("hahahahah") } - sem.acquire(2) - assert(AsyncRDDActionsSuite.asyncFailureHappend.get() === 2) + } +// +// test("countAsync") { +// assert(zeroPartRdd.countAsync().get() === 0) +// assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000) +// } +// +// test("collectAsync") { +// assert(zeroPartRdd.collectAsync().get() === Seq.empty) +// +// // Note that we sort the collected output because the order is indeterministic. +// val collected = sc.parallelize(1 to 1000, 3).collectAsync().get().sorted +// assert(collected === (1 to 1000)) +// } +// +// test("foreachAsync") { +// zeroPartRdd.foreachAsync(i => Unit).get() +// +// val accum = sc.accumulator(0) +// sc.parallelize(1 to 1000, 3).foreachAsync { i => +// accum += 1 +// }.get() +// assert(accum.value === 1000) +// } +// +// test("foreachPartitionAsync") { +// zeroPartRdd.foreachPartitionAsync(iter => Unit).get() +// +// val accum = sc.accumulator(0) +// sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter => +// accum += 1 +// }.get() +// assert(accum.value === 9) +// } +// +// test("takeAsync") { +// def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) { +// // Note that we sort the collected output because the order is indeterministic. +// assert(rdd.takeAsync(num).get().size === input.take(num).size) +// } +// val input = Range(1, 1000) +// +// var nums = sc.parallelize(input, 1) +// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) { +// testTake(nums, input, num) +// } +// +// nums = sc.parallelize(input, 2) +// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) { +// testTake(nums, input, num) +// } +// +// nums = sc.parallelize(input, 100) +// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) { +// testTake(nums, input, num) +// } +// +// nums = sc.parallelize(input, 1000) +// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) { +// testTake(nums, input, num) +// } +// } +// +// /** +// * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case +// * of a successful job execution. +// */ +// test("async success handling") { +// val f = sc.parallelize(1 to 10, 2).countAsync() +// +// // This semaphore is used to make sure our final assert waits until onComplete / onSuccess +// // finishes execution. +// val sem = new Semaphore(0) +// +// AsyncRDDActionsSuite.asyncSuccessHappened.set(0) +// f.onComplete { +// case scala.util.Success(res) => +// AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet() +// sem.release() +// case scala.util.Failure(e) => +// throw new Exception("Task should succeed") +// sem.release() +// } +// f.onSuccess { case a: Any => +// AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet() +// sem.release() +// } +// f.onFailure { case t => +// throw new Exception("Task should succeed") +// } +// assert(f.get() === 10) +// sem.acquire(2) +// assert(AsyncRDDActionsSuite.asyncSuccessHappened.get() === 2) +// } +// +// /** +// * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case +// * of a failed job execution. +// */ +// test("async failure handling") { +// val f = sc.parallelize(1 to 10, 2).map { i => +// throw new Exception("intentional"); i +// }.countAsync() +// +// // This semaphore is used to make sure our final assert waits until onComplete / onFailure +// // finishes execution. +// val sem = new Semaphore(0) +// +// AsyncRDDActionsSuite.asyncFailureHappend.set(0) +// f.onComplete { +// case scala.util.Success(res) => +// throw new Exception("Task should fail") +// sem.release() +// case scala.util.Failure(e) => +// AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet() +// sem.release() +// } +// f.onSuccess { case a: Any => +// throw new Exception("Task should fail") +// } +// f.onFailure { case t => +// AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet() +// sem.release() +// } +// intercept[SparkException] { +// f.get() +// } +// sem.acquire(2) +// assert(AsyncRDDActionsSuite.asyncFailureHappend.get() === 2) +// } } object AsyncRDDActionsSuite { // Some counters used in the test cases above. - var foreachCounter = 0 - - var foreachPartitionCounter = 0 - - var asyncSuccessHappened: AtomicInteger = _ + var asyncSuccessHappened = new AtomicInteger - var asyncFailureHappend: AtomicInteger = _ + var asyncFailureHappend = new AtomicInteger } |