From aeb80348dd40c66b84bbc5cfe60d716fbce25acb Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 7 Feb 2017 20:21:00 -0800 Subject: [SPARK-19413][SS] MapGroupsWithState for arbitrary stateful operations ## What changes were proposed in this pull request? `mapGroupsWithState` is a new API for arbitrary stateful operations in Structured Streaming, similar to `DStream.mapWithState` *Requirements* - Users should be able to specify a function that can do the following - Access the input row corresponding to a key - Access the previous state corresponding to a key - Optionally, update or remove the state - Output any number of new rows (or none at all) *Proposed API* ``` // ------------ New methods on KeyValueGroupedDataset ------------ class KeyValueGroupedDataset[K, V] { // Scala friendly def mapGroupsWithState[S: Encoder, U: Encoder](func: (K, Iterator[V], KeyedState[S]) => U) def flatMapGroupsWithState[S: Encode, U: Encoder](func: (K, Iterator[V], KeyedState[S]) => Iterator[U]) // Java friendly def mapGroupsWithState[S, U](func: MapGroupsWithStateFunction[K, V, S, R], stateEncoder: Encoder[S], resultEncoder: Encoder[U]) def flatMapGroupsWithState[S, U](func: FlatMapGroupsWithStateFunction[K, V, S, R], stateEncoder: Encoder[S], resultEncoder: Encoder[U]) } // ------------------- New Java-friendly function classes ------------------- public interface MapGroupsWithStateFunction extends Serializable { R call(K key, Iterator values, state: KeyedState) throws Exception; } public interface FlatMapGroupsWithStateFunction extends Serializable { Iterator call(K key, Iterator values, state: KeyedState) throws Exception; } // ---------------------- Wrapper class for state data ---------------------- trait State[S] { def exists(): Boolean def get(): S // throws Exception is state does not exist def getOption(): Option[S] def update(newState: S): Unit def remove(): Unit // exists() will be false after this } ``` Key Semantics of the State class - The state can be null. - If the state.remove() is called, then state.exists() will return false, and getOption will returm None. - After that state.update(newState) is called, then state.exists() will return true, and getOption will return Some(...). - None of the operations are thread-safe. This is to avoid memory barriers. *Usage* ``` val stateFunc = (word: String, words: Iterator[String, runningCount: KeyedState[Long]) => { val newCount = words.size + runningCount.getOption.getOrElse(0L) runningCount.update(newCount) (word, newCount) } dataset // type is Dataset[String] .groupByKey[String](w => w) // generates KeyValueGroupedDataset[String, String] .mapGroupsWithState[Long, (String, Long)](stateFunc) // returns Dataset[(String, Long)] ``` ## How was this patch tested? New unit tests. Author: Tathagata Das Closes #16758 from tdas/mapWithState. --- .../analysis/UnsupportedOperationChecker.scala | 11 +- .../spark/sql/catalyst/plans/logical/object.scala | 49 +++ .../analysis/UnsupportedOperationsSuite.scala | 24 +- .../function/FlatMapGroupsWithStateFunction.java | 38 +++ .../java/function/MapGroupsWithStateFunction.java | 38 +++ .../apache/spark/sql/KeyValueGroupedDataset.scala | 113 +++++++ .../scala/org/apache/spark/sql/KeyedState.scala | 142 +++++++++ .../spark/sql/execution/SparkStrategies.scala | 21 +- .../org/apache/spark/sql/execution/objects.scala | 22 ++ .../execution/streaming/IncrementalExecution.scala | 19 +- .../sql/execution/streaming/KeyedStateImpl.scala | 80 +++++ .../sql/execution/streaming/ProgressReporter.scala | 2 +- .../execution/streaming/StatefulAggregate.scala | 237 --------------- .../state/HDFSBackedStateStoreProvider.scala | 19 ++ .../sql/execution/streaming/state/StateStore.scala | 5 + .../sql/execution/streaming/state/package.scala | 11 +- .../execution/streaming/statefulOperators.scala | 323 ++++++++++++++++++++ .../org/apache/spark/sql/JavaDatasetSuite.java | 32 ++ .../sql/streaming/MapGroupsWithStateSuite.scala | 335 +++++++++++++++++++++ 19 files changed, 1272 insertions(+), 249 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java create mode 100644 sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index e4fd737b35..07b3558ee2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -46,8 +46,13 @@ object UnsupportedOperationChecker { "Queries without streaming sources cannot be executed with writeStream.start()")(plan) } + /** Collect all the streaming aggregates in a sub plan */ + def collectStreamingAggregates(subplan: LogicalPlan): Seq[Aggregate] = { + subplan.collect { case a: Aggregate if a.isStreaming => a } + } + // Disallow multiple streaming aggregations - val aggregates = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a } + val aggregates = collectStreamingAggregates(plan) if (aggregates.size > 1) { throwError( @@ -111,6 +116,10 @@ object UnsupportedOperationChecker { throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " + "streaming DataFrames/Datasets") + case m: MapGroupsWithState if collectStreamingAggregates(m).nonEmpty => + throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " + + "streaming DataFrame/Dataset") + case Join(left, right, joinType, _) => joinType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 0ab4c90166..0be4823bbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -313,6 +313,55 @@ case class MapGroups( outputObjAttr: Attribute, child: LogicalPlan) extends UnaryNode with ObjectProducer +/** Internal class representing State */ +trait LogicalKeyedState[S] + +/** Factory for constructing new `MapGroupsWithState` nodes. */ +object MapGroupsWithState { + def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( + func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + child: LogicalPlan): LogicalPlan = { + val mapped = new MapGroupsWithState( + func, + UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), + UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes), + groupingAttributes, + dataAttributes, + CatalystSerde.generateObjAttr[U], + encoderFor[S].resolveAndBind().deserializer, + encoderFor[S].namedExpressions, + child) + CatalystSerde.serialize[U](mapped) + } +} + +/** + * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`, + * while using state data. + * Func is invoked with an object representation of the grouping key an iterator containing the + * object representation of all the rows with that key. + * + * @param keyDeserializer used to extract the key object for each group. + * @param valueDeserializer used to extract the items in the iterator from an input row. + * @param groupingAttributes used to group the data + * @param dataAttributes used to read the data + * @param outputObjAttr used to define the output object + * @param stateDeserializer used to deserialize state before calling `func` + * @param stateSerializer used to serialize updated state after calling `func` + */ +case class MapGroupsWithState( + func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + stateDeserializer: Expression, + stateSerializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectProducer + /** Factory for constructing new `FlatMapGroupsInR` nodes. */ object FlatMapGroupsInR { def apply( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index dcdb1ae089..3b756e89d9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -22,13 +22,13 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{IntegerType, LongType} /** A dummy command for testing unsupported operations. */ case class DummyCommand() extends Command @@ -111,6 +111,24 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Complete, expectedMsgs = Seq("distinct aggregation")) + // MapGroupsWithState: Not supported after a streaming aggregation + val att = new AttributeReference(name = "a", dataType = LongType)() + assertSupportedInBatchPlan( + "mapGroupsWithState - mapGroupsWithState on batch relation", + MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation)) + + assertSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState on streaming relation before aggregation", + MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), streamRelation), + outputMode = Append) + + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState on streaming relation after aggregation", + MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), + Aggregate(Nil, aggExprs("c"), streamRelation)), + outputMode = Complete, + expectedMsgs = Seq("(map/flatMap)GroupsWithState")) + // Inner joins: Stream-stream not supported testBinaryOperationInStreamingPlan( "inner join", diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java new file mode 100644 index 0000000000..2570c8d02a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.KeyedState; + +/** + * ::Experimental:: + * Base interface for a map function used in + * {@link org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroupsWithState(FlatMapGroupsWithStateFunction, Encoder, Encoder)}. + * @since 2.1.1 + */ +@Experimental +@InterfaceStability.Evolving +public interface FlatMapGroupsWithStateFunction extends Serializable { + Iterator call(K key, Iterator values, KeyedState state) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java new file mode 100644 index 0000000000..614d3925e0 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.KeyedState; + +/** + * ::Experimental:: + * Base interface for a map function used in + * {@link org.apache.spark.sql.KeyValueGroupedDataset#mapGroupsWithState(MapGroupsWithStateFunction, Encoder, Encoder)} + * @since 2.1.1 + */ +@Experimental +@InterfaceStability.Evolving +public interface MapGroupsWithStateFunction extends Serializable { + R call(K key, Iterator values, KeyedState state) throws Exception; +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 395d709f26..94e689a4d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -218,6 +218,119 @@ class KeyValueGroupedDataset[K, V] private[sql]( mapGroups((key, data) => f.call(key, data.asJava))(encoder) } + /** + * ::Experimental:: + * (Scala-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def mapGroupsWithState[S: Encoder, U: Encoder]( + func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = { + flatMapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))) + } + + /** + * ::Experimental:: + * (Java-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = { + flatMapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s)) + )(stateEncoder, outputEncoder) + } + + /** + * ::Experimental:: + * (Scala-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def flatMapGroupsWithState[S: Encoder, U: Encoder]( + func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = { + Dataset[U]( + sparkSession, + MapGroupsWithState[K, V, S, U]( + func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + groupingAttributes, + dataAttributes, + logicalPlan)) + } + + /** + * ::Experimental:: + * (Java-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def flatMapGroupsWithState[S, U]( + func: FlatMapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = { + flatMapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala + )(stateEncoder, outputEncoder) + } + /** * (Scala-specific) * Reduces the elements of each group of data using the specified binary function. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala new file mode 100644 index 0000000000..6864b6f6b4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.IllegalArgumentException + +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState + +/** + * :: Experimental :: + * + * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and + * `flatMapGroupsWithState` operations on + * [[KeyValueGroupedDataset]]. + * + * Detail description on `[map/flatMap]GroupsWithState` operation + * ------------------------------------------------------------ + * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]] + * will invoke the user-given function on each group (defined by the grouping function in + * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger. + * That is, in every batch of the [[streaming.StreamingQuery StreamingQuery]], + * the function will be invoked once for each group that has data in the batch. + * + * The function is invoked with following parameters. + * - The key of the group. + * - An iterator containing all the values for this key. + * - A user-defined state object set by previous invocations of the given function. + * In case of a batch Dataset, there is only one invocation and state object will be empty as + * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` + * is equivalent to `[map/flatMap]Groups`. + * + * Important points to note about the function. + * - In a trigger, the function will be called only the groups present in the batch. So do not + * assume that the function will be called in every trigger for every group that has state. + * - There is no guaranteed ordering of values in the iterator in the function, neither with + * batch, nor with streaming Datasets. + * - All the data will be shuffled before applying the function. + * + * Important points to note about using KeyedState. + * - The value of the state cannot be null. So updating state with null will throw + * `IllegalArgumentException`. + * - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers. + * - If `remove()` is called, then `exists()` will return `false`, + * `get()` will throw `NoSuchElementException` and `getOption()` will return `None` + * - After that, if `update(newState)` is called, then `exists()` will again return `true`, + * `get()` and `getOption()`will return the updated value. + * + * Scala example of using KeyedState in `mapGroupsWithState`: + * {{{ + * /* A mapping function that maintains an integer state for string keys and returns a string. */ + * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = { + * // Check if state exists + * if (state.exists) { + * val existingState = state.get // Get the existing state + * val shouldRemove = ... // Decide whether to remove the state + * if (shouldRemove) { + * state.remove() // Remove the state + * } else { + * val newState = ... + * state.update(newState) // Set the new state + * } + * } else { + * val initialState = ... + * state.update(initialState) // Set the initial state + * } + * ... // return something + * } + * + * }}} + * + * Java example of using `KeyedState`: + * {{{ + * /* A mapping function that maintains an integer state for string keys and returns a string. */ + * MapGroupsWithStateFunction mappingFunction = + * new MapGroupsWithStateFunction() { + * + * @Override + * public String call(String key, Iterator value, KeyedState state) { + * if (state.exists()) { + * int existingState = state.get(); // Get the existing state + * boolean shouldRemove = ...; // Decide whether to remove the state + * if (shouldRemove) { + * state.remove(); // Remove the state + * } else { + * int newState = ...; + * state.update(newState); // Set the new state + * } + * } else { + * int initialState = ...; // Set the initial state + * state.update(initialState); + * } + * ... // return something + * } + * }; + * }}} + * + * @tparam S User-defined type of the state to be stored for each key. Must be encodable into + * Spark SQL types (see [[Encoder]] for more details). + * @since 2.1.1 + */ +@Experimental +@InterfaceStability.Evolving +trait KeyedState[S] extends LogicalKeyedState[S] { + + /** Whether state exists or not. */ + def exists: Boolean + + /** Get the state value if it exists, or throw NoSuchElementException. */ + @throws[NoSuchElementException]("when state does not exist") + def get: S + + /** Get the state value as a scala Option. */ + def getOption: Option[S] + + /** + * Update the value of the state. Note that `null` is not a valid value, and it throws + * IllegalArgumentException. + */ + @throws[IllegalArgumentException]("when updating with null") + def update(newState: S): Unit + + /** Remove this keyed state. */ + def remove(): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e3ec343479..557181ebd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan, MapGroupsWithState} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} @@ -313,6 +313,23 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Strategy to convert MapGroupsWithState logical operator to physical operator + * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. + */ + object MapGroupsWithStateStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case MapGroupsWithState( + f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) => + val execPlan = MapGroupsWithStateExec( + f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer, + planLater(child)) + execPlan :: Nil + case _ => + Nil + } + } + // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def numPartitions: Int = self.numPartitions @@ -354,6 +371,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil + case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) => + execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index fde3b2a528..199ba5ce69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -30,6 +30,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState +import org.apache.spark.sql.execution.streaming.KeyedStateImpl import org.apache.spark.sql.types.{DataType, ObjectType, StructType} @@ -144,6 +146,11 @@ object ObjectOperator { (i: InternalRow) => proj(i).get(0, deserializer.dataType) } + def deserializeRowToObject(deserializer: Expression): InternalRow => Any = { + val proj = GenerateSafeProjection.generate(deserializer :: Nil) + (i: InternalRow) => proj(i).get(0, deserializer.dataType) + } + def serializeObjectToRow(serializer: Seq[Expression]): Any => UnsafeRow = { val proj = GenerateUnsafeProjection.generate(serializer) val objType = serializer.head.collect { case b: BoundReference => b.dataType }.head @@ -344,6 +351,21 @@ case class MapGroupsExec( } } +object MapGroupsExec { + def apply( + func: (Any, Iterator[Any], LogicalKeyedState[Any]) => TraversableOnce[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + child: SparkPlan): MapGroupsExec = { + val f = (key: Any, values: Iterator[Any]) => func(key, values, new KeyedStateImpl[Any](None)) + new MapGroupsExec(f, keyDeserializer, valueDeserializer, + groupingAttributes, dataAttributes, outputObjAttr, child) + } +} + /** * Groups the input rows together and calls the R function with each group and an iterator * containing all elements in the group. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index bd7cec3917..a3e108b29e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.util.concurrent.atomic.AtomicInteger + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal} import org.apache.spark.sql.SparkSession @@ -39,8 +41,9 @@ class IncrementalExecution( extends QueryExecution(sparkSession, logicalPlan) with Logging { // TODO: make this always part of planning. - val stateStrategy = + val streamingExtraStrategies = sparkSession.sessionState.planner.StatefulAggregationStrategy +: + sparkSession.sessionState.planner.MapGroupsWithStateStrategy +: sparkSession.sessionState.planner.StreamingRelationStrategy +: sparkSession.sessionState.experimentalMethods.extraStrategies @@ -49,7 +52,7 @@ class IncrementalExecution( new SparkPlanner( sparkSession.sparkContext, sparkSession.sessionState.conf, - stateStrategy) + streamingExtraStrategies) /** * See [SPARK-18339] @@ -68,7 +71,7 @@ class IncrementalExecution( * Records the current id for a given stateful operator in the query plan as the `state` * preparation walks the query plan. */ - private var operatorId = 0 + private val operatorId = new AtomicInteger(0) /** Locates save/restore pairs surrounding aggregation. */ val state = new Rule[SparkPlan] { @@ -77,8 +80,8 @@ class IncrementalExecution( case StateStoreSaveExec(keys, None, None, None, UnaryExecNode(agg, StateStoreRestoreExec(keys2, None, child))) => - val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId) - operatorId += 1 + val stateId = + OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) StateStoreSaveExec( keys, @@ -90,6 +93,12 @@ class IncrementalExecution( keys, Some(stateId), child) :: Nil)) + case MapGroupsWithStateExec( + f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) => + val stateId = + OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) + MapGroupsWithStateExec( + f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala new file mode 100644 index 0000000000..eee7ec45dd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.KeyedState + +/** Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. */ +private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedState[S] { + private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) + private var defined: Boolean = optionalValue.isDefined + private var updated: Boolean = false + // whether value has been updated (but not removed) + private var removed: Boolean = false // whether value has been removed + + // ========= Public API ========= + override def exists: Boolean = defined + + override def get: S = { + if (defined) { + value + } else { + throw new NoSuchElementException("State is either not defined or has already been removed") + } + } + + override def getOption: Option[S] = { + if (defined) { + Some(value) + } else { + None + } + } + + override def update(newValue: S): Unit = { + if (newValue == null) { + throw new IllegalArgumentException("'null' is not a valid state value") + } + value = newValue + defined = true + updated = true + removed = false + } + + override def remove(): Unit = { + defined = false + updated = false + removed = true + } + + override def toString: String = { + s"KeyedState(${getOption.map(_.toString).getOrElse("")})" + } + + // ========= Internal API ========= + + /** Whether the state has been marked for removing */ + def isRemoved: Boolean = { + removed + } + + /** Whether the state has been been updated */ + def isUpdated: Boolean = { + updated + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 1f74fffbe6..693933f95a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -186,7 +186,7 @@ trait ProgressReporter extends Logging { // lastExecution could belong to one of the previous triggers if `!hasNewData`. // Walking the plan again should be inexpensive. val stateNodes = lastExecution.executedPlan.collect { - case p if p.isInstanceOf[StateStoreSaveExec] => p + case p if p.isInstanceOf[StateStoreWriter] => p } stateNodes.map { node => val numRowsUpdated = if (hasNewData) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala deleted file mode 100644 index d4ccced9ac..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ /dev/null @@ -1,237 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} -import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark -import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.StructType -import org.apache.spark.TaskContext - - -/** Used to identify the state store for a given operator. */ -case class OperatorStateId( - checkpointLocation: String, - operatorId: Long, - batchId: Long) - -/** - * An operator that saves or restores state from the [[StateStore]]. The [[OperatorStateId]] should - * be filled in by `prepareForExecution` in [[IncrementalExecution]]. - */ -trait StatefulOperator extends SparkPlan { - def stateId: Option[OperatorStateId] - - protected def getStateId: OperatorStateId = attachTree(this) { - stateId.getOrElse { - throw new IllegalStateException("State location not present for execution") - } - } -} - -/** - * For each input tuple, the key is calculated and the value from the [[StateStore]] is added - * to the stream (in addition to the input tuple) if present. - */ -case class StateStoreRestoreExec( - keyExpressions: Seq[Attribute], - stateId: Option[OperatorStateId], - child: SparkPlan) - extends execution.UnaryExecNode with StatefulOperator { - - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - - override protected def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - operatorId = getStateId.operatorId, - storeVersion = getStateId.batchId, - keyExpressions.toStructType, - child.output.toStructType, - sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) - iter.flatMap { row => - val key = getKey(row) - val savedState = store.get(key) - numOutputRows += 1 - row +: savedState.toSeq - } - } - } - - override def output: Seq[Attribute] = child.output - - override def outputPartitioning: Partitioning = child.outputPartitioning -} - -/** - * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]]. - */ -case class StateStoreSaveExec( - keyExpressions: Seq[Attribute], - stateId: Option[OperatorStateId] = None, - outputMode: Option[OutputMode] = None, - eventTimeWatermark: Option[Long] = None, - child: SparkPlan) - extends execution.UnaryExecNode with StatefulOperator { - - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), - "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) - - /** Generate a predicate that matches data older than the watermark */ - private lazy val watermarkPredicate: Option[Predicate] = { - val optionalWatermarkAttribute = - keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)) - - optionalWatermarkAttribute.map { watermarkAttribute => - // If we are evicting based on a window, use the end of the window. Otherwise just - // use the attribute itself. - val evictionExpression = - if (watermarkAttribute.dataType.isInstanceOf[StructType]) { - LessThanOrEqual( - GetStructField(watermarkAttribute, 1), - Literal(eventTimeWatermark.get * 1000)) - } else { - LessThanOrEqual( - watermarkAttribute, - Literal(eventTimeWatermark.get * 1000)) - } - - logInfo(s"Filtering state store on: $evictionExpression") - newPredicate(evictionExpression, keyExpressions) - } - } - - override protected def doExecute(): RDD[InternalRow] = { - metrics // force lazy init at driver - assert(outputMode.nonEmpty, - "Incorrect planning in IncrementalExecution, outputMode has not been set") - - child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - operatorId = getStateId.operatorId, - storeVersion = getStateId.batchId, - keyExpressions.toStructType, - child.output.toStructType, - sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => - val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) - val numOutputRows = longMetric("numOutputRows") - val numTotalStateRows = longMetric("numTotalStateRows") - val numUpdatedStateRows = longMetric("numUpdatedStateRows") - - // Abort the state store in case of error - TaskContext.get().addTaskCompletionListener(_ => { - if (!store.hasCommitted) { - store.abort() - } - }) - - outputMode match { - // Update and output all rows in the StateStore. - case Some(Complete) => - while (iter.hasNext) { - val row = iter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - numUpdatedStateRows += 1 - } - store.commit() - numTotalStateRows += store.numKeys() - store.iterator().map { case (k, v) => - numOutputRows += 1 - v.asInstanceOf[InternalRow] - } - - // Update and output only rows being evicted from the StateStore - case Some(Append) => - while (iter.hasNext) { - val row = iter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - numUpdatedStateRows += 1 - } - - // Assumption: Append mode can be done only when watermark has been specified - store.remove(watermarkPredicate.get.eval) - store.commit() - - numTotalStateRows += store.numKeys() - store.updates().filter(_.isInstanceOf[ValueRemoved]).map { removed => - numOutputRows += 1 - removed.value.asInstanceOf[InternalRow] - } - - // Update and output modified rows from the StateStore. - case Some(Update) => - - new Iterator[InternalRow] { - - // Filter late date using watermark if specified - private[this] val baseIterator = watermarkPredicate match { - case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) - case None => iter - } - - override def hasNext: Boolean = { - if (!baseIterator.hasNext) { - // Remove old aggregates if watermark specified - if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval) - store.commit() - numTotalStateRows += store.numKeys() - false - } else { - true - } - } - - override def next(): InternalRow = { - val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - numOutputRows += 1 - numUpdatedStateRows += 1 - row - } - } - - case _ => throw new UnsupportedOperationException(s"Invalid output mode: $outputMode") - } - } - } - - override def output: Seq[Attribute] = child.output - - override def outputPartitioning: Partitioning = child.outputPartitioning -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 1279b71c4d..61eb601a18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -147,6 +147,25 @@ private[state] class HDFSBackedStateStoreProvider( } } + /** Remove a single key. */ + override def remove(key: UnsafeRow): Unit = { + verify(state == UPDATING, "Cannot remove after already committed or aborted") + if (mapToUpdate.containsKey(key)) { + val value = mapToUpdate.remove(key) + Option(allUpdates.get(key)) match { + case Some(ValueUpdated(_, _)) | None => + // Value existed in previous version and maybe was updated, mark removed + allUpdates.put(key, ValueRemoved(key, value)) + case Some(ValueAdded(_, _)) => + // Value did not exist in previous version and was added, should not appear in updates + allUpdates.remove(key) + case Some(ValueRemoved(_, _)) => + // Remove already in update map, no need to change + } + writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value)) + } + } + /** Commit all the updates that have been made to the store, and return the new version. */ override def commit(): Long = { verify(state == UPDATING, "Cannot commit after already committed or aborted") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index e61d95a1b1..dcb24b26f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -58,6 +58,11 @@ trait StateStore { */ def remove(condition: UnsafeRow => Boolean): Unit + /** + * Remove a single key. + */ + def remove(key: UnsafeRow): Unit + /** * Commit all the updates that have been made to the store, and return the new version. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 1b56c08f72..589042afb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming import scala.reflect.ClassTag +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.internal.SessionState @@ -59,10 +60,18 @@ package object state { sessionState: SessionState, storeCoordinator: Option[StateStoreCoordinatorRef])( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { + val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) + val wrappedF = (store: StateStore, iter: Iterator[T]) => { + // Abort the state store in case of error + TaskContext.get().addTaskCompletionListener(_ => { + if (!store.hasCommitted) store.abort() + }) + cleanedF(store, iter) + } new StateStoreRDD( dataRDD, - cleanedF, + wrappedF, checkpointLocation, operatorId, storeVersion, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala new file mode 100644 index 0000000000..1292452574 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.CompletionIterator + + +/** Used to identify the state store for a given operator. */ +case class OperatorStateId( + checkpointLocation: String, + operatorId: Long, + batchId: Long) + +/** + * An operator that reads or writes state from the [[StateStore]]. The [[OperatorStateId]] should + * be filled in by `prepareForExecution` in [[IncrementalExecution]]. + */ +trait StatefulOperator extends SparkPlan { + def stateId: Option[OperatorStateId] + + protected def getStateId: OperatorStateId = attachTree(this) { + stateId.getOrElse { + throw new IllegalStateException("State location not present for execution") + } + } +} + +/** An operator that reads from a StateStore. */ +trait StateStoreReader extends StatefulOperator { + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) +} + +/** An operator that writes to a StateStore. */ +trait StateStoreWriter extends StatefulOperator { + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), + "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) +} + +/** + * For each input tuple, the key is calculated and the value from the [[StateStore]] is added + * to the stream (in addition to the input tuple) if present. + */ +case class StateStoreRestoreExec( + keyExpressions: Seq[Attribute], + stateId: Option[OperatorStateId], + child: SparkPlan) + extends execution.UnaryExecNode with StateStoreReader { + + override protected def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + operatorId = getStateId.operatorId, + storeVersion = getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => + val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + iter.flatMap { row => + val key = getKey(row) + val savedState = store.get(key) + numOutputRows += 1 + row +: savedState.toSeq + } + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning +} + +/** + * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]]. + */ +case class StateStoreSaveExec( + keyExpressions: Seq[Attribute], + stateId: Option[OperatorStateId] = None, + outputMode: Option[OutputMode] = None, + eventTimeWatermark: Option[Long] = None, + child: SparkPlan) + extends execution.UnaryExecNode with StateStoreWriter { + + /** Generate a predicate that matches data older than the watermark */ + private lazy val watermarkPredicate: Option[Predicate] = { + val optionalWatermarkAttribute = + keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)) + + optionalWatermarkAttribute.map { watermarkAttribute => + // If we are evicting based on a window, use the end of the window. Otherwise just + // use the attribute itself. + val evictionExpression = + if (watermarkAttribute.dataType.isInstanceOf[StructType]) { + LessThanOrEqual( + GetStructField(watermarkAttribute, 1), + Literal(eventTimeWatermark.get * 1000)) + } else { + LessThanOrEqual( + watermarkAttribute, + Literal(eventTimeWatermark.get * 1000)) + } + + logInfo(s"Filtering state store on: $evictionExpression") + newPredicate(evictionExpression, keyExpressions) + } + } + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + assert(outputMode.nonEmpty, + "Incorrect planning in IncrementalExecution, outputMode has not been set") + + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + operatorId = getStateId.operatorId, + storeVersion = getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val numOutputRows = longMetric("numOutputRows") + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + + outputMode match { + // Update and output all rows in the StateStore. + case Some(Complete) => + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + numUpdatedStateRows += 1 + } + store.commit() + numTotalStateRows += store.numKeys() + store.iterator().map { case (k, v) => + numOutputRows += 1 + v.asInstanceOf[InternalRow] + } + + // Update and output only rows being evicted from the StateStore + case Some(Append) => + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + numUpdatedStateRows += 1 + } + + // Assumption: Append mode can be done only when watermark has been specified + store.remove(watermarkPredicate.get.eval _) + store.commit() + + numTotalStateRows += store.numKeys() + store.updates().filter(_.isInstanceOf[ValueRemoved]).map { removed => + numOutputRows += 1 + removed.value.asInstanceOf[InternalRow] + } + + // Update and output modified rows from the StateStore. + case Some(Update) => + + new Iterator[InternalRow] { + + // Filter late date using watermark if specified + private[this] val baseIterator = watermarkPredicate match { + case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + case None => iter + } + + override def hasNext: Boolean = { + if (!baseIterator.hasNext) { + // Remove old aggregates if watermark specified + if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval _) + store.commit() + numTotalStateRows += store.numKeys() + false + } else { + true + } + } + + override def next(): InternalRow = { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + numOutputRows += 1 + numUpdatedStateRows += 1 + row + } + } + + case _ => throw new UnsupportedOperationException(s"Invalid output mode: $outputMode") + } + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning +} + + +/** Physical operator for executing streaming mapGroupsWithState. */ +case class MapGroupsWithStateExec( + func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + stateId: Option[OperatorStateId], + stateDeserializer: Expression, + stateSerializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter { + + override def outputPartitioning: Partitioning = child.outputPartitioning + + /** Distribute by grouping attributes */ + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingAttributes) :: Nil + + /** Ordering needed for using GroupingIterator */ + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsWithStateStore[InternalRow]( + getStateId.checkpointLocation, + operatorId = getStateId.operatorId, + storeVersion = getStateId.batchId, + groupingAttributes.toStructType, + child.output.toStructType, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val numOutputRows = longMetric("numOutputRows") + + // Generate a iterator that returns the rows grouped by the grouping function + val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) + + // Converters to and from object and rows + val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) + val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) + val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + val getStateObj = + ObjectOperator.deserializeRowToObject(stateDeserializer) + val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer) + + // For every group, get the key, values and corresponding state and call the function, + // and return an iterator of rows + val allRowsIterator = groupedIter.flatMap { case (keyRow, valueRowIter) => + + val key = keyRow.asInstanceOf[UnsafeRow] + val keyObj = getKeyObj(keyRow) // convert key to objects + val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects + val stateObjOption = store.get(key).map(getStateObj) // get existing state if any + val wrappedState = new KeyedStateImpl(stateObjOption) + val mappedIterator = func(keyObj, valueObjIter, wrappedState).map { obj => + numOutputRows += 1 + getOutputRow(obj) // convert back to rows + } + + // Return an iterator of rows generated this key, + // such that fully consumed, the updated state value will be saved + CompletionIterator[InternalRow, Iterator[InternalRow]]( + mappedIterator, { + // When the iterator is consumed, then write changes to state + if (wrappedState.isRemoved) { + store.remove(key) + numUpdatedStateRows += 1 + } else if (wrappedState.isUpdated) { + store.put(key, outputStateObj(wrappedState.get)) + numUpdatedStateRows += 1 + } + }) + } + + // Return an iterator of all the rows generated by all the keys, such that when fully + // consumer, all the state updates will be committed by the state store + CompletionIterator[InternalRow, Iterator[InternalRow]](allRowsIterator, { + store.commit() + numTotalStateRows += store.numKeys() + }) + } + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 8304b728aa..5ef4e887de 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -225,6 +225,38 @@ public class JavaDatasetSuite implements Serializable { Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList())); + Dataset mapped2 = grouped.mapGroupsWithState( + new MapGroupsWithStateFunction() { + @Override + public String call(Integer key, Iterator values, KeyedState s) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return sb.toString(); + } + }, + Encoders.LONG(), + Encoders.STRING()); + + Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped2.collectAsList())); + + Dataset flatMapped2 = grouped.flatMapGroupsWithState( + new FlatMapGroupsWithStateFunction() { + @Override + public Iterator call(Integer key, Iterator values, KeyedState s) { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return Collections.singletonList(sb.toString()).iterator(); + } + }, + Encoders.LONG(), + Encoders.STRING()); + + Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList())); + Dataset> reduced = grouped.reduceGroups(new ReduceFunction() { @Override public String call(String v1, String v2) throws Exception { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala new file mode 100644 index 0000000000..0524898b15 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkException +import org.apache.spark.sql.KeyedState +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.streaming.{KeyedStateImpl, MemoryStream} +import org.apache.spark.sql.execution.streaming.state.StateStore + +/** Class to check custom state types */ +case class RunningCount(count: Long) + +class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { + + import testImplicits._ + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } + + test("KeyedState - get, exists, update, remove") { + var state: KeyedStateImpl[String] = null + + def testState( + expectedData: Option[String], + shouldBeUpdated: Boolean = false, + shouldBeRemoved: Boolean = false): Unit = { + if (expectedData.isDefined) { + assert(state.exists) + assert(state.get === expectedData.get) + } else { + assert(!state.exists) + intercept[NoSuchElementException] { + state.get + } + } + assert(state.getOption === expectedData) + assert(state.isUpdated === shouldBeUpdated) + assert(state.isRemoved === shouldBeRemoved) + } + + // Updating empty state + state = new KeyedStateImpl[String](None) + testState(None) + state.update("") + testState(Some(""), shouldBeUpdated = true) + + // Updating exiting state + state = new KeyedStateImpl[String](Some("2")) + testState(Some("2")) + state.update("3") + testState(Some("3"), shouldBeUpdated = true) + + // Removing state + state.remove() + testState(None, shouldBeRemoved = true, shouldBeUpdated = false) + state.remove() // should be still callable + state.update("4") + testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true) + + // Updating by null throw exception + intercept[IllegalArgumentException] { + state.update(null) + } + } + + test("KeyedState - primitive type") { + var intState = new KeyedStateImpl[Int](None) + intercept[NoSuchElementException] { + intState.get + } + assert(intState.getOption === None) + + intState = new KeyedStateImpl[Int](Some(10)) + assert(intState.get == 10) + intState.update(0) + assert(intState.get == 0) + intState.remove() + intercept[NoSuchElementException] { + intState.get + } + } + + test("flatMapGroupsWithState - streaming") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + Iterator.empty + } else { + state.update(RunningCount(count)) + Iterator((key, count.toString)) + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) + + testStream(result, Append)( + AddData(inputData, "a"), + CheckLastBatch(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a", "b"), + CheckLastBatch(("a", "2"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckLastBatch(("b", "2")), + assertNumStateRows(total = 1, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckLastBatch(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) + ) + } + + test("flatMapGroupsWithState - streaming + func returns iterator that updates state lazily") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + // Additionally, it updates state lazily as the returned iterator get consumed + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + values.flatMap { _ => + val count = state.getOption.map(_.count).getOrElse(0L) + 1 + if (count == 3) { + state.remove() + None + } else { + state.update(RunningCount(count)) + Some((key, count.toString)) + } + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) + + testStream(result, Append)( + AddData(inputData, "a", "a", "b"), + CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckLastBatch(("b", "2")), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckLastBatch(("a", "1"), ("c", "1")) + ) + } + + test("flatMapGroupsWithState - batch") { + // Function that returns running count only if its even, otherwise does not return + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + if (state.exists) throw new IllegalArgumentException("state.exists should be false") + Iterator((key, values.size)) + } + checkAnswer( + Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF, + Seq(("a", 2), ("b", 1)).toDF) + } + + test("mapGroupsWithState - streaming") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + (key, "-1") + } else { + state.update(RunningCount(count)) + (key, count.toString) + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + + testStream(result, Append)( + AddData(inputData, "a"), + CheckLastBatch(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a", "b"), + CheckLastBatch(("a", "2"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 + CheckLastBatch(("a", "-1"), ("b", "2")), + assertNumStateRows(total = 1, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 + CheckLastBatch(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) + ) + } + + test("mapGroupsWithState - streaming + aggregation") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + (key, "-1") + } else { + state.update(RunningCount(count)) + (key, count.toString) + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + .groupByKey(_._1) + .count() + + testStream(result, Complete)( + AddData(inputData, "a"), + CheckLastBatch(("a", 1)), + AddData(inputData, "a", "b"), + // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 + CheckLastBatch(("a", 2), ("b", 1)), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), + // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; + // so increment a and b by 1 + CheckLastBatch(("a", 3), ("b", 2)), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), + // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; + // so increment a and c by 1 + CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) + ) + } + + test("mapGroupsWithState - batch") { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + if (state.exists) throw new IllegalArgumentException("state.exists should be false") + (key, values.size) + } + + checkAnswer( + spark.createDataset(Seq("a", "a", "b")) + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) + .toDF, + spark.createDataset(Seq(("a", 2), ("b", 1))).toDF) + } + + testQuietly("StateStore.abort on task failure handling") { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + (key, count) + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + + def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q => + MapGroupsWithStateSuite.failInTask = value + true + } + + testStream(result, Append)( + setFailInTask(false), + AddData(inputData, "a"), + CheckLastBatch(("a", 1L)), + AddData(inputData, "a"), + CheckLastBatch(("a", 2L)), + setFailInTask(true), + AddData(inputData, "a"), + ExpectFailure[SparkException](), // task should fail but should not increment count + setFailInTask(false), + StartStream(), + CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count + ) + } + + private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q => + val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get + assert(progressWithData.stateOperators(0).numRowsTotal === total, "incorrect total rows") + assert(progressWithData.stateOperators(0).numRowsUpdated === updated, "incorrect updates rows") + true + } +} + +object MapGroupsWithStateSuite { + var failInTask = true +} -- cgit v1.2.3