diff options
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala | 28 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala | 19 |
2 files changed, 39 insertions, 8 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index d92529748b..cbf656a204 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -68,7 +68,7 @@ trait StateStoreWriter extends StatefulOperator { } /** An operator that supports watermark. */ -trait WatermarkSupport extends SparkPlan { +trait WatermarkSupport extends UnaryExecNode { /** The keys that may have a watermark attribute. */ def keyExpressions: Seq[Attribute] @@ -76,8 +76,8 @@ trait WatermarkSupport extends SparkPlan { /** The watermark value. */ def eventTimeWatermark: Option[Long] - /** Generate a predicate that matches data older than the watermark */ - lazy val watermarkPredicate: Option[Predicate] = { + /** Generate an expression that matches data older than the watermark */ + lazy val watermarkExpression: Option[Expression] = { val optionalWatermarkAttribute = keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)) @@ -96,9 +96,19 @@ trait WatermarkSupport extends SparkPlan { } logInfo(s"Filtering state store on: $evictionExpression") - newPredicate(evictionExpression, keyExpressions) + evictionExpression } } + + /** Generate a predicate based on keys that matches data older than the watermark */ + lazy val watermarkPredicateForKeys: Option[Predicate] = + watermarkExpression.map(newPredicate(_, keyExpressions)) + + /** + * Generate a predicate based on the child output that matches data older than the watermark. + */ + lazy val watermarkPredicate: Option[Predicate] = + watermarkExpression.map(newPredicate(_, child.output)) } /** @@ -192,7 +202,7 @@ case class StateStoreSaveExec( } // Assumption: Append mode can be done only when watermark has been specified - store.remove(watermarkPredicate.get.eval _) + store.remove(watermarkPredicateForKeys.get.eval _) store.commit() numTotalStateRows += store.numKeys() @@ -215,7 +225,9 @@ case class StateStoreSaveExec( override def hasNext: Boolean = { if (!baseIterator.hasNext) { // Remove old aggregates if watermark specified - if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval _) + if (watermarkPredicateForKeys.nonEmpty) { + store.remove(watermarkPredicateForKeys.get.eval _) + } store.commit() numTotalStateRows += store.numKeys() false @@ -361,7 +373,7 @@ case class StreamingDeduplicateExec( val numUpdatedStateRows = longMetric("numUpdatedStateRows") val baseIterator = watermarkPredicate match { - case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + case Some(predicate) => iter.filter(row => !predicate.eval(row)) case None => iter } @@ -381,7 +393,7 @@ case class StreamingDeduplicateExec( } CompletionIterator[InternalRow, Iterator[InternalRow]](result, { - watermarkPredicate.foreach(f => store.remove(f.eval _)) + watermarkPredicateForKeys.foreach(f => store.remove(f.eval _)) store.commit() numTotalStateRows += store.numKeys() }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index 7ea716231e..a15c2cff93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -249,4 +249,23 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { } } } + + test("SPARK-19841: watermarkPredicate should filter based on keys") { + val input = MemoryStream[(Int, Int)] + val df = input.toDS.toDF("time", "id") + .withColumn("time", $"time".cast("timestamp")) + .withWatermark("time", "1 second") + .dropDuplicates("id", "time") // Change the column positions + .select($"id") + testStream(df)( + AddData(input, 1 -> 1, 1 -> 1, 1 -> 2), + CheckLastBatch(1, 2), + AddData(input, 1 -> 1, 2 -> 3, 2 -> 4), + CheckLastBatch(3, 4), + AddData(input, 1 -> 0, 1 -> 1, 3 -> 5, 3 -> 6), // Drop (1 -> 0, 1 -> 1) due to watermark + CheckLastBatch(5, 6), + AddData(input, 1 -> 0, 4 -> 7), // Drop (1 -> 0) due to watermark + CheckLastBatch(7) + ) + } } |