aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala75
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala83
3 files changed, 169 insertions, 3 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 4bd6431cbe..6e77f354b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -321,6 +321,7 @@ class StreamExecution(
initializationLatch.countDown()
try {
+ stopSources()
state.set(TERMINATED)
currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false)
@@ -558,6 +559,18 @@ class StreamExecution(
sparkSession.streams.postListenerEvent(event)
}
+ /** Stops all streaming sources safely. */
+ private def stopSources(): Unit = {
+ uniqueSources.foreach { source =>
+ try {
+ source.stop()
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Failed to stop streaming source: $source. Resources may have leaked.", e)
+ }
+ }
+ }
+
/**
* Signals to the thread executing micro-batches that it should stop running after the next
* batch. This method blocks until the thread stops running.
@@ -570,7 +583,6 @@ class StreamExecution(
microBatchThread.interrupt()
microBatchThread.join()
}
- uniqueSources.foreach(_.stop())
logInfo(s"Query $prettyIdString was stopped")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 1525ad5fd5..a0a2b2b4c9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -20,10 +20,12 @@ package org.apache.spark.sql.streaming
import java.util.concurrent.CountDownLatch
import org.apache.commons.lang3.RandomStringUtils
+import org.mockito.Mockito._
import org.scalactic.TolerantNumerics
import org.scalatest.concurrent.Eventually._
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.PatienceConfiguration.Timeout
+import org.scalatest.mock.MockitoSugar
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -32,11 +34,11 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.streaming.util.BlockingSource
+import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider}
import org.apache.spark.util.ManualClock
-class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging {
+class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging with MockitoSugar {
import AwaitTerminationTester._
import testImplicits._
@@ -481,6 +483,75 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging {
}
}
+ test("StreamExecution should call stop() on sources when a stream is stopped") {
+ var calledStop = false
+ val source = new Source {
+ override def stop(): Unit = {
+ calledStop = true
+ }
+ override def getOffset: Option[Offset] = None
+ override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
+ spark.emptyDataFrame
+ }
+ override def schema: StructType = MockSourceProvider.fakeSchema
+ }
+
+ MockSourceProvider.withMockSources(source) {
+ val df = spark.readStream
+ .format("org.apache.spark.sql.streaming.util.MockSourceProvider")
+ .load()
+
+ testStream(df)(StopStream)
+
+ assert(calledStop, "Did not call stop on source for stopped stream")
+ }
+ }
+
+ testQuietly("SPARK-19774: StreamExecution should call stop() on sources when a stream fails") {
+ var calledStop = false
+ val source1 = new Source {
+ override def stop(): Unit = {
+ throw new RuntimeException("Oh no!")
+ }
+ override def getOffset: Option[Offset] = Some(LongOffset(1))
+ override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
+ spark.range(2).toDF(MockSourceProvider.fakeSchema.fieldNames: _*)
+ }
+ override def schema: StructType = MockSourceProvider.fakeSchema
+ }
+ val source2 = new Source {
+ override def stop(): Unit = {
+ calledStop = true
+ }
+ override def getOffset: Option[Offset] = None
+ override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
+ spark.emptyDataFrame
+ }
+ override def schema: StructType = MockSourceProvider.fakeSchema
+ }
+
+ MockSourceProvider.withMockSources(source1, source2) {
+ val df1 = spark.readStream
+ .format("org.apache.spark.sql.streaming.util.MockSourceProvider")
+ .load()
+ .as[Int]
+
+ val df2 = spark.readStream
+ .format("org.apache.spark.sql.streaming.util.MockSourceProvider")
+ .load()
+ .as[Int]
+
+ testStream(df1.union(df2).map(i => i / 0))(
+ AssertOnQuery { sq =>
+ intercept[StreamingQueryException](sq.processAllAvailable())
+ sq.exception.isDefined && !sq.isActive
+ }
+ )
+
+ assert(calledStop, "Did not call stop on source for stopped stream")
+ }
+ }
+
/** Create a streaming DF that only execute one batch in which it returns the given static DF */
private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = {
require(!triggerDF.isStreaming)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala
new file mode 100644
index 0000000000..0bf05381a7
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming.util
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.execution.streaming.Source
+import org.apache.spark.sql.sources.StreamSourceProvider
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+
+/**
+ * A StreamSourceProvider that provides mocked Sources for unit testing. Example usage:
+ *
+ * {{{
+ * MockSourceProvider.withMockSources(source1, source2) {
+ * val df1 = spark.readStream
+ * .format("org.apache.spark.sql.streaming.util.MockSourceProvider")
+ * .load()
+ *
+ * val df2 = spark.readStream
+ * .format("org.apache.spark.sql.streaming.util.MockSourceProvider")
+ * .load()
+ *
+ * df1.union(df2)
+ * ...
+ * }
+ * }}}
+ */
+class MockSourceProvider extends StreamSourceProvider {
+ override def sourceSchema(
+ spark: SQLContext,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): (String, StructType) = {
+ ("dummySource", MockSourceProvider.fakeSchema)
+ }
+
+ override def createSource(
+ spark: SQLContext,
+ metadataPath: String,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): Source = {
+ MockSourceProvider.sourceProviderFunction()
+ }
+}
+
+object MockSourceProvider {
+ // Function to generate sources. May provide multiple sources if the user implements such a
+ // function.
+ private var sourceProviderFunction: () => Source = _
+
+ final val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)
+
+ def withMockSources(source: Source, otherSources: Source*)(f: => Unit): Unit = {
+ var i = 0
+ val sources = source +: otherSources
+ sourceProviderFunction = () => {
+ val source = sources(i % sources.length)
+ i += 1
+ source
+ }
+ try {
+ f
+ } finally {
+ sourceProviderFunction = null
+ }
+ }
+}