aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala167
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/SparkContextSuite.scala57
-rw-r--r--docs/programming-guide.md2
-rw-r--r--pom.xml1
-rw-r--r--project/SparkBuild.scala1
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala186
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala52
9 files changed, 347 insertions, 126 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 65edeeffb8..7cccf74003 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -58,12 +58,26 @@ import org.apache.spark.util._
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
* cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
*
+ * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before
+ * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details.
+ *
* @param config a Spark Config object describing the application configuration. Any settings in
* this config overrides the default configs as well as system properties.
*/
-
class SparkContext(config: SparkConf) extends Logging {
+ // The call site where this SparkContext was constructed.
+ private val creationSite: CallSite = Utils.getCallSite()
+
+ // If true, log warnings instead of throwing exceptions when multiple SparkContexts are active
+ private val allowMultipleContexts: Boolean =
+ config.getBoolean("spark.driver.allowMultipleContexts", false)
+
+ // In order to prevent multiple SparkContexts from being active at the same time, mark this
+ // context as having started construction.
+ // NOTE: this must be placed at the beginning of the SparkContext constructor.
+ SparkContext.markPartiallyConstructed(this, allowMultipleContexts)
+
// This is used only by YARN for now, but should be relevant to other cluster types (Mesos,
// etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It
// contains a map from hostname to a list of input format splits on the host.
@@ -1166,27 +1180,30 @@ class SparkContext(config: SparkConf) extends Logging {
/** Shut down the SparkContext. */
def stop() {
- postApplicationEnd()
- ui.foreach(_.stop())
- // Do this only if not stopped already - best case effort.
- // prevent NPE if stopped more than once.
- val dagSchedulerCopy = dagScheduler
- dagScheduler = null
- if (dagSchedulerCopy != null) {
- env.metricsSystem.report()
- metadataCleaner.cancel()
- env.actorSystem.stop(heartbeatReceiver)
- cleaner.foreach(_.stop())
- dagSchedulerCopy.stop()
- taskScheduler = null
- // TODO: Cache.stop()?
- env.stop()
- SparkEnv.set(null)
- listenerBus.stop()
- eventLogger.foreach(_.stop())
- logInfo("Successfully stopped SparkContext")
- } else {
- logInfo("SparkContext already stopped")
+ SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ postApplicationEnd()
+ ui.foreach(_.stop())
+ // Do this only if not stopped already - best case effort.
+ // prevent NPE if stopped more than once.
+ val dagSchedulerCopy = dagScheduler
+ dagScheduler = null
+ if (dagSchedulerCopy != null) {
+ env.metricsSystem.report()
+ metadataCleaner.cancel()
+ env.actorSystem.stop(heartbeatReceiver)
+ cleaner.foreach(_.stop())
+ dagSchedulerCopy.stop()
+ taskScheduler = null
+ // TODO: Cache.stop()?
+ env.stop()
+ SparkEnv.set(null)
+ listenerBus.stop()
+ eventLogger.foreach(_.stop())
+ logInfo("Successfully stopped SparkContext")
+ SparkContext.clearActiveContext()
+ } else {
+ logInfo("SparkContext already stopped")
+ }
}
}
@@ -1475,6 +1492,11 @@ class SparkContext(config: SparkConf) extends Logging {
private[spark] def cleanup(cleanupTime: Long) {
persistentRdds.clearOldValues(cleanupTime)
}
+
+ // In order to prevent multiple SparkContexts from being active at the same time, mark this
+ // context as having finished construction.
+ // NOTE: this must be placed at the end of the SparkContext constructor.
+ SparkContext.setActiveContext(this, allowMultipleContexts)
}
/**
@@ -1483,6 +1505,107 @@ class SparkContext(config: SparkConf) extends Logging {
*/
object SparkContext extends Logging {
+ /**
+ * Lock that guards access to global variables that track SparkContext construction.
+ */
+ private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object()
+
+ /**
+ * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `None`.
+ *
+ * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK
+ */
+ private var activeContext: Option[SparkContext] = None
+
+ /**
+ * Points to a partially-constructed SparkContext if some thread is in the SparkContext
+ * constructor, or `None` if no SparkContext is being constructed.
+ *
+ * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK
+ */
+ private var contextBeingConstructed: Option[SparkContext] = None
+
+ /**
+ * Called to ensure that no other SparkContext is running in this JVM.
+ *
+ * Throws an exception if a running context is detected and logs a warning if another thread is
+ * constructing a SparkContext. This warning is necessary because the current locking scheme
+ * prevents us from reliably distinguishing between cases where another context is being
+ * constructed and cases where another constructor threw an exception.
+ */
+ private def assertNoOtherContextIsRunning(
+ sc: SparkContext,
+ allowMultipleContexts: Boolean): Unit = {
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ contextBeingConstructed.foreach { otherContext =>
+ if (otherContext ne sc) { // checks for reference equality
+ // Since otherContext might point to a partially-constructed context, guard against
+ // its creationSite field being null:
+ val otherContextCreationSite =
+ Option(otherContext.creationSite).map(_.longForm).getOrElse("unknown location")
+ val warnMsg = "Another SparkContext is being constructed (or threw an exception in its" +
+ " constructor). This may indicate an error, since only one SparkContext may be" +
+ " running in this JVM (see SPARK-2243)." +
+ s" The other SparkContext was created at:\n$otherContextCreationSite"
+ logWarning(warnMsg)
+ }
+
+ activeContext.foreach { ctx =>
+ val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." +
+ " To ignore this error, set spark.driver.allowMultipleContexts = true. " +
+ s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}"
+ val exception = new SparkException(errMsg)
+ if (allowMultipleContexts) {
+ logWarning("Multiple running SparkContexts detected in the same JVM!", exception)
+ } else {
+ throw exception
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is
+ * running. Throws an exception if a running context is detected and logs a warning if another
+ * thread is constructing a SparkContext. This warning is necessary because the current locking
+ * scheme prevents us from reliably distinguishing between cases where another context is being
+ * constructed and cases where another constructor threw an exception.
+ */
+ private[spark] def markPartiallyConstructed(
+ sc: SparkContext,
+ allowMultipleContexts: Boolean): Unit = {
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ assertNoOtherContextIsRunning(sc, allowMultipleContexts)
+ contextBeingConstructed = Some(sc)
+ }
+ }
+
+ /**
+ * Called at the end of the SparkContext constructor to ensure that no other SparkContext has
+ * raced with this constructor and started.
+ */
+ private[spark] def setActiveContext(
+ sc: SparkContext,
+ allowMultipleContexts: Boolean): Unit = {
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ assertNoOtherContextIsRunning(sc, allowMultipleContexts)
+ contextBeingConstructed = None
+ activeContext = Some(sc)
+ }
+ }
+
+ /**
+ * Clears the active SparkContext metadata. This is called by `SparkContext#stop()`. It's
+ * also called in unit tests to prevent a flood of warnings from test suites that don't / can't
+ * properly clean up their SparkContexts.
+ */
+ private[spark] def clearActiveContext(): Unit = {
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ activeContext = None
+ }
+ }
+
private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description"
private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index d50ed32ca0..6a6d9bf685 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -42,6 +42,9 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD}
/**
* A Java-friendly version of [[org.apache.spark.SparkContext]] that returns
* [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones.
+ *
+ * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before
+ * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details.
*/
class JavaSparkContext(val sc: SparkContext)
extends JavaSparkContextVarargsWorkaround with Closeable {
diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
index 4b27477790..ce804f94f3 100644
--- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
@@ -37,20 +37,24 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext {
.set("spark.dynamicAllocation.enabled", "true")
intercept[SparkException] { new SparkContext(conf) }
SparkEnv.get.stop() // cleanup the created environment
+ SparkContext.clearActiveContext()
// Only min
val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "1")
intercept[SparkException] { new SparkContext(conf1) }
SparkEnv.get.stop()
+ SparkContext.clearActiveContext()
// Only max
val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "2")
intercept[SparkException] { new SparkContext(conf2) }
SparkEnv.get.stop()
+ SparkContext.clearActiveContext()
// Both min and max, but min > max
intercept[SparkException] { createSparkContext(2, 1) }
SparkEnv.get.stop()
+ SparkContext.clearActiveContext()
// Both min and max, and min == max
val sc1 = createSparkContext(1, 1)
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
index 31edad1c56..9e454ddcc5 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
@@ -21,9 +21,62 @@ import org.scalatest.FunSuite
import org.apache.hadoop.io.BytesWritable
-class SparkContextSuite extends FunSuite {
- //Regression test for SPARK-3121
+class SparkContextSuite extends FunSuite with LocalSparkContext {
+
+ /** Allows system properties to be changed in tests */
+ private def withSystemProperty[T](property: String, value: String)(block: => T): T = {
+ val originalValue = System.getProperty(property)
+ try {
+ System.setProperty(property, value)
+ block
+ } finally {
+ if (originalValue == null) {
+ System.clearProperty(property)
+ } else {
+ System.setProperty(property, originalValue)
+ }
+ }
+ }
+
+ test("Only one SparkContext may be active at a time") {
+ // Regression test for SPARK-4180
+ withSystemProperty("spark.driver.allowMultipleContexts", "false") {
+ val conf = new SparkConf().setAppName("test").setMaster("local")
+ sc = new SparkContext(conf)
+ // A SparkContext is already running, so we shouldn't be able to create a second one
+ intercept[SparkException] { new SparkContext(conf) }
+ // After stopping the running context, we should be able to create a new one
+ resetSparkContext()
+ sc = new SparkContext(conf)
+ }
+ }
+
+ test("Can still construct a new SparkContext after failing to construct a previous one") {
+ withSystemProperty("spark.driver.allowMultipleContexts", "false") {
+ // This is an invalid configuration (no app name or master URL)
+ intercept[SparkException] {
+ new SparkContext(new SparkConf())
+ }
+ // Even though those earlier calls failed, we should still be able to create a new context
+ sc = new SparkContext(new SparkConf().setMaster("local").setAppName("test"))
+ }
+ }
+
+ test("Check for multiple SparkContexts can be disabled via undocumented debug option") {
+ withSystemProperty("spark.driver.allowMultipleContexts", "true") {
+ var secondSparkContext: SparkContext = null
+ try {
+ val conf = new SparkConf().setAppName("test").setMaster("local")
+ sc = new SparkContext(conf)
+ secondSparkContext = new SparkContext(conf)
+ } finally {
+ Option(secondSparkContext).foreach(_.stop())
+ }
+ }
+ }
+
test("BytesWritable implicit conversion is correct") {
+ // Regression test for SPARK-3121
val bytesWritable = new BytesWritable()
val inputArray = (1 to 10).map(_.toByte).toArray
bytesWritable.set(inputArray, 0, 10)
diff --git a/docs/programming-guide.md b/docs/programming-guide.md
index 9de2f914b8..49f319ba77 100644
--- a/docs/programming-guide.md
+++ b/docs/programming-guide.md
@@ -117,6 +117,8 @@ The first thing a Spark program must do is to create a [SparkContext](api/scala/
how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) object
that contains information about your application.
+Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before creating a new one.
+
{% highlight scala %}
val conf = new SparkConf().setAppName(appName).setMaster(master)
new SparkContext(conf)
diff --git a/pom.xml b/pom.xml
index 639ea22a1f..cc7bce1757 100644
--- a/pom.xml
+++ b/pom.xml
@@ -978,6 +978,7 @@
<spark.testing>1</spark.testing>
<spark.ui.enabled>false</spark.ui.enabled>
<spark.executor.extraClassPath>${test_classpath}</spark.executor.extraClassPath>
+ <spark.driver.allowMultipleContexts>true</spark.driver.allowMultipleContexts>
</systemProperties>
</configuration>
<executions>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index c96a6c4954..1697b6d4f2 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -377,6 +377,7 @@ object TestSettings {
javaOptions in Test += "-Dspark.testing=1",
javaOptions in Test += "-Dspark.port.maxRetries=100",
javaOptions in Test += "-Dspark.ui.enabled=false",
+ javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true",
javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark")
.map { case (k,v) => s"-D$k=$v" }.toSeq,
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index 30a359677c..86b96785d7 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -470,32 +470,31 @@ class BasicOperationsSuite extends TestSuiteBase {
}
test("slice") {
- val ssc = new StreamingContext(conf, Seconds(1))
- val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4))
- val stream = new TestInputStream[Int](ssc, input, 2)
- stream.foreachRDD(_ => {}) // Dummy output stream
- ssc.start()
- Thread.sleep(2000)
- def getInputFromSlice(fromMillis: Long, toMillis: Long) = {
- stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet
- }
+ withStreamingContext(new StreamingContext(conf, Seconds(1))) { ssc =>
+ val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4))
+ val stream = new TestInputStream[Int](ssc, input, 2)
+ stream.foreachRDD(_ => {}) // Dummy output stream
+ ssc.start()
+ Thread.sleep(2000)
+ def getInputFromSlice(fromMillis: Long, toMillis: Long) = {
+ stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet
+ }
- assert(getInputFromSlice(0, 1000) == Set(1))
- assert(getInputFromSlice(0, 2000) == Set(1, 2))
- assert(getInputFromSlice(1000, 2000) == Set(1, 2))
- assert(getInputFromSlice(2000, 4000) == Set(2, 3, 4))
- ssc.stop()
- Thread.sleep(1000)
+ assert(getInputFromSlice(0, 1000) == Set(1))
+ assert(getInputFromSlice(0, 2000) == Set(1, 2))
+ assert(getInputFromSlice(1000, 2000) == Set(1, 2))
+ assert(getInputFromSlice(2000, 4000) == Set(2, 3, 4))
+ }
}
-
test("slice - has not been initialized") {
- val ssc = new StreamingContext(conf, Seconds(1))
- val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4))
- val stream = new TestInputStream[Int](ssc, input, 2)
- val thrown = intercept[SparkException] {
- stream.slice(new Time(0), new Time(1000))
+ withStreamingContext(new StreamingContext(conf, Seconds(1))) { ssc =>
+ val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4))
+ val stream = new TestInputStream[Int](ssc, input, 2)
+ val thrown = intercept[SparkException] {
+ stream.slice(new Time(0), new Time(1000))
+ }
+ assert(thrown.getMessage.contains("has not been initialized"))
}
- assert(thrown.getMessage.contains("has not been initialized"))
}
val cleanupTestInput = (0 until 10).map(x => Seq(x, x + 1)).toSeq
@@ -555,73 +554,72 @@ class BasicOperationsSuite extends TestSuiteBase {
test("rdd cleanup - input blocks and persisted RDDs") {
// Actually receive data over through receiver to create BlockRDDs
- // Start the server
- val testServer = new TestServer()
- testServer.start()
-
- // Set up the streaming context and input streams
- val ssc = new StreamingContext(conf, batchDuration)
- val networkStream = ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)
- val mappedStream = networkStream.map(_ + ".").persist()
- val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
- val outputStream = new TestOutputStream(mappedStream, outputBuffer)
-
- outputStream.register()
- ssc.start()
-
- // Feed data to the server to send to the network receiver
- val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
- val input = Seq(1, 2, 3, 4, 5, 6)
+ withTestServer(new TestServer()) { testServer =>
+ withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
+ testServer.start()
+ // Set up the streaming context and input streams
+ val networkStream =
+ ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)
+ val mappedStream = networkStream.map(_ + ".").persist()
+ val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
+ val outputStream = new TestOutputStream(mappedStream, outputBuffer)
+
+ outputStream.register()
+ ssc.start()
+
+ // Feed data to the server to send to the network receiver
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val input = Seq(1, 2, 3, 4, 5, 6)
+
+ val blockRdds = new mutable.HashMap[Time, BlockRDD[_]]
+ val persistentRddIds = new mutable.HashMap[Time, Int]
+
+ def collectRddInfo() { // get all RDD info required for verification
+ networkStream.generatedRDDs.foreach { case (time, rdd) =>
+ blockRdds(time) = rdd.asInstanceOf[BlockRDD[_]]
+ }
+ mappedStream.generatedRDDs.foreach { case (time, rdd) =>
+ persistentRddIds(time) = rdd.id
+ }
+ }
- val blockRdds = new mutable.HashMap[Time, BlockRDD[_]]
- val persistentRddIds = new mutable.HashMap[Time, Int]
+ Thread.sleep(200)
+ for (i <- 0 until input.size) {
+ testServer.send(input(i).toString + "\n")
+ Thread.sleep(200)
+ clock.addToTime(batchDuration.milliseconds)
+ collectRddInfo()
+ }
- def collectRddInfo() { // get all RDD info required for verification
- networkStream.generatedRDDs.foreach { case (time, rdd) =>
- blockRdds(time) = rdd.asInstanceOf[BlockRDD[_]]
- }
- mappedStream.generatedRDDs.foreach { case (time, rdd) =>
- persistentRddIds(time) = rdd.id
+ Thread.sleep(200)
+ collectRddInfo()
+ logInfo("Stopping server")
+ testServer.stop()
+
+ // verify data has been received
+ assert(outputBuffer.size > 0)
+ assert(blockRdds.size > 0)
+ assert(persistentRddIds.size > 0)
+
+ import Time._
+
+ val latestPersistedRddId = persistentRddIds(persistentRddIds.keySet.max)
+ val earliestPersistedRddId = persistentRddIds(persistentRddIds.keySet.min)
+ val latestBlockRdd = blockRdds(blockRdds.keySet.max)
+ val earliestBlockRdd = blockRdds(blockRdds.keySet.min)
+ // verify that the latest mapped RDD is persisted but the earliest one has been unpersisted
+ assert(ssc.sparkContext.persistentRdds.contains(latestPersistedRddId))
+ assert(!ssc.sparkContext.persistentRdds.contains(earliestPersistedRddId))
+
+ // verify that the latest input blocks are present but the earliest blocks have been removed
+ assert(latestBlockRdd.isValid)
+ assert(latestBlockRdd.collect != null)
+ assert(!earliestBlockRdd.isValid)
+ earliestBlockRdd.blockIds.foreach { blockId =>
+ assert(!ssc.sparkContext.env.blockManager.master.contains(blockId))
+ }
}
}
-
- Thread.sleep(200)
- for (i <- 0 until input.size) {
- testServer.send(input(i).toString + "\n")
- Thread.sleep(200)
- clock.addToTime(batchDuration.milliseconds)
- collectRddInfo()
- }
-
- Thread.sleep(200)
- collectRddInfo()
- logInfo("Stopping server")
- testServer.stop()
- logInfo("Stopping context")
-
- // verify data has been received
- assert(outputBuffer.size > 0)
- assert(blockRdds.size > 0)
- assert(persistentRddIds.size > 0)
-
- import Time._
-
- val latestPersistedRddId = persistentRddIds(persistentRddIds.keySet.max)
- val earliestPersistedRddId = persistentRddIds(persistentRddIds.keySet.min)
- val latestBlockRdd = blockRdds(blockRdds.keySet.max)
- val earliestBlockRdd = blockRdds(blockRdds.keySet.min)
- // verify that the latest mapped RDD is persisted but the earliest one has been unpersisted
- assert(ssc.sparkContext.persistentRdds.contains(latestPersistedRddId))
- assert(!ssc.sparkContext.persistentRdds.contains(earliestPersistedRddId))
-
- // verify that the latest input blocks are present but the earliest blocks have been removed
- assert(latestBlockRdd.isValid)
- assert(latestBlockRdd.collect != null)
- assert(!earliestBlockRdd.isValid)
- earliestBlockRdd.blockIds.foreach { blockId =>
- assert(!ssc.sparkContext.env.blockManager.master.contains(blockId))
- }
- ssc.stop()
}
/** Test cleanup of RDDs in DStream metadata */
@@ -635,13 +633,15 @@ class BasicOperationsSuite extends TestSuiteBase {
// Setup the stream computation
assert(batchDuration === Seconds(1),
"Batch duration has changed from 1 second, check cleanup tests")
- val ssc = setupStreams(cleanupTestInput, operation)
- val operatedStream = ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]]
- if (rememberDuration != null) ssc.remember(rememberDuration)
- val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput)
- val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
- assert(clock.time === Seconds(10).milliseconds)
- assert(output.size === numExpectedOutput)
- operatedStream
+ withStreamingContext(setupStreams(cleanupTestInput, operation)) { ssc =>
+ val operatedStream =
+ ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]]
+ if (rememberDuration != null) ssc.remember(rememberDuration)
+ val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput)
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ assert(clock.time === Seconds(10).milliseconds)
+ assert(output.size === numExpectedOutput)
+ operatedStream
+ }
}
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 2154c24abd..52972f63c6 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -164,6 +164,40 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
after(afterFunction)
/**
+ * Run a block of code with the given StreamingContext and automatically
+ * stop the context when the block completes or when an exception is thrown.
+ */
+ def withStreamingContext[R](ssc: StreamingContext)(block: StreamingContext => R): R = {
+ try {
+ block(ssc)
+ } finally {
+ try {
+ ssc.stop(stopSparkContext = true)
+ } catch {
+ case e: Exception =>
+ logError("Error stopping StreamingContext", e)
+ }
+ }
+ }
+
+ /**
+ * Run a block of code with the given TestServer and automatically
+ * stop the server when the block completes or when an exception is thrown.
+ */
+ def withTestServer[R](testServer: TestServer)(block: TestServer => R): R = {
+ try {
+ block(testServer)
+ } finally {
+ try {
+ testServer.stop()
+ } catch {
+ case e: Exception =>
+ logError("Error stopping TestServer", e)
+ }
+ }
+ }
+
+ /**
* Set up required DStreams to test the DStream operation using the two sequences
* of input collections.
*/
@@ -282,10 +316,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
assert(output.size === numExpectedOutput, "Unexpected number of outputs generated")
Thread.sleep(100) // Give some time for the forgetting old RDDs to complete
- } catch {
- case e: Exception => {e.printStackTrace(); throw e}
} finally {
- ssc.stop()
+ ssc.stop(stopSparkContext = true)
}
output
}
@@ -351,9 +383,10 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
useSet: Boolean
) {
val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size
- val ssc = setupStreams[U, V](input, operation)
- val output = runStreams[V](ssc, numBatches_, expectedOutput.size)
- verifyOutput[V](output, expectedOutput, useSet)
+ withStreamingContext(setupStreams[U, V](input, operation)) { ssc =>
+ val output = runStreams[V](ssc, numBatches_, expectedOutput.size)
+ verifyOutput[V](output, expectedOutput, useSet)
+ }
}
/**
@@ -389,8 +422,9 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
useSet: Boolean
) {
val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size
- val ssc = setupStreams[U, V, W](input1, input2, operation)
- val output = runStreams[W](ssc, numBatches_, expectedOutput.size)
- verifyOutput[W](output, expectedOutput, useSet)
+ withStreamingContext(setupStreams[U, V, W](input1, input2, operation)) { ssc =>
+ val output = runStreams[W](ssc, numBatches_, expectedOutput.size)
+ verifyOutput[W](output, expectedOutput, useSet)
+ }
}
}