aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2017-02-23 11:25:39 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2017-02-23 11:25:39 -0800
commit9bf4e2baad0e2851da554d85223ffaa029cfa490 (patch)
treea08773e6a82e7d5fa78ca2f71d707e74be36a9cc
parent7bf09433f5c5e08154ba106be21fe24f17cd282b (diff)
downloadspark-9bf4e2baad0e2851da554d85223ffaa029cfa490.tar.gz
spark-9bf4e2baad0e2851da554d85223ffaa029cfa490.tar.bz2
spark-9bf4e2baad0e2851da554d85223ffaa029cfa490.zip
[SPARK-19497][SS] Implement streaming deduplication
## What changes were proposed in this pull request? This PR adds a special streaming deduplication operator to support `dropDuplicates` with `aggregation` and watermark. It reuses the `dropDuplicates` API but creates new logical plan `Deduplication` and new physical plan `DeduplicationExec`. The following cases are supported: - one or multiple `dropDuplicates()` without aggregation (with or without watermark) - `dropDuplicates` before aggregation Not supported cases: - `dropDuplicates` after aggregation Breaking changes: - `dropDuplicates` without aggregation doesn't work with `complete` or `update` mode. ## How was this patch tested? The new unit tests. Author: Shixiong Zhu <shixiong@databricks.com> Closes #16970 from zsxwing/dedup.
-rw-r--r--python/pyspark/sql/dataframe.py6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala9
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala56
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala33
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala39
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala140
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala252
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala36
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala2
15 files changed, 578 insertions, 58 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 70efeaf016..bb6df22682 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1158,6 +1158,12 @@ class DataFrame(object):
"""Return a new :class:`DataFrame` with duplicate rows removed,
optionally only considering certain columns.
+ For a static batch :class:`DataFrame`, it just drops duplicate rows. For a streaming
+ :class:`DataFrame`, it will keep all data across triggers as intermediate state to drop
+ duplicates rows. You can use :func:`withWatermark` to limit how late the duplicate data can
+ be and system will accordingly limit the state. In addition, too late data older than
+ watermark will be dropped to avoid any possibility of duplicates.
+
:func:`drop_duplicates` is an alias for :func:`dropDuplicates`.
>>> from pyspark.sql import Row
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 07b3558ee2..397f5cfe2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -75,7 +75,7 @@ object UnsupportedOperationChecker {
if (watermarkAttributes.isEmpty) {
throwError(
s"$outputMode output mode not supported when there are streaming aggregations on " +
- s"streaming DataFrames/DataSets")(plan)
+ s"streaming DataFrames/DataSets without watermark")(plan)
}
case InternalOutputModes.Complete if aggregates.isEmpty =>
@@ -120,6 +120,10 @@ object UnsupportedOperationChecker {
throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " +
"streaming DataFrame/Dataset")
+ case d: Deduplicate if collectStreamingAggregates(d).nonEmpty =>
+ throwError("dropDuplicates is not supported after aggregation on a " +
+ "streaming DataFrame/Dataset")
+
case Join(left, right, joinType, _) =>
joinType match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index af846a09a8..036da3ad20 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -56,7 +56,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
ReplaceExpressions,
ComputeCurrentTime,
GetCurrentDatabase(sessionCatalog),
- RewriteDistinctAggregates) ::
+ RewriteDistinctAggregates,
+ ReplaceDeduplicateWithAggregate) ::
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
@@ -1143,6 +1144,24 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
}
/**
+ * Replaces logical [[Deduplicate]] operator with an [[Aggregate]] operator.
+ */
+object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case Deduplicate(keys, child, streaming) if !streaming =>
+ val keyExprIds = keys.map(_.exprId)
+ val aggCols = child.output.map { attr =>
+ if (keyExprIds.contains(attr.exprId)) {
+ attr
+ } else {
+ Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId)
+ }
+ }
+ Aggregate(keys, aggCols, child)
+ }
+}
+
+/**
* Replaces logical [[Intersect]] operator with a left-semi [[Join]] operator.
* {{{
* SELECT a1, a2 FROM Tab1 INTERSECT SELECT b1, b2 FROM Tab2
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index d17d12cd83..ce1c55dc08 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -864,3 +864,12 @@ case object OneRowRelation extends LeafNode {
override def output: Seq[Attribute] = Nil
override def computeStats(conf: CatalystConf): Statistics = Statistics(sizeInBytes = 1)
}
+
+/** A logical plan for `dropDuplicates`. */
+case class Deduplicate(
+ keys: Seq[Attribute],
+ child: LogicalPlan,
+ streaming: Boolean) extends UnaryNode {
+
+ override def output: Seq[Attribute] = child.output
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index 3b756e89d9..82be69a0f7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.streaming.OutputMode
-import org.apache.spark.sql.types.{IntegerType, LongType}
+import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder}
+import org.apache.spark.unsafe.types.CalendarInterval
/** A dummy command for testing unsupported operations. */
case class DummyCommand() extends Command
@@ -36,6 +37,11 @@ case class DummyCommand() extends Command
class UnsupportedOperationsSuite extends SparkFunSuite {
val attribute = AttributeReference("a", IntegerType, nullable = true)()
+ val watermarkMetadata = new MetadataBuilder()
+ .withMetadata(attribute.metadata)
+ .putLong(EventTimeWatermark.delayKey, 1000L)
+ .build()
+ val attributeWithWatermark = attribute.withMetadata(watermarkMetadata)
val batchRelation = LocalRelation(attribute)
val streamRelation = new TestStreamingRelation(attribute)
@@ -98,6 +104,27 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
outputMode = Update,
expectedMsgs = Seq("multiple streaming aggregations"))
+ assertSupportedInStreamingPlan(
+ "aggregate - streaming aggregations in update mode",
+ Aggregate(Nil, aggExprs("d"), streamRelation),
+ outputMode = Update)
+
+ assertSupportedInStreamingPlan(
+ "aggregate - streaming aggregations in complete mode",
+ Aggregate(Nil, aggExprs("d"), streamRelation),
+ outputMode = Complete)
+
+ assertSupportedInStreamingPlan(
+ "aggregate - streaming aggregations with watermark in append mode",
+ Aggregate(Seq(attributeWithWatermark), aggExprs("d"), streamRelation),
+ outputMode = Append)
+
+ assertNotSupportedInStreamingPlan(
+ "aggregate - streaming aggregations without watermark in append mode",
+ Aggregate(Nil, aggExprs("d"), streamRelation),
+ outputMode = Append,
+ expectedMsgs = Seq("streaming aggregations", "without watermark"))
+
// Aggregation: Distinct aggregates not supported on streaming relation
val distinctAggExprs = Seq(Count("*").toAggregateExpression(isDistinct = true).as("c"))
assertSupportedInStreamingPlan(
@@ -129,6 +156,33 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
outputMode = Complete,
expectedMsgs = Seq("(map/flatMap)GroupsWithState"))
+ assertSupportedInStreamingPlan(
+ "mapGroupsWithState - mapGroupsWithState on batch relation inside streaming relation",
+ MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation),
+ outputMode = Append
+ )
+
+ // Deduplicate
+ assertSupportedInStreamingPlan(
+ "Deduplicate - Deduplicate on streaming relation before aggregation",
+ Aggregate(
+ Seq(attributeWithWatermark),
+ aggExprs("c"),
+ Deduplicate(Seq(att), streamRelation, streaming = true)),
+ outputMode = Append)
+
+ assertNotSupportedInStreamingPlan(
+ "Deduplicate - Deduplicate on streaming relation after aggregation",
+ Deduplicate(Seq(att), Aggregate(Nil, aggExprs("c"), streamRelation), streaming = true),
+ outputMode = Complete,
+ expectedMsgs = Seq("dropDuplicates"))
+
+ assertSupportedInStreamingPlan(
+ "Deduplicate - Deduplicate on batch relation inside a streaming query",
+ Deduplicate(Seq(att), batchRelation, streaming = false),
+ outputMode = Append
+ )
+
// Inner joins: Stream-stream not supported
testBinaryOperationInStreamingPlan(
"inner join",
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index f23e262f28..e68423f85c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.Alias
+import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -30,7 +32,8 @@ class ReplaceOperatorSuite extends PlanTest {
Batch("Replace Operators", FixedPoint(100),
ReplaceDistinctWithAggregate,
ReplaceExceptWithAntiJoin,
- ReplaceIntersectWithSemiJoin) :: Nil
+ ReplaceIntersectWithSemiJoin,
+ ReplaceDeduplicateWithAggregate) :: Nil
}
test("replace Intersect with Left-semi Join") {
@@ -71,4 +74,32 @@ class ReplaceOperatorSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("replace batch Deduplicate with Aggregate") {
+ val input = LocalRelation('a.int, 'b.int)
+ val attrA = input.output(0)
+ val attrB = input.output(1)
+ val query = Deduplicate(Seq(attrA), input, streaming = false) // dropDuplicates("a")
+ val optimized = Optimize.execute(query.analyze)
+
+ val correctAnswer =
+ Aggregate(
+ Seq(attrA),
+ Seq(
+ attrA,
+ Alias(new First(attrB).toAggregateExpression(), attrB.name)(attrB.exprId)
+ ),
+ input)
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("don't replace streaming Deduplicate") {
+ val input = LocalRelation('a.int, 'b.int)
+ val attrA = input.output(0)
+ val query = Deduplicate(Seq(attrA), input, streaming = true) // dropDuplicates("a")
+ val optimized = Optimize.execute(query.analyze)
+
+ comparePlans(optimized, query)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 1ebc53d2bb..3c212d656e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -557,7 +557,8 @@ class Dataset[T] private[sql](
* Spark will use this watermark for several purposes:
* - To know when a given time window aggregation can be finalized and thus can be emitted when
* using output modes that do not allow updates.
- * - To minimize the amount of state that we need to keep for on-going aggregations.
+ * - To minimize the amount of state that we need to keep for on-going aggregations,
+ * `mapGroupsWithState` and `dropDuplicates` operators.
*
* The current watermark is computed by looking at the `MAX(eventTime)` seen across
* all of the partitions in the query minus a user specified `delayThreshold`. Due to the cost
@@ -1981,6 +1982,12 @@ class Dataset[T] private[sql](
* Returns a new Dataset that contains only the unique rows from this Dataset.
* This is an alias for `distinct`.
*
+ * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
+ * will keep all data across triggers as intermediate state to drop duplicates rows. You can use
+ * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
+ * the state. In addition, too late data older than watermark will be dropped to avoid any
+ * possibility of duplicates.
+ *
* @group typedrel
* @since 2.0.0
*/
@@ -1990,13 +1997,19 @@ class Dataset[T] private[sql](
* (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only
* the subset of columns.
*
+ * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
+ * will keep all data across triggers as intermediate state to drop duplicates rows. You can use
+ * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
+ * the state. In addition, too late data older than watermark will be dropped to avoid any
+ * possibility of duplicates.
+ *
* @group typedrel
* @since 2.0.0
*/
def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan {
val resolver = sparkSession.sessionState.analyzer.resolver
val allColumns = queryExecution.analyzed.output
- val groupCols = colNames.flatMap { colName =>
+ val groupCols = colNames.toSet.toSeq.flatMap { (colName: String) =>
// It is possibly there are more than one columns with the same name,
// so we call filter instead of find.
val cols = allColumns.filter(col => resolver(col.name, colName))
@@ -2006,21 +2019,19 @@ class Dataset[T] private[sql](
}
cols
}
- val groupColExprIds = groupCols.map(_.exprId)
- val aggCols = logicalPlan.output.map { attr =>
- if (groupColExprIds.contains(attr.exprId)) {
- attr
- } else {
- Alias(new First(attr).toAggregateExpression(), attr.name)()
- }
- }
- Aggregate(groupCols, aggCols, logicalPlan)
+ Deduplicate(groupCols, logicalPlan, isStreaming)
}
/**
* Returns a new Dataset with duplicate rows removed, considering only
* the subset of columns.
*
+ * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
+ * will keep all data across triggers as intermediate state to drop duplicates rows. You can use
+ * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
+ * the state. In addition, too late data older than watermark will be dropped to avoid any
+ * possibility of duplicates.
+ *
* @group typedrel
* @since 2.0.0
*/
@@ -2030,6 +2041,12 @@ class Dataset[T] private[sql](
* Returns a new [[Dataset]] with duplicate rows removed, considering only
* the subset of columns.
*
+ * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
+ * will keep all data across triggers as intermediate state to drop duplicates rows. You can use
+ * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
+ * the state. In addition, too late data older than watermark will be dropped to avoid any
+ * possibility of duplicates.
+ *
* @group typedrel
* @since 2.0.0
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 0e3d5595df..027b1481af 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -22,9 +22,10 @@ import org.apache.spark.sql.{SaveMode, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan, MapGroupsWithState}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
@@ -245,6 +246,18 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
/**
+ * Used to plan the streaming deduplicate operator.
+ */
+ object StreamingDeduplicationStrategy extends Strategy {
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case Deduplicate(keys, child, true) =>
+ StreamingDeduplicateExec(keys, planLater(child)) :: Nil
+
+ case _ => Nil
+ }
+ }
+
+ /**
* Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
*/
object Aggregation extends Strategy {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index a3e108b29e..ffdcd9b19d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -45,6 +45,7 @@ class IncrementalExecution(
sparkSession.sessionState.planner.StatefulAggregationStrategy +:
sparkSession.sessionState.planner.MapGroupsWithStateStrategy +:
sparkSession.sessionState.planner.StreamingRelationStrategy +:
+ sparkSession.sessionState.planner.StreamingDeduplicationStrategy +:
sparkSession.sessionState.experimentalMethods.extraStrategies
// Modified planner with stateful operations.
@@ -93,6 +94,15 @@ class IncrementalExecution(
keys,
Some(stateId),
child) :: Nil))
+ case StreamingDeduplicateExec(keys, child, None, None) =>
+ val stateId =
+ OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
+
+ StreamingDeduplicateExec(
+ keys,
+ child,
+ Some(stateId),
+ Some(currentEventTimeWatermark))
case MapGroupsWithStateExec(
f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) =>
val stateId =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 1292452574..d92529748b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -25,12 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjecti
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
-import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.OutputMode
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, NullType, StructType}
import org.apache.spark.util.CompletionIterator
@@ -68,6 +67,40 @@ trait StateStoreWriter extends StatefulOperator {
"numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"))
}
+/** An operator that supports watermark. */
+trait WatermarkSupport extends SparkPlan {
+
+ /** The keys that may have a watermark attribute. */
+ def keyExpressions: Seq[Attribute]
+
+ /** The watermark value. */
+ def eventTimeWatermark: Option[Long]
+
+ /** Generate a predicate that matches data older than the watermark */
+ lazy val watermarkPredicate: Option[Predicate] = {
+ val optionalWatermarkAttribute =
+ keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey))
+
+ optionalWatermarkAttribute.map { watermarkAttribute =>
+ // If we are evicting based on a window, use the end of the window. Otherwise just
+ // use the attribute itself.
+ val evictionExpression =
+ if (watermarkAttribute.dataType.isInstanceOf[StructType]) {
+ LessThanOrEqual(
+ GetStructField(watermarkAttribute, 1),
+ Literal(eventTimeWatermark.get * 1000))
+ } else {
+ LessThanOrEqual(
+ watermarkAttribute,
+ Literal(eventTimeWatermark.get * 1000))
+ }
+
+ logInfo(s"Filtering state store on: $evictionExpression")
+ newPredicate(evictionExpression, keyExpressions)
+ }
+ }
+}
+
/**
* For each input tuple, the key is calculated and the value from the [[StateStore]] is added
* to the stream (in addition to the input tuple) if present.
@@ -76,7 +109,7 @@ case class StateStoreRestoreExec(
keyExpressions: Seq[Attribute],
stateId: Option[OperatorStateId],
child: SparkPlan)
- extends execution.UnaryExecNode with StateStoreReader {
+ extends UnaryExecNode with StateStoreReader {
override protected def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
@@ -113,31 +146,7 @@ case class StateStoreSaveExec(
outputMode: Option[OutputMode] = None,
eventTimeWatermark: Option[Long] = None,
child: SparkPlan)
- extends execution.UnaryExecNode with StateStoreWriter {
-
- /** Generate a predicate that matches data older than the watermark */
- private lazy val watermarkPredicate: Option[Predicate] = {
- val optionalWatermarkAttribute =
- keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey))
-
- optionalWatermarkAttribute.map { watermarkAttribute =>
- // If we are evicting based on a window, use the end of the window. Otherwise just
- // use the attribute itself.
- val evictionExpression =
- if (watermarkAttribute.dataType.isInstanceOf[StructType]) {
- LessThanOrEqual(
- GetStructField(watermarkAttribute, 1),
- Literal(eventTimeWatermark.get * 1000))
- } else {
- LessThanOrEqual(
- watermarkAttribute,
- Literal(eventTimeWatermark.get * 1000))
- }
-
- logInfo(s"Filtering state store on: $evictionExpression")
- newPredicate(evictionExpression, keyExpressions)
- }
- }
+ extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver
@@ -146,8 +155,8 @@ case class StateStoreSaveExec(
child.execute().mapPartitionsWithStateStore(
getStateId.checkpointLocation,
- operatorId = getStateId.operatorId,
- storeVersion = getStateId.batchId,
+ getStateId.operatorId,
+ getStateId.batchId,
keyExpressions.toStructType,
child.output.toStructType,
sqlContext.sessionState,
@@ -262,8 +271,8 @@ case class MapGroupsWithStateExec(
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsWithStateStore[InternalRow](
getStateId.checkpointLocation,
- operatorId = getStateId.operatorId,
- storeVersion = getStateId.batchId,
+ getStateId.operatorId,
+ getStateId.batchId,
groupingAttributes.toStructType,
child.output.toStructType,
sqlContext.sessionState,
@@ -321,3 +330,70 @@ case class MapGroupsWithStateExec(
}
}
}
+
+
+/** Physical operator for executing streaming Deduplicate. */
+case class StreamingDeduplicateExec(
+ keyExpressions: Seq[Attribute],
+ child: SparkPlan,
+ stateId: Option[OperatorStateId] = None,
+ eventTimeWatermark: Option[Long] = None)
+ extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
+
+ /** Distribute by grouping attributes */
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(keyExpressions) :: Nil
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ metrics // force lazy init at driver
+
+ child.execute().mapPartitionsWithStateStore(
+ getStateId.checkpointLocation,
+ getStateId.operatorId,
+ getStateId.batchId,
+ keyExpressions.toStructType,
+ child.output.toStructType,
+ sqlContext.sessionState,
+ Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
+ val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
+ val numOutputRows = longMetric("numOutputRows")
+ val numTotalStateRows = longMetric("numTotalStateRows")
+ val numUpdatedStateRows = longMetric("numUpdatedStateRows")
+
+ val baseIterator = watermarkPredicate match {
+ case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
+ case None => iter
+ }
+
+ val result = baseIterator.filter { r =>
+ val row = r.asInstanceOf[UnsafeRow]
+ val key = getKey(row)
+ val value = store.get(key)
+ if (value.isEmpty) {
+ store.put(key.copy(), StreamingDeduplicateExec.EMPTY_ROW)
+ numUpdatedStateRows += 1
+ numOutputRows += 1
+ true
+ } else {
+ // Drop duplicated rows
+ false
+ }
+ }
+
+ CompletionIterator[InternalRow, Iterator[InternalRow]](result, {
+ watermarkPredicate.foreach(f => store.remove(f.eval _))
+ store.commit()
+ numTotalStateRows += store.numKeys()
+ })
+ }
+ }
+
+ override def output: Seq[Attribute] = child.output
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+}
+
+object StreamingDeduplicateExec {
+ private val EMPTY_ROW =
+ UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala
new file mode 100644
index 0000000000..7ea716231e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala
@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.functions._
+
+class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
+
+ import testImplicits._
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ StateStore.stop()
+ }
+
+ test("deduplicate with all columns") {
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS().dropDuplicates()
+
+ testStream(result, Append)(
+ AddData(inputData, "a"),
+ CheckLastBatch("a"),
+ assertNumStateRows(total = 1, updated = 1),
+ AddData(inputData, "a"),
+ CheckLastBatch(),
+ assertNumStateRows(total = 1, updated = 0),
+ AddData(inputData, "b"),
+ CheckLastBatch("b"),
+ assertNumStateRows(total = 2, updated = 1)
+ )
+ }
+
+ test("deduplicate with some columns") {
+ val inputData = MemoryStream[(String, Int)]
+ val result = inputData.toDS().dropDuplicates("_1")
+
+ testStream(result, Append)(
+ AddData(inputData, "a" -> 1),
+ CheckLastBatch("a" -> 1),
+ assertNumStateRows(total = 1, updated = 1),
+ AddData(inputData, "a" -> 2), // Dropped
+ CheckLastBatch(),
+ assertNumStateRows(total = 1, updated = 0),
+ AddData(inputData, "b" -> 1),
+ CheckLastBatch("b" -> 1),
+ assertNumStateRows(total = 2, updated = 1)
+ )
+ }
+
+ test("multiple deduplicates") {
+ val inputData = MemoryStream[(String, Int)]
+ val result = inputData.toDS().dropDuplicates().dropDuplicates("_1")
+
+ testStream(result, Append)(
+ AddData(inputData, "a" -> 1),
+ CheckLastBatch("a" -> 1),
+ assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)),
+
+ AddData(inputData, "a" -> 2), // Dropped from the second `dropDuplicates`
+ CheckLastBatch(),
+ assertNumStateRows(total = Seq(1L, 2L), updated = Seq(0L, 1L)),
+
+ AddData(inputData, "b" -> 1),
+ CheckLastBatch("b" -> 1),
+ assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L))
+ )
+ }
+
+ test("deduplicate with watermark") {
+ val inputData = MemoryStream[Int]
+ val result = inputData.toDS()
+ .withColumn("eventTime", $"value".cast("timestamp"))
+ .withWatermark("eventTime", "10 seconds")
+ .dropDuplicates()
+ .select($"eventTime".cast("long").as[Long])
+
+ testStream(result, Append)(
+ AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*),
+ CheckLastBatch(10 to 15: _*),
+ assertNumStateRows(total = 6, updated = 6),
+
+ AddData(inputData, 25), // Advance watermark to 15 seconds
+ CheckLastBatch(25),
+ assertNumStateRows(total = 7, updated = 1),
+
+ AddData(inputData, 25), // Drop states less than watermark
+ CheckLastBatch(),
+ assertNumStateRows(total = 1, updated = 0),
+
+ AddData(inputData, 10), // Should not emit anything as data less than watermark
+ CheckLastBatch(),
+ assertNumStateRows(total = 1, updated = 0),
+
+ AddData(inputData, 45), // Advance watermark to 35 seconds
+ CheckLastBatch(45),
+ assertNumStateRows(total = 2, updated = 1),
+
+ AddData(inputData, 45), // Drop states less than watermark
+ CheckLastBatch(),
+ assertNumStateRows(total = 1, updated = 0)
+ )
+ }
+
+ test("deduplicate with aggregate - append mode") {
+ val inputData = MemoryStream[Int]
+ val windowedaggregate = inputData.toDS()
+ .withColumn("eventTime", $"value".cast("timestamp"))
+ .withWatermark("eventTime", "10 seconds")
+ .dropDuplicates()
+ .withWatermark("eventTime", "10 seconds")
+ .groupBy(window($"eventTime", "5 seconds") as 'window)
+ .agg(count("*") as 'count)
+ .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
+
+ testStream(windowedaggregate)(
+ AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*),
+ CheckLastBatch(),
+ // states in aggregate in [10, 14), [15, 20) (2 windows)
+ // states in deduplicate is 10 to 15
+ assertNumStateRows(total = Seq(2L, 6L), updated = Seq(2L, 6L)),
+
+ AddData(inputData, 25), // Advance watermark to 15 seconds
+ CheckLastBatch(),
+ // states in aggregate in [10, 14), [15, 20) and [25, 30) (3 windows)
+ // states in deduplicate is 10 to 15 and 25
+ assertNumStateRows(total = Seq(3L, 7L), updated = Seq(1L, 1L)),
+
+ AddData(inputData, 25), // Emit items less than watermark and drop their state
+ CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate
+ // states in aggregate in [15, 20) and [25, 30) (2 windows, note aggregate uses the end of
+ // window to evict items, so [15, 20) is still in the state store)
+ // states in deduplicate is 25
+ assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)),
+
+ AddData(inputData, 10), // Should not emit anything as data less than watermark
+ CheckLastBatch(),
+ assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)),
+
+ AddData(inputData, 40), // Advance watermark to 30 seconds
+ CheckLastBatch(),
+ // states in aggregate in [15, 20), [25, 30) and [40, 45)
+ // states in deduplicate is 25 and 40,
+ assertNumStateRows(total = Seq(3L, 2L), updated = Seq(1L, 1L)),
+
+ AddData(inputData, 40), // Emit items less than watermark and drop their state
+ CheckLastBatch((15 -> 1), (25 -> 1)),
+ // states in aggregate in [40, 45)
+ // states in deduplicate is 40,
+ assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L))
+ )
+ }
+
+ test("deduplicate with aggregate - update mode") {
+ val inputData = MemoryStream[(String, Int)]
+ val result = inputData.toDS()
+ .select($"_1" as "str", $"_2" as "num")
+ .dropDuplicates()
+ .groupBy("str")
+ .agg(sum("num"))
+ .as[(String, Long)]
+
+ testStream(result, Update)(
+ AddData(inputData, "a" -> 1),
+ CheckLastBatch("a" -> 1L),
+ assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)),
+ AddData(inputData, "a" -> 1), // Dropped
+ CheckLastBatch(),
+ assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)),
+ AddData(inputData, "a" -> 2),
+ CheckLastBatch("a" -> 3L),
+ assertNumStateRows(total = Seq(1L, 2L), updated = Seq(1L, 1L)),
+ AddData(inputData, "b" -> 1),
+ CheckLastBatch("b" -> 1L),
+ assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L))
+ )
+ }
+
+ test("deduplicate with aggregate - complete mode") {
+ val inputData = MemoryStream[(String, Int)]
+ val result = inputData.toDS()
+ .select($"_1" as "str", $"_2" as "num")
+ .dropDuplicates()
+ .groupBy("str")
+ .agg(sum("num"))
+ .as[(String, Long)]
+
+ testStream(result, Complete)(
+ AddData(inputData, "a" -> 1),
+ CheckLastBatch("a" -> 1L),
+ assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)),
+ AddData(inputData, "a" -> 1), // Dropped
+ CheckLastBatch("a" -> 1L),
+ assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)),
+ AddData(inputData, "a" -> 2),
+ CheckLastBatch("a" -> 3L),
+ assertNumStateRows(total = Seq(1L, 2L), updated = Seq(1L, 1L)),
+ AddData(inputData, "b" -> 1),
+ CheckLastBatch("a" -> 3L, "b" -> 1L),
+ assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L))
+ )
+ }
+
+ test("deduplicate with file sink") {
+ withTempDir { output =>
+ withTempDir { checkpointDir =>
+ val outputPath = output.getAbsolutePath
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS().dropDuplicates()
+ val q = result.writeStream
+ .format("parquet")
+ .outputMode(Append)
+ .option("checkpointLocation", checkpointDir.getPath)
+ .start(outputPath)
+ try {
+ inputData.addData("a")
+ q.processAllAvailable()
+ checkDataset(spark.read.parquet(outputPath).as[String], "a")
+
+ inputData.addData("a") // Dropped
+ q.processAllAvailable()
+ checkDataset(spark.read.parquet(outputPath).as[String], "a")
+
+ inputData.addData("b")
+ q.processAllAvailable()
+ checkDataset(spark.read.parquet(outputPath).as[String], "a", "b")
+ } finally {
+ q.stop()
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
index 0524898b15..6cf4d51f99 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore
/** Class to check custom state types */
case class RunningCount(count: Long)
-class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll {
+class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
import testImplicits._
@@ -321,13 +321,6 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll {
CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count
)
}
-
- private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q =>
- val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
- assert(progressWithData.stateOperators(0).numRowsTotal === total, "incorrect total rows")
- assert(progressWithData.stateOperators(0).numRowsUpdated === updated, "incorrect updates rows")
- true
- }
}
object MapGroupsWithStateSuite {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
new file mode 100644
index 0000000000..894786c50e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+trait StateStoreMetricsTest extends StreamTest {
+
+ def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery =
+ AssertOnQuery { q =>
+ val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
+ assert(
+ progressWithData.stateOperators.map(_.numRowsTotal) === total,
+ "incorrect total rows")
+ assert(
+ progressWithData.stateOperators.map(_.numRowsUpdated) === updated,
+ "incorrect updates rows")
+ true
+ }
+
+ def assertNumStateRows(total: Long, updated: Long): AssertOnQuery =
+ assertNumStateRows(Seq(total), Seq(updated))
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 0296a2ade3..f44cfada29 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -338,7 +338,7 @@ class StreamSuite extends StreamTest {
.writeStream
.format("memory")
.queryName("testquery")
- .outputMode("complete")
+ .outputMode("append")
.start()
try {
query.processAllAvailable()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index eca2647dea..0c8015672b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -35,7 +35,7 @@ object FailureSinglton {
var firstTime = true
}
-class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll {
+class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
override def afterAll(): Unit = {
super.afterAll()