aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/FutureAction.scala232
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala91
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala14
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala260
5 files changed, 474 insertions, 133 deletions
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
new file mode 100644
index 0000000000..465cc1fa7d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.concurrent._
+import scala.concurrent.duration.Duration
+import scala.util.Try
+
+import org.apache.spark.scheduler.{JobSucceeded, JobWaiter}
+import org.apache.spark.scheduler.JobFailed
+
+
+/**
+ * A future for the result of an action. This is an extension of the Scala Future interface to
+ * support cancellation.
+ */
+trait FutureAction[T] extends Future[T] {
+
+ /**
+ * Cancels the execution of this action.
+ */
+ def cancel()
+
+ /**
+ * Blocks until this action completes.
+ * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf
+ * for unbounded waiting, or a finite positive duration
+ * @return this FutureAction
+ */
+ override def ready(atMost: Duration)(implicit permit: CanAwait): FutureAction.this.type
+
+ /**
+ * Await and return the result (of type T) of this action.
+ * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf
+ * for unbounded waiting, or a finite positive duration
+ * @throws Exception exception during action execution
+ * @return the result value if the action is completed within the specific maximum wait time
+ */
+ @throws(classOf[Exception])
+ override def result(atMost: Duration)(implicit permit: CanAwait): T
+
+ /**
+ * When this action is completed, either through an exception, or a value, apply the provided
+ * function.
+ */
+ def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext)
+
+ /**
+ * Returns whether the action has already been completed with a value or an exception.
+ */
+ override def isCompleted: Boolean
+
+ /**
+ * The value of this Future.
+ *
+ * If the future is not completed the returned value will be None. If the future is completed
+ * the value will be Some(Success(t)) if it contains a valid result, or Some(Failure(error)) if
+ * it contains an exception.
+ */
+ override def value: Option[Try[T]]
+
+ /**
+ * Block and return the result of this job.
+ */
+ @throws(classOf[Exception])
+ def get(): T = Await.result(this, Duration.Inf)
+}
+
+
+/**
+ * The future holding the result of an action that triggers a single job. Examples include
+ * count, collect, reduce.
+ */
+class FutureJob[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
+ extends FutureAction[T] {
+
+ override def cancel() {
+ jobWaiter.kill()
+ }
+
+ override def ready(atMost: Duration)(implicit permit: CanAwait): FutureJob.this.type = {
+ if (!atMost.isFinite()) {
+ awaitResult()
+ } else {
+ val finishTime = System.currentTimeMillis() + atMost.toMillis
+ while (!isCompleted) {
+ val time = System.currentTimeMillis()
+ if (time >= finishTime) {
+ throw new TimeoutException
+ } else {
+ jobWaiter.wait(finishTime - time)
+ }
+ }
+ }
+ this
+ }
+
+ @throws(classOf[Exception])
+ override def result(atMost: Duration)(implicit permit: CanAwait): T = {
+ ready(atMost)(permit)
+ awaitResult() match {
+ case scala.util.Success(res) => res
+ case scala.util.Failure(e) => throw e
+ }
+ }
+
+ override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) {
+ executor.execute(new Runnable {
+ override def run() {
+ func(awaitResult())
+ }
+ })
+ }
+
+ override def isCompleted: Boolean = jobWaiter.jobFinished
+
+ override def value: Option[Try[T]] = {
+ if (jobWaiter.jobFinished) {
+ Some(awaitResult())
+ } else {
+ None
+ }
+ }
+
+ private def awaitResult(): Try[T] = {
+ jobWaiter.awaitResult() match {
+ case JobSucceeded => scala.util.Success(resultFunc)
+ case JobFailed(e: Exception, _) => scala.util.Failure(e)
+ }
+ }
+}
+
+
+/**
+ * A FutureAction for actions that could trigger multiple Spark jobs. Examples include take,
+ * takeSample.
+ *
+ * This is implemented as a Scala Promise that can be cancelled. Note that the promise itself is
+ * also its own Future (i.e. this.future returns this). See the implementation of takeAsync for
+ * usage.
+ */
+class CancellablePromise[T] extends FutureAction[T] with Promise[T] {
+ // Cancellation works by setting the cancelled flag to true and interrupt the action thread
+ // if it is in progress. Before executing the action, the execution thread needs to check the
+ // cancelled flag in case cancel() is called before the thread even starts to execute. Because
+ // this and the execution thread is synchronized on the same promise object (this), the actual
+ // cancellation/interrupt event can only be triggered when the execution thread is waiting for
+ // the result of a job.
+
+ override def cancel(): Unit = this.synchronized {
+ _cancelled = true
+ if (thread != null) {
+ thread.interrupt()
+ }
+ }
+
+ /**
+ * Executes some action enclosed in the closure. This execution of func is wrapped in a
+ * synchronized block to guarantee that this promise can only be cancelled when the task is
+ * waiting for
+ */
+ def run(func: => T)(implicit executor: ExecutionContext): Unit = scala.concurrent.future {
+ thread = Thread.currentThread
+ try {
+ this.success(this.synchronized {
+ if (cancelled) {
+ // This action has been cancelled before this thread even started running.
+ throw new InterruptedException
+ }
+ func
+ })
+ } catch {
+ case e: Exception => this.failure(e)
+ } finally {
+ thread = null
+ }
+ }
+
+ /**
+ * Returns whether the promise has been cancelled.
+ */
+ def cancelled: Boolean = _cancelled
+
+ // Pointer to the thread that is executing the action. It is set when the action is run.
+ @volatile private var thread: Thread = _
+
+ // A flag indicating whether the future has been cancelled. This is used in case the future
+ // is cancelled before the action was even run (and thus we have no thread to interrupt).
+ @volatile private var _cancelled: Boolean = false
+
+ // Internally, we delegate most functionality to this promise.
+ private val p = promise[T]()
+
+ override def future: this.type = this
+
+ override def tryComplete(result: Try[T]): Boolean = p.tryComplete(result)
+
+ @scala.throws(classOf[InterruptedException])
+ @scala.throws(classOf[scala.concurrent.TimeoutException])
+ override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = {
+ p.future.ready(atMost)(permit)
+ this
+ }
+
+ @scala.throws(classOf[Exception])
+ override def result(atMost: Duration)(implicit permit: CanAwait): T = {
+ p.future.result(atMost)(permit)
+ }
+
+ override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit = {
+ p.future.onComplete(func)(executor)
+ }
+
+ override def isCompleted: Boolean = p.isCompleted
+
+ override def value: Option[Try[T]] = p.future.value
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 3012453a45..5c2946db4e 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -817,17 +817,23 @@ class SparkContext(
result
}
+ /**
+ * Submit a job for execution and return a FutureJob holding the result. Note that the
+ * processPartition closure will be "cleaned" so the caller doesn't have to clean the closure
+ * explicitly.
+ */
def submitJob[T, U, R](
rdd: RDD[T],
processPartition: Iterator[T] => U,
partitions: Seq[Int],
partitionResultHandler: (Int, U) => Unit,
- resultFunc: () => R): FutureJob[R] =
+ resultFunc: => R): FutureJob[R] =
{
+ val cleanF = clean(processPartition)
val callSite = Utils.formatSparkCallSite
val waiter = dagScheduler.submitJob(
rdd,
- (context: TaskContext, iter: Iterator[T]) => processPartition(iter),
+ (context: TaskContext, iter: Iterator[T]) => cleanF(iter),
partitions,
callSite,
allowLocal = false,
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index 6b810f753e..6806b8730b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -17,22 +17,27 @@
package org.apache.spark.rdd
+import java.util.concurrent.atomic.AtomicLong
+
import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
+import scala.concurrent.ExecutionContext.Implicits.global
-import org.apache.spark.FutureJob
+import org.apache.spark.{Logging, CancellablePromise, FutureAction}
/**
* A set of asynchronous RDD actions available through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
*/
-class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable {
+class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with Logging {
/**
* Return a future for counting the number of elements in the RDD.
*/
- def countAsync(): FutureJob[Long] = {
- var totalCount: java.lang.Long = 0L
- self.context.submitJob[T, Long, Long](
+ def countAsync(): FutureAction[Long] = {
+ val totalCount = new AtomicLong
+ self.context.submitJob(
self,
(iter: Iterator[T]) => {
var result = 0L
@@ -43,39 +48,85 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable {
result
},
Range(0, self.partitions.size),
- (index, data) => totalCount += data,
- () => totalCount)
+ (index: Int, data: Long) => totalCount.addAndGet(data),
+ totalCount.get())
}
/**
* Return a future for retrieving all elements of this RDD.
*/
- def collectAsync(): FutureJob[Seq[T]] = {
+ def collectAsync(): FutureAction[Seq[T]] = {
val results = new ArrayBuffer[T]
self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size),
- (index, data) => results ++= data, () => results)
+ (index, data) => results ++= data, results)
}
- def takeAsync(num: Int): FutureJob[Seq[T]] = {
- // TODO: Implement this.
- null
+ /**
+ * The async version of take that returns a FutureAction.
+ */
+ def takeAsync(num: Int): FutureAction[Seq[T]] = {
+ val promise = new CancellablePromise[Seq[T]]
+
+ promise.run {
+ val buf = new ArrayBuffer[T](num)
+ val totalParts = self.partitions.length
+ var partsScanned = 0
+ while (buf.size < num && partsScanned < totalParts) {
+ // The number of partitions to try in this iteration. It is ok for this number to be
+ // greater than totalParts because we actually cap it at totalParts in runJob.
+ var numPartsToTry = 1
+ if (partsScanned > 0) {
+ // If we didn't find any rows after the first iteration, just try all partitions next.
+ // Otherwise, interpolate the number of partitions we need to try, but overestimate it
+ // by 50%.
+ if (buf.size == 0) {
+ numPartsToTry = totalParts - 1
+ } else {
+ numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt
+ }
+ }
+ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
+
+ val left = num - buf.size
+ val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+
+ val job = self.context.submitJob(
+ self,
+ (it: Iterator[T]) => it.take(left).toArray,
+ p,
+ (index: Int, data: Array[T]) => buf ++= data.take(num - buf.size),
+ Unit)
+
+ // Wait for the job to complete. If the action is cancelled (with an interrupt),
+ // cancel the job and stop the execution.
+ try {
+ Await.result(job, Duration.Inf)
+ } catch {
+ case e: InterruptedException =>
+ job.cancel()
+ throw e
+ }
+ partsScanned += numPartsToTry
+ }
+ buf.toSeq
+ }
+
+ promise.future
}
/**
* Applies a function f to all elements of this RDD.
*/
- def foreachAsync(f: T => Unit): FutureJob[Unit] = {
- val cleanF = self.context.clean(f)
- self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size),
- (index, data) => Unit, () => Unit)
+ def foreachAsync(f: T => Unit): FutureAction[Unit] = {
+ self.context.submitJob[T, Unit, Unit](self, _.foreach(f), Range(0, self.partitions.size),
+ (index, data) => Unit, Unit)
}
/**
* Applies a function f to each partition of this RDD.
*/
- def foreachPartitionAsync(f: Iterator[T] => Unit): FutureJob[Unit] = {
- val cleanF = self.context.clean(f)
- self.context.submitJob[T, Unit, Unit](self, cleanF, Range(0, self.partitions.size),
- (index, data) => Unit, () => Unit)
+ def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = {
+ self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size),
+ (index, data) => Unit, Unit)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
index 3961466fdf..be0dabf4b9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -185,12 +185,14 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
def taskSetFinished(manager: TaskSetManager) {
this.synchronized {
- activeTaskSets -= manager.taskSet.id
- manager.parent.removeSchedulable(manager)
- logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
- taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
- taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
- taskSetTaskIds.remove(manager.taskSet.id)
+ if (activeTaskSets.contains(manager.taskSet.id)) {
+ activeTaskSets -= manager.taskSet.id
+ manager.parent.removeSchedulable(manager)
+ logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
+ taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
+ taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
+ taskSetTaskIds.remove(manager.taskSet.id)
+ }
}
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index 3a65b7dabc..0fd96ed3b1 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -20,135 +20,185 @@ package org.apache.spark.rdd
import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.AtomicInteger
+import scala.concurrent.Await
+import scala.concurrent.future
+import scala.concurrent.duration._
import scala.concurrent.ExecutionContext.Implicits.global
-import org.scalatest.FunSuite
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.spark.SparkContext._
-import org.apache.spark.{SparkException, SharedSparkContext}
+import org.apache.spark.{SparkContext, SparkException, LocalSparkContext}
-class AsyncRDDActionsSuite extends FunSuite with SharedSparkContext {
+class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll {
- lazy val zeroPartRdd = new EmptyRDD[Int](sc)
-
- test("countAsync") {
- assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000)
- }
-
- test("countAsync zero partition") {
- assert(zeroPartRdd.countAsync().get() === 0)
- }
-
- test("collectAsync") {
- assert(sc.parallelize(1 to 1000, 3).collectAsync().get() === (1 to 1000))
- }
+ @transient private var sc: SparkContext = _
- test("collectAsync zero partition") {
- assert(zeroPartRdd.collectAsync().get() === Seq.empty)
+ override def beforeAll() {
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
}
- test("foreachAsync") {
- AsyncRDDActionsSuite.foreachCounter = 0
- sc.parallelize(1 to 1000, 3).foreachAsync { i =>
- AsyncRDDActionsSuite.foreachCounter += 1
- }.get()
- assert(AsyncRDDActionsSuite.foreachCounter === 1000)
+ override def afterAll() {
+ LocalSparkContext.stop(sc)
+ sc = null
}
- test("foreachAsync zero partition") {
- zeroPartRdd.foreachAsync(i => Unit).get()
- }
-
- test("foreachPartitionAsync") {
- AsyncRDDActionsSuite.foreachPartitionCounter = 0
- sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter =>
- AsyncRDDActionsSuite.foreachPartitionCounter += 1
- }.get()
- assert(AsyncRDDActionsSuite.foreachPartitionCounter === 9)
- }
-
- test("foreachPartitionAsync zero partition") {
- zeroPartRdd.foreachPartitionAsync(iter => Unit).get()
- }
+ lazy val zeroPartRdd = new EmptyRDD[Int](sc)
- /**
- * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
- * of a successful job execution.
- */
- test("async success handling") {
- val f = sc.parallelize(1 to 10, 2).countAsync()
+ test("job cancellation") {
+ val f = sc.parallelize(1 to 1000, 2).map { i => Thread.sleep(1000); i }.countAsync()
- // This semaphore is used to make sure our final assert waits until onComplete / onSuccess
- // finishes execution.
val sem = new Semaphore(0)
-
- AsyncRDDActionsSuite.asyncSuccessHappened = new AtomicInteger
- f.onComplete {
- case scala.util.Success(res) =>
- AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
- sem.release()
- case scala.util.Failure(e) =>
- throw new Exception("Task should succeed")
- sem.release()
- }
- f.onSuccess { case a: Any =>
- AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
- sem.release()
- }
- f.onFailure { case t =>
- throw new Exception("Task should succeed")
+ future {
+ //sem.acquire()
+ Thread.sleep(1000)
+ f.cancel()
+ println("killing previous job")
}
- assert(f.get() === 10)
- sem.acquire(2)
- assert(AsyncRDDActionsSuite.asyncSuccessHappened.get() === 2)
- }
-
- /**
- * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
- * of a failed job execution.
- */
- test("async failure handling") {
- val f = sc.parallelize(1 to 10, 2).map { i =>
- throw new Exception("intentional"); i
- }.countAsync()
-
- // This semaphore is used to make sure our final assert waits until onComplete / onFailure
- // finishes execution.
- val sem = new Semaphore(0)
- AsyncRDDActionsSuite.asyncFailureHappend = new AtomicInteger
- f.onComplete {
- case scala.util.Success(res) =>
- throw new Exception("Task should fail")
- sem.release()
- case scala.util.Failure(e) =>
- AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
- sem.release()
- }
- f.onSuccess { case a: Any =>
- throw new Exception("Task should fail")
- }
- f.onFailure { case t =>
- AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
- sem.release()
- }
intercept[SparkException] {
- f.get()
+ println("lalalalalala")
+ println(f.get())
+ println("hahahahah")
}
- sem.acquire(2)
- assert(AsyncRDDActionsSuite.asyncFailureHappend.get() === 2)
+
}
+//
+// test("countAsync") {
+// assert(zeroPartRdd.countAsync().get() === 0)
+// assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000)
+// }
+//
+// test("collectAsync") {
+// assert(zeroPartRdd.collectAsync().get() === Seq.empty)
+//
+// // Note that we sort the collected output because the order is indeterministic.
+// val collected = sc.parallelize(1 to 1000, 3).collectAsync().get().sorted
+// assert(collected === (1 to 1000))
+// }
+//
+// test("foreachAsync") {
+// zeroPartRdd.foreachAsync(i => Unit).get()
+//
+// val accum = sc.accumulator(0)
+// sc.parallelize(1 to 1000, 3).foreachAsync { i =>
+// accum += 1
+// }.get()
+// assert(accum.value === 1000)
+// }
+//
+// test("foreachPartitionAsync") {
+// zeroPartRdd.foreachPartitionAsync(iter => Unit).get()
+//
+// val accum = sc.accumulator(0)
+// sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter =>
+// accum += 1
+// }.get()
+// assert(accum.value === 9)
+// }
+//
+// test("takeAsync") {
+// def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) {
+// // Note that we sort the collected output because the order is indeterministic.
+// assert(rdd.takeAsync(num).get().size === input.take(num).size)
+// }
+// val input = Range(1, 1000)
+//
+// var nums = sc.parallelize(input, 1)
+// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
+// testTake(nums, input, num)
+// }
+//
+// nums = sc.parallelize(input, 2)
+// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
+// testTake(nums, input, num)
+// }
+//
+// nums = sc.parallelize(input, 100)
+// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
+// testTake(nums, input, num)
+// }
+//
+// nums = sc.parallelize(input, 1000)
+// for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
+// testTake(nums, input, num)
+// }
+// }
+//
+// /**
+// * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
+// * of a successful job execution.
+// */
+// test("async success handling") {
+// val f = sc.parallelize(1 to 10, 2).countAsync()
+//
+// // This semaphore is used to make sure our final assert waits until onComplete / onSuccess
+// // finishes execution.
+// val sem = new Semaphore(0)
+//
+// AsyncRDDActionsSuite.asyncSuccessHappened.set(0)
+// f.onComplete {
+// case scala.util.Success(res) =>
+// AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
+// sem.release()
+// case scala.util.Failure(e) =>
+// throw new Exception("Task should succeed")
+// sem.release()
+// }
+// f.onSuccess { case a: Any =>
+// AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
+// sem.release()
+// }
+// f.onFailure { case t =>
+// throw new Exception("Task should succeed")
+// }
+// assert(f.get() === 10)
+// sem.acquire(2)
+// assert(AsyncRDDActionsSuite.asyncSuccessHappened.get() === 2)
+// }
+//
+// /**
+// * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
+// * of a failed job execution.
+// */
+// test("async failure handling") {
+// val f = sc.parallelize(1 to 10, 2).map { i =>
+// throw new Exception("intentional"); i
+// }.countAsync()
+//
+// // This semaphore is used to make sure our final assert waits until onComplete / onFailure
+// // finishes execution.
+// val sem = new Semaphore(0)
+//
+// AsyncRDDActionsSuite.asyncFailureHappend.set(0)
+// f.onComplete {
+// case scala.util.Success(res) =>
+// throw new Exception("Task should fail")
+// sem.release()
+// case scala.util.Failure(e) =>
+// AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
+// sem.release()
+// }
+// f.onSuccess { case a: Any =>
+// throw new Exception("Task should fail")
+// }
+// f.onFailure { case t =>
+// AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
+// sem.release()
+// }
+// intercept[SparkException] {
+// f.get()
+// }
+// sem.acquire(2)
+// assert(AsyncRDDActionsSuite.asyncFailureHappend.get() === 2)
+// }
}
object AsyncRDDActionsSuite {
// Some counters used in the test cases above.
- var foreachCounter = 0
-
- var foreachPartitionCounter = 0
-
- var asyncSuccessHappened: AtomicInteger = _
+ var asyncSuccessHappened = new AtomicInteger
- var asyncFailureHappend: AtomicInteger = _
+ var asyncFailureHappend = new AtomicInteger
}