aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2016-04-05 11:12:05 -0700
committerMichael Armbrust <michael@databricks.com>2016-04-05 11:12:05 -0700
commit463bac001171622538fc93d2e31d1a617ab562e6 (patch)
tree406a34b25c2955a070153db16cca5b3b8f31161a /sql
parentf77f11c67125fdac2e6849a4d45d9286fc872ed9 (diff)
downloadspark-463bac001171622538fc93d2e31d1a617ab562e6.tar.gz
spark-463bac001171622538fc93d2e31d1a617ab562e6.tar.bz2
spark-463bac001171622538fc93d2e31d1a617ab562e6.zip
[SPARK-14257][SQL] Allow multiple continuous queries to be started from the same DataFrame
## What changes were proposed in this pull request? Make StreamingRelation store the closure to create the source in StreamExecution so that we can start multiple continuous queries from the same DataFrame. ## How was this patch tested? `test("DataFrame reuse")` Author: Shixiong Zhu <shixiong@databricks.com> Closes #12049 from zsxwing/df-reuse.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala62
10 files changed, 118 insertions, 26 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
index 2306df09b8..d7f71bd4b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.collection.mutable
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution}
+import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef
import org.apache.spark.sql.util.ContinuousQueryListener
@@ -178,11 +178,19 @@ class ContinuousQueryManager(sqlContext: SQLContext) {
throw new IllegalArgumentException(
s"Cannot start query with name $name as a query with that name is already active")
}
+ val logicalPlan = df.logicalPlan.transform {
+ case StreamingRelation(dataSource, _, output) =>
+ // Materialize source to avoid creating it in every batch
+ val source = dataSource.createSource()
+ // We still need to use the previous `output` instead of `source.schema` as attributes in
+ // "df.logicalPlan" has already used attributes of the previous `output`.
+ StreamingExecutionRelation(source, output)
+ }
val query = new StreamExecution(
sqlContext,
name,
checkpointLocation,
- df.logicalPlan,
+ logicalPlan,
sink,
trigger)
query.start()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index a5a6e01e99..15f2344df6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -176,7 +176,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap)
- Dataset.ofRows(sqlContext, StreamingRelation(dataSource.createSource()))
+ Dataset.ofRows(sqlContext, StreamingRelation(dataSource))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index db2134b020..f472a5068e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -47,7 +47,7 @@ import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
import org.apache.spark.sql.execution.python.EvaluatePython
-import org.apache.spark.sql.execution.streaming.StreamingRelation
+import org.apache.spark.sql.execution.streaming.{StreamingExecutionRelation, StreamingRelation}
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -462,7 +462,9 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
@Experimental
- def isStreaming: Boolean = logicalPlan.find(_.isInstanceOf[StreamingRelation]).isDefined
+ def isStreaming: Boolean = logicalPlan.find { n =>
+ n.isInstanceOf[StreamingRelation] || n.isInstanceOf[StreamingExecutionRelation]
+ }.isDefined
/**
* Displays the [[Dataset]] in a tabular form. Strings more than 20 characters will be truncated,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 64f80699ce..3e4acb752a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -43,9 +43,9 @@ import org.apache.spark.util.UninterruptibleThread
* and the results are committed transactionally to the given [[Sink]].
*/
class StreamExecution(
- val sqlContext: SQLContext,
+ override val sqlContext: SQLContext,
override val name: String,
- val checkpointRoot: String,
+ checkpointRoot: String,
private[sql] val logicalPlan: LogicalPlan,
val sink: Sink,
val trigger: Trigger) extends ContinuousQuery with Logging {
@@ -72,7 +72,7 @@ class StreamExecution(
/** All stream sources present the query plan. */
private val sources =
- logicalPlan.collect { case s: StreamingRelation => s.source }
+ logicalPlan.collect { case s: StreamingExecutionRelation => s.source }
/** A list of unique sources in the query plan. */
private val uniqueSources = sources.distinct
@@ -295,7 +295,7 @@ class StreamExecution(
var replacements = new ArrayBuffer[(Attribute, Attribute)]
// Replace sources in the logical plan with data that has arrived since the last batch.
val withNewSources = logicalPlan transform {
- case StreamingRelation(source, output) =>
+ case StreamingExecutionRelation(source, output) =>
newData.get(source).map { data =>
val newPlan = data.logicalPlan
assert(output.size == newPlan.output.size,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
index e35c444348..f951dea735 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
@@ -19,16 +19,37 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
+import org.apache.spark.sql.execution.datasources.DataSource
object StreamingRelation {
- def apply(source: Source): StreamingRelation =
- StreamingRelation(source, source.schema.toAttributes)
+ def apply(dataSource: DataSource): StreamingRelation = {
+ val source = dataSource.createSource()
+ StreamingRelation(dataSource, source.toString, source.schema.toAttributes)
+ }
+}
+
+/**
+ * Used to link a streaming [[DataSource]] into a
+ * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. This is only used for creating
+ * a streaming [[org.apache.spark.sql.DataFrame]] from [[org.apache.spark.sql.DataFrameReader]].
+ * It should be used to create [[Source]] and converted to [[StreamingExecutionRelation]] when
+ * passing to [StreamExecution]] to run a query.
+ */
+case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute])
+ extends LeafNode {
+ override def toString: String = sourceName
}
/**
* Used to link a streaming [[Source]] of data into a
* [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]].
*/
-case class StreamingRelation(source: Source, output: Seq[Attribute]) extends LeafNode {
+case class StreamingExecutionRelation(source: Source, output: Seq[Attribute]) extends LeafNode {
override def toString: String = source.toString
}
+
+object StreamingExecutionRelation {
+ def apply(source: Source): StreamingExecutionRelation = {
+ StreamingExecutionRelation(source, source.schema.toAttributes)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 7d97f81b0f..b652530d7c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -22,11 +22,9 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
-import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext}
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, RowEncoder}
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.types.StructType
object MemoryStream {
@@ -45,7 +43,7 @@ object MemoryStream {
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
extends Source with Logging {
protected val encoder = encoderFor[A]
- protected val logicalPlan = StreamingRelation(this)
+ protected val logicalPlan = StreamingExecutionRelation(this)
protected val output = logicalPlan.output
protected val batches = new ArrayBuffer[Dataset[A]]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
index 3444e56e9e..6ccc99fe17 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
@@ -36,6 +36,7 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.util.Utils
@@ -66,9 +67,9 @@ import org.apache.spark.util.Utils
trait StreamTest extends QueryTest with Timeouts {
implicit class RichSource(s: Source) {
- def toDF(): DataFrame = Dataset.ofRows(sqlContext, StreamingRelation(s))
+ def toDF(): DataFrame = Dataset.ofRows(sqlContext, StreamingExecutionRelation(s))
- def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingRelation(s))
+ def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingExecutionRelation(s))
}
/** How long to wait for an active stream to catch up when checking a result. */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
index 29bd3e018e..33787de9da 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
@@ -29,7 +29,7 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkException
import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest}
-import org.apache.spark.sql.execution.streaming.{MemorySink, MemoryStream, StreamExecution, StreamingRelation}
+import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
@@ -294,8 +294,8 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
if (withError) {
logDebug(s"Terminating query ${queryToStop.name} with error")
queryToStop.asInstanceOf[StreamExecution].logicalPlan.collect {
- case StreamingRelation(memoryStream, _) =>
- memoryStream.asInstanceOf[MemoryStream[Int]].addData(0)
+ case StreamingExecutionRelation(source, _) =>
+ source.asInstanceOf[MemoryStream[Int]].addData(0)
}
} else {
logDebug(s"Stopping query ${queryToStop.name}")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
index 054f5c9fa2..09daa7f81a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
@@ -71,8 +71,9 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext {
}
reader.stream(path)
.queryExecution.analyzed
- .collect { case StreamingRelation(s: FileStreamSource, _) => s }
- .head
+ .collect { case StreamingRelation(dataSource, _, _) =>
+ dataSource.createSource().asInstanceOf[FileStreamSource]
+ }.head
}
val valueSchema = new StructType().add("value", StringType)
@@ -96,8 +97,9 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext {
reader.stream()
}
df.queryExecution.analyzed
- .collect { case StreamingRelation(s: FileStreamSource, _) => s }
- .head
+ .collect { case StreamingRelation(dataSource, _, _) =>
+ dataSource.createSource().asInstanceOf[FileStreamSource]
+ }.head
.schema
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index fbb1792596..e4ea555526 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -17,9 +17,13 @@
package org.apache.spark.sql.streaming
-import org.apache.spark.sql.{Row, StreamTest}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark.sql.{DataFrame, Row, SQLContext, StreamTest}
import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
class StreamSuite extends StreamTest with SharedSQLContext {
@@ -81,4 +85,60 @@ class StreamSuite extends StreamTest with SharedSQLContext {
AddData(inputData, 1, 2, 3, 4),
CheckAnswer(2, 4))
}
+
+ test("DataFrame reuse") {
+ def assertDF(df: DataFrame) {
+ withTempDir { outputDir =>
+ withTempDir { checkpointDir =>
+ val query = df.write.format("parquet")
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .startStream(outputDir.getAbsolutePath)
+ try {
+ query.processAllAvailable()
+ val outputDf = sqlContext.read.parquet(outputDir.getAbsolutePath).as[Long]
+ checkDataset[Long](outputDf, (0L to 10L).toArray: _*)
+ } finally {
+ query.stop()
+ }
+ }
+ }
+ }
+
+ val df = sqlContext.read.format(classOf[FakeDefaultSource].getName).stream()
+ assertDF(df)
+ assertDF(df)
+ }
+}
+
+/**
+ * A fake StreamSourceProvider thats creates a fake Source that cannot be reused.
+ */
+class FakeDefaultSource extends StreamSourceProvider {
+
+ override def createSource(
+ sqlContext: SQLContext,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): Source = {
+ // Create a fake Source that emits 0 to 10.
+ new Source {
+ private var offset = -1L
+
+ override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil)
+
+ override def getOffset: Option[Offset] = {
+ if (offset >= 10) {
+ None
+ } else {
+ offset += 1
+ Some(LongOffset(offset))
+ }
+ }
+
+ override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
+ val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1
+ sqlContext.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a")
+ }
+ }
+ }
}