diff options
author | Burak Yavuz <brkyvz@gmail.com> | 2017-03-03 10:35:15 -0800 |
---|---|---|
committer | Shixiong Zhu <shixiong@databricks.com> | 2017-03-03 10:35:15 -0800 |
commit | 9314c08377cc8da88f4e31d1a9d41376e96a81b3 (patch) | |
tree | 04bfe9cc2d2aa9601113a9b04a2095cf27cf5913 /sql | |
parent | 37a1c0e461737d4a4bbb03d397b651ec5ba00e96 (diff) | |
download | spark-9314c08377cc8da88f4e31d1a9d41376e96a81b3.tar.gz spark-9314c08377cc8da88f4e31d1a9d41376e96a81b3.tar.bz2 spark-9314c08377cc8da88f4e31d1a9d41376e96a81b3.zip |
[SPARK-19774] StreamExecution should call stop() on sources when a stream fails
## What changes were proposed in this pull request?
We call stop() on a Structured Streaming Source only when the stream is shutdown when a user calls streamingQuery.stop(). We should actually stop all sources when the stream fails as well, otherwise we may leak resources, e.g. connections to Kafka.
## How was this patch tested?
Unit tests in `StreamingQuerySuite`.
Author: Burak Yavuz <brkyvz@gmail.com>
Closes #17107 from brkyvz/close-source.
Diffstat (limited to 'sql')
3 files changed, 169 insertions, 3 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 4bd6431cbe..6e77f354b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -321,6 +321,7 @@ class StreamExecution( initializationLatch.countDown() try { + stopSources() state.set(TERMINATED) currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false) @@ -558,6 +559,18 @@ class StreamExecution( sparkSession.streams.postListenerEvent(event) } + /** Stops all streaming sources safely. */ + private def stopSources(): Unit = { + uniqueSources.foreach { source => + try { + source.stop() + } catch { + case NonFatal(e) => + logWarning(s"Failed to stop streaming source: $source. Resources may have leaked.", e) + } + } + } + /** * Signals to the thread executing micro-batches that it should stop running after the next * batch. This method blocks until the thread stops running. @@ -570,7 +583,6 @@ class StreamExecution( microBatchThread.interrupt() microBatchThread.join() } - uniqueSources.foreach(_.stop()) logInfo(s"Query $prettyIdString was stopped") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 1525ad5fd5..a0a2b2b4c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.streaming import java.util.concurrent.CountDownLatch import org.apache.commons.lang3.RandomStringUtils +import org.mockito.Mockito._ import org.scalactic.TolerantNumerics import org.scalatest.concurrent.Eventually._ import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.mock.MockitoSugar import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} @@ -32,11 +34,11 @@ import org.apache.spark.SparkException import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.util.BlockingSource +import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider} import org.apache.spark.util.ManualClock -class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { +class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging with MockitoSugar { import AwaitTerminationTester._ import testImplicits._ @@ -481,6 +483,75 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { } } + test("StreamExecution should call stop() on sources when a stream is stopped") { + var calledStop = false + val source = new Source { + override def stop(): Unit = { + calledStop = true + } + override def getOffset: Option[Offset] = None + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + spark.emptyDataFrame + } + override def schema: StructType = MockSourceProvider.fakeSchema + } + + MockSourceProvider.withMockSources(source) { + val df = spark.readStream + .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + .load() + + testStream(df)(StopStream) + + assert(calledStop, "Did not call stop on source for stopped stream") + } + } + + testQuietly("SPARK-19774: StreamExecution should call stop() on sources when a stream fails") { + var calledStop = false + val source1 = new Source { + override def stop(): Unit = { + throw new RuntimeException("Oh no!") + } + override def getOffset: Option[Offset] = Some(LongOffset(1)) + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + spark.range(2).toDF(MockSourceProvider.fakeSchema.fieldNames: _*) + } + override def schema: StructType = MockSourceProvider.fakeSchema + } + val source2 = new Source { + override def stop(): Unit = { + calledStop = true + } + override def getOffset: Option[Offset] = None + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + spark.emptyDataFrame + } + override def schema: StructType = MockSourceProvider.fakeSchema + } + + MockSourceProvider.withMockSources(source1, source2) { + val df1 = spark.readStream + .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + .load() + .as[Int] + + val df2 = spark.readStream + .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + .load() + .as[Int] + + testStream(df1.union(df2).map(i => i / 0))( + AssertOnQuery { sq => + intercept[StreamingQueryException](sq.processAllAvailable()) + sq.exception.isDefined && !sq.isActive + } + ) + + assert(calledStop, "Did not call stop on source for stopped stream") + } + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala new file mode 100644 index 0000000000..0bf05381a7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.util + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.streaming.Source +import org.apache.spark.sql.sources.StreamSourceProvider +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +/** + * A StreamSourceProvider that provides mocked Sources for unit testing. Example usage: + * + * {{{ + * MockSourceProvider.withMockSources(source1, source2) { + * val df1 = spark.readStream + * .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + * .load() + * + * val df2 = spark.readStream + * .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + * .load() + * + * df1.union(df2) + * ... + * } + * }}} + */ +class MockSourceProvider extends StreamSourceProvider { + override def sourceSchema( + spark: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + ("dummySource", MockSourceProvider.fakeSchema) + } + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + MockSourceProvider.sourceProviderFunction() + } +} + +object MockSourceProvider { + // Function to generate sources. May provide multiple sources if the user implements such a + // function. + private var sourceProviderFunction: () => Source = _ + + final val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) + + def withMockSources(source: Source, otherSources: Source*)(f: => Unit): Unit = { + var i = 0 + val sources = source +: otherSources + sourceProviderFunction = () => { + val source = sources(i % sources.length) + i += 1 + source + } + try { + f + } finally { + sourceProviderFunction = null + } + } +} |