diff options
author | Shixiong Zhu <shixiong@databricks.com> | 2017-03-08 13:18:07 -0800 |
---|---|---|
committer | Tathagata Das <tathagata.das1565@gmail.com> | 2017-03-08 13:18:07 -0800 |
commit | 1bf9012380de2aa7bdf39220b55748defde8b700 (patch) | |
tree | 3efc5be6f6506eef72a98c08132d95c0ce9f8fcd /sql/core/src | |
parent | e9e2c612d58a19ddcb4b6abfb7389a4b0f7ef6f8 (diff) | |
download | spark-1bf9012380de2aa7bdf39220b55748defde8b700.tar.gz spark-1bf9012380de2aa7bdf39220b55748defde8b700.tar.bz2 spark-1bf9012380de2aa7bdf39220b55748defde8b700.zip |
[SPARK-19858][SS] Add output mode to flatMapGroupsWithState and disallow invalid cases
## What changes were proposed in this pull request?
Add a output mode parameter to `flatMapGroupsWithState` and just define `mapGroupsWithState` as `flatMapGroupsWithState(Update)`.
`UnsupportedOperationChecker` is modified to disallow unsupported cases.
- Batch mapGroupsWithState or flatMapGroupsWithState is always allowed.
- For streaming (map/flatMap)GroupsWithState, see the following table:
| Operators | Supported Query Output Mode |
| ------------- | ------------- |
| flatMapGroupsWithState(Update) without aggregation | Update |
| flatMapGroupsWithState(Update) with aggregation | None |
| flatMapGroupsWithState(Append) without aggregation | Append |
| flatMapGroupsWithState(Append) before aggregation | Append, Update, Complete |
| flatMapGroupsWithState(Append) after aggregation | None |
| Multiple flatMapGroupsWithState(Append)s | Append |
| Multiple mapGroupsWithStates | None |
| Mxing mapGroupsWithStates and flatMapGroupsWithStates | None |
| Other cases of multiple flatMapGroupsWithState | None |
## How was this patch tested?
The added unit tests. Here are the tests related to (map/flatMap)GroupsWithState:
```
[info] - batch plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation: supported (1 millisecond)
[info] - batch plan - flatMapGroupsWithState - multiple flatMapGroupsWithState(Append)s on batch relation: supported (0 milliseconds)
[info] - batch plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation: supported (0 milliseconds)
[info] - batch plan - flatMapGroupsWithState - multiple flatMapGroupsWithState(Update)s on batch relation: supported (0 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in update mode: supported (2 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in append mode: not supported (7 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in complete mode: not supported (5 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Append mode: not supported (11 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Update mode: not supported (5 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Complete mode: not supported (5 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation without aggregation in append mode: supported (1 millisecond)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation without aggregation in update mode: not supported (6 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Append mode: supported (1 millisecond)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Update mode: supported (0 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Complete mode: supported (1 millisecond)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation after aggregation in Append mode: not supported (6 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation after aggregation in Update mode: not supported (4 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation in complete mode: not supported (2 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation inside streaming relation in Append output mode: supported (1 millisecond)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation inside streaming relation in Update output mode: supported (1 millisecond)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation inside streaming relation in Append output mode: supported (0 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation inside streaming relation in Update output mode: supported (0 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are in append mode: supported (2 milliseconds)
[info] - streaming plan - flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some are not in append mode: not supported (7 milliseconds)
[info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation without aggregation in append mode: not supported (3 milliseconds)
[info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation without aggregation in complete mode: not supported (3 milliseconds)
[info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Append mode: not supported (6 milliseconds)
[info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Update mode: not supported (3 milliseconds)
[info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Complete mode: not supported (4 milliseconds)
[info] - streaming plan - mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are in append mode: not supported (4 milliseconds)
[info] - streaming plan - mapGroupsWithState - mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation: not supported (4 milliseconds)
```
Author: Shixiong Zhu <shixiong@databricks.com>
Closes #17197 from zsxwing/mapgroups-check.
Diffstat (limited to 'sql/core/src')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala | 91 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala | 21 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala | 4 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala | 4 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala | 16 | ||||
-rw-r--r-- | sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java | 2 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala (renamed from sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala) | 46 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala | 41 |
8 files changed, 141 insertions, 84 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 3a548c251f..ab956ffd64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -24,8 +24,10 @@ import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator +import org.apache.spark.sql.streaming.OutputMode /** * :: Experimental :: @@ -238,8 +240,16 @@ class KeyValueGroupedDataset[K, V] private[sql]( @InterfaceStability.Evolving def mapGroupsWithState[S: Encoder, U: Encoder]( func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = { - flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))) + val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)) + Dataset[U]( + sparkSession, + FlatMapGroupsWithState[K, V, S, U]( + flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + groupingAttributes, + dataAttributes, + OutputMode.Update, + isMapGroupsWithState = true, + child = logicalPlan)) } /** @@ -267,8 +277,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( func: MapGroupsWithStateFunction[K, V, S, U], stateEncoder: Encoder[S], outputEncoder: Encoder[U]): Dataset[U] = { - flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s)) + mapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s) )(stateEncoder, outputEncoder) } @@ -284,6 +294,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. * @since 2.1.1 @@ -291,14 +303,44 @@ class KeyValueGroupedDataset[K, V] private[sql]( @Experimental @InterfaceStability.Evolving def flatMapGroupsWithState[S: Encoder, U: Encoder]( - func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = { + func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: OutputMode): Dataset[U] = { + if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { + throw new IllegalArgumentException("The output mode of function should be append or update") + } Dataset[U]( sparkSession, - MapGroupsWithState[K, V, S, U]( + FlatMapGroupsWithState[K, V, S, U]( func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], groupingAttributes, dataAttributes, - logicalPlan)) + outputMode, + isMapGroupsWithState = false, + child = logicalPlan)) + } + + /** + * ::Experimental:: + * (Scala-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def flatMapGroupsWithState[S: Encoder, U: Encoder]( + func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: String): Dataset[U] = { + flatMapGroupsWithState(func, InternalOutputModes(outputMode)) } /** @@ -314,6 +356,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. * @param func Function to be called on every group. + * @param outputMode The output mode of the function. * @param stateEncoder Encoder for the state type. * @param outputEncoder Encoder for the output type. * @@ -324,14 +367,46 @@ class KeyValueGroupedDataset[K, V] private[sql]( @InterfaceStability.Evolving def flatMapGroupsWithState[S, U]( func: FlatMapGroupsWithStateFunction[K, V, S, U], + outputMode: OutputMode, stateEncoder: Encoder[S], outputEncoder: Encoder[U]): Dataset[U] = { flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala + (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala, + outputMode )(stateEncoder, outputEncoder) } /** + * ::Experimental:: + * (Java-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def flatMapGroupsWithState[S, U]( + func: FlatMapGroupsWithStateFunction[K, V, S, U], + outputMode: String, + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = { + flatMapGroupsWithState(func, InternalOutputModes(outputMode), stateEncoder, outputEncoder) + } + + /** * (Scala-specific) * Reduces the elements of each group of data using the specified binary function. * The given function must be commutative and associative or the result may be non-deterministic. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 20bf4925db..0f7aa3709c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -326,14 +326,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** - * Strategy to convert MapGroupsWithState logical operator to physical operator + * Strategy to convert [[FlatMapGroupsWithState]] logical operator to physical operator * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. */ object MapGroupsWithStateStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case MapGroupsWithState( - f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) => - val execPlan = MapGroupsWithStateExec( + case FlatMapGroupsWithState( + f, + keyDeser, + valueDeser, + groupAttr, + dataAttr, + outputAttr, + stateDeser, + stateSer, + outputMode, + child, + _) => + val execPlan = FlatMapGroupsWithStateExec( f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer, planLater(child)) execPlan :: Nil @@ -381,7 +391,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil - case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) => + case logical.FlatMapGroupsWithState( + f, key, value, grouping, data, output, _, _, _, child, _) => execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index ffdcd9b19d..610ce5e1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -103,11 +103,11 @@ class IncrementalExecution( child, Some(stateId), Some(currentEventTimeWatermark)) - case MapGroupsWithStateExec( + case FlatMapGroupsWithStateExec( f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - MapGroupsWithStateExec( + FlatMapGroupsWithStateExec( f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index cbf656a204..c3075a3eac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -257,8 +257,8 @@ case class StateStoreSaveExec( } -/** Physical operator for executing streaming mapGroupsWithState. */ -case class MapGroupsWithStateExec( +/** Physical operator for executing streaming flatMapGroupsWithState. */ +case class FlatMapGroupsWithStateExec( func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], keyDeserializer: Expression, valueDeserializer: Expression, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 0f7a33723c..c8fda8cd83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.streaming import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, ForeachWriter} -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink} @@ -69,17 +69,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * @since 2.0.0 */ def outputMode(outputMode: String): DataStreamWriter[T] = { - this.outputMode = outputMode.toLowerCase match { - case "append" => - OutputMode.Append - case "complete" => - OutputMode.Complete - case "update" => - OutputMode.Update - case _ => - throw new IllegalArgumentException(s"Unknown output mode $outputMode. " + - "Accepted output modes are 'append', 'complete', 'update'") - } + this.outputMode = InternalOutputModes(outputMode) this } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index e3b0e37cca..d06e35bb44 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,6 +23,7 @@ import java.sql.Date; import java.sql.Timestamp; import java.util.*; +import org.apache.spark.sql.streaming.OutputMode; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; @@ -205,6 +206,7 @@ public class JavaDatasetSuite implements Serializable { } return Collections.singletonList(sb.toString()).iterator(); }, + OutputMode.Append(), Encoders.LONG(), Encoders.STRING()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 6cf4d51f99..902b842e97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore /** Class to check custom state types */ case class RunningCount(count: Long) -class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -119,9 +119,9 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) + .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str) - testStream(result, Append)( + testStream(result, Update)( AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), @@ -162,9 +162,9 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) + .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str) - testStream(result, Append)( + testStream(result, Update)( AddData(inputData, "a", "a", "b"), CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), StopStream, @@ -185,7 +185,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA Iterator((key, values.size)) } checkAnswer( - Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF, + Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc, Update).toDF, Seq(("a", 2), ("b", 1)).toDF) } @@ -210,7 +210,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA .groupByKey(x => x) .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) - testStream(result, Append)( + testStream(result, Update)( AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), @@ -230,7 +230,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA ) } - test("mapGroupsWithState - streaming + aggregation") { + test("flatMapGroupsWithState - streaming + aggregation") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { @@ -238,10 +238,10 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { state.remove() - (key, "-1") + Iterator(key -> "-1") } else { state.update(RunningCount(count)) - (key, count.toString) + Iterator(key -> count.toString) } } @@ -249,7 +249,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA val result = inputData.toDS() .groupByKey(x => x) - .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + .flatMapGroupsWithState(stateFunc, Append) // Types = State: MyState, Out: (Str, Str) .groupByKey(_._1) .count() @@ -290,7 +290,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA testQuietly("StateStore.abort on task failure handling") { val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") + if (FlatMapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") val count = state.getOption.map(_.count).getOrElse(0L) + values.size state.update(RunningCount(count)) (key, count) @@ -303,11 +303,11 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q => - MapGroupsWithStateSuite.failInTask = value + FlatMapGroupsWithStateSuite.failInTask = value true } - testStream(result, Append)( + testStream(result, Update)( setFailInTask(false), AddData(inputData, "a"), CheckLastBatch(("a", 1L)), @@ -321,8 +321,24 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count ) } + + test("disallow complete mode") { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + Iterator[String]() + } + + var e = intercept[IllegalArgumentException] { + MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, Complete) + } + assert(e.getMessage === "The output mode of function should be append or update") + + e = intercept[IllegalArgumentException] { + MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, "complete") + } + assert(e.getMessage === "The output mode of function should be append or update") + } } -object MapGroupsWithStateSuite { +object FlatMapGroupsWithStateSuite { var failInTask = true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 0470411a0f..f61dcdcbcf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -24,8 +24,7 @@ import scala.concurrent.duration._ import org.apache.hadoop.fs.Path import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.scalatest.PrivateMethodTester.PrivateMethod +import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ @@ -107,7 +106,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { } } -class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with PrivateMethodTester { +class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { private def newMetadataDir = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -390,42 +389,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with Pr private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath - test("supported strings in outputMode(string)") { - val outputModeMethod = PrivateMethod[OutputMode]('outputMode) - - def testMode(outputMode: String, expected: OutputMode): Unit = { - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - val w = df.writeStream - w.outputMode(outputMode) - val setOutputMode = w invokePrivate outputModeMethod() - assert(setOutputMode === expected) - } - - testMode("append", OutputMode.Append) - testMode("Append", OutputMode.Append) - testMode("complete", OutputMode.Complete) - testMode("Complete", OutputMode.Complete) - testMode("update", OutputMode.Update) - testMode("Update", OutputMode.Update) - } - - test("unsupported strings in outputMode(string)") { - def testMode(outputMode: String): Unit = { - val acceptedModes = Seq("append", "update", "complete") - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - val w = df.writeStream - val e = intercept[IllegalArgumentException](w.outputMode(outputMode)) - (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) - } - } - testMode("Xyz") - } - test("check foreach() catches null writers") { val df = spark.readStream .format("org.apache.spark.sql.streaming.test") |