aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala62
10 files changed, 118 insertions, 26 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
index 2306df09b8..d7f71bd4b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.collection.mutable
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution}
+import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef
import org.apache.spark.sql.util.ContinuousQueryListener
@@ -178,11 +178,19 @@ class ContinuousQueryManager(sqlContext: SQLContext) {
throw new IllegalArgumentException(
s"Cannot start query with name $name as a query with that name is already active")
}
+ val logicalPlan = df.logicalPlan.transform {
+ case StreamingRelation(dataSource, _, output) =>
+ // Materialize source to avoid creating it in every batch
+ val source = dataSource.createSource()
+ // We still need to use the previous `output` instead of `source.schema` as attributes in
+ // "df.logicalPlan" has already used attributes of the previous `output`.
+ StreamingExecutionRelation(source, output)
+ }
val query = new StreamExecution(
sqlContext,
name,
checkpointLocation,
- df.logicalPlan,
+ logicalPlan,
sink,
trigger)
query.start()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index a5a6e01e99..15f2344df6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -176,7 +176,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap)
- Dataset.ofRows(sqlContext, StreamingRelation(dataSource.createSource()))
+ Dataset.ofRows(sqlContext, StreamingRelation(dataSource))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index db2134b020..f472a5068e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -47,7 +47,7 @@ import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
import org.apache.spark.sql.execution.python.EvaluatePython
-import org.apache.spark.sql.execution.streaming.StreamingRelation
+import org.apache.spark.sql.execution.streaming.{StreamingExecutionRelation, StreamingRelation}
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -462,7 +462,9 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
@Experimental
- def isStreaming: Boolean = logicalPlan.find(_.isInstanceOf[StreamingRelation]).isDefined
+ def isStreaming: Boolean = logicalPlan.find { n =>
+ n.isInstanceOf[StreamingRelation] || n.isInstanceOf[StreamingExecutionRelation]
+ }.isDefined
/**
* Displays the [[Dataset]] in a tabular form. Strings more than 20 characters will be truncated,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 64f80699ce..3e4acb752a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -43,9 +43,9 @@ import org.apache.spark.util.UninterruptibleThread
* and the results are committed transactionally to the given [[Sink]].
*/
class StreamExecution(
- val sqlContext: SQLContext,
+ override val sqlContext: SQLContext,
override val name: String,
- val checkpointRoot: String,
+ checkpointRoot: String,
private[sql] val logicalPlan: LogicalPlan,
val sink: Sink,
val trigger: Trigger) extends ContinuousQuery with Logging {
@@ -72,7 +72,7 @@ class StreamExecution(
/** All stream sources present the query plan. */
private val sources =
- logicalPlan.collect { case s: StreamingRelation => s.source }
+ logicalPlan.collect { case s: StreamingExecutionRelation => s.source }
/** A list of unique sources in the query plan. */
private val uniqueSources = sources.distinct
@@ -295,7 +295,7 @@ class StreamExecution(
var replacements = new ArrayBuffer[(Attribute, Attribute)]
// Replace sources in the logical plan with data that has arrived since the last batch.
val withNewSources = logicalPlan transform {
- case StreamingRelation(source, output) =>
+ case StreamingExecutionRelation(source, output) =>
newData.get(source).map { data =>
val newPlan = data.logicalPlan
assert(output.size == newPlan.output.size,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
index e35c444348..f951dea735 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
@@ -19,16 +19,37 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
+import org.apache.spark.sql.execution.datasources.DataSource
object StreamingRelation {
- def apply(source: Source): StreamingRelation =
- StreamingRelation(source, source.schema.toAttributes)
+ def apply(dataSource: DataSource): StreamingRelation = {
+ val source = dataSource.createSource()
+ StreamingRelation(dataSource, source.toString, source.schema.toAttributes)
+ }
+}
+
+/**
+ * Used to link a streaming [[DataSource]] into a
+ * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. This is only used for creating
+ * a streaming [[org.apache.spark.sql.DataFrame]] from [[org.apache.spark.sql.DataFrameReader]].
+ * It should be used to create [[Source]] and converted to [[StreamingExecutionRelation]] when
+ * passing to [StreamExecution]] to run a query.
+ */
+case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute])
+ extends LeafNode {
+ override def toString: String = sourceName
}
/**
* Used to link a streaming [[Source]] of data into a
* [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]].
*/
-case class StreamingRelation(source: Source, output: Seq[Attribute]) extends LeafNode {
+case class StreamingExecutionRelation(source: Source, output: Seq[Attribute]) extends LeafNode {
override def toString: String = source.toString
}
+
+object StreamingExecutionRelation {
+ def apply(source: Source): StreamingExecutionRelation = {
+ StreamingExecutionRelation(source, source.schema.toAttributes)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 7d97f81b0f..b652530d7c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -22,11 +22,9 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
-import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext}
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, RowEncoder}
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.types.StructType
object MemoryStream {
@@ -45,7 +43,7 @@ object MemoryStream {
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
extends Source with Logging {
protected val encoder = encoderFor[A]
- protected val logicalPlan = StreamingRelation(this)
+ protected val logicalPlan = StreamingExecutionRelation(this)
protected val output = logicalPlan.output
protected val batches = new ArrayBuffer[Dataset[A]]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
index 3444e56e9e..6ccc99fe17 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
@@ -36,6 +36,7 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.util.Utils
@@ -66,9 +67,9 @@ import org.apache.spark.util.Utils
trait StreamTest extends QueryTest with Timeouts {
implicit class RichSource(s: Source) {
- def toDF(): DataFrame = Dataset.ofRows(sqlContext, StreamingRelation(s))
+ def toDF(): DataFrame = Dataset.ofRows(sqlContext, StreamingExecutionRelation(s))
- def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingRelation(s))
+ def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingExecutionRelation(s))
}
/** How long to wait for an active stream to catch up when checking a result. */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
index 29bd3e018e..33787de9da 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
@@ -29,7 +29,7 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkException
import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest}
-import org.apache.spark.sql.execution.streaming.{MemorySink, MemoryStream, StreamExecution, StreamingRelation}
+import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
@@ -294,8 +294,8 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
if (withError) {
logDebug(s"Terminating query ${queryToStop.name} with error")
queryToStop.asInstanceOf[StreamExecution].logicalPlan.collect {
- case StreamingRelation(memoryStream, _) =>
- memoryStream.asInstanceOf[MemoryStream[Int]].addData(0)
+ case StreamingExecutionRelation(source, _) =>
+ source.asInstanceOf[MemoryStream[Int]].addData(0)
}
} else {
logDebug(s"Stopping query ${queryToStop.name}")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
index 054f5c9fa2..09daa7f81a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
@@ -71,8 +71,9 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext {
}
reader.stream(path)
.queryExecution.analyzed
- .collect { case StreamingRelation(s: FileStreamSource, _) => s }
- .head
+ .collect { case StreamingRelation(dataSource, _, _) =>
+ dataSource.createSource().asInstanceOf[FileStreamSource]
+ }.head
}
val valueSchema = new StructType().add("value", StringType)
@@ -96,8 +97,9 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext {
reader.stream()
}
df.queryExecution.analyzed
- .collect { case StreamingRelation(s: FileStreamSource, _) => s }
- .head
+ .collect { case StreamingRelation(dataSource, _, _) =>
+ dataSource.createSource().asInstanceOf[FileStreamSource]
+ }.head
.schema
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index fbb1792596..e4ea555526 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -17,9 +17,13 @@
package org.apache.spark.sql.streaming
-import org.apache.spark.sql.{Row, StreamTest}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark.sql.{DataFrame, Row, SQLContext, StreamTest}
import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
class StreamSuite extends StreamTest with SharedSQLContext {
@@ -81,4 +85,60 @@ class StreamSuite extends StreamTest with SharedSQLContext {
AddData(inputData, 1, 2, 3, 4),
CheckAnswer(2, 4))
}
+
+ test("DataFrame reuse") {
+ def assertDF(df: DataFrame) {
+ withTempDir { outputDir =>
+ withTempDir { checkpointDir =>
+ val query = df.write.format("parquet")
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .startStream(outputDir.getAbsolutePath)
+ try {
+ query.processAllAvailable()
+ val outputDf = sqlContext.read.parquet(outputDir.getAbsolutePath).as[Long]
+ checkDataset[Long](outputDf, (0L to 10L).toArray: _*)
+ } finally {
+ query.stop()
+ }
+ }
+ }
+ }
+
+ val df = sqlContext.read.format(classOf[FakeDefaultSource].getName).stream()
+ assertDF(df)
+ assertDF(df)
+ }
+}
+
+/**
+ * A fake StreamSourceProvider thats creates a fake Source that cannot be reused.
+ */
+class FakeDefaultSource extends StreamSourceProvider {
+
+ override def createSource(
+ sqlContext: SQLContext,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): Source = {
+ // Create a fake Source that emits 0 to 10.
+ new Source {
+ private var offset = -1L
+
+ override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil)
+
+ override def getOffset: Option[Offset] = {
+ if (offset >= 10) {
+ None
+ } else {
+ offset += 1
+ Some(LongOffset(offset))
+ }
+ }
+
+ override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
+ val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1
+ sqlContext.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a")
+ }
+ }
+ }
}