aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala/org/apache
diff options
context:
space:
mode:
authorRichard W. Eggert II <richard.eggert@gmail.com>2015-12-15 18:22:58 -0800
committerAndrew Or <andrew@databricks.com>2015-12-15 18:22:58 -0800
commit765a488494dac0ed38d2b81742c06467b79d96b2 (patch)
treeafccd68b4da804f2a84a3995669370e083b772c2 /core/src/test/scala/org/apache
parenta63d9edcfb8a714a17492517927aa114dea8fea0 (diff)
downloadspark-765a488494dac0ed38d2b81742c06467b79d96b2.tar.gz
spark-765a488494dac0ed38d2b81742c06467b79d96b2.tar.bz2
spark-765a488494dac0ed38d2b81742c06467b79d96b2.zip
[SPARK-9026][SPARK-4514] Modifications to JobWaiter, FutureAction, and AsyncRDDActions to support non-blocking operation
These changes rework the implementations of `SimpleFutureAction`, `ComplexFutureAction`, `JobWaiter`, and `AsyncRDDActions` such that asynchronous callbacks on the generated `Futures` NEVER block waiting for a job to complete. A small amount of mutex synchronization is necessary to protect the internal fields that manage cancellation, but these locks are only held very briefly and in practice should almost never cause any blocking to occur. The existing blocking APIs of these classes are retained, but they simply delegate to the underlying non-blocking API and `Await` the results with indefinite timeouts. Associated JIRA ticket: https://issues.apache.org/jira/browse/SPARK-9026 Also fixes: https://issues.apache.org/jira/browse/SPARK-4514 This pull request contains all my own original work, which I release to the Spark project under its open source license. Author: Richard W. Eggert II <richard.eggert@gmail.com> Closes #9264 from reggert/fix-futureaction.
Diffstat (limited to 'core/src/test/scala/org/apache')
-rw-r--r--core/src/test/scala/org/apache/spark/Smuggle.scala82
-rw-r--r--core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala33
3 files changed, 139 insertions, 2 deletions
diff --git a/core/src/test/scala/org/apache/spark/Smuggle.scala b/core/src/test/scala/org/apache/spark/Smuggle.scala
new file mode 100644
index 0000000000..01694a6e6f
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/Smuggle.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.util.UUID
+import java.util.concurrent.locks.ReentrantReadWriteLock
+
+import scala.collection.mutable
+
+/**
+ * Utility wrapper to "smuggle" objects into tasks while bypassing serialization.
+ * This is intended for testing purposes, primarily to make locks, semaphores, and
+ * other constructs that would not survive serialization available from within tasks.
+ * A Smuggle reference is itself serializable, but after being serialized and
+ * deserialized, it still refers to the same underlying "smuggled" object, as long
+ * as it was deserialized within the same JVM. This can be useful for tests that
+ * depend on the timing of task completion to be deterministic, since one can "smuggle"
+ * a lock or semaphore into the task, and then the task can block until the test gives
+ * the go-ahead to proceed via the lock.
+ */
+class Smuggle[T] private(val key: Symbol) extends Serializable {
+ def smuggledObject: T = Smuggle.get(key)
+}
+
+
+object Smuggle {
+ /**
+ * Wraps the specified object to be smuggled into a serialized task without
+ * being serialized itself.
+ *
+ * @param smuggledObject
+ * @tparam T
+ * @return Smuggle wrapper around smuggledObject.
+ */
+ def apply[T](smuggledObject: T): Smuggle[T] = {
+ val key = Symbol(UUID.randomUUID().toString)
+ lock.writeLock().lock()
+ try {
+ smuggledObjects += key -> smuggledObject
+ } finally {
+ lock.writeLock().unlock()
+ }
+ new Smuggle(key)
+ }
+
+ private val lock = new ReentrantReadWriteLock
+ private val smuggledObjects = mutable.WeakHashMap.empty[Symbol, Any]
+
+ private def get[T](key: Symbol) : T = {
+ lock.readLock().lock()
+ try {
+ smuggledObjects(key).asInstanceOf[T]
+ } finally {
+ lock.readLock().unlock()
+ }
+ }
+
+ /**
+ * Implicit conversion of a Smuggle wrapper to the object being smuggled.
+ *
+ * @param smuggle the wrapper to unpack.
+ * @tparam T
+ * @return the smuggled object represented by the wrapper.
+ */
+ implicit def unpackSmuggledObject[T](smuggle : Smuggle[T]): T = smuggle.smuggledObject
+
+}
diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala
index 46516e8d25..5483f2b843 100644
--- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala
@@ -86,4 +86,30 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont
Set(firstJobId, secondJobId))
}
}
+
+ test("getJobIdsForGroup() with takeAsync()") {
+ sc = new SparkContext("local", "test", new SparkConf(false))
+ sc.setJobGroup("my-job-group2", "description")
+ sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty
+ val firstJobFuture = sc.parallelize(1 to 1000, 1).takeAsync(1)
+ val firstJobId = eventually(timeout(10 seconds)) {
+ firstJobFuture.jobIds.head
+ }
+ eventually(timeout(10 seconds)) {
+ sc.statusTracker.getJobIdsForGroup("my-job-group2") should be (Seq(firstJobId))
+ }
+ }
+
+ test("getJobIdsForGroup() with takeAsync() across multiple partitions") {
+ sc = new SparkContext("local", "test", new SparkConf(false))
+ sc.setJobGroup("my-job-group2", "description")
+ sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty
+ val firstJobFuture = sc.parallelize(1 to 1000, 2).takeAsync(999)
+ val firstJobId = eventually(timeout(10 seconds)) {
+ firstJobFuture.jobIds.head
+ }
+ eventually(timeout(10 seconds)) {
+ sc.statusTracker.getJobIdsForGroup("my-job-group2") should have size 2
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index ec99f2a1ba..de015ebd5d 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.rdd
import java.util.concurrent.Semaphore
-import scala.concurrent.{Await, TimeoutException}
+import scala.concurrent._
import scala.concurrent.duration.Duration
import scala.concurrent.ExecutionContext.Implicits.global
@@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
-import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite}
+import org.apache.spark._
class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Timeouts {
@@ -197,4 +197,33 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim
Await.result(f, Duration(20, "milliseconds"))
}
}
+
+ private def testAsyncAction[R](action: RDD[Int] => FutureAction[R]): Unit = {
+ val executionContextInvoked = Promise[Unit]
+ val fakeExecutionContext = new ExecutionContext {
+ override def execute(runnable: Runnable): Unit = {
+ executionContextInvoked.success(())
+ }
+ override def reportFailure(t: Throwable): Unit = ()
+ }
+ val starter = Smuggle(new Semaphore(0))
+ starter.drainPermits()
+ val rdd = sc.parallelize(1 to 100, 4).mapPartitions {itr => starter.acquire(1); itr}
+ val f = action(rdd)
+ f.onComplete(_ => ())(fakeExecutionContext)
+ // Here we verify that registering the callback didn't cause a thread to be consumed.
+ assert(!executionContextInvoked.isCompleted)
+ // Now allow the executors to proceed with task processing.
+ starter.release(rdd.partitions.length)
+ // Waiting for the result verifies that the tasks were successfully processed.
+ Await.result(executionContextInvoked.future, atMost = 15.seconds)
+ }
+
+ test("SimpleFutureAction callback must not consume a thread while waiting") {
+ testAsyncAction(_.countAsync())
+ }
+
+ test("ComplexFutureAction callback must not consume a thread while waiting") {
+ testAsyncAction((_.takeAsync(100)))
+ }
}