aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2016-02-10 16:45:06 -0800
committerShixiong Zhu <shixiong@databricks.com>2016-02-10 16:45:06 -0800
commit0902e20288366db6270f3a444e66114b1b63a3e2 (patch)
treed65416bdfba1304bf1f1d2c2ff4ac6d0a90ff153
parent29c547303f886b96b74b411ac70f0fa81113f086 (diff)
downloadspark-0902e20288366db6270f3a444e66114b1b63a3e2.tar.gz
spark-0902e20288366db6270f3a444e66114b1b63a3e2.tar.bz2
spark-0902e20288366db6270f3a444e66114b1b63a3e2.zip
[SPARK-13146][SQL] Management API for continuous queries
### Management API for Continuous Queries **API for getting status of each query** - Whether active or not - Unique name of each query - Status of the sources and sinks - Exceptions **API for managing each query** - Immediately stop an active query - Waiting for a query to be terminated, correctly or with error **API for managing multiple queries** - Listing all active queries - Getting an active query by name - Waiting for any one of the active queries to be terminated **API for listening to query life cycle events** - ContinuousQueryListener API for query start, progress and termination events. Author: Tathagata Das <tathagata.das1565@gmail.com> Closes #11030 from tdas/streaming-df-management-api.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala72
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala54
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala193
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala34
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala34
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala82
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala215
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala67
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala252
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala306
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala139
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala69
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala222
17 files changed, 1680 insertions, 109 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala
index 1c2c0290fc..eb69804c39 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala
@@ -17,14 +17,84 @@
package org.apache.spark.sql
+import org.apache.spark.annotation.Experimental
+
/**
+ * :: Experimental ::
* A handle to a query that is executing continuously in the background as new data arrives.
+ * All these methods are thread-safe.
+ * @since 2.0.0
*/
+@Experimental
trait ContinuousQuery {
/**
- * Stops the execution of this query if it is running. This method blocks until the threads
+ * Returns the name of the query.
+ * @since 2.0.0
+ */
+ def name: String
+
+ /**
+ * Returns the SQLContext associated with `this` query
+ * @since 2.0.0
+ */
+ def sqlContext: SQLContext
+
+ /**
+ * Whether the query is currently active or not
+ * @since 2.0.0
+ */
+ def isActive: Boolean
+
+ /**
+ * Returns the [[ContinuousQueryException]] if the query was terminated by an exception.
+ * @since 2.0.0
+ */
+ def exception: Option[ContinuousQueryException]
+
+ /**
+ * Returns current status of all the sources.
+ * @since 2.0.0
+ */
+ def sourceStatuses: Array[SourceStatus]
+
+ /** Returns current status of the sink. */
+ def sinkStatus: SinkStatus
+
+ /**
+ * Waits for the termination of `this` query, either by `query.stop()` or by an exception.
+ * If the query has terminated with an exception, then the exception will be thrown.
+ *
+ * If the query has terminated, then all subsequent calls to this method will either return
+ * immediately (if the query was terminated by `stop()`), or throw the exception
+ * immediately (if the query has terminated with exception).
+ *
+ * @throws ContinuousQueryException, if `this` query has terminated with an exception.
+ *
+ * @since 2.0.0
+ */
+ def awaitTermination(): Unit
+
+ /**
+ * Waits for the termination of `this` query, either by `query.stop()` or by an exception.
+ * If the query has terminated with an exception, then the exception will be throw.
+ * Otherwise, it returns whether the query has terminated or not within the `timeoutMs`
+ * milliseconds.
+ *
+ * If the query has terminated, then all subsequent calls to this method will either return
+ * `true` immediately (if the query was terminated by `stop()`), or throw the exception
+ * immediately (if the query has terminated with exception).
+ *
+ * @throws ContinuousQueryException, if `this` query has terminated with an exception
+ *
+ * @since 2.0.0
+ */
+ def awaitTermination(timeoutMs: Long): Boolean
+
+ /**
+ * Stops the execution of this query if it is running. This method blocks until the threads
* performing execution has stopped.
+ * @since 2.0.0
*/
def stop(): Unit
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala
new file mode 100644
index 0000000000..67dd9dbe23
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.execution.streaming.{Offset, StreamExecution}
+
+/**
+ * :: Experimental ::
+ * Exception that stopped a [[ContinuousQuery]].
+ * @param query Query that caused the exception
+ * @param message Message of this exception
+ * @param cause Internal cause of this exception
+ * @param startOffset Starting offset (if known) of the range of data in which exception occurred
+ * @param endOffset Ending offset (if known) of the range of data in exception occurred
+ * @since 2.0.0
+ */
+@Experimental
+class ContinuousQueryException private[sql](
+ val query: ContinuousQuery,
+ val message: String,
+ val cause: Throwable,
+ val startOffset: Option[Offset] = None,
+ val endOffset: Option[Offset] = None
+ ) extends Exception(message, cause) {
+
+ /** Time when the exception occurred */
+ val time: Long = System.currentTimeMillis
+
+ override def toString(): String = {
+ val causeStr =
+ s"${cause.getMessage} ${cause.getStackTrace.take(10).mkString("", "\n|\t", "\n")}"
+ s"""
+ |$causeStr
+ |
+ |${query.asInstanceOf[StreamExecution].toDebugString}
+ """.stripMargin
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
new file mode 100644
index 0000000000..13142d0e61
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution}
+import org.apache.spark.sql.util.ContinuousQueryListener
+
+/**
+ * :: Experimental ::
+ * A class to manage all the [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active
+ * on a [[SQLContext]].
+ *
+ * @since 2.0.0
+ */
+@Experimental
+class ContinuousQueryManager(sqlContext: SQLContext) {
+
+ private val listenerBus = new ContinuousQueryListenerBus(sqlContext.sparkContext.listenerBus)
+ private val activeQueries = new mutable.HashMap[String, ContinuousQuery]
+ private val activeQueriesLock = new Object
+ private val awaitTerminationLock = new Object
+
+ private var lastTerminatedQuery: ContinuousQuery = null
+
+ /**
+ * Returns a list of active queries associated with this SQLContext
+ *
+ * @since 2.0.0
+ */
+ def active: Array[ContinuousQuery] = activeQueriesLock.synchronized {
+ activeQueries.values.toArray
+ }
+
+ /**
+ * Returns an active query from this SQLContext or throws exception if bad name
+ *
+ * @since 2.0.0
+ */
+ def get(name: String): ContinuousQuery = activeQueriesLock.synchronized {
+ activeQueries.get(name).getOrElse {
+ throw new IllegalArgumentException(s"There is no active query with name $name")
+ }
+ }
+
+ /**
+ * Wait until any of the queries on the associated SQLContext has terminated since the
+ * creation of the context, or since `resetTerminated()` was called. If any query was terminated
+ * with an exception, then the exception will be thrown.
+ *
+ * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either
+ * return immediately (if the query was terminated by `query.stop()`),
+ * or throw the exception immediately (if the query was terminated with exception). Use
+ * `resetTerminated()` to clear past terminations and wait for new terminations.
+ *
+ * In the case where multiple queries have terminated since `resetTermination()` was called,
+ * if any query has terminated with exception, then `awaitAnyTermination()` will
+ * throw any of the exception. For correctly documenting exceptions across multiple queries,
+ * users need to stop all of them after any of them terminates with exception, and then check the
+ * `query.exception()` for each query.
+ *
+ * @throws ContinuousQueryException, if any query has terminated with an exception
+ *
+ * @since 2.0.0
+ */
+ def awaitAnyTermination(): Unit = {
+ awaitTerminationLock.synchronized {
+ while (lastTerminatedQuery == null) {
+ awaitTerminationLock.wait(10)
+ }
+ if (lastTerminatedQuery != null && lastTerminatedQuery.exception.nonEmpty) {
+ throw lastTerminatedQuery.exception.get
+ }
+ }
+ }
+
+ /**
+ * Wait until any of the queries on the associated SQLContext has terminated since the
+ * creation of the context, or since `resetTerminated()` was called. Returns whether any query
+ * has terminated or not (multiple may have terminated). If any query has terminated with an
+ * exception, then the exception will be thrown.
+ *
+ * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either
+ * return `true` immediately (if the query was terminated by `query.stop()`),
+ * or throw the exception immediately (if the query was terminated with exception). Use
+ * `resetTerminated()` to clear past terminations and wait for new terminations.
+ *
+ * In the case where multiple queries have terminated since `resetTermination()` was called,
+ * if any query has terminated with exception, then `awaitAnyTermination()` will
+ * throw any of the exception. For correctly documenting exceptions across multiple queries,
+ * users need to stop all of them after any of them terminates with exception, and then check the
+ * `query.exception()` for each query.
+ *
+ * @throws ContinuousQueryException, if any query has terminated with an exception
+ *
+ * @since 2.0.0
+ */
+ def awaitAnyTermination(timeoutMs: Long): Boolean = {
+
+ val startTime = System.currentTimeMillis
+ def isTimedout = System.currentTimeMillis - startTime >= timeoutMs
+
+ awaitTerminationLock.synchronized {
+ while (!isTimedout && lastTerminatedQuery == null) {
+ awaitTerminationLock.wait(10)
+ }
+ if (lastTerminatedQuery != null && lastTerminatedQuery.exception.nonEmpty) {
+ throw lastTerminatedQuery.exception.get
+ }
+ lastTerminatedQuery != null
+ }
+ }
+
+ /**
+ * Forget about past terminated queries so that `awaitAnyTermination()` can be used again to
+ * wait for new terminations.
+ *
+ * @since 2.0.0
+ */
+ def resetTerminated(): Unit = {
+ awaitTerminationLock.synchronized {
+ lastTerminatedQuery = null
+ }
+ }
+
+ /**
+ * Register a [[ContinuousQueryListener]] to receive up-calls for life cycle events of
+ * [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]].
+ *
+ * @since 2.0.0
+ */
+ def addListener(listener: ContinuousQueryListener): Unit = {
+ listenerBus.addListener(listener)
+ }
+
+ /**
+ * Deregister a [[ContinuousQueryListener]].
+ *
+ * @since 2.0.0
+ */
+ def removeListener(listener: ContinuousQueryListener): Unit = {
+ listenerBus.removeListener(listener)
+ }
+
+ /** Post a listener event */
+ private[sql] def postListenerEvent(event: ContinuousQueryListener.Event): Unit = {
+ listenerBus.post(event)
+ }
+
+ /** Start a query */
+ private[sql] def startQuery(name: String, df: DataFrame, sink: Sink): ContinuousQuery = {
+ activeQueriesLock.synchronized {
+ if (activeQueries.contains(name)) {
+ throw new IllegalArgumentException(
+ s"Cannot start query with name $name as a query with that name is already active")
+ }
+ val query = new StreamExecution(sqlContext, name, df.logicalPlan, sink)
+ query.start()
+ activeQueries.put(name, query)
+ query
+ }
+ }
+
+ /** Notify (by the ContinuousQuery) that the query has been terminated */
+ private[sql] def notifyQueryTermination(terminatedQuery: ContinuousQuery): Unit = {
+ activeQueriesLock.synchronized {
+ activeQueries -= terminatedQuery.name
+ }
+ awaitTerminationLock.synchronized {
+ if (lastTerminatedQuery == null || terminatedQuery.exception.nonEmpty) {
+ lastTerminatedQuery = terminatedQuery
+ }
+ awaitTerminationLock.notifyAll()
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 8060198968..d6bdd3d825 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -206,6 +206,17 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}
/**
+ * Specifies the name of the [[ContinuousQuery]] that can be started with `stream()`.
+ * This name must be unique among all the currently active queries in the associated SQLContext.
+ *
+ * @since 2.0.0
+ */
+ def queryName(queryName: String): DataFrameWriter = {
+ this.extraOptions += ("queryName" -> queryName)
+ this
+ }
+
+ /**
* Starts the execution of the streaming query, which will continually output results to the given
* path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with
* the stream.
@@ -230,7 +241,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
extraOptions.toMap,
normalizedParCols.getOrElse(Nil))
- new StreamExecution(df.sqlContext, df.logicalPlan, sink)
+ df.sqlContext.continuousQueryManager.startQuery(
+ extraOptions.getOrElse("queryName", StreamExecution.nextName), df, sink)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 1661fdbec5..050a1031c0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -181,6 +181,8 @@ class SQLContext private[sql](
@transient
lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager
+ protected[sql] lazy val continuousQueryManager = new ContinuousQueryManager(this)
+
@transient
protected[sql] lazy val catalog: Catalog = new SimpleCatalog(conf)
@@ -836,6 +838,16 @@ class SQLContext private[sql](
}
/**
+ * Returns a [[ContinuousQueryManager]] that allows managing all the
+ * [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active on `this` context.
+ *
+ * @since 2.0.0
+ */
+ def streams: ContinuousQueryManager = {
+ continuousQueryManager
+ }
+
+ /**
* Returns the names of tables in the current database as an array.
*
* @group ddl_ops
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala
new file mode 100644
index 0000000000..ce21451b2c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.execution.streaming.{Offset, Sink}
+
+/**
+ * :: Experimental ::
+ * Status and metrics of a streaming [[Sink]].
+ *
+ * @param description Description of the source corresponding to this status
+ * @param offset Current offset up to which data has been written by the sink
+ * @since 2.0.0
+ */
+@Experimental
+class SinkStatus private[sql](
+ val description: String,
+ val offset: Option[Offset])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala
new file mode 100644
index 0000000000..2479e67e36
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.execution.streaming.{Offset, Source}
+
+/**
+ * :: Experimental ::
+ * Status and metrics of a streaming [[Source]].
+ *
+ * @param description Description of the source corresponding to this status
+ * @param offset Current offset of the source, if known
+ * @since 2.0.0
+ */
+@Experimental
+class SourceStatus private[sql] (
+ val description: String,
+ val offset: Option[Offset])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala
new file mode 100644
index 0000000000..b1d24b6cfc
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerEvent}
+import org.apache.spark.sql.util.ContinuousQueryListener
+import org.apache.spark.sql.util.ContinuousQueryListener._
+import org.apache.spark.util.ListenerBus
+
+/**
+ * A bus to forward events to [[ContinuousQueryListener]]s. This one will wrap received
+ * [[ContinuousQueryListener.Event]]s as WrappedContinuousQueryListenerEvents and send them to the
+ * Spark listener bus. It also registers itself with Spark listener bus, so that it can receive
+ * WrappedContinuousQueryListenerEvents, unwrap them as ContinuousQueryListener.Events and
+ * dispatch them to ContinuousQueryListener.
+ */
+class ContinuousQueryListenerBus(sparkListenerBus: LiveListenerBus)
+ extends SparkListener with ListenerBus[ContinuousQueryListener, ContinuousQueryListener.Event] {
+
+ sparkListenerBus.addListener(this)
+
+ /**
+ * Post a ContinuousQueryListener event to the Spark listener bus asynchronously. This event will
+ * be dispatched to all ContinuousQueryListener in the thread of the Spark listener bus.
+ */
+ def post(event: ContinuousQueryListener.Event) {
+ event match {
+ case s: QueryStarted =>
+ postToAll(s)
+ case _ =>
+ sparkListenerBus.post(new WrappedContinuousQueryListenerEvent(event))
+ }
+ }
+
+ override def onOtherEvent(event: SparkListenerEvent): Unit = {
+ event match {
+ case WrappedContinuousQueryListenerEvent(e) =>
+ postToAll(e)
+ case _ =>
+ }
+ }
+
+ override protected def doPostEvent(
+ listener: ContinuousQueryListener,
+ event: ContinuousQueryListener.Event): Unit = {
+ event match {
+ case queryStarted: QueryStarted =>
+ listener.onQueryStarted(queryStarted)
+ case queryProgress: QueryProgress =>
+ listener.onQueryProgress(queryProgress)
+ case queryTerminated: QueryTerminated =>
+ listener.onQueryTerminated(queryTerminated)
+ case _ =>
+ }
+ }
+
+ /**
+ * Wrapper for StreamingListenerEvent as SparkListenerEvent so that it can be posted to Spark
+ * listener bus.
+ */
+ private case class WrappedContinuousQueryListenerEvent(
+ streamingListenerEvent: ContinuousQueryListener.Event) extends SparkListenerEvent {
+
+ // Do not log streaming events in event log as history server does not support these events.
+ protected[spark] override def logEvent: Boolean = false
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index ebebb82971..bc7c520930 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -17,16 +17,20 @@
package org.apache.spark.sql.execution.streaming
-import java.lang.Thread.UncaughtExceptionHandler
+import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.ArrayBuffer
+import scala.util.control.NonFatal
import org.apache.spark.Logging
-import org.apache.spark.sql.{ContinuousQuery, DataFrame, SQLContext}
+import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.util.ContinuousQueryListener
+import org.apache.spark.sql.util.ContinuousQueryListener._
/**
* Manages the execution of a streaming Spark SQL query that is occurring in a separate thread.
@@ -35,15 +39,15 @@ import org.apache.spark.sql.execution.QueryExecution
* and the results are committed transactionally to the given [[Sink]].
*/
class StreamExecution(
- sqlContext: SQLContext,
+ val sqlContext: SQLContext,
+ override val name: String,
private[sql] val logicalPlan: LogicalPlan,
val sink: Sink) extends ContinuousQuery with Logging {
/** An monitor used to wait/notify when batches complete. */
private val awaitBatchLock = new Object
-
- @volatile
- private var batchRun = false
+ private val startLatch = new CountDownLatch(1)
+ private val terminationLatch = new CountDownLatch(1)
/** Minimum amount of time in between the start of each batch. */
private val minBatchTime = 10
@@ -55,9 +59,92 @@ class StreamExecution(
private val sources =
logicalPlan.collect { case s: StreamingRelation => s.source }
- // Start the execution at the current offsets stored in the sink. (i.e. avoid reprocessing data
- // that we have already processed).
- {
+ /** Defines the internal state of execution */
+ @volatile
+ private var state: State = INITIALIZED
+
+ @volatile
+ private[sql] var lastExecution: QueryExecution = null
+
+ @volatile
+ private[sql] var streamDeathCause: ContinuousQueryException = null
+
+ /** The thread that runs the micro-batches of this stream. */
+ private[sql] val microBatchThread = new Thread(s"stream execution thread for $name") {
+ override def run(): Unit = { runBatches() }
+ }
+
+ /** Whether the query is currently active or not */
+ override def isActive: Boolean = state == ACTIVE
+
+ /** Returns current status of all the sources. */
+ override def sourceStatuses: Array[SourceStatus] = {
+ sources.map(s => new SourceStatus(s.toString, streamProgress.get(s))).toArray
+ }
+
+ /** Returns current status of the sink. */
+ override def sinkStatus: SinkStatus = new SinkStatus(sink.toString, sink.currentOffset)
+
+ /** Returns the [[ContinuousQueryException]] if the query was terminated by an exception. */
+ override def exception: Option[ContinuousQueryException] = Option(streamDeathCause)
+
+ /**
+ * Starts the execution. This returns only after the thread has started and [[QueryStarted]] event
+ * has been posted to all the listeners.
+ */
+ private[sql] def start(): Unit = {
+ microBatchThread.setDaemon(true)
+ microBatchThread.start()
+ startLatch.await() // Wait until thread started and QueryStart event has been posted
+ }
+
+ /**
+ * Repeatedly attempts to run batches as data arrives.
+ *
+ * Note that this method ensures that [[QueryStarted]] and [[QueryTerminated]] events are posted
+ * so that listeners are guaranteed to get former event before the latter. Furthermore, this
+ * method also ensures that [[QueryStarted]] event is posted before the `start()` method returns.
+ */
+ private def runBatches(): Unit = {
+ try {
+ // Mark ACTIVE and then post the event. QueryStarted event is synchronously sent to listeners,
+ // so must mark this as ACTIVE first.
+ state = ACTIVE
+ postEvent(new QueryStarted(this)) // Assumption: Does not throw exception.
+
+ // Unblock starting thread
+ startLatch.countDown()
+
+ // While active, repeatedly attempt to run batches.
+ SQLContext.setActive(sqlContext)
+ populateStartOffsets()
+ logInfo(s"Stream running at $streamProgress")
+ while (isActive) {
+ attemptBatch()
+ Thread.sleep(minBatchTime) // TODO: Could be tighter
+ }
+ } catch {
+ case _: InterruptedException if state == TERMINATED => // interrupted by stop()
+ case NonFatal(e) =>
+ streamDeathCause = new ContinuousQueryException(
+ this,
+ s"Query $name terminated with exception: ${e.getMessage}",
+ e,
+ Some(streamProgress.toCompositeOffset(sources)))
+ logError(s"Query $name terminated with error", e)
+ } finally {
+ state = TERMINATED
+ sqlContext.streams.notifyQueryTermination(StreamExecution.this)
+ postEvent(new QueryTerminated(this))
+ terminationLatch.countDown()
+ }
+ }
+
+ /**
+ * Populate the start offsets to start the execution at the current offsets stored in the sink
+ * (i.e. avoid reprocessing data that we have already processed).
+ */
+ private def populateStartOffsets(): Unit = {
sink.currentOffset match {
case Some(c: CompositeOffset) =>
val storedProgress = c.offsets
@@ -74,37 +161,8 @@ class StreamExecution(
}
}
- logInfo(s"Stream running at $streamProgress")
-
- /** When false, signals to the microBatchThread that it should stop running. */
- @volatile private var shouldRun = true
-
- /** The thread that runs the micro-batches of this stream. */
- private[sql] val microBatchThread = new Thread("stream execution thread") {
- override def run(): Unit = {
- SQLContext.setActive(sqlContext)
- while (shouldRun) {
- attemptBatch()
- Thread.sleep(minBatchTime) // TODO: Could be tighter
- }
- }
- }
- microBatchThread.setDaemon(true)
- microBatchThread.setUncaughtExceptionHandler(
- new UncaughtExceptionHandler {
- override def uncaughtException(t: Thread, e: Throwable): Unit = {
- streamDeathCause = e
- }
- })
- microBatchThread.start()
-
- @volatile
- private[sql] var lastExecution: QueryExecution = null
- @volatile
- private[sql] var streamDeathCause: Throwable = null
-
/**
- * Checks to see if any new data is present in any of the sources. When new data is available,
+ * Checks to see if any new data is present in any of the sources. When new data is available,
* a batch is executed and passed to the sink, updating the currentOffsets.
*/
private def attemptBatch(): Unit = {
@@ -150,36 +208,43 @@ class StreamExecution(
streamProgress.synchronized {
// Update the offsets and calculate a new composite offset
newOffsets.foreach(streamProgress.update)
- val newStreamProgress = logicalPlan.collect {
- case StreamingRelation(source, _) => streamProgress.get(source)
- }
- val batchOffset = CompositeOffset(newStreamProgress)
// Construct the batch and send it to the sink.
+ val batchOffset = streamProgress.toCompositeOffset(sources)
val nextBatch = new Batch(batchOffset, new DataFrame(sqlContext, newPlan))
sink.addBatch(nextBatch)
}
- batchRun = true
awaitBatchLock.synchronized {
// Wake up any threads that are waiting for the stream to progress.
awaitBatchLock.notifyAll()
}
val batchTime = (System.nanoTime() - startTime).toDouble / 1000000
- logInfo(s"Compete up to $newOffsets in ${batchTime}ms")
+ logInfo(s"Completed up to $newOffsets in ${batchTime}ms")
+ postEvent(new QueryProgress(this))
}
logDebug(s"Waiting for data, current: $streamProgress")
}
+ private def postEvent(event: ContinuousQueryListener.Event) {
+ sqlContext.streams.postListenerEvent(event)
+ }
+
/**
* Signals to the thread executing micro-batches that it should stop running after the next
* batch. This method blocks until the thread stops running.
*/
- def stop(): Unit = {
- shouldRun = false
- if (microBatchThread.isAlive) { microBatchThread.join() }
+ override def stop(): Unit = {
+ // Set the state to TERMINATED so that the batching thread knows that it was interrupted
+ // intentionally
+ state = TERMINATED
+ if (microBatchThread.isAlive) {
+ microBatchThread.interrupt()
+ microBatchThread.join()
+ }
+ logInfo(s"Query $name was stopped")
}
/**
@@ -198,14 +263,60 @@ class StreamExecution(
logDebug(s"Unblocked at $newOffset for $source")
}
- override def toString: String =
+ override def awaitTermination(): Unit = {
+ if (state == INITIALIZED) {
+ throw new IllegalStateException("Cannot wait for termination on a query that has not started")
+ }
+ terminationLatch.await()
+ if (streamDeathCause != null) {
+ throw streamDeathCause
+ }
+ }
+
+ override def awaitTermination(timeoutMs: Long): Boolean = {
+ if (state == INITIALIZED) {
+ throw new IllegalStateException("Cannot wait for termination on a query that has not started")
+ }
+ require(timeoutMs > 0, "Timeout has to be positive")
+ terminationLatch.await(timeoutMs, TimeUnit.MILLISECONDS)
+ if (streamDeathCause != null) {
+ throw streamDeathCause
+ } else {
+ !isActive
+ }
+ }
+
+ override def toString: String = {
+ s"Continuous Query - $name [state = $state]"
+ }
+
+ def toDebugString: String = {
+ val deathCauseStr = if (streamDeathCause != null) {
+ "Error:\n" + stackTraceToString(streamDeathCause.cause)
+ } else ""
s"""
- |=== Streaming Query ===
- |CurrentOffsets: $streamProgress
+ |=== Continuous Query ===
+ |Name: $name
+ |Current Offsets: $streamProgress
+ |
+ |Current State: $state
|Thread State: ${microBatchThread.getState}
- |${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""}
|
+ |Logical Plan:
|$logicalPlan
+ |
+ |$deathCauseStr
""".stripMargin
+ }
+
+ trait State
+ case object INITIALIZED extends State
+ case object ACTIVE extends State
+ case object TERMINATED extends State
}
+private[sql] object StreamExecution {
+ private val nextId = new AtomicInteger()
+
+ def nextName: String = s"query-${nextId.getAndIncrement}"
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
index 0ded1d7152..d45b9bd983 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
@@ -55,6 +55,10 @@ class StreamProgress {
copied
}
+ private[sql] def toCompositeOffset(source: Seq[Source]): CompositeOffset = {
+ CompositeOffset(source.map(get))
+ }
+
override def toString: String =
currentOffsets.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index e6a0842936..8124df15af 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.execution.streaming
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.ArrayBuffer
+import scala.util.control.NonFatal
import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext}
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{encoderFor, RowEncoder}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.StructType
object MemoryStream {
@@ -46,14 +47,13 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
protected val logicalPlan = StreamingRelation(this)
protected val output = logicalPlan.output
protected val batches = new ArrayBuffer[Dataset[A]]
+
protected var currentOffset: LongOffset = new LongOffset(-1)
protected def blockManager = SparkEnv.get.blockManager
def schema: StructType = encoder.schema
- def getCurrentOffset: Offset = currentOffset
-
def toDS()(implicit sqlContext: SQLContext): Dataset[A] = {
new Dataset(sqlContext, logicalPlan)
}
@@ -62,6 +62,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
new DataFrame(sqlContext, logicalPlan)
}
+ def addData(data: A*): Offset = {
+ addData(data.toTraversable)
+ }
+
def addData(data: TraversableOnce[A]): Offset = {
import sqlContext.implicits._
this.synchronized {
@@ -110,6 +114,7 @@ class MemorySink(schema: StructType) extends Sink with Logging {
}
override def addBatch(nextBatch: Batch): Unit = synchronized {
+ nextBatch.data.collect() // 'compute' the batch's data and record the batch
batches.append(nextBatch)
}
@@ -131,8 +136,13 @@ class MemorySink(schema: StructType) extends Sink with Logging {
batches.dropRight(num)
}
- override def toString: String = synchronized {
- batches.map(b => s"${b.end}: ${b.data.collect().mkString(" ")}").mkString("\n")
+ def toDebugString: String = synchronized {
+ batches.map { b =>
+ val dataStr = try b.data.collect().mkString(" ") catch {
+ case NonFatal(e) => "[Error converting to string]"
+ }
+ s"${b.end}: $dataStr"
+ }.mkString("\n")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala
new file mode 100644
index 0000000000..73c78d1b62
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.util
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.ContinuousQuery
+import org.apache.spark.sql.util.ContinuousQueryListener._
+
+/**
+ * :: Experimental ::
+ * Interface for listening to events related to [[ContinuousQuery ContinuousQueries]].
+ * @note The methods are not thread-safe as they may be called from different threads.
+ */
+@Experimental
+abstract class ContinuousQueryListener {
+
+ /**
+ * Called when a query is started.
+ * @note This is called synchronously with
+ * [[org.apache.spark.sql.DataFrameWriter `DataFrameWriter.stream()`]],
+ * that is, `onQueryStart` will be called on all listeners before `DataFrameWriter.stream()`
+ * returns the corresponding [[ContinuousQuery]].
+ */
+ def onQueryStarted(queryStarted: QueryStarted)
+
+ /** Called when there is some status update (ingestion rate updated, etc. */
+ def onQueryProgress(queryProgress: QueryProgress)
+
+ /** Called when a query is stopped, with or without error */
+ def onQueryTerminated(queryTerminated: QueryTerminated)
+}
+
+
+/**
+ * :: Experimental ::
+ * Companion object of [[ContinuousQueryListener]] that defines the listener events.
+ */
+@Experimental
+object ContinuousQueryListener {
+
+ /** Base type of [[ContinuousQueryListener]] events */
+ trait Event
+
+ /** Event representing the start of a query */
+ class QueryStarted private[sql](val query: ContinuousQuery) extends Event
+
+ /** Event representing any progress updates in a query */
+ class QueryProgress private[sql](val query: ContinuousQuery) extends Event
+
+ /** Event representing that termination of a query */
+ class QueryTerminated private[sql](val query: ContinuousQuery) extends Event
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
index 7e388ea602..62710e72fb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
@@ -21,9 +21,16 @@ import java.lang.Thread.UncaughtExceptionHandler
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
+import scala.language.experimental.macros
+import scala.reflect.ClassTag
import scala.util.Random
+import scala.util.control.NonFatal
-import org.scalatest.concurrent.Timeouts
+import org.scalatest.Assertions
+import org.scalatest.concurrent.{Eventually, Timeouts}
+import org.scalatest.concurrent.PatienceConfiguration.Timeout
+import org.scalatest.exceptions.TestFailedDueToTimeoutException
+import org.scalatest.time.Span
import org.scalatest.time.SpanSugar._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder}
@@ -64,7 +71,7 @@ trait StreamTest extends QueryTest with Timeouts {
}
/** How long to wait for an active stream to catch up when checking a result. */
- val streamingTimout = 10.seconds
+ val streamingTimeout = 10.seconds
/** A trait for actions that can be performed while testing a streaming DataFrame. */
trait StreamAction
@@ -128,7 +135,38 @@ trait StreamTest extends QueryTest with Timeouts {
case object StartStream extends StreamAction
/** Signals that a failure is expected and should not kill the test. */
- case object ExpectFailure extends StreamAction
+ case class ExpectFailure[T <: Throwable : ClassTag]() extends StreamAction {
+ val causeClass: Class[T] = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]]
+ override def toString(): String = s"ExpectFailure[${causeClass.getCanonicalName}]"
+ }
+
+ /** Assert that a body is true */
+ class Assert(condition: => Boolean, val message: String = "") extends StreamAction {
+ def run(): Unit = { Assertions.assert(condition) }
+ override def toString: String = s"Assert(<condition>, $message)"
+ }
+
+ object Assert {
+ def apply(condition: => Boolean, message: String = ""): Assert = new Assert(condition, message)
+ def apply(message: String)(body: => Unit): Assert = new Assert( { body; true }, message)
+ def apply(body: => Unit): Assert = new Assert( { body; true }, "")
+ }
+
+ /** Assert that a condition on the active query is true */
+ class AssertOnQuery(val condition: StreamExecution => Boolean, val message: String)
+ extends StreamAction {
+ override def toString: String = s"AssertOnQuery(<condition>, $message)"
+ }
+
+ object AssertOnQuery {
+ def apply(condition: StreamExecution => Boolean, message: String = ""): AssertOnQuery = {
+ new AssertOnQuery(condition, message)
+ }
+
+ def apply(message: String)(condition: StreamExecution => Boolean): AssertOnQuery = {
+ new AssertOnQuery(condition, message)
+ }
+ }
/** A helper for running actions on a Streaming Dataset. See `checkAnswer(DataFrame)`. */
def testStream(stream: Dataset[_])(actions: StreamAction*): Unit =
@@ -145,6 +183,7 @@ trait StreamTest extends QueryTest with Timeouts {
var pos = 0
var currentPlan: LogicalPlan = stream.logicalPlan
var currentStream: StreamExecution = null
+ var lastStream: StreamExecution = null
val awaiting = new mutable.HashMap[Source, Offset]()
val sink = new MemorySink(stream.schema)
@@ -170,6 +209,7 @@ trait StreamTest extends QueryTest with Timeouts {
def threadState =
if (currentStream != null && currentStream.microBatchThread.isAlive) "alive" else "dead"
+
def testState =
s"""
|== Progress ==
@@ -181,16 +221,49 @@ trait StreamTest extends QueryTest with Timeouts {
|${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""}
|
|== Sink ==
- |$sink
+ |${sink.toDebugString}
|
|== Plan ==
|${if (currentStream != null) currentStream.lastExecution else ""}
- """
+ """.stripMargin
+
+ def verify(condition: => Boolean, message: String): Unit = {
+ try {
+ Assertions.assert(condition)
+ } catch {
+ case NonFatal(e) =>
+ failTest(message, e)
+ }
+ }
+
+ def eventually[T](message: String)(func: => T): T = {
+ try {
+ Eventually.eventually(Timeout(streamingTimeout)) {
+ func
+ }
+ } catch {
+ case NonFatal(e) =>
+ failTest(message, e)
+ }
+ }
+
+ def failTest(message: String, cause: Throwable = null) = {
- def checkState(check: Boolean, error: String) = if (!check) {
+ // Recursively pretty print a exception with truncated stacktrace and internal cause
+ def exceptionToString(e: Throwable, prefix: String = ""): String = {
+ val base = s"$prefix${e.getMessage}" +
+ e.getStackTrace.take(10).mkString(s"\n$prefix", s"\n$prefix\t", "\n")
+ if (e.getCause != null) {
+ base + s"\n$prefix\tCaused by: " + exceptionToString(e.getCause, s"$prefix\t")
+ } else {
+ base
+ }
+ }
+ val c = Option(cause).map(exceptionToString(_))
+ val m = if (message != null && message.size > 0) Some(message) else None
fail(
s"""
- |Invalid State: $error
+ |${(m ++ c).mkString(": ")}
|$testState
""".stripMargin)
}
@@ -201,9 +274,13 @@ trait StreamTest extends QueryTest with Timeouts {
startedTest.foreach { action =>
action match {
case StartStream =>
- checkState(currentStream == null, "stream already running")
-
- currentStream = new StreamExecution(sqlContext, stream.logicalPlan, sink)
+ verify(currentStream == null, "stream already running")
+ lastStream = currentStream
+ currentStream =
+ sqlContext
+ .streams
+ .startQuery(StreamExecution.nextName, stream, sink)
+ .asInstanceOf[StreamExecution]
currentStream.microBatchThread.setUncaughtExceptionHandler(
new UncaughtExceptionHandler {
override def uncaughtException(t: Thread, e: Throwable): Unit = {
@@ -213,77 +290,100 @@ trait StreamTest extends QueryTest with Timeouts {
})
case StopStream =>
- checkState(currentStream != null, "can not stop a stream that is not running")
- currentStream.stop()
- currentStream = null
+ verify(currentStream != null, "can not stop a stream that is not running")
+ try failAfter(streamingTimeout) {
+ currentStream.stop()
+ verify(!currentStream.microBatchThread.isAlive,
+ s"microbatch thread not stopped")
+ verify(!currentStream.isActive,
+ "query.isActive() is false even after stopping")
+ verify(currentStream.exception.isEmpty,
+ s"query.exception() is not empty after clean stop: " +
+ currentStream.exception.map(_.toString()).getOrElse(""))
+ } catch {
+ case _: InterruptedException =>
+ case _: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
+ failTest("Timed out while stopping and waiting for microbatchthread to terminate.")
+ case t: Throwable =>
+ failTest("Error while stopping stream", t)
+ } finally {
+ lastStream = currentStream
+ currentStream = null
+ }
case DropBatches(num) =>
- checkState(currentStream == null, "dropping batches while running leads to corruption")
+ verify(currentStream == null, "dropping batches while running leads to corruption")
sink.dropBatches(num)
- case ExpectFailure =>
- try failAfter(streamingTimout) {
- while (streamDeathCause == null) {
- Thread.sleep(100)
+ case ef: ExpectFailure[_] =>
+ verify(currentStream != null, "can not expect failure when stream is not running")
+ try failAfter(streamingTimeout) {
+ val thrownException = intercept[ContinuousQueryException] {
+ currentStream.awaitTermination()
}
+ eventually("microbatch thread not stopped after termination with failure") {
+ assert(!currentStream.microBatchThread.isAlive)
+ }
+ verify(thrownException.query.eq(currentStream),
+ s"incorrect query reference in exception")
+ verify(currentStream.exception === Some(thrownException),
+ s"incorrect exception returned by query.exception()")
+
+ val exception = currentStream.exception.get
+ verify(exception.cause.getClass === ef.causeClass,
+ "incorrect cause in exception returned by query.exception()\n" +
+ s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}")
} catch {
case _: InterruptedException =>
case _: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
- fail(
- s"""
- |Timed out while waiting for failure.
- |$testState
- """.stripMargin)
+ failTest("Timed out while waiting for failure")
+ case t: Throwable =>
+ failTest("Error while checking stream failure", t)
+ } finally {
+ lastStream = currentStream
+ currentStream = null
+ streamDeathCause = null
}
- currentStream = null
- streamDeathCause = null
+ case a: AssertOnQuery =>
+ verify(currentStream != null || lastStream != null,
+ "cannot assert when not stream has been started")
+ val streamToAssert = Option(currentStream).getOrElse(lastStream)
+ verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}")
+
+ case a: Assert =>
+ val streamToAssert = Option(currentStream).getOrElse(lastStream)
+ verify({ a.run(); true }, s"Assert failed: ${a.message}")
case a: AddData =>
awaiting.put(a.source, a.addData())
case CheckAnswerRows(expectedAnswer) =>
- checkState(currentStream != null, "stream not running")
+ verify(currentStream != null, "stream not running")
// Block until all data added has been processed
awaiting.foreach { case (source, offset) =>
- failAfter(streamingTimout) {
+ failAfter(streamingTimeout) {
currentStream.awaitOffset(source, offset)
}
}
val allData = try sink.allData catch {
case e: Exception =>
- fail(
- s"""
- |Exception while getting data from sink $e
- |$testState
- """.stripMargin)
+ failTest("Exception while getting data from sink", e)
}
QueryTest.sameRows(expectedAnswer, allData).foreach {
- error => fail(
- s"""
- |$error
- |$testState
- """.stripMargin)
+ error => failTest(error)
}
}
pos += 1
}
} catch {
case _: InterruptedException if streamDeathCause != null =>
- fail(
- s"""
- |Stream Thread Died
- |$testState
- """.stripMargin)
+ failTest("Stream Thread Died")
case _: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
- fail(
- s"""
- |Timed out waiting for stream
- |$testState
- """.stripMargin)
+ failTest("Timed out waiting for stream")
} finally {
if (currentStream != null && currentStream.microBatchThread.isAlive) {
currentStream.stop()
@@ -335,7 +435,8 @@ trait StreamTest extends QueryTest with Timeouts {
case r if r < 0.7 => // AddData
addRandomData()
- case _ => // StartStream
+ case _ => // StopStream
+ addCheck()
actions += StopStream
running = false
}
@@ -345,4 +446,59 @@ trait StreamTest extends QueryTest with Timeouts {
addCheck()
testStream(ds)(actions: _*)
}
+
+
+ object AwaitTerminationTester {
+
+ trait ExpectedBehavior
+
+ /** Expect awaitTermination to not be blocked */
+ case object ExpectNotBlocked extends ExpectedBehavior
+
+ /** Expect awaitTermination to get blocked */
+ case object ExpectBlocked extends ExpectedBehavior
+
+ /** Expect awaitTermination to throw an exception */
+ case class ExpectException[E <: Exception]()(implicit val t: ClassTag[E])
+ extends ExpectedBehavior
+
+ private val DEFAULT_TEST_TIMEOUT = 1 second
+
+ def test(
+ expectedBehavior: ExpectedBehavior,
+ awaitTermFunc: () => Unit,
+ testTimeout: Span = DEFAULT_TEST_TIMEOUT
+ ): Unit = {
+
+ expectedBehavior match {
+ case ExpectNotBlocked =>
+ withClue("Got blocked when expected non-blocking.") {
+ failAfter(testTimeout) {
+ awaitTermFunc()
+ }
+ }
+
+ case ExpectBlocked =>
+ withClue("Was not blocked when expected.") {
+ intercept[TestFailedDueToTimeoutException] {
+ failAfter(testTimeout) {
+ awaitTermFunc()
+ }
+ }
+ }
+
+ case e: ExpectException[_] =>
+ val thrownException =
+ withClue(s"Did not throw ${e.t.runtimeClass.getSimpleName} when expected.") {
+ intercept[ContinuousQueryException] {
+ failAfter(testTimeout) {
+ awaitTermFunc()
+ }
+ }
+ }
+ assert(thrownException.cause.getClass === e.t.runtimeClass,
+ "exception of incorrect type was throw")
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
new file mode 100644
index 0000000000..daf08efca4
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
@@ -0,0 +1,306 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import scala.concurrent.Future
+import scala.util.Random
+import scala.util.control.NonFatal
+
+import org.scalatest.BeforeAndAfter
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.concurrent.PatienceConfiguration.Timeout
+import org.scalatest.time.Span
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest}
+import org.apache.spark.sql.execution.streaming.{MemorySink, MemoryStream, StreamExecution, StreamingRelation}
+import org.apache.spark.sql.test.SharedSQLContext
+
+class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with BeforeAndAfter {
+
+ import AwaitTerminationTester._
+ import testImplicits._
+
+ override val streamingTimeout = 20.seconds
+
+ before {
+ assert(sqlContext.streams.active.isEmpty)
+ sqlContext.streams.resetTerminated()
+ }
+
+ after {
+ assert(sqlContext.streams.active.isEmpty)
+ sqlContext.streams.resetTerminated()
+ }
+
+ test("listing") {
+ val (m1, ds1) = makeDataset
+ val (m2, ds2) = makeDataset
+ val (m3, ds3) = makeDataset
+
+ withQueriesOn(ds1, ds2, ds3) { queries =>
+ require(queries.size === 3)
+ assert(sqlContext.streams.active.toSet === queries.toSet)
+ val (q1, q2, q3) = (queries(0), queries(1), queries(2))
+
+ assert(sqlContext.streams.get(q1.name).eq(q1))
+ assert(sqlContext.streams.get(q2.name).eq(q2))
+ assert(sqlContext.streams.get(q3.name).eq(q3))
+ intercept[IllegalArgumentException] {
+ sqlContext.streams.get("non-existent-name")
+ }
+
+ q1.stop()
+
+ assert(sqlContext.streams.active.toSet === Set(q2, q3))
+ val ex1 = withClue("no error while getting non-active query") {
+ intercept[IllegalArgumentException] {
+ sqlContext.streams.get(q1.name)
+ }
+ }
+ assert(ex1.getMessage.contains(q1.name), "error does not contain name of query to be fetched")
+ assert(sqlContext.streams.get(q2.name).eq(q2))
+
+ m2.addData(0) // q2 should terminate with error
+
+ eventually(Timeout(streamingTimeout)) {
+ require(!q2.isActive)
+ require(q2.exception.isDefined)
+ }
+ val ex2 = withClue("no error while getting non-active query") {
+ intercept[IllegalArgumentException] {
+ sqlContext.streams.get(q2.name).eq(q2)
+ }
+ }
+
+ assert(sqlContext.streams.active.toSet === Set(q3))
+ }
+ }
+
+ test("awaitAnyTermination without timeout and resetTerminated") {
+ val datasets = Seq.fill(5)(makeDataset._2)
+ withQueriesOn(datasets: _*) { queries =>
+ require(queries.size === datasets.size)
+ assert(sqlContext.streams.active.toSet === queries.toSet)
+
+ // awaitAnyTermination should be blocking
+ testAwaitAnyTermination(ExpectBlocked)
+
+ // Stop a query asynchronously and see if it is reported through awaitAnyTermination
+ val q1 = stopRandomQueryAsync(stopAfter = 100 milliseconds, withError = false)
+ testAwaitAnyTermination(ExpectNotBlocked)
+ require(!q1.isActive) // should be inactive by the time the prev awaitAnyTerm returned
+
+ // All subsequent calls to awaitAnyTermination should be non-blocking
+ testAwaitAnyTermination(ExpectNotBlocked)
+
+ // Resetting termination should make awaitAnyTermination() blocking again
+ sqlContext.streams.resetTerminated()
+ testAwaitAnyTermination(ExpectBlocked)
+
+ // Terminate a query asynchronously with exception and see awaitAnyTermination throws
+ // the exception
+ val q2 = stopRandomQueryAsync(100 milliseconds, withError = true)
+ testAwaitAnyTermination(ExpectException[SparkException])
+ require(!q2.isActive) // should be inactive by the time the prev awaitAnyTerm returned
+
+ // All subsequent calls to awaitAnyTermination should throw the exception
+ testAwaitAnyTermination(ExpectException[SparkException])
+
+ // Resetting termination should make awaitAnyTermination() blocking again
+ sqlContext.streams.resetTerminated()
+ testAwaitAnyTermination(ExpectBlocked)
+
+ // Terminate multiple queries, one with failure and see whether awaitAnyTermination throws
+ // the exception
+ val q3 = stopRandomQueryAsync(10 milliseconds, withError = false)
+ testAwaitAnyTermination(ExpectNotBlocked)
+ require(!q3.isActive)
+ val q4 = stopRandomQueryAsync(10 milliseconds, withError = true)
+ eventually(Timeout(streamingTimeout)) { require(!q4.isActive) }
+ // After q4 terminates with exception, awaitAnyTerm should start throwing exception
+ testAwaitAnyTermination(ExpectException[SparkException])
+ }
+ }
+
+ test("awaitAnyTermination with timeout and resetTerminated") {
+ val datasets = Seq.fill(6)(makeDataset._2)
+ withQueriesOn(datasets: _*) { queries =>
+ require(queries.size === datasets.size)
+ assert(sqlContext.streams.active.toSet === queries.toSet)
+
+ // awaitAnyTermination should be blocking or non-blocking depending on timeout values
+ testAwaitAnyTermination(
+ ExpectBlocked,
+ awaitTimeout = 2 seconds,
+ expectedReturnedValue = false,
+ testBehaviorFor = 1 second)
+
+ testAwaitAnyTermination(
+ ExpectNotBlocked,
+ awaitTimeout = 50 milliseconds,
+ expectedReturnedValue = false,
+ testBehaviorFor = 1 second)
+
+ // Stop a query asynchronously within timeout and awaitAnyTerm should be unblocked
+ val q1 = stopRandomQueryAsync(stopAfter = 100 milliseconds, withError = false)
+ testAwaitAnyTermination(
+ ExpectNotBlocked,
+ awaitTimeout = 1 second,
+ expectedReturnedValue = true,
+ testBehaviorFor = 2 seconds)
+ require(!q1.isActive) // should be inactive by the time the prev awaitAnyTerm returned
+
+ // All subsequent calls to awaitAnyTermination should be non-blocking even if timeout is high
+ testAwaitAnyTermination(
+ ExpectNotBlocked, awaitTimeout = 2 seconds, expectedReturnedValue = true)
+
+ // Resetting termination should make awaitAnyTermination() blocking again
+ sqlContext.streams.resetTerminated()
+ testAwaitAnyTermination(
+ ExpectBlocked,
+ awaitTimeout = 2 seconds,
+ expectedReturnedValue = false,
+ testBehaviorFor = 1 second)
+
+ // Terminate a query asynchronously with exception within timeout, awaitAnyTermination should
+ // throws the exception
+ val q2 = stopRandomQueryAsync(100 milliseconds, withError = true)
+ testAwaitAnyTermination(
+ ExpectException[SparkException],
+ awaitTimeout = 1 second,
+ testBehaviorFor = 2 seconds)
+ require(!q2.isActive) // should be inactive by the time the prev awaitAnyTerm returned
+
+ // All subsequent calls to awaitAnyTermination should throw the exception
+ testAwaitAnyTermination(
+ ExpectException[SparkException],
+ awaitTimeout = 1 second,
+ testBehaviorFor = 2 seconds)
+
+ // Terminate a query asynchronously outside the timeout, awaitAnyTerm should be blocked
+ sqlContext.streams.resetTerminated()
+ val q3 = stopRandomQueryAsync(1 second, withError = true)
+ testAwaitAnyTermination(
+ ExpectNotBlocked,
+ awaitTimeout = 100 milliseconds,
+ expectedReturnedValue = false,
+ testBehaviorFor = 2 seconds)
+
+ // After that query is stopped, awaitAnyTerm should throw exception
+ eventually(Timeout(streamingTimeout)) { require(!q3.isActive) } // wait for query to stop
+ testAwaitAnyTermination(
+ ExpectException[SparkException],
+ awaitTimeout = 100 milliseconds,
+ testBehaviorFor = 2 seconds)
+
+
+ // Terminate multiple queries, one with failure and see whether awaitAnyTermination throws
+ // the exception
+ sqlContext.streams.resetTerminated()
+
+ val q4 = stopRandomQueryAsync(10 milliseconds, withError = false)
+ testAwaitAnyTermination(
+ ExpectNotBlocked, awaitTimeout = 1 second, expectedReturnedValue = true)
+ require(!q4.isActive)
+ val q5 = stopRandomQueryAsync(10 milliseconds, withError = true)
+ eventually(Timeout(streamingTimeout)) { require(!q5.isActive) }
+ // After q5 terminates with exception, awaitAnyTerm should start throwing exception
+ testAwaitAnyTermination(ExpectException[SparkException], awaitTimeout = 100 milliseconds)
+ }
+ }
+
+
+ /** Run a body of code by defining a query each on multiple datasets */
+ private def withQueriesOn(datasets: Dataset[_]*)(body: Seq[ContinuousQuery] => Unit): Unit = {
+ failAfter(streamingTimeout) {
+ val queries = withClue("Error starting queries") {
+ datasets.map { ds =>
+ @volatile var query: StreamExecution = null
+ try {
+ val df = ds.toDF
+ query = sqlContext
+ .streams
+ .startQuery(StreamExecution.nextName, df, new MemorySink(df.schema))
+ .asInstanceOf[StreamExecution]
+ } catch {
+ case NonFatal(e) =>
+ if (query != null) query.stop()
+ throw e
+ }
+ query
+ }
+ }
+ try {
+ body(queries)
+ } finally {
+ queries.foreach(_.stop())
+ }
+ }
+ }
+
+ /** Test the behavior of awaitAnyTermination */
+ private def testAwaitAnyTermination(
+ expectedBehavior: ExpectedBehavior,
+ expectedReturnedValue: Boolean = false,
+ awaitTimeout: Span = null,
+ testBehaviorFor: Span = 2 seconds
+ ): Unit = {
+
+ def awaitTermFunc(): Unit = {
+ if (awaitTimeout != null && awaitTimeout.toMillis > 0) {
+ val returnedValue = sqlContext.streams.awaitAnyTermination(awaitTimeout.toMillis)
+ assert(returnedValue === expectedReturnedValue, "Returned value does not match expected")
+ } else {
+ sqlContext.streams.awaitAnyTermination()
+ }
+ }
+
+ AwaitTerminationTester.test(expectedBehavior, awaitTermFunc, testBehaviorFor)
+ }
+
+ /** Stop a random active query either with `stop()` or with an error */
+ private def stopRandomQueryAsync(stopAfter: Span, withError: Boolean): ContinuousQuery = {
+
+ import scala.concurrent.ExecutionContext.Implicits.global
+
+ val activeQueries = sqlContext.streams.active
+ val queryToStop = activeQueries(Random.nextInt(activeQueries.length))
+ Future {
+ Thread.sleep(stopAfter.toMillis)
+ if (withError) {
+ logDebug(s"Terminating query ${queryToStop.name} with error")
+ queryToStop.asInstanceOf[StreamExecution].logicalPlan.collect {
+ case StreamingRelation(memoryStream, _) =>
+ memoryStream.asInstanceOf[MemoryStream[Int]].addData(0)
+ }
+ } else {
+ logDebug(s"Stopping query ${queryToStop.name}")
+ queryToStop.stop()
+ }
+ }
+ queryToStop
+ }
+
+ private def makeDataset: (MemoryStream[Int], Dataset[Int]) = {
+ val inputData = MemoryStream[Int]
+ val mapped = inputData.toDS.map(6 / _)
+ (inputData, mapped)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala
new file mode 100644
index 0000000000..dac1a398ff
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.StreamTest
+import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset, MemoryStream, StreamExecution}
+import org.apache.spark.sql.test.SharedSQLContext
+
+class ContinuousQuerySuite extends StreamTest with SharedSQLContext {
+
+ import AwaitTerminationTester._
+ import testImplicits._
+
+ test("lifecycle states and awaitTermination") {
+ val inputData = MemoryStream[Int]
+ val mapped = inputData.toDS().map { 6 / _}
+
+ testStream(mapped)(
+ AssertOnQuery(_.isActive === true),
+ AssertOnQuery(_.exception.isEmpty),
+ AddData(inputData, 1, 2),
+ CheckAnswer(6, 3),
+ TestAwaitTermination(ExpectBlocked),
+ TestAwaitTermination(ExpectBlocked, timeoutMs = 2000),
+ TestAwaitTermination(ExpectNotBlocked, timeoutMs = 10, expectedReturnValue = false),
+ StopStream,
+ AssertOnQuery(_.isActive === false),
+ AssertOnQuery(_.exception.isEmpty),
+ TestAwaitTermination(ExpectNotBlocked),
+ TestAwaitTermination(ExpectNotBlocked, timeoutMs = 2000, expectedReturnValue = true),
+ TestAwaitTermination(ExpectNotBlocked, timeoutMs = 10, expectedReturnValue = true),
+ StartStream,
+ AssertOnQuery(_.isActive === true),
+ AddData(inputData, 0),
+ ExpectFailure[SparkException],
+ AssertOnQuery(_.isActive === false),
+ TestAwaitTermination(ExpectException[SparkException]),
+ TestAwaitTermination(ExpectException[SparkException], timeoutMs = 2000),
+ TestAwaitTermination(ExpectException[SparkException], timeoutMs = 10),
+ AssertOnQuery(
+ q => q.exception.get.startOffset.get === q.streamProgress.toCompositeOffset(Seq(inputData)),
+ "incorrect start offset on exception")
+ )
+ }
+
+ test("source and sink statuses") {
+ val inputData = MemoryStream[Int]
+ val mapped = inputData.toDS().map(6 / _)
+
+ testStream(mapped)(
+ AssertOnQuery(_.sourceStatuses.length === 1),
+ AssertOnQuery(_.sourceStatuses(0).description.contains("Memory")),
+ AssertOnQuery(_.sourceStatuses(0).offset === None),
+ AssertOnQuery(_.sinkStatus.description.contains("Memory")),
+ AssertOnQuery(_.sinkStatus.offset === None),
+ AddData(inputData, 1, 2),
+ CheckAnswer(6, 3),
+ AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(0))),
+ AssertOnQuery(_.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(0)))),
+ AddData(inputData, 1, 2),
+ CheckAnswer(6, 3, 6, 3),
+ AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(1))),
+ AssertOnQuery(_.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(1)))),
+ AddData(inputData, 0),
+ ExpectFailure[SparkException],
+ AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(2))),
+ AssertOnQuery(_.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(1))))
+ )
+ }
+
+ /**
+ * A [[StreamAction]] to test the behavior of `ContinuousQuery.awaitTermination()`.
+ *
+ * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown)
+ * @param timeoutMs Timeout in milliseconds
+ * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout)
+ * When timeoutMs > 0, awaitTermination(timeoutMs) is tested
+ * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used
+ */
+ case class TestAwaitTermination(
+ expectedBehavior: ExpectedBehavior,
+ timeoutMs: Int = -1,
+ expectedReturnValue: Boolean = false
+ ) extends AssertOnQuery(
+ TestAwaitTermination.assertOnQueryCondition(expectedBehavior, timeoutMs, expectedReturnValue),
+ "Error testing awaitTermination behavior"
+ ) {
+ override def toString(): String = {
+ s"TestAwaitTermination($expectedBehavior, timeoutMs = $timeoutMs, " +
+ s"expectedReturnValue = $expectedReturnValue)"
+ }
+ }
+
+ object TestAwaitTermination {
+
+ /**
+ * Tests the behavior of `ContinuousQuery.awaitTermination`.
+ *
+ * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown)
+ * @param timeoutMs Timeout in milliseconds
+ * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout)
+ * When timeoutMs > 0, awaitTermination(timeoutMs) is tested
+ * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used
+ */
+ def assertOnQueryCondition(
+ expectedBehavior: ExpectedBehavior,
+ timeoutMs: Int,
+ expectedReturnValue: Boolean
+ )(q: StreamExecution): Boolean = {
+
+ def awaitTermFunc(): Unit = {
+ if (timeoutMs <= 0) {
+ q.awaitTermination()
+ } else {
+ val returnedValue = q.awaitTermination(timeoutMs)
+ assert(returnedValue === expectedReturnValue, "Returned value does not match expected")
+ }
+ }
+ AwaitTerminationTester.test(expectedBehavior, awaitTermFunc)
+ true // If the control reached here, then everything worked as expected
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
index b762f9b90e..f060c6f623 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.streaming.test
-import org.apache.spark.sql.{AnalysisException, SQLContext, StreamTest}
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.{AnalysisException, ContinuousQuery, SQLContext, StreamTest}
import org.apache.spark.sql.execution.streaming.{Batch, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.test.SharedSQLContext
@@ -57,9 +59,13 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
}
}
-class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext {
+class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter {
import testImplicits._
+ after {
+ sqlContext.streams.active.foreach(_.stop())
+ }
+
test("resolve default source") {
sqlContext.read
.format("org.apache.spark.sql.streaming.test")
@@ -188,4 +194,63 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext {
assert(LastOptions.parameters("boolOpt") == "false")
assert(LastOptions.parameters("doubleOpt") == "6.7")
}
+
+ test("unique query names") {
+
+ /** Start a query with a specific name */
+ def startQueryWithName(name: String = ""): ContinuousQuery = {
+ sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream("/test")
+ .write
+ .format("org.apache.spark.sql.streaming.test")
+ .queryName(name)
+ .stream()
+ }
+
+ /** Start a query without specifying a name */
+ def startQueryWithoutName(): ContinuousQuery = {
+ sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream("/test")
+ .write
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ }
+
+ /** Get the names of active streams */
+ def activeStreamNames: Set[String] = {
+ val streams = sqlContext.streams.active
+ val names = streams.map(_.name).toSet
+ assert(streams.length === names.size, s"names of active queries are not unique: $names")
+ names
+ }
+
+ val q1 = startQueryWithName("name")
+
+ // Should not be able to start another query with the same name
+ intercept[IllegalArgumentException] {
+ startQueryWithName("name")
+ }
+ assert(activeStreamNames === Set("name"))
+
+ // Should be able to start queries with other names
+ val q3 = startQueryWithName("another-name")
+ assert(activeStreamNames === Set("name", "another-name"))
+
+ // Should be able to start queries with auto-generated names
+ val q4 = startQueryWithoutName()
+ assert(activeStreamNames.contains(q4.name))
+
+ // Should not be able to start a query with same auto-generated name
+ intercept[IllegalArgumentException] {
+ startQueryWithName(q4.name)
+ }
+
+ // Should be able to start query with that name after stopping the previous query
+ q1.stop()
+ val q5 = startQueryWithName("name")
+ assert(activeStreamNames.contains("name"))
+ sqlContext.streams.active.foreach(_.stop())
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala
new file mode 100644
index 0000000000..d6cc6ad86b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.util
+
+import java.util.concurrent.ConcurrentLinkedQueue
+
+import scala.util.control.NonFatal
+
+import org.scalatest.BeforeAndAfter
+import org.scalatest.PrivateMethodTester._
+import org.scalatest.concurrent.AsyncAssertions.Waiter
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.concurrent.PatienceConfiguration.Timeout
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.util.ContinuousQueryListener.{QueryProgress, QueryStarted, QueryTerminated}
+
+class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with BeforeAndAfter {
+
+ import testImplicits._
+
+ after {
+ sqlContext.streams.active.foreach(_.stop())
+ assert(sqlContext.streams.active.isEmpty)
+ assert(addedListeners.isEmpty)
+ }
+
+ test("single listener") {
+ val listener = new QueryStatusCollector
+ val input = MemoryStream[Int]
+ withListenerAdded(listener) {
+ testStream(input.toDS)(
+ StartStream,
+ Assert("Incorrect query status in onQueryStarted") {
+ val status = listener.startStatus
+ assert(status != null)
+ assert(status.active == true)
+ assert(status.sourceStatuses.size === 1)
+ assert(status.sourceStatuses(0).description.contains("Memory"))
+
+ // The source and sink offsets must be None as this must be called before the
+ // batches have started
+ assert(status.sourceStatuses(0).offset === None)
+ assert(status.sinkStatus.offset === None)
+
+ // No progress events or termination events
+ assert(listener.progressStatuses.isEmpty)
+ assert(listener.terminationStatus === null)
+ },
+ AddDataMemory(input, Seq(1, 2, 3)),
+ CheckAnswer(1, 2, 3),
+ Assert("Incorrect query status in onQueryProgress") {
+ eventually(Timeout(streamingTimeout)) {
+
+ // There should be only on progress event as batch has been processed
+ assert(listener.progressStatuses.size === 1)
+ val status = listener.progressStatuses.peek()
+ assert(status != null)
+ assert(status.active == true)
+ assert(status.sourceStatuses(0).offset === Some(LongOffset(0)))
+ assert(status.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(0))))
+
+ // No termination events
+ assert(listener.terminationStatus === null)
+ }
+ },
+ StopStream,
+ Assert("Incorrect query status in onQueryTerminated") {
+ eventually(Timeout(streamingTimeout)) {
+ val status = listener.terminationStatus
+ assert(status != null)
+
+ assert(status.active === false) // must be inactive by the time onQueryTerm is called
+ assert(status.sourceStatuses(0).offset === Some(LongOffset(0)))
+ assert(status.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(0))))
+ }
+ listener.checkAsyncErrors()
+ }
+ )
+ }
+ }
+
+ test("adding and removing listener") {
+ def isListenerActive(listener: QueryStatusCollector): Boolean = {
+ listener.reset()
+ testStream(MemoryStream[Int].toDS)(
+ StartStream,
+ StopStream
+ )
+ listener.startStatus != null
+ }
+
+ try {
+ val listener1 = new QueryStatusCollector
+ val listener2 = new QueryStatusCollector
+
+ sqlContext.streams.addListener(listener1)
+ assert(isListenerActive(listener1) === true)
+ assert(isListenerActive(listener2) === false)
+ sqlContext.streams.addListener(listener2)
+ assert(isListenerActive(listener1) === true)
+ assert(isListenerActive(listener2) === true)
+ sqlContext.streams.removeListener(listener1)
+ assert(isListenerActive(listener1) === false)
+ assert(isListenerActive(listener2) === true)
+ } finally {
+ addedListeners.foreach(sqlContext.streams.removeListener)
+ }
+ }
+
+ test("event ordering") {
+ val listener = new QueryStatusCollector
+ withListenerAdded(listener) {
+ for (i <- 1 to 100) {
+ listener.reset()
+ require(listener.startStatus === null)
+ testStream(MemoryStream[Int].toDS)(
+ StartStream,
+ Assert(listener.startStatus !== null, "onQueryStarted not called before query returned"),
+ StopStream,
+ Assert { listener.checkAsyncErrors() }
+ )
+ }
+ }
+ }
+
+
+ private def withListenerAdded(listener: ContinuousQueryListener)(body: => Unit): Unit = {
+ @volatile var query: StreamExecution = null
+ try {
+ failAfter(1 minute) {
+ sqlContext.streams.addListener(listener)
+ body
+ }
+ } finally {
+ sqlContext.streams.removeListener(listener)
+ }
+ }
+
+ private def addedListeners(): Array[ContinuousQueryListener] = {
+ val listenerBusMethod =
+ PrivateMethod[ContinuousQueryListenerBus]('listenerBus)
+ val listenerBus = sqlContext.streams invokePrivate listenerBusMethod()
+ listenerBus.listeners.toArray.map(_.asInstanceOf[ContinuousQueryListener])
+ }
+
+ class QueryStatusCollector extends ContinuousQueryListener {
+
+ private val asyncTestWaiter = new Waiter // to catch errors in the async listener events
+
+ @volatile var startStatus: QueryStatus = null
+ @volatile var terminationStatus: QueryStatus = null
+ val progressStatuses = new ConcurrentLinkedQueue[QueryStatus]
+
+ def reset(): Unit = {
+ startStatus = null
+ terminationStatus = null
+ progressStatuses.clear()
+
+ // To reset the waiter
+ try asyncTestWaiter.await(timeout(1 milliseconds)) catch {
+ case NonFatal(e) =>
+ }
+ }
+
+ def checkAsyncErrors(): Unit = {
+ asyncTestWaiter.await(timeout(streamingTimeout))
+ }
+
+
+ override def onQueryStarted(queryStarted: QueryStarted): Unit = {
+ asyncTestWaiter {
+ startStatus = QueryStatus(queryStarted.query)
+ }
+ }
+
+ override def onQueryProgress(queryProgress: QueryProgress): Unit = {
+ asyncTestWaiter {
+ assert(startStatus != null, "onQueryProgress called before onQueryStarted")
+ progressStatuses.add(QueryStatus(queryProgress.query))
+ }
+ }
+
+ override def onQueryTerminated(queryTerminated: QueryTerminated): Unit = {
+ asyncTestWaiter {
+ assert(startStatus != null, "onQueryTerminated called before onQueryStarted")
+ terminationStatus = QueryStatus(queryTerminated.query)
+ }
+ asyncTestWaiter.dismiss()
+ }
+ }
+
+ case class QueryStatus(
+ active: Boolean,
+ expection: Option[Exception],
+ sourceStatuses: Array[SourceStatus],
+ sinkStatus: SinkStatus)
+
+ object QueryStatus {
+ def apply(query: ContinuousQuery): QueryStatus = {
+ QueryStatus(query.isActive, query.exception, query.sourceStatuses, query.sinkStatus)
+ }
+ }
+}