aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2017-03-08 13:18:07 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2017-03-08 13:18:07 -0800
commit1bf9012380de2aa7bdf39220b55748defde8b700 (patch)
tree3efc5be6f6506eef72a98c08132d95c0ce9f8fcd
parente9e2c612d58a19ddcb4b6abfb7389a4b0f7ef6f8 (diff)
downloadspark-1bf9012380de2aa7bdf39220b55748defde8b700.tar.gz
spark-1bf9012380de2aa7bdf39220b55748defde8b700.tar.bz2
spark-1bf9012380de2aa7bdf39220b55748defde8b700.zip
[SPARK-19858][SS] Add output mode to flatMapGroupsWithState and disallow invalid cases
## What changes were proposed in this pull request? Add a output mode parameter to `flatMapGroupsWithState` and just define `mapGroupsWithState` as `flatMapGroupsWithState(Update)`. `UnsupportedOperationChecker` is modified to disallow unsupported cases. - Batch mapGroupsWithState or flatMapGroupsWithState is always allowed. - For streaming (map/flatMap)GroupsWithState, see the following table: | Operators | Supported Query Output Mode | | ------------- | ------------- | | flatMapGroupsWithState(Update) without aggregation | Update | | flatMapGroupsWithState(Update) with aggregation | None | | flatMapGroupsWithState(Append) without aggregation | Append | | flatMapGroupsWithState(Append) before aggregation | Append, Update, Complete | | flatMapGroupsWithState(Append) after aggregation | None | | Multiple flatMapGroupsWithState(Append)s | Append | | Multiple mapGroupsWithStates | None | | Mxing mapGroupsWithStates and flatMapGroupsWithStates | None | | Other cases of multiple flatMapGroupsWithState | None | ## How was this patch tested? The added unit tests. Here are the tests related to (map/flatMap)GroupsWithState: ``` [info] - batch plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation: supported (1 millisecond) [info] - batch plan - flatMapGroupsWithState - multiple flatMapGroupsWithState(Append)s on batch relation: supported (0 milliseconds) [info] - batch plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation: supported (0 milliseconds) [info] - batch plan - flatMapGroupsWithState - multiple flatMapGroupsWithState(Update)s on batch relation: supported (0 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in update mode: supported (2 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in append mode: not supported (7 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in complete mode: not supported (5 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Append mode: not supported (11 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Update mode: not supported (5 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Complete mode: not supported (5 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation without aggregation in append mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation without aggregation in update mode: not supported (6 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Append mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Update mode: supported (0 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Complete mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation after aggregation in Append mode: not supported (6 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation after aggregation in Update mode: not supported (4 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation in complete mode: not supported (2 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation inside streaming relation in Append output mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation inside streaming relation in Update output mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation inside streaming relation in Append output mode: supported (0 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation inside streaming relation in Update output mode: supported (0 milliseconds) [info] - streaming plan - flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are in append mode: supported (2 milliseconds) [info] - streaming plan - flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some are not in append mode: not supported (7 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation without aggregation in append mode: not supported (3 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation without aggregation in complete mode: not supported (3 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Append mode: not supported (6 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Update mode: not supported (3 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Complete mode: not supported (4 milliseconds) [info] - streaming plan - mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are in append mode: not supported (4 milliseconds) [info] - streaming plan - mapGroupsWithState - mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation: not supported (4 milliseconds) ``` Author: Shixiong Zhu <shixiong@databricks.com> Closes #17197 from zsxwing/mapgroups-check.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala77
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala203
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala48
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala91
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala16
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala (renamed from sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala)46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala41
13 files changed, 485 insertions, 107 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 397f5cfe2a..a9ff61e0e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -51,6 +51,37 @@ object UnsupportedOperationChecker {
subplan.collect { case a: Aggregate if a.isStreaming => a }
}
+ val mapGroupsWithStates = plan.collect {
+ case f: FlatMapGroupsWithState if f.isStreaming && f.isMapGroupsWithState => f
+ }
+
+ // Disallow multiple `mapGroupsWithState`s.
+ if (mapGroupsWithStates.size >= 2) {
+ throwError(
+ "Multiple mapGroupsWithStates are not supported on a streaming DataFrames/Datasets")(plan)
+ }
+
+ val flatMapGroupsWithStates = plan.collect {
+ case f: FlatMapGroupsWithState if f.isStreaming && !f.isMapGroupsWithState => f
+ }
+
+ // Disallow mixing `mapGroupsWithState`s and `flatMapGroupsWithState`s
+ if (mapGroupsWithStates.nonEmpty && flatMapGroupsWithStates.nonEmpty) {
+ throwError(
+ "Mixing mapGroupsWithStates and flatMapGroupsWithStates are not supported on a " +
+ "streaming DataFrames/Datasets")(plan)
+ }
+
+ // Only allow multiple `FlatMapGroupsWithState(Append)`s in append mode.
+ if (flatMapGroupsWithStates.size >= 2 && (
+ outputMode != InternalOutputModes.Append ||
+ flatMapGroupsWithStates.exists(_.outputMode != InternalOutputModes.Append)
+ )) {
+ throwError(
+ "Multiple flatMapGroupsWithStates are not supported when they are not all in append mode" +
+ " or the output mode is not append on a streaming DataFrames/Datasets")(plan)
+ }
+
// Disallow multiple streaming aggregations
val aggregates = collectStreamingAggregates(plan)
@@ -116,9 +147,49 @@ object UnsupportedOperationChecker {
throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " +
"streaming DataFrames/Datasets")
- case m: MapGroupsWithState if collectStreamingAggregates(m).nonEmpty =>
- throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " +
- "streaming DataFrame/Dataset")
+ // mapGroupsWithState: Allowed only when no aggregation + Update output mode
+ case m: FlatMapGroupsWithState if m.isStreaming && m.isMapGroupsWithState =>
+ if (collectStreamingAggregates(plan).isEmpty) {
+ if (outputMode != InternalOutputModes.Update) {
+ throwError("mapGroupsWithState is not supported with " +
+ s"$outputMode output mode on a streaming DataFrame/Dataset")
+ } else {
+ // Allowed when no aggregation + Update output mode
+ }
+ } else {
+ throwError("mapGroupsWithState is not supported with aggregation " +
+ "on a streaming DataFrame/Dataset")
+ }
+
+ // flatMapGroupsWithState without aggregation
+ case m: FlatMapGroupsWithState
+ if m.isStreaming && collectStreamingAggregates(plan).isEmpty =>
+ m.outputMode match {
+ case InternalOutputModes.Update =>
+ if (outputMode != InternalOutputModes.Update) {
+ throwError("flatMapGroupsWithState in update mode is not supported with " +
+ s"$outputMode output mode on a streaming DataFrame/Dataset")
+ }
+ case InternalOutputModes.Append =>
+ if (outputMode != InternalOutputModes.Append) {
+ throwError("flatMapGroupsWithState in append mode is not supported with " +
+ s"$outputMode output mode on a streaming DataFrame/Dataset")
+ }
+ }
+
+ // flatMapGroupsWithState(Update) with aggregation
+ case m: FlatMapGroupsWithState
+ if m.isStreaming && m.outputMode == InternalOutputModes.Update
+ && collectStreamingAggregates(plan).nonEmpty =>
+ throwError("flatMapGroupsWithState in update mode is not supported with " +
+ "aggregation on a streaming DataFrame/Dataset")
+
+ // flatMapGroupsWithState(Append) with aggregation
+ case m: FlatMapGroupsWithState
+ if m.isStreaming && m.outputMode == InternalOutputModes.Append
+ && collectStreamingAggregates(m).nonEmpty =>
+ throwError("flatMapGroupsWithState in append mode is not supported after " +
+ s"aggregation on a streaming DataFrame/Dataset")
case d: Deduplicate if collectStreamingAggregates(d).nonEmpty =>
throwError("dropDuplicates is not supported after aggregation on a " +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 0be4823bbc..617239f56c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
+import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types._
object CatalystSerde {
@@ -317,13 +318,15 @@ case class MapGroups(
trait LogicalKeyedState[S]
/** Factory for constructing new `MapGroupsWithState` nodes. */
-object MapGroupsWithState {
+object FlatMapGroupsWithState {
def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder](
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
+ outputMode: OutputMode,
+ isMapGroupsWithState: Boolean,
child: LogicalPlan): LogicalPlan = {
- val mapped = new MapGroupsWithState(
+ val mapped = new FlatMapGroupsWithState(
func,
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
@@ -332,7 +335,9 @@ object MapGroupsWithState {
CatalystSerde.generateObjAttr[U],
encoderFor[S].resolveAndBind().deserializer,
encoderFor[S].namedExpressions,
- child)
+ outputMode,
+ child,
+ isMapGroupsWithState)
CatalystSerde.serialize[U](mapped)
}
}
@@ -350,8 +355,10 @@ object MapGroupsWithState {
* @param outputObjAttr used to define the output object
* @param stateDeserializer used to deserialize state before calling `func`
* @param stateSerializer used to serialize updated state after calling `func`
+ * @param outputMode the output mode of `func`
+ * @param isMapGroupsWithState whether it is created by the `mapGroupsWithState` method
*/
-case class MapGroupsWithState(
+case class FlatMapGroupsWithState(
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
@@ -360,7 +367,14 @@ case class MapGroupsWithState(
outputObjAttr: Attribute,
stateDeserializer: Expression,
stateSerializer: Seq[NamedExpression],
- child: LogicalPlan) extends UnaryNode with ObjectProducer
+ outputMode: OutputMode,
+ child: LogicalPlan,
+ isMapGroupsWithState: Boolean = false) extends UnaryNode with ObjectProducer {
+
+ if (isMapGroupsWithState) {
+ assert(outputMode == OutputMode.Update)
+ }
+}
/** Factory for constructing new `FlatMapGroupsInR` nodes. */
object FlatMapGroupsInR {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala
index 351bd6fff4..bdf2baf736 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala
@@ -44,4 +44,19 @@ private[sql] object InternalOutputModes {
* aggregations, it will be equivalent to `Append` mode.
*/
case object Update extends OutputMode
+
+
+ def apply(outputMode: String): OutputMode = {
+ outputMode.toLowerCase match {
+ case "append" =>
+ OutputMode.Append
+ case "complete" =>
+ OutputMode.Complete
+ case "update" =>
+ OutputMode.Update
+ case _ =>
+ throw new IllegalArgumentException(s"Unknown output mode $outputMode. " +
+ "Accepted output modes are 'append', 'complete', 'update'")
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index 82be69a0f7..200c39f43a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _}
+import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder}
@@ -138,29 +138,202 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
outputMode = Complete,
expectedMsgs = Seq("distinct aggregation"))
- // MapGroupsWithState: Not supported after a streaming aggregation
val att = new AttributeReference(name = "a", dataType = LongType)()
- assertSupportedInBatchPlan(
- "mapGroupsWithState - mapGroupsWithState on batch relation",
- MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation))
+ // FlatMapGroupsWithState: Both function modes equivalent and supported in batch.
+ for (funcMode <- Seq(Append, Update)) {
+ assertSupportedInBatchPlan(
+ s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation))
+
+ assertSupportedInBatchPlan(
+ s"flatMapGroupsWithState - multiple flatMapGroupsWithState($funcMode)s on batch relation",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode,
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation)))
+ }
+
+ // FlatMapGroupsWithState(Update) in streaming without aggregation
+ assertSupportedInStreamingPlan(
+ "flatMapGroupsWithState - flatMapGroupsWithState(Update) " +
+ "on streaming relation without aggregation in update mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation),
+ outputMode = Update)
+
+ assertNotSupportedInStreamingPlan(
+ "flatMapGroupsWithState - flatMapGroupsWithState(Update) " +
+ "on streaming relation without aggregation in append mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation),
+ outputMode = Append,
+ expectedMsgs = Seq("flatMapGroupsWithState in update mode", "Append"))
+
+ assertNotSupportedInStreamingPlan(
+ "flatMapGroupsWithState - flatMapGroupsWithState(Update) " +
+ "on streaming relation without aggregation in complete mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation),
+ outputMode = Complete,
+ // Disallowed by the aggregation check but let's still keep this test in case it's broken in
+ // future.
+ expectedMsgs = Seq("Complete"))
+
+ // FlatMapGroupsWithState(Update) in streaming with aggregation
+ for (outputMode <- Seq(Append, Update, Complete)) {
+ assertNotSupportedInStreamingPlan(
+ "flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation " +
+ s"with aggregation in $outputMode mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update,
+ Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)),
+ outputMode = outputMode,
+ expectedMsgs = Seq("flatMapGroupsWithState in update mode", "with aggregation"))
+ }
+ // FlatMapGroupsWithState(Append) in streaming without aggregation
assertSupportedInStreamingPlan(
- "mapGroupsWithState - mapGroupsWithState on streaming relation before aggregation",
- MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), streamRelation),
+ "flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
+ "on streaming relation without aggregation in append mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation),
outputMode = Append)
assertNotSupportedInStreamingPlan(
- "mapGroupsWithState - mapGroupsWithState on streaming relation after aggregation",
- MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att),
- Aggregate(Nil, aggExprs("c"), streamRelation)),
+ "flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
+ "on streaming relation without aggregation in update mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation),
+ outputMode = Update,
+ expectedMsgs = Seq("flatMapGroupsWithState in append mode", "update"))
+
+ // FlatMapGroupsWithState(Append) in streaming with aggregation
+ for (outputMode <- Seq(Append, Update, Complete)) {
+ assertSupportedInStreamingPlan(
+ "flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
+ s"on streaming relation before aggregation in $outputMode mode",
+ Aggregate(
+ Seq(attributeWithWatermark),
+ aggExprs("c"),
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)),
+ outputMode = outputMode)
+ }
+
+ for (outputMode <- Seq(Append, Update)) {
+ assertNotSupportedInStreamingPlan(
+ "flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
+ s"on streaming relation after aggregation in $outputMode mode",
+ FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append,
+ Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)),
+ outputMode = outputMode,
+ expectedMsgs = Seq("flatMapGroupsWithState", "after aggregation"))
+ }
+
+ assertNotSupportedInStreamingPlan(
+ "flatMapGroupsWithState - " +
+ "flatMapGroupsWithState(Update) on streaming relation in complete mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation),
outputMode = Complete,
- expectedMsgs = Seq("(map/flatMap)GroupsWithState"))
+ // Disallowed by the aggregation check but let's still keep this test in case it's broken in
+ // future.
+ expectedMsgs = Seq("Complete"))
+ // FlatMapGroupsWithState inside batch relation should always be allowed
+ for (funcMode <- Seq(Append, Update)) {
+ for (outputMode <- Seq(Append, Update)) { // Complete is not supported without aggregation
+ assertSupportedInStreamingPlan(
+ s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation inside " +
+ s"streaming relation in $outputMode output mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation),
+ outputMode = outputMode
+ )
+ }
+ }
+
+ // multiple FlatMapGroupsWithStates
assertSupportedInStreamingPlan(
- "mapGroupsWithState - mapGroupsWithState on batch relation inside streaming relation",
- MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation),
- outputMode = Append
- )
+ "flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are " +
+ "in append mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append,
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)),
+ outputMode = Append)
+
+ assertNotSupportedInStreamingPlan(
+ "flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some" +
+ " are not in append mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update,
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)),
+ outputMode = Append,
+ expectedMsgs = Seq("multiple flatMapGroupsWithState", "append"))
+
+ // mapGroupsWithState
+ assertNotSupportedInStreamingPlan(
+ "mapGroupsWithState - mapGroupsWithState " +
+ "on streaming relation without aggregation in append mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation,
+ isMapGroupsWithState = true),
+ outputMode = Append,
+ // Disallowed by the aggregation check but let's still keep this test in case it's broken in
+ // future.
+ expectedMsgs = Seq("mapGroupsWithState", "append"))
+
+ assertNotSupportedInStreamingPlan(
+ "mapGroupsWithState - mapGroupsWithState " +
+ "on streaming relation without aggregation in complete mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation,
+ isMapGroupsWithState = true),
+ outputMode = Complete,
+ // Disallowed by the aggregation check but let's still keep this test in case it's broken in
+ // future.
+ expectedMsgs = Seq("Complete"))
+
+ for (outputMode <- Seq(Append, Update, Complete)) {
+ assertNotSupportedInStreamingPlan(
+ "mapGroupsWithState - mapGroupsWithState on streaming relation " +
+ s"with aggregation in $outputMode mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update,
+ Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation),
+ isMapGroupsWithState = true),
+ outputMode = outputMode,
+ expectedMsgs = Seq("mapGroupsWithState", "with aggregation"))
+ }
+
+ // multiple mapGroupsWithStates
+ assertNotSupportedInStreamingPlan(
+ "mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are " +
+ "in append mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update,
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation,
+ isMapGroupsWithState = true),
+ isMapGroupsWithState = true),
+ outputMode = Append,
+ expectedMsgs = Seq("multiple mapGroupsWithStates"))
+
+ // mixing mapGroupsWithStates and flatMapGroupsWithStates
+ assertNotSupportedInStreamingPlan(
+ "mapGroupsWithState - " +
+ "mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update,
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation,
+ isMapGroupsWithState = false),
+ isMapGroupsWithState = true),
+ outputMode = Append,
+ expectedMsgs = Seq("Mixing mapGroupsWithStates and flatMapGroupsWithStates"))
// Deduplicate
assertSupportedInStreamingPlan(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala
new file mode 100644
index 0000000000..201dac35ed
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.streaming
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.streaming.OutputMode
+
+class InternalOutputModesSuite extends SparkFunSuite {
+
+ test("supported strings") {
+ def testMode(outputMode: String, expected: OutputMode): Unit = {
+ assert(InternalOutputModes(outputMode) === expected)
+ }
+
+ testMode("append", OutputMode.Append)
+ testMode("Append", OutputMode.Append)
+ testMode("complete", OutputMode.Complete)
+ testMode("Complete", OutputMode.Complete)
+ testMode("update", OutputMode.Update)
+ testMode("Update", OutputMode.Update)
+ }
+
+ test("unsupported strings") {
+ def testMode(outputMode: String): Unit = {
+ val acceptedModes = Seq("append", "update", "complete")
+ val e = intercept[IllegalArgumentException](InternalOutputModes(outputMode))
+ (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s =>
+ assert(e.getMessage.toLowerCase.contains(s.toLowerCase))
+ }
+ }
+ testMode("Xyz")
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 3a548c251f..ab956ffd64 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -24,8 +24,10 @@ import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct}
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.ReduceAggregator
+import org.apache.spark.sql.streaming.OutputMode
/**
* :: Experimental ::
@@ -238,8 +240,16 @@ class KeyValueGroupedDataset[K, V] private[sql](
@InterfaceStability.Evolving
def mapGroupsWithState[S: Encoder, U: Encoder](
func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = {
- flatMapGroupsWithState[S, U](
- (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)))
+ val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))
+ Dataset[U](
+ sparkSession,
+ FlatMapGroupsWithState[K, V, S, U](
+ flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
+ groupingAttributes,
+ dataAttributes,
+ OutputMode.Update,
+ isMapGroupsWithState = true,
+ child = logicalPlan))
}
/**
@@ -267,8 +277,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
func: MapGroupsWithStateFunction[K, V, S, U],
stateEncoder: Encoder[S],
outputEncoder: Encoder[U]): Dataset[U] = {
- flatMapGroupsWithState[S, U](
- (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s))
+ mapGroupsWithState[S, U](
+ (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s)
)(stateEncoder, outputEncoder)
}
@@ -284,6 +294,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func Function to be called on every group.
+ * @param outputMode The output mode of the function.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 2.1.1
@@ -291,14 +303,44 @@ class KeyValueGroupedDataset[K, V] private[sql](
@Experimental
@InterfaceStability.Evolving
def flatMapGroupsWithState[S: Encoder, U: Encoder](
- func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = {
+ func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: OutputMode): Dataset[U] = {
+ if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) {
+ throw new IllegalArgumentException("The output mode of function should be append or update")
+ }
Dataset[U](
sparkSession,
- MapGroupsWithState[K, V, S, U](
+ FlatMapGroupsWithState[K, V, S, U](
func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
groupingAttributes,
dataAttributes,
- logicalPlan))
+ outputMode,
+ isMapGroupsWithState = false,
+ child = logicalPlan))
+ }
+
+ /**
+ * ::Experimental::
+ * (Scala-specific)
+ * Applies the given function to each group of data, while maintaining a user-defined per-group
+ * state. The result Dataset will represent the objects returned by the function.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+ * updates to each group's state will be saved across invocations.
+ * See [[KeyedState]] for more details.
+ *
+ * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func Function to be called on every group.
+ * @param outputMode The output mode of the function.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 2.1.1
+ */
+ @Experimental
+ @InterfaceStability.Evolving
+ def flatMapGroupsWithState[S: Encoder, U: Encoder](
+ func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: String): Dataset[U] = {
+ flatMapGroupsWithState(func, InternalOutputModes(outputMode))
}
/**
@@ -314,6 +356,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
* @param func Function to be called on every group.
+ * @param outputMode The output mode of the function.
* @param stateEncoder Encoder for the state type.
* @param outputEncoder Encoder for the output type.
*
@@ -324,14 +367,46 @@ class KeyValueGroupedDataset[K, V] private[sql](
@InterfaceStability.Evolving
def flatMapGroupsWithState[S, U](
func: FlatMapGroupsWithStateFunction[K, V, S, U],
+ outputMode: OutputMode,
stateEncoder: Encoder[S],
outputEncoder: Encoder[U]): Dataset[U] = {
flatMapGroupsWithState[S, U](
- (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala
+ (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala,
+ outputMode
)(stateEncoder, outputEncoder)
}
/**
+ * ::Experimental::
+ * (Java-specific)
+ * Applies the given function to each group of data, while maintaining a user-defined per-group
+ * state. The result Dataset will represent the objects returned by the function.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+ * updates to each group's state will be saved across invocations.
+ * See [[KeyedState]] for more details.
+ *
+ * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func Function to be called on every group.
+ * @param outputMode The output mode of the function.
+ * @param stateEncoder Encoder for the state type.
+ * @param outputEncoder Encoder for the output type.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 2.1.1
+ */
+ @Experimental
+ @InterfaceStability.Evolving
+ def flatMapGroupsWithState[S, U](
+ func: FlatMapGroupsWithStateFunction[K, V, S, U],
+ outputMode: String,
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U]): Dataset[U] = {
+ flatMapGroupsWithState(func, InternalOutputModes(outputMode), stateEncoder, outputEncoder)
+ }
+
+ /**
* (Scala-specific)
* Reduces the elements of each group of data using the specified binary function.
* The given function must be commutative and associative or the result may be non-deterministic.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 20bf4925db..0f7aa3709c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -326,14 +326,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
/**
- * Strategy to convert MapGroupsWithState logical operator to physical operator
+ * Strategy to convert [[FlatMapGroupsWithState]] logical operator to physical operator
* in streaming plans. Conversion for batch plans is handled by [[BasicOperators]].
*/
object MapGroupsWithStateStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case MapGroupsWithState(
- f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) =>
- val execPlan = MapGroupsWithStateExec(
+ case FlatMapGroupsWithState(
+ f,
+ keyDeser,
+ valueDeser,
+ groupAttr,
+ dataAttr,
+ outputAttr,
+ stateDeser,
+ stateSer,
+ outputMode,
+ child,
+ _) =>
+ val execPlan = FlatMapGroupsWithStateExec(
f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer,
planLater(child))
execPlan :: Nil
@@ -381,7 +391,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil
case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
- case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) =>
+ case logical.FlatMapGroupsWithState(
+ f, key, value, grouping, data, output, _, _, _, child, _) =>
execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
execution.CoGroupExec(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index ffdcd9b19d..610ce5e1eb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -103,11 +103,11 @@ class IncrementalExecution(
child,
Some(stateId),
Some(currentEventTimeWatermark))
- case MapGroupsWithStateExec(
+ case FlatMapGroupsWithStateExec(
f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) =>
val stateId =
OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
- MapGroupsWithStateExec(
+ FlatMapGroupsWithStateExec(
f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index cbf656a204..c3075a3eac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -257,8 +257,8 @@ case class StateStoreSaveExec(
}
-/** Physical operator for executing streaming mapGroupsWithState. */
-case class MapGroupsWithStateExec(
+/** Physical operator for executing streaming flatMapGroupsWithState. */
+case class FlatMapGroupsWithStateExec(
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 0f7a33723c..c8fda8cd83 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.streaming
import scala.collection.JavaConverters._
import org.apache.spark.annotation.{Experimental, InterfaceStability}
-import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, ForeachWriter}
-import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter}
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink}
@@ -69,17 +69,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
* @since 2.0.0
*/
def outputMode(outputMode: String): DataStreamWriter[T] = {
- this.outputMode = outputMode.toLowerCase match {
- case "append" =>
- OutputMode.Append
- case "complete" =>
- OutputMode.Complete
- case "update" =>
- OutputMode.Update
- case _ =>
- throw new IllegalArgumentException(s"Unknown output mode $outputMode. " +
- "Accepted output modes are 'append', 'complete', 'update'")
- }
+ this.outputMode = InternalOutputModes(outputMode)
this
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index e3b0e37cca..d06e35bb44 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -23,6 +23,7 @@ import java.sql.Date;
import java.sql.Timestamp;
import java.util.*;
+import org.apache.spark.sql.streaming.OutputMode;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
@@ -205,6 +206,7 @@ public class JavaDatasetSuite implements Serializable {
}
return Collections.singletonList(sb.toString()).iterator();
},
+ OutputMode.Append(),
Encoders.LONG(),
Encoders.STRING());
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 6cf4d51f99..902b842e97 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore
/** Class to check custom state types */
case class RunningCount(count: Long)
-class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
+class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
import testImplicits._
@@ -119,9 +119,9 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
val result =
inputData.toDS()
.groupByKey(x => x)
- .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str)
+ .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str)
- testStream(result, Append)(
+ testStream(result, Update)(
AddData(inputData, "a"),
CheckLastBatch(("a", "1")),
assertNumStateRows(total = 1, updated = 1),
@@ -162,9 +162,9 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
val result =
inputData.toDS()
.groupByKey(x => x)
- .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str)
+ .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str)
- testStream(result, Append)(
+ testStream(result, Update)(
AddData(inputData, "a", "a", "b"),
CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")),
StopStream,
@@ -185,7 +185,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
Iterator((key, values.size))
}
checkAnswer(
- Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF,
+ Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc, Update).toDF,
Seq(("a", 2), ("b", 1)).toDF)
}
@@ -210,7 +210,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
.groupByKey(x => x)
.mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
- testStream(result, Append)(
+ testStream(result, Update)(
AddData(inputData, "a"),
CheckLastBatch(("a", "1")),
assertNumStateRows(total = 1, updated = 1),
@@ -230,7 +230,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
)
}
- test("mapGroupsWithState - streaming + aggregation") {
+ test("flatMapGroupsWithState - streaming + aggregation") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
@@ -238,10 +238,10 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
val count = state.getOption.map(_.count).getOrElse(0L) + values.size
if (count == 3) {
state.remove()
- (key, "-1")
+ Iterator(key -> "-1")
} else {
state.update(RunningCount(count))
- (key, count.toString)
+ Iterator(key -> count.toString)
}
}
@@ -249,7 +249,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
val result =
inputData.toDS()
.groupByKey(x => x)
- .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
+ .flatMapGroupsWithState(stateFunc, Append) // Types = State: MyState, Out: (Str, Str)
.groupByKey(_._1)
.count()
@@ -290,7 +290,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
testQuietly("StateStore.abort on task failure handling") {
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
- if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure")
+ if (FlatMapGroupsWithStateSuite.failInTask) throw new Exception("expected failure")
val count = state.getOption.map(_.count).getOrElse(0L) + values.size
state.update(RunningCount(count))
(key, count)
@@ -303,11 +303,11 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
.mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q =>
- MapGroupsWithStateSuite.failInTask = value
+ FlatMapGroupsWithStateSuite.failInTask = value
true
}
- testStream(result, Append)(
+ testStream(result, Update)(
setFailInTask(false),
AddData(inputData, "a"),
CheckLastBatch(("a", 1L)),
@@ -321,8 +321,24 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count
)
}
+
+ test("disallow complete mode") {
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+ Iterator[String]()
+ }
+
+ var e = intercept[IllegalArgumentException] {
+ MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, Complete)
+ }
+ assert(e.getMessage === "The output mode of function should be append or update")
+
+ e = intercept[IllegalArgumentException] {
+ MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, "complete")
+ }
+ assert(e.getMessage === "The output mode of function should be append or update")
+ }
}
-object MapGroupsWithStateSuite {
+object FlatMapGroupsWithStateSuite {
var failInTask = true
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
index 0470411a0f..f61dcdcbcf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
@@ -24,8 +24,7 @@ import scala.concurrent.duration._
import org.apache.hadoop.fs.Path
import org.mockito.Mockito._
-import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
-import org.scalatest.PrivateMethodTester.PrivateMethod
+import org.scalatest.BeforeAndAfter
import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming._
@@ -107,7 +106,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
}
}
-class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with PrivateMethodTester {
+class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
private def newMetadataDir =
Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
@@ -390,42 +389,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with Pr
private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath
- test("supported strings in outputMode(string)") {
- val outputModeMethod = PrivateMethod[OutputMode]('outputMode)
-
- def testMode(outputMode: String, expected: OutputMode): Unit = {
- val df = spark.readStream
- .format("org.apache.spark.sql.streaming.test")
- .load()
- val w = df.writeStream
- w.outputMode(outputMode)
- val setOutputMode = w invokePrivate outputModeMethod()
- assert(setOutputMode === expected)
- }
-
- testMode("append", OutputMode.Append)
- testMode("Append", OutputMode.Append)
- testMode("complete", OutputMode.Complete)
- testMode("Complete", OutputMode.Complete)
- testMode("update", OutputMode.Update)
- testMode("Update", OutputMode.Update)
- }
-
- test("unsupported strings in outputMode(string)") {
- def testMode(outputMode: String): Unit = {
- val acceptedModes = Seq("append", "update", "complete")
- val df = spark.readStream
- .format("org.apache.spark.sql.streaming.test")
- .load()
- val w = df.writeStream
- val e = intercept[IllegalArgumentException](w.outputMode(outputMode))
- (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s =>
- assert(e.getMessage.toLowerCase.contains(s.toLowerCase))
- }
- }
- testMode("Xyz")
- }
-
test("check foreach() catches null writers") {
val df = spark.readStream
.format("org.apache.spark.sql.streaming.test")