diff options
author | Shixiong Zhu <shixiong@databricks.com> | 2017-03-08 13:18:07 -0800 |
---|---|---|
committer | Tathagata Das <tathagata.das1565@gmail.com> | 2017-03-08 13:18:07 -0800 |
commit | 1bf9012380de2aa7bdf39220b55748defde8b700 (patch) | |
tree | 3efc5be6f6506eef72a98c08132d95c0ce9f8fcd | |
parent | e9e2c612d58a19ddcb4b6abfb7389a4b0f7ef6f8 (diff) | |
download | spark-1bf9012380de2aa7bdf39220b55748defde8b700.tar.gz spark-1bf9012380de2aa7bdf39220b55748defde8b700.tar.bz2 spark-1bf9012380de2aa7bdf39220b55748defde8b700.zip |
[SPARK-19858][SS] Add output mode to flatMapGroupsWithState and disallow invalid cases
## What changes were proposed in this pull request?
Add a output mode parameter to `flatMapGroupsWithState` and just define `mapGroupsWithState` as `flatMapGroupsWithState(Update)`.
`UnsupportedOperationChecker` is modified to disallow unsupported cases.
- Batch mapGroupsWithState or flatMapGroupsWithState is always allowed.
- For streaming (map/flatMap)GroupsWithState, see the following table:
| Operators | Supported Query Output Mode |
| ------------- | ------------- |
| flatMapGroupsWithState(Update) without aggregation | Update |
| flatMapGroupsWithState(Update) with aggregation | None |
| flatMapGroupsWithState(Append) without aggregation | Append |
| flatMapGroupsWithState(Append) before aggregation | Append, Update, Complete |
| flatMapGroupsWithState(Append) after aggregation | None |
| Multiple flatMapGroupsWithState(Append)s | Append |
| Multiple mapGroupsWithStates | None |
| Mxing mapGroupsWithStates and flatMapGroupsWithStates | None |
| Other cases of multiple flatMapGroupsWithState | None |
## How was this patch tested?
The added unit tests. Here are the tests related to (map/flatMap)GroupsWithState:
```
[info] - batch plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation: supported (1 millisecond)
[info] - batch plan - flatMapGroupsWithState - multiple flatMapGroupsWithState(Append)s on batch relation: supported (0 milliseconds)
[info] - batch plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation: supported (0 milliseconds)
[info] - batch plan - flatMapGroupsWithState - multiple flatMapGroupsWithState(Update)s on batch relation: supported (0 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in update mode: supported (2 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in append mode: not supported (7 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in complete mode: not supported (5 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Append mode: not supported (11 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Update mode: not supported (5 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Complete mode: not supported (5 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation without aggregation in append mode: supported (1 millisecond)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation without aggregation in update mode: not supported (6 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Append mode: supported (1 millisecond)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Update mode: supported (0 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Complete mode: supported (1 millisecond)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation after aggregation in Append mode: not supported (6 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation after aggregation in Update mode: not supported (4 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation in complete mode: not supported (2 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation inside streaming relation in Append output mode: supported (1 millisecond)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation inside streaming relation in Update output mode: supported (1 millisecond)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation inside streaming relation in Append output mode: supported (0 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation inside streaming relation in Update output mode: supported (0 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are in append mode: supported (2 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some are not in append mode: not supported (7 milliseconds)
[info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation without aggregation in append mode: not supported (3 milliseconds)
[info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation without aggregation in complete mode: not supported (3 milliseconds)
[info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Append mode: not supported (6 milliseconds)
[info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Update mode: not supported (3 milliseconds)
[info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Complete mode: not supported (4 milliseconds)
[info] - streaming plan - mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are in append mode: not supported (4 milliseconds)
[info] - streaming plan - mapGroupsWithState - mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation: not supported (4 milliseconds)
```
Author: Shixiong Zhu <shixiong@databricks.com>
Closes #17197 from zsxwing/mapgroups-check.
13 files changed, 485 insertions, 107 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 397f5cfe2a..a9ff61e0e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -51,6 +51,37 @@ object UnsupportedOperationChecker { subplan.collect { case a: Aggregate if a.isStreaming => a } } + val mapGroupsWithStates = plan.collect { + case f: FlatMapGroupsWithState if f.isStreaming && f.isMapGroupsWithState => f + } + + // Disallow multiple `mapGroupsWithState`s. + if (mapGroupsWithStates.size >= 2) { + throwError( + "Multiple mapGroupsWithStates are not supported on a streaming DataFrames/Datasets")(plan) + } + + val flatMapGroupsWithStates = plan.collect { + case f: FlatMapGroupsWithState if f.isStreaming && !f.isMapGroupsWithState => f + } + + // Disallow mixing `mapGroupsWithState`s and `flatMapGroupsWithState`s + if (mapGroupsWithStates.nonEmpty && flatMapGroupsWithStates.nonEmpty) { + throwError( + "Mixing mapGroupsWithStates and flatMapGroupsWithStates are not supported on a " + + "streaming DataFrames/Datasets")(plan) + } + + // Only allow multiple `FlatMapGroupsWithState(Append)`s in append mode. + if (flatMapGroupsWithStates.size >= 2 && ( + outputMode != InternalOutputModes.Append || + flatMapGroupsWithStates.exists(_.outputMode != InternalOutputModes.Append) + )) { + throwError( + "Multiple flatMapGroupsWithStates are not supported when they are not all in append mode" + + " or the output mode is not append on a streaming DataFrames/Datasets")(plan) + } + // Disallow multiple streaming aggregations val aggregates = collectStreamingAggregates(plan) @@ -116,9 +147,49 @@ object UnsupportedOperationChecker { throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " + "streaming DataFrames/Datasets") - case m: MapGroupsWithState if collectStreamingAggregates(m).nonEmpty => - throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " + - "streaming DataFrame/Dataset") + // mapGroupsWithState: Allowed only when no aggregation + Update output mode + case m: FlatMapGroupsWithState if m.isStreaming && m.isMapGroupsWithState => + if (collectStreamingAggregates(plan).isEmpty) { + if (outputMode != InternalOutputModes.Update) { + throwError("mapGroupsWithState is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + } else { + // Allowed when no aggregation + Update output mode + } + } else { + throwError("mapGroupsWithState is not supported with aggregation " + + "on a streaming DataFrame/Dataset") + } + + // flatMapGroupsWithState without aggregation + case m: FlatMapGroupsWithState + if m.isStreaming && collectStreamingAggregates(plan).isEmpty => + m.outputMode match { + case InternalOutputModes.Update => + if (outputMode != InternalOutputModes.Update) { + throwError("flatMapGroupsWithState in update mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + } + case InternalOutputModes.Append => + if (outputMode != InternalOutputModes.Append) { + throwError("flatMapGroupsWithState in append mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + } + } + + // flatMapGroupsWithState(Update) with aggregation + case m: FlatMapGroupsWithState + if m.isStreaming && m.outputMode == InternalOutputModes.Update + && collectStreamingAggregates(plan).nonEmpty => + throwError("flatMapGroupsWithState in update mode is not supported with " + + "aggregation on a streaming DataFrame/Dataset") + + // flatMapGroupsWithState(Append) with aggregation + case m: FlatMapGroupsWithState + if m.isStreaming && m.outputMode == InternalOutputModes.Append + && collectStreamingAggregates(m).nonEmpty => + throwError("flatMapGroupsWithState in append mode is not supported after " + + s"aggregation on a streaming DataFrame/Dataset") case d: Deduplicate if collectStreamingAggregates(d).nonEmpty => throwError("dropDuplicates is not supported after aggregation on a " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 0be4823bbc..617239f56c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types._ object CatalystSerde { @@ -317,13 +318,15 @@ case class MapGroups( trait LogicalKeyedState[S] /** Factory for constructing new `MapGroupsWithState` nodes. */ -object MapGroupsWithState { +object FlatMapGroupsWithState { def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], + outputMode: OutputMode, + isMapGroupsWithState: Boolean, child: LogicalPlan): LogicalPlan = { - val mapped = new MapGroupsWithState( + val mapped = new FlatMapGroupsWithState( func, UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes), @@ -332,7 +335,9 @@ object MapGroupsWithState { CatalystSerde.generateObjAttr[U], encoderFor[S].resolveAndBind().deserializer, encoderFor[S].namedExpressions, - child) + outputMode, + child, + isMapGroupsWithState) CatalystSerde.serialize[U](mapped) } } @@ -350,8 +355,10 @@ object MapGroupsWithState { * @param outputObjAttr used to define the output object * @param stateDeserializer used to deserialize state before calling `func` * @param stateSerializer used to serialize updated state after calling `func` + * @param outputMode the output mode of `func` + * @param isMapGroupsWithState whether it is created by the `mapGroupsWithState` method */ -case class MapGroupsWithState( +case class FlatMapGroupsWithState( func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], keyDeserializer: Expression, valueDeserializer: Expression, @@ -360,7 +367,14 @@ case class MapGroupsWithState( outputObjAttr: Attribute, stateDeserializer: Expression, stateSerializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectProducer + outputMode: OutputMode, + child: LogicalPlan, + isMapGroupsWithState: Boolean = false) extends UnaryNode with ObjectProducer { + + if (isMapGroupsWithState) { + assert(outputMode == OutputMode.Update) + } +} /** Factory for constructing new `FlatMapGroupsInR` nodes. */ object FlatMapGroupsInR { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala index 351bd6fff4..bdf2baf736 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala @@ -44,4 +44,19 @@ private[sql] object InternalOutputModes { * aggregations, it will be equivalent to `Append` mode. */ case object Update extends OutputMode + + + def apply(outputMode: String): OutputMode = { + outputMode.toLowerCase match { + case "append" => + OutputMode.Append + case "complete" => + OutputMode.Complete + case "update" => + OutputMode.Update + case _ => + throw new IllegalArgumentException(s"Unknown output mode $outputMode. " + + "Accepted output modes are 'append', 'complete', 'update'") + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 82be69a0f7..200c39f43a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _} +import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder} @@ -138,29 +138,202 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Complete, expectedMsgs = Seq("distinct aggregation")) - // MapGroupsWithState: Not supported after a streaming aggregation val att = new AttributeReference(name = "a", dataType = LongType)() - assertSupportedInBatchPlan( - "mapGroupsWithState - mapGroupsWithState on batch relation", - MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation)) + // FlatMapGroupsWithState: Both function modes equivalent and supported in batch. + for (funcMode <- Seq(Append, Update)) { + assertSupportedInBatchPlan( + s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation)) + + assertSupportedInBatchPlan( + s"flatMapGroupsWithState - multiple flatMapGroupsWithState($funcMode)s on batch relation", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation))) + } + + // FlatMapGroupsWithState(Update) in streaming without aggregation + assertSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + + "on streaming relation without aggregation in update mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation), + outputMode = Update) + + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + + "on streaming relation without aggregation in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation), + outputMode = Append, + expectedMsgs = Seq("flatMapGroupsWithState in update mode", "Append")) + + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + + "on streaming relation without aggregation in complete mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation), + outputMode = Complete, + // Disallowed by the aggregation check but let's still keep this test in case it's broken in + // future. + expectedMsgs = Seq("Complete")) + + // FlatMapGroupsWithState(Update) in streaming with aggregation + for (outputMode <- Seq(Append, Update, Complete)) { + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation " + + s"with aggregation in $outputMode mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), + outputMode = outputMode, + expectedMsgs = Seq("flatMapGroupsWithState in update mode", "with aggregation")) + } + // FlatMapGroupsWithState(Append) in streaming without aggregation assertSupportedInStreamingPlan( - "mapGroupsWithState - mapGroupsWithState on streaming relation before aggregation", - MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), streamRelation), + "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + + "on streaming relation without aggregation in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation), outputMode = Append) assertNotSupportedInStreamingPlan( - "mapGroupsWithState - mapGroupsWithState on streaming relation after aggregation", - MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), - Aggregate(Nil, aggExprs("c"), streamRelation)), + "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + + "on streaming relation without aggregation in update mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation), + outputMode = Update, + expectedMsgs = Seq("flatMapGroupsWithState in append mode", "update")) + + // FlatMapGroupsWithState(Append) in streaming with aggregation + for (outputMode <- Seq(Append, Update, Complete)) { + assertSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + + s"on streaming relation before aggregation in $outputMode mode", + Aggregate( + Seq(attributeWithWatermark), + aggExprs("c"), + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)), + outputMode = outputMode) + } + + for (outputMode <- Seq(Append, Update)) { + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + + s"on streaming relation after aggregation in $outputMode mode", + FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), + outputMode = outputMode, + expectedMsgs = Seq("flatMapGroupsWithState", "after aggregation")) + } + + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - " + + "flatMapGroupsWithState(Update) on streaming relation in complete mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation), outputMode = Complete, - expectedMsgs = Seq("(map/flatMap)GroupsWithState")) + // Disallowed by the aggregation check but let's still keep this test in case it's broken in + // future. + expectedMsgs = Seq("Complete")) + // FlatMapGroupsWithState inside batch relation should always be allowed + for (funcMode <- Seq(Append, Update)) { + for (outputMode <- Seq(Append, Update)) { // Complete is not supported without aggregation + assertSupportedInStreamingPlan( + s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation inside " + + s"streaming relation in $outputMode output mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation), + outputMode = outputMode + ) + } + } + + // multiple FlatMapGroupsWithStates assertSupportedInStreamingPlan( - "mapGroupsWithState - mapGroupsWithState on batch relation inside streaming relation", - MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation), - outputMode = Append - ) + "flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are " + + "in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)), + outputMode = Append) + + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some" + + " are not in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)), + outputMode = Append, + expectedMsgs = Seq("multiple flatMapGroupsWithState", "append")) + + // mapGroupsWithState + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState " + + "on streaming relation without aggregation in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation, + isMapGroupsWithState = true), + outputMode = Append, + // Disallowed by the aggregation check but let's still keep this test in case it's broken in + // future. + expectedMsgs = Seq("mapGroupsWithState", "append")) + + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState " + + "on streaming relation without aggregation in complete mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation, + isMapGroupsWithState = true), + outputMode = Complete, + // Disallowed by the aggregation check but let's still keep this test in case it's broken in + // future. + expectedMsgs = Seq("Complete")) + + for (outputMode <- Seq(Append, Update, Complete)) { + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState on streaming relation " + + s"with aggregation in $outputMode mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation), + isMapGroupsWithState = true), + outputMode = outputMode, + expectedMsgs = Seq("mapGroupsWithState", "with aggregation")) + } + + // multiple mapGroupsWithStates + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are " + + "in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation, + isMapGroupsWithState = true), + isMapGroupsWithState = true), + outputMode = Append, + expectedMsgs = Seq("multiple mapGroupsWithStates")) + + // mixing mapGroupsWithStates and flatMapGroupsWithStates + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - " + + "mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation, + isMapGroupsWithState = false), + isMapGroupsWithState = true), + outputMode = Append, + expectedMsgs = Seq("Mixing mapGroupsWithStates and flatMapGroupsWithStates")) // Deduplicate assertSupportedInStreamingPlan( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala new file mode 100644 index 0000000000..201dac35ed --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.streaming + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.streaming.OutputMode + +class InternalOutputModesSuite extends SparkFunSuite { + + test("supported strings") { + def testMode(outputMode: String, expected: OutputMode): Unit = { + assert(InternalOutputModes(outputMode) === expected) + } + + testMode("append", OutputMode.Append) + testMode("Append", OutputMode.Append) + testMode("complete", OutputMode.Complete) + testMode("Complete", OutputMode.Complete) + testMode("update", OutputMode.Update) + testMode("Update", OutputMode.Update) + } + + test("unsupported strings") { + def testMode(outputMode: String): Unit = { + val acceptedModes = Seq("append", "update", "complete") + val e = intercept[IllegalArgumentException](InternalOutputModes(outputMode)) + (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + } + testMode("Xyz") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 3a548c251f..ab956ffd64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -24,8 +24,10 @@ import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator +import org.apache.spark.sql.streaming.OutputMode /** * :: Experimental :: @@ -238,8 +240,16 @@ class KeyValueGroupedDataset[K, V] private[sql]( @InterfaceStability.Evolving def mapGroupsWithState[S: Encoder, U: Encoder]( func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = { - flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))) + val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)) + Dataset[U]( + sparkSession, + FlatMapGroupsWithState[K, V, S, U]( + flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + groupingAttributes, + dataAttributes, + OutputMode.Update, + isMapGroupsWithState = true, + child = logicalPlan)) } /** @@ -267,8 +277,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( func: MapGroupsWithStateFunction[K, V, S, U], stateEncoder: Encoder[S], outputEncoder: Encoder[U]): Dataset[U] = { - flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s)) + mapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s) )(stateEncoder, outputEncoder) } @@ -284,6 +294,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. * @since 2.1.1 @@ -291,14 +303,44 @@ class KeyValueGroupedDataset[K, V] private[sql]( @Experimental @InterfaceStability.Evolving def flatMapGroupsWithState[S: Encoder, U: Encoder]( - func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = { + func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: OutputMode): Dataset[U] = { + if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { + throw new IllegalArgumentException("The output mode of function should be append or update") + } Dataset[U]( sparkSession, - MapGroupsWithState[K, V, S, U]( + FlatMapGroupsWithState[K, V, S, U]( func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], groupingAttributes, dataAttributes, - logicalPlan)) + outputMode, + isMapGroupsWithState = false, + child = logicalPlan)) + } + + /** + * ::Experimental:: + * (Scala-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def flatMapGroupsWithState[S: Encoder, U: Encoder]( + func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: String): Dataset[U] = { + flatMapGroupsWithState(func, InternalOutputModes(outputMode)) } /** @@ -314,6 +356,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. * @param func Function to be called on every group. + * @param outputMode The output mode of the function. * @param stateEncoder Encoder for the state type. * @param outputEncoder Encoder for the output type. * @@ -324,14 +367,46 @@ class KeyValueGroupedDataset[K, V] private[sql]( @InterfaceStability.Evolving def flatMapGroupsWithState[S, U]( func: FlatMapGroupsWithStateFunction[K, V, S, U], + outputMode: OutputMode, stateEncoder: Encoder[S], outputEncoder: Encoder[U]): Dataset[U] = { flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala + (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala, + outputMode )(stateEncoder, outputEncoder) } /** + * ::Experimental:: + * (Java-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def flatMapGroupsWithState[S, U]( + func: FlatMapGroupsWithStateFunction[K, V, S, U], + outputMode: String, + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = { + flatMapGroupsWithState(func, InternalOutputModes(outputMode), stateEncoder, outputEncoder) + } + + /** * (Scala-specific) * Reduces the elements of each group of data using the specified binary function. * The given function must be commutative and associative or the result may be non-deterministic. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 20bf4925db..0f7aa3709c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -326,14 +326,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** - * Strategy to convert MapGroupsWithState logical operator to physical operator + * Strategy to convert [[FlatMapGroupsWithState]] logical operator to physical operator * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. */ object MapGroupsWithStateStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case MapGroupsWithState( - f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) => - val execPlan = MapGroupsWithStateExec( + case FlatMapGroupsWithState( + f, + keyDeser, + valueDeser, + groupAttr, + dataAttr, + outputAttr, + stateDeser, + stateSer, + outputMode, + child, + _) => + val execPlan = FlatMapGroupsWithStateExec( f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer, planLater(child)) execPlan :: Nil @@ -381,7 +391,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil - case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) => + case logical.FlatMapGroupsWithState( + f, key, value, grouping, data, output, _, _, _, child, _) => execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index ffdcd9b19d..610ce5e1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -103,11 +103,11 @@ class IncrementalExecution( child, Some(stateId), Some(currentEventTimeWatermark)) - case MapGroupsWithStateExec( + case FlatMapGroupsWithStateExec( f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - MapGroupsWithStateExec( + FlatMapGroupsWithStateExec( f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index cbf656a204..c3075a3eac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -257,8 +257,8 @@ case class StateStoreSaveExec( } -/** Physical operator for executing streaming mapGroupsWithState. */ -case class MapGroupsWithStateExec( +/** Physical operator for executing streaming flatMapGroupsWithState. */ +case class FlatMapGroupsWithStateExec( func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], keyDeserializer: Expression, valueDeserializer: Expression, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 0f7a33723c..c8fda8cd83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.streaming import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, ForeachWriter} -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink} @@ -69,17 +69,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * @since 2.0.0 */ def outputMode(outputMode: String): DataStreamWriter[T] = { - this.outputMode = outputMode.toLowerCase match { - case "append" => - OutputMode.Append - case "complete" => - OutputMode.Complete - case "update" => - OutputMode.Update - case _ => - throw new IllegalArgumentException(s"Unknown output mode $outputMode. " + - "Accepted output modes are 'append', 'complete', 'update'") - } + this.outputMode = InternalOutputModes(outputMode) this } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index e3b0e37cca..d06e35bb44 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,6 +23,7 @@ import java.sql.Date; import java.sql.Timestamp; import java.util.*; +import org.apache.spark.sql.streaming.OutputMode; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; @@ -205,6 +206,7 @@ public class JavaDatasetSuite implements Serializable { } return Collections.singletonList(sb.toString()).iterator(); }, + OutputMode.Append(), Encoders.LONG(), Encoders.STRING()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 6cf4d51f99..902b842e97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore /** Class to check custom state types */ case class RunningCount(count: Long) -class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -119,9 +119,9 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) + .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str) - testStream(result, Append)( + testStream(result, Update)( AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), @@ -162,9 +162,9 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) + .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str) - testStream(result, Append)( + testStream(result, Update)( AddData(inputData, "a", "a", "b"), CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), StopStream, @@ -185,7 +185,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA Iterator((key, values.size)) } checkAnswer( - Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF, + Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc, Update).toDF, Seq(("a", 2), ("b", 1)).toDF) } @@ -210,7 +210,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA .groupByKey(x => x) .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) - testStream(result, Append)( + testStream(result, Update)( AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), @@ -230,7 +230,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA ) } - test("mapGroupsWithState - streaming + aggregation") { + test("flatMapGroupsWithState - streaming + aggregation") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { @@ -238,10 +238,10 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { state.remove() - (key, "-1") + Iterator(key -> "-1") } else { state.update(RunningCount(count)) - (key, count.toString) + Iterator(key -> count.toString) } } @@ -249,7 +249,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA val result = inputData.toDS() .groupByKey(x => x) - .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + .flatMapGroupsWithState(stateFunc, Append) // Types = State: MyState, Out: (Str, Str) .groupByKey(_._1) .count() @@ -290,7 +290,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA testQuietly("StateStore.abort on task failure handling") { val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") + if (FlatMapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") val count = state.getOption.map(_.count).getOrElse(0L) + values.size state.update(RunningCount(count)) (key, count) @@ -303,11 +303,11 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q => - MapGroupsWithStateSuite.failInTask = value + FlatMapGroupsWithStateSuite.failInTask = value true } - testStream(result, Append)( + testStream(result, Update)( setFailInTask(false), AddData(inputData, "a"), CheckLastBatch(("a", 1L)), @@ -321,8 +321,24 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count ) } + + test("disallow complete mode") { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + Iterator[String]() + } + + var e = intercept[IllegalArgumentException] { + MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, Complete) + } + assert(e.getMessage === "The output mode of function should be append or update") + + e = intercept[IllegalArgumentException] { + MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, "complete") + } + assert(e.getMessage === "The output mode of function should be append or update") + } } -object MapGroupsWithStateSuite { +object FlatMapGroupsWithStateSuite { var failInTask = true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 0470411a0f..f61dcdcbcf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -24,8 +24,7 @@ import scala.concurrent.duration._ import org.apache.hadoop.fs.Path import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.scalatest.PrivateMethodTester.PrivateMethod +import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ @@ -107,7 +106,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { } } -class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with PrivateMethodTester { +class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { private def newMetadataDir = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -390,42 +389,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with Pr private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath - test("supported strings in outputMode(string)") { - val outputModeMethod = PrivateMethod[OutputMode]('outputMode) - - def testMode(outputMode: String, expected: OutputMode): Unit = { - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - val w = df.writeStream - w.outputMode(outputMode) - val setOutputMode = w invokePrivate outputModeMethod() - assert(setOutputMode === expected) - } - - testMode("append", OutputMode.Append) - testMode("Append", OutputMode.Append) - testMode("complete", OutputMode.Complete) - testMode("Complete", OutputMode.Complete) - testMode("update", OutputMode.Update) - testMode("Update", OutputMode.Update) - } - - test("unsupported strings in outputMode(string)") { - def testMode(outputMode: String): Unit = { - val acceptedModes = Seq("append", "update", "complete") - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - val w = df.writeStream - val e = intercept[IllegalArgumentException](w.outputMode(outputMode)) - (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) - } - } - testMode("Xyz") - } - test("check foreach() catches null writers") { val df = spark.readStream .format("org.apache.spark.sql.streaming.test") |