From 1bf9012380de2aa7bdf39220b55748defde8b700 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 8 Mar 2017 13:18:07 -0800 Subject: [SPARK-19858][SS] Add output mode to flatMapGroupsWithState and disallow invalid cases ## What changes were proposed in this pull request? Add a output mode parameter to `flatMapGroupsWithState` and just define `mapGroupsWithState` as `flatMapGroupsWithState(Update)`. `UnsupportedOperationChecker` is modified to disallow unsupported cases. - Batch mapGroupsWithState or flatMapGroupsWithState is always allowed. - For streaming (map/flatMap)GroupsWithState, see the following table: | Operators | Supported Query Output Mode | | ------------- | ------------- | | flatMapGroupsWithState(Update) without aggregation | Update | | flatMapGroupsWithState(Update) with aggregation | None | | flatMapGroupsWithState(Append) without aggregation | Append | | flatMapGroupsWithState(Append) before aggregation | Append, Update, Complete | | flatMapGroupsWithState(Append) after aggregation | None | | Multiple flatMapGroupsWithState(Append)s | Append | | Multiple mapGroupsWithStates | None | | Mxing mapGroupsWithStates and flatMapGroupsWithStates | None | | Other cases of multiple flatMapGroupsWithState | None | ## How was this patch tested? The added unit tests. Here are the tests related to (map/flatMap)GroupsWithState: ``` [info] - batch plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation: supported (1 millisecond) [info] - batch plan - flatMapGroupsWithState - multiple flatMapGroupsWithState(Append)s on batch relation: supported (0 milliseconds) [info] - batch plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation: supported (0 milliseconds) [info] - batch plan - flatMapGroupsWithState - multiple flatMapGroupsWithState(Update)s on batch relation: supported (0 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in update mode: supported (2 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in append mode: not supported (7 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in complete mode: not supported (5 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Append mode: not supported (11 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Update mode: not supported (5 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Complete mode: not supported (5 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation without aggregation in append mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation without aggregation in update mode: not supported (6 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Append mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Update mode: supported (0 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Complete mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation after aggregation in Append mode: not supported (6 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation after aggregation in Update mode: not supported (4 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation in complete mode: not supported (2 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation inside streaming relation in Append output mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation inside streaming relation in Update output mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation inside streaming relation in Append output mode: supported (0 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation inside streaming relation in Update output mode: supported (0 milliseconds) [info] - streaming plan - flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are in append mode: supported (2 milliseconds) [info] - streaming plan - flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some are not in append mode: not supported (7 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation without aggregation in append mode: not supported (3 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation without aggregation in complete mode: not supported (3 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Append mode: not supported (6 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Update mode: not supported (3 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Complete mode: not supported (4 milliseconds) [info] - streaming plan - mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are in append mode: not supported (4 milliseconds) [info] - streaming plan - mapGroupsWithState - mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation: not supported (4 milliseconds) ``` Author: Shixiong Zhu Closes #17197 from zsxwing/mapgroups-check. --- .../apache/spark/sql/KeyValueGroupedDataset.scala | 91 +++++- .../spark/sql/execution/SparkStrategies.scala | 21 +- .../execution/streaming/IncrementalExecution.scala | 4 +- .../execution/streaming/statefulOperators.scala | 4 +- .../spark/sql/streaming/DataStreamWriter.scala | 16 +- .../org/apache/spark/sql/JavaDatasetSuite.java | 2 + .../streaming/FlatMapGroupsWithStateSuite.scala | 344 +++++++++++++++++++++ .../sql/streaming/MapGroupsWithStateSuite.scala | 328 -------------------- .../test/DataStreamReaderWriterSuite.scala | 41 +-- 9 files changed, 454 insertions(+), 397 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 3a548c251f..ab956ffd64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -24,8 +24,10 @@ import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator +import org.apache.spark.sql.streaming.OutputMode /** * :: Experimental :: @@ -238,8 +240,16 @@ class KeyValueGroupedDataset[K, V] private[sql]( @InterfaceStability.Evolving def mapGroupsWithState[S: Encoder, U: Encoder]( func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = { - flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))) + val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)) + Dataset[U]( + sparkSession, + FlatMapGroupsWithState[K, V, S, U]( + flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + groupingAttributes, + dataAttributes, + OutputMode.Update, + isMapGroupsWithState = true, + child = logicalPlan)) } /** @@ -267,8 +277,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( func: MapGroupsWithStateFunction[K, V, S, U], stateEncoder: Encoder[S], outputEncoder: Encoder[U]): Dataset[U] = { - flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s)) + mapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s) )(stateEncoder, outputEncoder) } @@ -284,6 +294,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. * @since 2.1.1 @@ -291,14 +303,44 @@ class KeyValueGroupedDataset[K, V] private[sql]( @Experimental @InterfaceStability.Evolving def flatMapGroupsWithState[S: Encoder, U: Encoder]( - func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = { + func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: OutputMode): Dataset[U] = { + if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { + throw new IllegalArgumentException("The output mode of function should be append or update") + } Dataset[U]( sparkSession, - MapGroupsWithState[K, V, S, U]( + FlatMapGroupsWithState[K, V, S, U]( func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], groupingAttributes, dataAttributes, - logicalPlan)) + outputMode, + isMapGroupsWithState = false, + child = logicalPlan)) + } + + /** + * ::Experimental:: + * (Scala-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def flatMapGroupsWithState[S: Encoder, U: Encoder]( + func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: String): Dataset[U] = { + flatMapGroupsWithState(func, InternalOutputModes(outputMode)) } /** @@ -314,6 +356,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. * @param func Function to be called on every group. + * @param outputMode The output mode of the function. * @param stateEncoder Encoder for the state type. * @param outputEncoder Encoder for the output type. * @@ -324,13 +367,45 @@ class KeyValueGroupedDataset[K, V] private[sql]( @InterfaceStability.Evolving def flatMapGroupsWithState[S, U]( func: FlatMapGroupsWithStateFunction[K, V, S, U], + outputMode: OutputMode, stateEncoder: Encoder[S], outputEncoder: Encoder[U]): Dataset[U] = { flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala + (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala, + outputMode )(stateEncoder, outputEncoder) } + /** + * ::Experimental:: + * (Java-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def flatMapGroupsWithState[S, U]( + func: FlatMapGroupsWithStateFunction[K, V, S, U], + outputMode: String, + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = { + flatMapGroupsWithState(func, InternalOutputModes(outputMode), stateEncoder, outputEncoder) + } + /** * (Scala-specific) * Reduces the elements of each group of data using the specified binary function. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 20bf4925db..0f7aa3709c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -326,14 +326,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** - * Strategy to convert MapGroupsWithState logical operator to physical operator + * Strategy to convert [[FlatMapGroupsWithState]] logical operator to physical operator * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. */ object MapGroupsWithStateStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case MapGroupsWithState( - f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) => - val execPlan = MapGroupsWithStateExec( + case FlatMapGroupsWithState( + f, + keyDeser, + valueDeser, + groupAttr, + dataAttr, + outputAttr, + stateDeser, + stateSer, + outputMode, + child, + _) => + val execPlan = FlatMapGroupsWithStateExec( f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer, planLater(child)) execPlan :: Nil @@ -381,7 +391,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil - case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) => + case logical.FlatMapGroupsWithState( + f, key, value, grouping, data, output, _, _, _, child, _) => execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index ffdcd9b19d..610ce5e1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -103,11 +103,11 @@ class IncrementalExecution( child, Some(stateId), Some(currentEventTimeWatermark)) - case MapGroupsWithStateExec( + case FlatMapGroupsWithStateExec( f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - MapGroupsWithStateExec( + FlatMapGroupsWithStateExec( f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index cbf656a204..c3075a3eac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -257,8 +257,8 @@ case class StateStoreSaveExec( } -/** Physical operator for executing streaming mapGroupsWithState. */ -case class MapGroupsWithStateExec( +/** Physical operator for executing streaming flatMapGroupsWithState. */ +case class FlatMapGroupsWithStateExec( func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], keyDeserializer: Expression, valueDeserializer: Expression, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 0f7a33723c..c8fda8cd83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.streaming import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, ForeachWriter} -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink} @@ -69,17 +69,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * @since 2.0.0 */ def outputMode(outputMode: String): DataStreamWriter[T] = { - this.outputMode = outputMode.toLowerCase match { - case "append" => - OutputMode.Append - case "complete" => - OutputMode.Complete - case "update" => - OutputMode.Update - case _ => - throw new IllegalArgumentException(s"Unknown output mode $outputMode. " + - "Accepted output modes are 'append', 'complete', 'update'") - } + this.outputMode = InternalOutputModes(outputMode) this } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index e3b0e37cca..d06e35bb44 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,6 +23,7 @@ import java.sql.Date; import java.sql.Timestamp; import java.util.*; +import org.apache.spark.sql.streaming.OutputMode; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; @@ -205,6 +206,7 @@ public class JavaDatasetSuite implements Serializable { } return Collections.singletonList(sb.toString()).iterator(); }, + OutputMode.Append(), Encoders.LONG(), Encoders.STRING()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala new file mode 100644 index 0000000000..902b842e97 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkException +import org.apache.spark.sql.KeyedState +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.streaming.{KeyedStateImpl, MemoryStream} +import org.apache.spark.sql.execution.streaming.state.StateStore + +/** Class to check custom state types */ +case class RunningCount(count: Long) + +class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { + + import testImplicits._ + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } + + test("KeyedState - get, exists, update, remove") { + var state: KeyedStateImpl[String] = null + + def testState( + expectedData: Option[String], + shouldBeUpdated: Boolean = false, + shouldBeRemoved: Boolean = false): Unit = { + if (expectedData.isDefined) { + assert(state.exists) + assert(state.get === expectedData.get) + } else { + assert(!state.exists) + intercept[NoSuchElementException] { + state.get + } + } + assert(state.getOption === expectedData) + assert(state.isUpdated === shouldBeUpdated) + assert(state.isRemoved === shouldBeRemoved) + } + + // Updating empty state + state = new KeyedStateImpl[String](None) + testState(None) + state.update("") + testState(Some(""), shouldBeUpdated = true) + + // Updating exiting state + state = new KeyedStateImpl[String](Some("2")) + testState(Some("2")) + state.update("3") + testState(Some("3"), shouldBeUpdated = true) + + // Removing state + state.remove() + testState(None, shouldBeRemoved = true, shouldBeUpdated = false) + state.remove() // should be still callable + state.update("4") + testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true) + + // Updating by null throw exception + intercept[IllegalArgumentException] { + state.update(null) + } + } + + test("KeyedState - primitive type") { + var intState = new KeyedStateImpl[Int](None) + intercept[NoSuchElementException] { + intState.get + } + assert(intState.getOption === None) + + intState = new KeyedStateImpl[Int](Some(10)) + assert(intState.get == 10) + intState.update(0) + assert(intState.get == 0) + intState.remove() + intercept[NoSuchElementException] { + intState.get + } + } + + test("flatMapGroupsWithState - streaming") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + Iterator.empty + } else { + state.update(RunningCount(count)) + Iterator((key, count.toString)) + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str) + + testStream(result, Update)( + AddData(inputData, "a"), + CheckLastBatch(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a", "b"), + CheckLastBatch(("a", "2"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckLastBatch(("b", "2")), + assertNumStateRows(total = 1, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckLastBatch(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) + ) + } + + test("flatMapGroupsWithState - streaming + func returns iterator that updates state lazily") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + // Additionally, it updates state lazily as the returned iterator get consumed + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + values.flatMap { _ => + val count = state.getOption.map(_.count).getOrElse(0L) + 1 + if (count == 3) { + state.remove() + None + } else { + state.update(RunningCount(count)) + Some((key, count.toString)) + } + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str) + + testStream(result, Update)( + AddData(inputData, "a", "a", "b"), + CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckLastBatch(("b", "2")), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckLastBatch(("a", "1"), ("c", "1")) + ) + } + + test("flatMapGroupsWithState - batch") { + // Function that returns running count only if its even, otherwise does not return + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + if (state.exists) throw new IllegalArgumentException("state.exists should be false") + Iterator((key, values.size)) + } + checkAnswer( + Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc, Update).toDF, + Seq(("a", 2), ("b", 1)).toDF) + } + + test("mapGroupsWithState - streaming") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + (key, "-1") + } else { + state.update(RunningCount(count)) + (key, count.toString) + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + + testStream(result, Update)( + AddData(inputData, "a"), + CheckLastBatch(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a", "b"), + CheckLastBatch(("a", "2"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 + CheckLastBatch(("a", "-1"), ("b", "2")), + assertNumStateRows(total = 1, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 + CheckLastBatch(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) + ) + } + + test("flatMapGroupsWithState - streaming + aggregation") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + Iterator(key -> "-1") + } else { + state.update(RunningCount(count)) + Iterator(key -> count.toString) + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(stateFunc, Append) // Types = State: MyState, Out: (Str, Str) + .groupByKey(_._1) + .count() + + testStream(result, Complete)( + AddData(inputData, "a"), + CheckLastBatch(("a", 1)), + AddData(inputData, "a", "b"), + // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 + CheckLastBatch(("a", 2), ("b", 1)), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), + // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; + // so increment a and b by 1 + CheckLastBatch(("a", 3), ("b", 2)), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), + // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; + // so increment a and c by 1 + CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) + ) + } + + test("mapGroupsWithState - batch") { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + if (state.exists) throw new IllegalArgumentException("state.exists should be false") + (key, values.size) + } + + checkAnswer( + spark.createDataset(Seq("a", "a", "b")) + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) + .toDF, + spark.createDataset(Seq(("a", 2), ("b", 1))).toDF) + } + + testQuietly("StateStore.abort on task failure handling") { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + if (FlatMapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + (key, count) + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + + def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q => + FlatMapGroupsWithStateSuite.failInTask = value + true + } + + testStream(result, Update)( + setFailInTask(false), + AddData(inputData, "a"), + CheckLastBatch(("a", 1L)), + AddData(inputData, "a"), + CheckLastBatch(("a", 2L)), + setFailInTask(true), + AddData(inputData, "a"), + ExpectFailure[SparkException](), // task should fail but should not increment count + setFailInTask(false), + StartStream(), + CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count + ) + } + + test("disallow complete mode") { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + Iterator[String]() + } + + var e = intercept[IllegalArgumentException] { + MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, Complete) + } + assert(e.getMessage === "The output mode of function should be append or update") + + e = intercept[IllegalArgumentException] { + MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, "complete") + } + assert(e.getMessage === "The output mode of function should be append or update") + } +} + +object FlatMapGroupsWithStateSuite { + var failInTask = true +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala deleted file mode 100644 index 6cf4d51f99..0000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.SparkException -import org.apache.spark.sql.KeyedState -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.{KeyedStateImpl, MemoryStream} -import org.apache.spark.sql.execution.streaming.state.StateStore - -/** Class to check custom state types */ -case class RunningCount(count: Long) - -class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { - - import testImplicits._ - - override def afterAll(): Unit = { - super.afterAll() - StateStore.stop() - } - - test("KeyedState - get, exists, update, remove") { - var state: KeyedStateImpl[String] = null - - def testState( - expectedData: Option[String], - shouldBeUpdated: Boolean = false, - shouldBeRemoved: Boolean = false): Unit = { - if (expectedData.isDefined) { - assert(state.exists) - assert(state.get === expectedData.get) - } else { - assert(!state.exists) - intercept[NoSuchElementException] { - state.get - } - } - assert(state.getOption === expectedData) - assert(state.isUpdated === shouldBeUpdated) - assert(state.isRemoved === shouldBeRemoved) - } - - // Updating empty state - state = new KeyedStateImpl[String](None) - testState(None) - state.update("") - testState(Some(""), shouldBeUpdated = true) - - // Updating exiting state - state = new KeyedStateImpl[String](Some("2")) - testState(Some("2")) - state.update("3") - testState(Some("3"), shouldBeUpdated = true) - - // Removing state - state.remove() - testState(None, shouldBeRemoved = true, shouldBeUpdated = false) - state.remove() // should be still callable - state.update("4") - testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true) - - // Updating by null throw exception - intercept[IllegalArgumentException] { - state.update(null) - } - } - - test("KeyedState - primitive type") { - var intState = new KeyedStateImpl[Int](None) - intercept[NoSuchElementException] { - intState.get - } - assert(intState.getOption === None) - - intState = new KeyedStateImpl[Int](Some(10)) - assert(intState.get == 10) - intState.update(0) - assert(intState.get == 0) - intState.remove() - intercept[NoSuchElementException] { - intState.get - } - } - - test("flatMapGroupsWithState - streaming") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count if state is defined, otherwise does not return anything - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - - val count = state.getOption.map(_.count).getOrElse(0L) + values.size - if (count == 3) { - state.remove() - Iterator.empty - } else { - state.update(RunningCount(count)) - Iterator((key, count.toString)) - } - } - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) - - testStream(result, Append)( - AddData(inputData, "a"), - CheckLastBatch(("a", "1")), - assertNumStateRows(total = 1, updated = 1), - AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), - assertNumStateRows(total = 2, updated = 2), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), - assertNumStateRows(total = 1, updated = 2), - StopStream, - StartStream(), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")), - assertNumStateRows(total = 3, updated = 2) - ) - } - - test("flatMapGroupsWithState - streaming + func returns iterator that updates state lazily") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count if state is defined, otherwise does not return anything - // Additionally, it updates state lazily as the returned iterator get consumed - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - values.flatMap { _ => - val count = state.getOption.map(_.count).getOrElse(0L) + 1 - if (count == 3) { - state.remove() - None - } else { - state.update(RunningCount(count)) - Some((key, count.toString)) - } - } - } - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) - - testStream(result, Append)( - AddData(inputData, "a", "a", "b"), - CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), - StopStream, - StartStream(), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")) - ) - } - - test("flatMapGroupsWithState - batch") { - // Function that returns running count only if its even, otherwise does not return - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - if (state.exists) throw new IllegalArgumentException("state.exists should be false") - Iterator((key, values.size)) - } - checkAnswer( - Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF, - Seq(("a", 2), ("b", 1)).toDF) - } - - test("mapGroupsWithState - streaming") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - - val count = state.getOption.map(_.count).getOrElse(0L) + values.size - if (count == 3) { - state.remove() - (key, "-1") - } else { - state.update(RunningCount(count)) - (key, count.toString) - } - } - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) - - testStream(result, Append)( - AddData(inputData, "a"), - CheckLastBatch(("a", "1")), - assertNumStateRows(total = 1, updated = 1), - AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), - assertNumStateRows(total = 2, updated = 2), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 - CheckLastBatch(("a", "-1"), ("b", "2")), - assertNumStateRows(total = 1, updated = 2), - StopStream, - StartStream(), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 - CheckLastBatch(("a", "1"), ("c", "1")), - assertNumStateRows(total = 3, updated = 2) - ) - } - - test("mapGroupsWithState - streaming + aggregation") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - - val count = state.getOption.map(_.count).getOrElse(0L) + values.size - if (count == 3) { - state.remove() - (key, "-1") - } else { - state.update(RunningCount(count)) - (key, count.toString) - } - } - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) - .groupByKey(_._1) - .count() - - testStream(result, Complete)( - AddData(inputData, "a"), - CheckLastBatch(("a", 1)), - AddData(inputData, "a", "b"), - // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 - CheckLastBatch(("a", 2), ("b", 1)), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), - // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; - // so increment a and b by 1 - CheckLastBatch(("a", 3), ("b", 2)), - StopStream, - StartStream(), - AddData(inputData, "a", "c"), - // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; - // so increment a and c by 1 - CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) - ) - } - - test("mapGroupsWithState - batch") { - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - if (state.exists) throw new IllegalArgumentException("state.exists should be false") - (key, values.size) - } - - checkAnswer( - spark.createDataset(Seq("a", "a", "b")) - .groupByKey(x => x) - .mapGroupsWithState(stateFunc) - .toDF, - spark.createDataset(Seq(("a", 2), ("b", 1))).toDF) - } - - testQuietly("StateStore.abort on task failure handling") { - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") - val count = state.getOption.map(_.count).getOrElse(0L) + values.size - state.update(RunningCount(count)) - (key, count) - } - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) - - def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q => - MapGroupsWithStateSuite.failInTask = value - true - } - - testStream(result, Append)( - setFailInTask(false), - AddData(inputData, "a"), - CheckLastBatch(("a", 1L)), - AddData(inputData, "a"), - CheckLastBatch(("a", 2L)), - setFailInTask(true), - AddData(inputData, "a"), - ExpectFailure[SparkException](), // task should fail but should not increment count - setFailInTask(false), - StartStream(), - CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count - ) - } -} - -object MapGroupsWithStateSuite { - var failInTask = true -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 0470411a0f..f61dcdcbcf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -24,8 +24,7 @@ import scala.concurrent.duration._ import org.apache.hadoop.fs.Path import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.scalatest.PrivateMethodTester.PrivateMethod +import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ @@ -107,7 +106,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { } } -class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with PrivateMethodTester { +class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { private def newMetadataDir = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -390,42 +389,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with Pr private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath - test("supported strings in outputMode(string)") { - val outputModeMethod = PrivateMethod[OutputMode]('outputMode) - - def testMode(outputMode: String, expected: OutputMode): Unit = { - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - val w = df.writeStream - w.outputMode(outputMode) - val setOutputMode = w invokePrivate outputModeMethod() - assert(setOutputMode === expected) - } - - testMode("append", OutputMode.Append) - testMode("Append", OutputMode.Append) - testMode("complete", OutputMode.Complete) - testMode("Complete", OutputMode.Complete) - testMode("update", OutputMode.Update) - testMode("Update", OutputMode.Update) - } - - test("unsupported strings in outputMode(string)") { - def testMode(outputMode: String): Unit = { - val acceptedModes = Seq("append", "update", "complete") - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - val w = df.writeStream - val e = intercept[IllegalArgumentException](w.outputMode(outputMode)) - (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) - } - } - testMode("Xyz") - } - test("check foreach() catches null writers") { val df = spark.readStream .format("org.apache.spark.sql.streaming.test") -- cgit v1.2.3