aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2017-02-07 20:21:00 -0800
committerShixiong Zhu <shixiong@databricks.com>2017-02-07 20:21:00 -0800
commitaeb80348dd40c66b84bbc5cfe60d716fbce25acb (patch)
treeb1ae06e53c7272cf4a408b7da2e1c7413061b0b2
parente33aaa2ac53a6e17e160e4e63821450b3609033b (diff)
downloadspark-aeb80348dd40c66b84bbc5cfe60d716fbce25acb.tar.gz
spark-aeb80348dd40c66b84bbc5cfe60d716fbce25acb.tar.bz2
spark-aeb80348dd40c66b84bbc5cfe60d716fbce25acb.zip
[SPARK-19413][SS] MapGroupsWithState for arbitrary stateful operations
## What changes were proposed in this pull request? `mapGroupsWithState` is a new API for arbitrary stateful operations in Structured Streaming, similar to `DStream.mapWithState` *Requirements* - Users should be able to specify a function that can do the following - Access the input row corresponding to a key - Access the previous state corresponding to a key - Optionally, update or remove the state - Output any number of new rows (or none at all) *Proposed API* ``` // ------------ New methods on KeyValueGroupedDataset ------------ class KeyValueGroupedDataset[K, V] { // Scala friendly def mapGroupsWithState[S: Encoder, U: Encoder](func: (K, Iterator[V], KeyedState[S]) => U) def flatMapGroupsWithState[S: Encode, U: Encoder](func: (K, Iterator[V], KeyedState[S]) => Iterator[U]) // Java friendly def mapGroupsWithState[S, U](func: MapGroupsWithStateFunction[K, V, S, R], stateEncoder: Encoder[S], resultEncoder: Encoder[U]) def flatMapGroupsWithState[S, U](func: FlatMapGroupsWithStateFunction[K, V, S, R], stateEncoder: Encoder[S], resultEncoder: Encoder[U]) } // ------------------- New Java-friendly function classes ------------------- public interface MapGroupsWithStateFunction<K, V, S, R> extends Serializable { R call(K key, Iterator<V> values, state: KeyedState<S>) throws Exception; } public interface FlatMapGroupsWithStateFunction<K, V, S, R> extends Serializable { Iterator<R> call(K key, Iterator<V> values, state: KeyedState<S>) throws Exception; } // ---------------------- Wrapper class for state data ---------------------- trait State[S] { def exists(): Boolean def get(): S // throws Exception is state does not exist def getOption(): Option[S] def update(newState: S): Unit def remove(): Unit // exists() will be false after this } ``` Key Semantics of the State class - The state can be null. - If the state.remove() is called, then state.exists() will return false, and getOption will returm None. - After that state.update(newState) is called, then state.exists() will return true, and getOption will return Some(...). - None of the operations are thread-safe. This is to avoid memory barriers. *Usage* ``` val stateFunc = (word: String, words: Iterator[String, runningCount: KeyedState[Long]) => { val newCount = words.size + runningCount.getOption.getOrElse(0L) runningCount.update(newCount) (word, newCount) } dataset // type is Dataset[String] .groupByKey[String](w => w) // generates KeyValueGroupedDataset[String, String] .mapGroupsWithState[Long, (String, Long)](stateFunc) // returns Dataset[(String, Long)] ``` ## How was this patch tested? New unit tests. Author: Tathagata Das <tathagata.das1565@gmail.com> Closes #16758 from tdas/mapWithState.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala49
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala24
-rw-r--r--sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java38
-rw-r--r--sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java38
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala113
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala142
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala80
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala)134
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java32
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala335
18 files changed, 1059 insertions, 36 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index e4fd737b35..07b3558ee2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -46,8 +46,13 @@ object UnsupportedOperationChecker {
"Queries without streaming sources cannot be executed with writeStream.start()")(plan)
}
+ /** Collect all the streaming aggregates in a sub plan */
+ def collectStreamingAggregates(subplan: LogicalPlan): Seq[Aggregate] = {
+ subplan.collect { case a: Aggregate if a.isStreaming => a }
+ }
+
// Disallow multiple streaming aggregations
- val aggregates = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a }
+ val aggregates = collectStreamingAggregates(plan)
if (aggregates.size > 1) {
throwError(
@@ -111,6 +116,10 @@ object UnsupportedOperationChecker {
throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " +
"streaming DataFrames/Datasets")
+ case m: MapGroupsWithState if collectStreamingAggregates(m).nonEmpty =>
+ throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " +
+ "streaming DataFrame/Dataset")
+
case Join(left, right, joinType, _) =>
joinType match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 0ab4c90166..0be4823bbc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -313,6 +313,55 @@ case class MapGroups(
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectProducer
+/** Internal class representing State */
+trait LogicalKeyedState[S]
+
+/** Factory for constructing new `MapGroupsWithState` nodes. */
+object MapGroupsWithState {
+ def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder](
+ func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ child: LogicalPlan): LogicalPlan = {
+ val mapped = new MapGroupsWithState(
+ func,
+ UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
+ UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
+ groupingAttributes,
+ dataAttributes,
+ CatalystSerde.generateObjAttr[U],
+ encoderFor[S].resolveAndBind().deserializer,
+ encoderFor[S].namedExpressions,
+ child)
+ CatalystSerde.serialize[U](mapped)
+ }
+}
+
+/**
+ * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`,
+ * while using state data.
+ * Func is invoked with an object representation of the grouping key an iterator containing the
+ * object representation of all the rows with that key.
+ *
+ * @param keyDeserializer used to extract the key object for each group.
+ * @param valueDeserializer used to extract the items in the iterator from an input row.
+ * @param groupingAttributes used to group the data
+ * @param dataAttributes used to read the data
+ * @param outputObjAttr used to define the output object
+ * @param stateDeserializer used to deserialize state before calling `func`
+ * @param stateSerializer used to serialize updated state after calling `func`
+ */
+case class MapGroupsWithState(
+ func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ outputObjAttr: Attribute,
+ stateDeserializer: Expression,
+ stateSerializer: Seq[NamedExpression],
+ child: LogicalPlan) extends UnaryNode with ObjectProducer
+
/** Factory for constructing new `FlatMapGroupsInR` nodes. */
object FlatMapGroupsInR {
def apply(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index dcdb1ae089..3b756e89d9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -22,13 +22,13 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.streaming.OutputMode
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{IntegerType, LongType}
/** A dummy command for testing unsupported operations. */
case class DummyCommand() extends Command
@@ -111,6 +111,24 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
outputMode = Complete,
expectedMsgs = Seq("distinct aggregation"))
+ // MapGroupsWithState: Not supported after a streaming aggregation
+ val att = new AttributeReference(name = "a", dataType = LongType)()
+ assertSupportedInBatchPlan(
+ "mapGroupsWithState - mapGroupsWithState on batch relation",
+ MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation))
+
+ assertSupportedInStreamingPlan(
+ "mapGroupsWithState - mapGroupsWithState on streaming relation before aggregation",
+ MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), streamRelation),
+ outputMode = Append)
+
+ assertNotSupportedInStreamingPlan(
+ "mapGroupsWithState - mapGroupsWithState on streaming relation after aggregation",
+ MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att),
+ Aggregate(Nil, aggExprs("c"), streamRelation)),
+ outputMode = Complete,
+ expectedMsgs = Seq("(map/flatMap)GroupsWithState"))
+
// Inner joins: Stream-stream not supported
testBinaryOperationInStreamingPlan(
"inner join",
diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
new file mode 100644
index 0000000000..2570c8d02a
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.KeyedState;
+
+/**
+ * ::Experimental::
+ * Base interface for a map function used in
+ * {@link org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroupsWithState(FlatMapGroupsWithStateFunction, Encoder, Encoder)}.
+ * @since 2.1.1
+ */
+@Experimental
+@InterfaceStability.Evolving
+public interface FlatMapGroupsWithStateFunction<K, V, S, R> extends Serializable {
+ Iterator<R> call(K key, Iterator<V> values, KeyedState<S> state) throws Exception;
+}
diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
new file mode 100644
index 0000000000..614d3925e0
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.KeyedState;
+
+/**
+ * ::Experimental::
+ * Base interface for a map function used in
+ * {@link org.apache.spark.sql.KeyValueGroupedDataset#mapGroupsWithState(MapGroupsWithStateFunction, Encoder, Encoder)}
+ * @since 2.1.1
+ */
+@Experimental
+@InterfaceStability.Evolving
+public interface MapGroupsWithStateFunction<K, V, S, R> extends Serializable {
+ R call(K key, Iterator<V> values, KeyedState<S> state) throws Exception;
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 395d709f26..94e689a4d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -219,6 +219,119 @@ class KeyValueGroupedDataset[K, V] private[sql](
}
/**
+ * ::Experimental::
+ * (Scala-specific)
+ * Applies the given function to each group of data, while maintaining a user-defined per-group
+ * state. The result Dataset will represent the objects returned by the function.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+ * updates to each group's state will be saved across invocations.
+ * See [[KeyedState]] for more details.
+ *
+ * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 2.1.1
+ */
+ @Experimental
+ @InterfaceStability.Evolving
+ def mapGroupsWithState[S: Encoder, U: Encoder](
+ func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = {
+ flatMapGroupsWithState[S, U](
+ (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)))
+ }
+
+ /**
+ * ::Experimental::
+ * (Java-specific)
+ * Applies the given function to each group of data, while maintaining a user-defined per-group
+ * state. The result Dataset will represent the objects returned by the function.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+ * updates to each group's state will be saved across invocations.
+ * See [[KeyedState]] for more details.
+ *
+ * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func Function to be called on every group.
+ * @param stateEncoder Encoder for the state type.
+ * @param outputEncoder Encoder for the output type.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 2.1.1
+ */
+ @Experimental
+ @InterfaceStability.Evolving
+ def mapGroupsWithState[S, U](
+ func: MapGroupsWithStateFunction[K, V, S, U],
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U]): Dataset[U] = {
+ flatMapGroupsWithState[S, U](
+ (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s))
+ )(stateEncoder, outputEncoder)
+ }
+
+ /**
+ * ::Experimental::
+ * (Scala-specific)
+ * Applies the given function to each group of data, while maintaining a user-defined per-group
+ * state. The result Dataset will represent the objects returned by the function.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+ * updates to each group's state will be saved across invocations.
+ * See [[KeyedState]] for more details.
+ *
+ * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 2.1.1
+ */
+ @Experimental
+ @InterfaceStability.Evolving
+ def flatMapGroupsWithState[S: Encoder, U: Encoder](
+ func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = {
+ Dataset[U](
+ sparkSession,
+ MapGroupsWithState[K, V, S, U](
+ func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
+ groupingAttributes,
+ dataAttributes,
+ logicalPlan))
+ }
+
+ /**
+ * ::Experimental::
+ * (Java-specific)
+ * Applies the given function to each group of data, while maintaining a user-defined per-group
+ * state. The result Dataset will represent the objects returned by the function.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+ * updates to each group's state will be saved across invocations.
+ * See [[KeyedState]] for more details.
+ *
+ * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func Function to be called on every group.
+ * @param stateEncoder Encoder for the state type.
+ * @param outputEncoder Encoder for the output type.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 2.1.1
+ */
+ @Experimental
+ @InterfaceStability.Evolving
+ def flatMapGroupsWithState[S, U](
+ func: FlatMapGroupsWithStateFunction[K, V, S, U],
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U]): Dataset[U] = {
+ flatMapGroupsWithState[S, U](
+ (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala
+ )(stateEncoder, outputEncoder)
+ }
+
+ /**
* (Scala-specific)
* Reduces the elements of each group of data using the specified binary function.
* The given function must be commutative and associative or the result may be non-deterministic.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala
new file mode 100644
index 0000000000..6864b6f6b4
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.lang.IllegalArgumentException
+
+import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
+
+/**
+ * :: Experimental ::
+ *
+ * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and
+ * `flatMapGroupsWithState` operations on
+ * [[KeyValueGroupedDataset]].
+ *
+ * Detail description on `[map/flatMap]GroupsWithState` operation
+ * ------------------------------------------------------------
+ * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]]
+ * will invoke the user-given function on each group (defined by the grouping function in
+ * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger.
+ * That is, in every batch of the [[streaming.StreamingQuery StreamingQuery]],
+ * the function will be invoked once for each group that has data in the batch.
+ *
+ * The function is invoked with following parameters.
+ * - The key of the group.
+ * - An iterator containing all the values for this key.
+ * - A user-defined state object set by previous invocations of the given function.
+ * In case of a batch Dataset, there is only one invocation and state object will be empty as
+ * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState`
+ * is equivalent to `[map/flatMap]Groups`.
+ *
+ * Important points to note about the function.
+ * - In a trigger, the function will be called only the groups present in the batch. So do not
+ * assume that the function will be called in every trigger for every group that has state.
+ * - There is no guaranteed ordering of values in the iterator in the function, neither with
+ * batch, nor with streaming Datasets.
+ * - All the data will be shuffled before applying the function.
+ *
+ * Important points to note about using KeyedState.
+ * - The value of the state cannot be null. So updating state with null will throw
+ * `IllegalArgumentException`.
+ * - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers.
+ * - If `remove()` is called, then `exists()` will return `false`,
+ * `get()` will throw `NoSuchElementException` and `getOption()` will return `None`
+ * - After that, if `update(newState)` is called, then `exists()` will again return `true`,
+ * `get()` and `getOption()`will return the updated value.
+ *
+ * Scala example of using KeyedState in `mapGroupsWithState`:
+ * {{{
+ * /* A mapping function that maintains an integer state for string keys and returns a string. */
+ * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = {
+ * // Check if state exists
+ * if (state.exists) {
+ * val existingState = state.get // Get the existing state
+ * val shouldRemove = ... // Decide whether to remove the state
+ * if (shouldRemove) {
+ * state.remove() // Remove the state
+ * } else {
+ * val newState = ...
+ * state.update(newState) // Set the new state
+ * }
+ * } else {
+ * val initialState = ...
+ * state.update(initialState) // Set the initial state
+ * }
+ * ... // return something
+ * }
+ *
+ * }}}
+ *
+ * Java example of using `KeyedState`:
+ * {{{
+ * /* A mapping function that maintains an integer state for string keys and returns a string. */
+ * MapGroupsWithStateFunction<String, Integer, Integer, String> mappingFunction =
+ * new MapGroupsWithStateFunction<String, Integer, Integer, String>() {
+ *
+ * @Override
+ * public String call(String key, Iterator<Integer> value, KeyedState<Integer> state) {
+ * if (state.exists()) {
+ * int existingState = state.get(); // Get the existing state
+ * boolean shouldRemove = ...; // Decide whether to remove the state
+ * if (shouldRemove) {
+ * state.remove(); // Remove the state
+ * } else {
+ * int newState = ...;
+ * state.update(newState); // Set the new state
+ * }
+ * } else {
+ * int initialState = ...; // Set the initial state
+ * state.update(initialState);
+ * }
+ * ... // return something
+ * }
+ * };
+ * }}}
+ *
+ * @tparam S User-defined type of the state to be stored for each key. Must be encodable into
+ * Spark SQL types (see [[Encoder]] for more details).
+ * @since 2.1.1
+ */
+@Experimental
+@InterfaceStability.Evolving
+trait KeyedState[S] extends LogicalKeyedState[S] {
+
+ /** Whether state exists or not. */
+ def exists: Boolean
+
+ /** Get the state value if it exists, or throw NoSuchElementException. */
+ @throws[NoSuchElementException]("when state does not exist")
+ def get: S
+
+ /** Get the state value as a scala Option. */
+ def getOption: Option[S]
+
+ /**
+ * Update the value of the state. Note that `null` is not a valid value, and it throws
+ * IllegalArgumentException.
+ */
+ @throws[IllegalArgumentException]("when updating with null")
+ def update(newState: S): Unit
+
+ /** Remove this keyed state. */
+ def remove(): Unit
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index e3ec343479..557181ebd9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan, MapGroupsWithState}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
@@ -313,6 +313,23 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
+ /**
+ * Strategy to convert MapGroupsWithState logical operator to physical operator
+ * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]].
+ */
+ object MapGroupsWithStateStrategy extends Strategy {
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case MapGroupsWithState(
+ f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) =>
+ val execPlan = MapGroupsWithStateExec(
+ f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer,
+ planLater(child))
+ execPlan :: Nil
+ case _ =>
+ Nil
+ }
+ }
+
// Can we automate these 'pass through' operations?
object BasicOperators extends Strategy {
def numPartitions: Int = self.numPartitions
@@ -354,6 +371,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil
case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
+ case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) =>
+ execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
execution.CoGroupExec(
f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index fde3b2a528..199ba5ce69 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -30,6 +30,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
+import org.apache.spark.sql.execution.streaming.KeyedStateImpl
import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
@@ -144,6 +146,11 @@ object ObjectOperator {
(i: InternalRow) => proj(i).get(0, deserializer.dataType)
}
+ def deserializeRowToObject(deserializer: Expression): InternalRow => Any = {
+ val proj = GenerateSafeProjection.generate(deserializer :: Nil)
+ (i: InternalRow) => proj(i).get(0, deserializer.dataType)
+ }
+
def serializeObjectToRow(serializer: Seq[Expression]): Any => UnsafeRow = {
val proj = GenerateUnsafeProjection.generate(serializer)
val objType = serializer.head.collect { case b: BoundReference => b.dataType }.head
@@ -344,6 +351,21 @@ case class MapGroupsExec(
}
}
+object MapGroupsExec {
+ def apply(
+ func: (Any, Iterator[Any], LogicalKeyedState[Any]) => TraversableOnce[Any],
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ outputObjAttr: Attribute,
+ child: SparkPlan): MapGroupsExec = {
+ val f = (key: Any, values: Iterator[Any]) => func(key, values, new KeyedStateImpl[Any](None))
+ new MapGroupsExec(f, keyDeserializer, valueDeserializer,
+ groupingAttributes, dataAttributes, outputObjAttr, child)
+ }
+}
+
/**
* Groups the input rows together and calls the R function with each group and an iterator
* containing all elements in the group.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index bd7cec3917..a3e108b29e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.streaming
+import java.util.concurrent.atomic.AtomicInteger
+
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal}
import org.apache.spark.sql.SparkSession
@@ -39,8 +41,9 @@ class IncrementalExecution(
extends QueryExecution(sparkSession, logicalPlan) with Logging {
// TODO: make this always part of planning.
- val stateStrategy =
+ val streamingExtraStrategies =
sparkSession.sessionState.planner.StatefulAggregationStrategy +:
+ sparkSession.sessionState.planner.MapGroupsWithStateStrategy +:
sparkSession.sessionState.planner.StreamingRelationStrategy +:
sparkSession.sessionState.experimentalMethods.extraStrategies
@@ -49,7 +52,7 @@ class IncrementalExecution(
new SparkPlanner(
sparkSession.sparkContext,
sparkSession.sessionState.conf,
- stateStrategy)
+ streamingExtraStrategies)
/**
* See [SPARK-18339]
@@ -68,7 +71,7 @@ class IncrementalExecution(
* Records the current id for a given stateful operator in the query plan as the `state`
* preparation walks the query plan.
*/
- private var operatorId = 0
+ private val operatorId = new AtomicInteger(0)
/** Locates save/restore pairs surrounding aggregation. */
val state = new Rule[SparkPlan] {
@@ -77,8 +80,8 @@ class IncrementalExecution(
case StateStoreSaveExec(keys, None, None, None,
UnaryExecNode(agg,
StateStoreRestoreExec(keys2, None, child))) =>
- val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId)
- operatorId += 1
+ val stateId =
+ OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
StateStoreSaveExec(
keys,
@@ -90,6 +93,12 @@ class IncrementalExecution(
keys,
Some(stateId),
child) :: Nil))
+ case MapGroupsWithStateExec(
+ f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) =>
+ val stateId =
+ OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
+ MapGroupsWithStateExec(
+ f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
new file mode 100644
index 0000000000..eee7ec45dd
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.sql.KeyedState
+
+/** Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. */
+private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedState[S] {
+ private var value: S = optionalValue.getOrElse(null.asInstanceOf[S])
+ private var defined: Boolean = optionalValue.isDefined
+ private var updated: Boolean = false
+ // whether value has been updated (but not removed)
+ private var removed: Boolean = false // whether value has been removed
+
+ // ========= Public API =========
+ override def exists: Boolean = defined
+
+ override def get: S = {
+ if (defined) {
+ value
+ } else {
+ throw new NoSuchElementException("State is either not defined or has already been removed")
+ }
+ }
+
+ override def getOption: Option[S] = {
+ if (defined) {
+ Some(value)
+ } else {
+ None
+ }
+ }
+
+ override def update(newValue: S): Unit = {
+ if (newValue == null) {
+ throw new IllegalArgumentException("'null' is not a valid state value")
+ }
+ value = newValue
+ defined = true
+ updated = true
+ removed = false
+ }
+
+ override def remove(): Unit = {
+ defined = false
+ updated = false
+ removed = true
+ }
+
+ override def toString: String = {
+ s"KeyedState(${getOption.map(_.toString).getOrElse("<undefined>")})"
+ }
+
+ // ========= Internal API =========
+
+ /** Whether the state has been marked for removing */
+ def isRemoved: Boolean = {
+ removed
+ }
+
+ /** Whether the state has been been updated */
+ def isUpdated: Boolean = {
+ updated
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index 1f74fffbe6..693933f95a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -186,7 +186,7 @@ trait ProgressReporter extends Logging {
// lastExecution could belong to one of the previous triggers if `!hasNewData`.
// Walking the plan again should be inexpensive.
val stateNodes = lastExecution.executedPlan.collect {
- case p if p.isInstanceOf[StateStoreSaveExec] => p
+ case p if p.isInstanceOf[StateStoreWriter] => p
}
stateNodes.map { node =>
val numRowsUpdated = if (hasNewData) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 1279b71c4d..61eb601a18 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -147,6 +147,25 @@ private[state] class HDFSBackedStateStoreProvider(
}
}
+ /** Remove a single key. */
+ override def remove(key: UnsafeRow): Unit = {
+ verify(state == UPDATING, "Cannot remove after already committed or aborted")
+ if (mapToUpdate.containsKey(key)) {
+ val value = mapToUpdate.remove(key)
+ Option(allUpdates.get(key)) match {
+ case Some(ValueUpdated(_, _)) | None =>
+ // Value existed in previous version and maybe was updated, mark removed
+ allUpdates.put(key, ValueRemoved(key, value))
+ case Some(ValueAdded(_, _)) =>
+ // Value did not exist in previous version and was added, should not appear in updates
+ allUpdates.remove(key)
+ case Some(ValueRemoved(_, _)) =>
+ // Remove already in update map, no need to change
+ }
+ writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value))
+ }
+ }
+
/** Commit all the updates that have been made to the store, and return the new version. */
override def commit(): Long = {
verify(state == UPDATING, "Cannot commit after already committed or aborted")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index e61d95a1b1..dcb24b26f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -59,6 +59,11 @@ trait StateStore {
def remove(condition: UnsafeRow => Boolean): Unit
/**
+ * Remove a single key.
+ */
+ def remove(key: UnsafeRow): Unit
+
+ /**
* Commit all the updates that have been made to the store, and return the new version.
*/
def commit(): Long
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index 1b56c08f72..589042afb1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming
import scala.reflect.ClassTag
+import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.internal.SessionState
@@ -59,10 +60,18 @@ package object state {
sessionState: SessionState,
storeCoordinator: Option[StateStoreCoordinatorRef])(
storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
+
val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
+ val wrappedF = (store: StateStore, iter: Iterator[T]) => {
+ // Abort the state store in case of error
+ TaskContext.get().addTaskCompletionListener(_ => {
+ if (!store.hasCommitted) store.abort()
+ })
+ cleanedF(store, iter)
+ }
new StateStoreRDD(
dataRDD,
- cleanedF,
+ wrappedF,
checkpointLocation,
operatorId,
storeVersion,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index d4ccced9ac..1292452574 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -22,16 +22,16 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
-import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState}
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution
+import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.streaming.state._
-import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
-import org.apache.spark.TaskContext
+import org.apache.spark.util.CompletionIterator
/** Used to identify the state store for a given operator. */
@@ -41,7 +41,7 @@ case class OperatorStateId(
batchId: Long)
/**
- * An operator that saves or restores state from the [[StateStore]]. The [[OperatorStateId]] should
+ * An operator that reads or writes state from the [[StateStore]]. The [[OperatorStateId]] should
* be filled in by `prepareForExecution` in [[IncrementalExecution]].
*/
trait StatefulOperator extends SparkPlan {
@@ -54,6 +54,20 @@ trait StatefulOperator extends SparkPlan {
}
}
+/** An operator that reads from a StateStore. */
+trait StateStoreReader extends StatefulOperator {
+ override lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+}
+
+/** An operator that writes to a StateStore. */
+trait StateStoreWriter extends StatefulOperator {
+ override lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
+ "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"),
+ "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"))
+}
+
/**
* For each input tuple, the key is calculated and the value from the [[StateStore]] is added
* to the stream (in addition to the input tuple) if present.
@@ -62,10 +76,7 @@ case class StateStoreRestoreExec(
keyExpressions: Seq[Attribute],
stateId: Option[OperatorStateId],
child: SparkPlan)
- extends execution.UnaryExecNode with StatefulOperator {
-
- override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+ extends execution.UnaryExecNode with StateStoreReader {
override protected def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
@@ -102,12 +113,7 @@ case class StateStoreSaveExec(
outputMode: Option[OutputMode] = None,
eventTimeWatermark: Option[Long] = None,
child: SparkPlan)
- extends execution.UnaryExecNode with StatefulOperator {
-
- override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
- "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"),
- "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"))
+ extends execution.UnaryExecNode with StateStoreWriter {
/** Generate a predicate that matches data older than the watermark */
private lazy val watermarkPredicate: Option[Predicate] = {
@@ -151,13 +157,6 @@ case class StateStoreSaveExec(
val numTotalStateRows = longMetric("numTotalStateRows")
val numUpdatedStateRows = longMetric("numUpdatedStateRows")
- // Abort the state store in case of error
- TaskContext.get().addTaskCompletionListener(_ => {
- if (!store.hasCommitted) {
- store.abort()
- }
- })
-
outputMode match {
// Update and output all rows in the StateStore.
case Some(Complete) =>
@@ -184,7 +183,7 @@ case class StateStoreSaveExec(
}
// Assumption: Append mode can be done only when watermark has been specified
- store.remove(watermarkPredicate.get.eval)
+ store.remove(watermarkPredicate.get.eval _)
store.commit()
numTotalStateRows += store.numKeys()
@@ -207,7 +206,7 @@ case class StateStoreSaveExec(
override def hasNext: Boolean = {
if (!baseIterator.hasNext) {
// Remove old aggregates if watermark specified
- if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval)
+ if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval _)
store.commit()
numTotalStateRows += store.numKeys()
false
@@ -235,3 +234,90 @@ case class StateStoreSaveExec(
override def outputPartitioning: Partitioning = child.outputPartitioning
}
+
+
+/** Physical operator for executing streaming mapGroupsWithState. */
+case class MapGroupsWithStateExec(
+ func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ outputObjAttr: Attribute,
+ stateId: Option[OperatorStateId],
+ stateDeserializer: Expression,
+ stateSerializer: Seq[NamedExpression],
+ child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter {
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
+ /** Distribute by grouping attributes */
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(groupingAttributes) :: Nil
+
+ /** Ordering needed for using GroupingIterator */
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ Seq(groupingAttributes.map(SortOrder(_, Ascending)))
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsWithStateStore[InternalRow](
+ getStateId.checkpointLocation,
+ operatorId = getStateId.operatorId,
+ storeVersion = getStateId.batchId,
+ groupingAttributes.toStructType,
+ child.output.toStructType,
+ sqlContext.sessionState,
+ Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
+ val numTotalStateRows = longMetric("numTotalStateRows")
+ val numUpdatedStateRows = longMetric("numUpdatedStateRows")
+ val numOutputRows = longMetric("numOutputRows")
+
+ // Generate a iterator that returns the rows grouped by the grouping function
+ val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
+
+ // Converters to and from object and rows
+ val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
+ val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
+ val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
+ val getStateObj =
+ ObjectOperator.deserializeRowToObject(stateDeserializer)
+ val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer)
+
+ // For every group, get the key, values and corresponding state and call the function,
+ // and return an iterator of rows
+ val allRowsIterator = groupedIter.flatMap { case (keyRow, valueRowIter) =>
+
+ val key = keyRow.asInstanceOf[UnsafeRow]
+ val keyObj = getKeyObj(keyRow) // convert key to objects
+ val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
+ val stateObjOption = store.get(key).map(getStateObj) // get existing state if any
+ val wrappedState = new KeyedStateImpl(stateObjOption)
+ val mappedIterator = func(keyObj, valueObjIter, wrappedState).map { obj =>
+ numOutputRows += 1
+ getOutputRow(obj) // convert back to rows
+ }
+
+ // Return an iterator of rows generated this key,
+ // such that fully consumed, the updated state value will be saved
+ CompletionIterator[InternalRow, Iterator[InternalRow]](
+ mappedIterator, {
+ // When the iterator is consumed, then write changes to state
+ if (wrappedState.isRemoved) {
+ store.remove(key)
+ numUpdatedStateRows += 1
+ } else if (wrappedState.isUpdated) {
+ store.put(key, outputStateObj(wrappedState.get))
+ numUpdatedStateRows += 1
+ }
+ })
+ }
+
+ // Return an iterator of all the rows generated by all the keys, such that when fully
+ // consumer, all the state updates will be committed by the state store
+ CompletionIterator[InternalRow, Iterator[InternalRow]](allRowsIterator, {
+ store.commit()
+ numTotalStateRows += store.numKeys()
+ })
+ }
+ }
+}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 8304b728aa..5ef4e887de 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -225,6 +225,38 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList()));
+ Dataset<String> mapped2 = grouped.mapGroupsWithState(
+ new MapGroupsWithStateFunction<Integer, String, Long, String>() {
+ @Override
+ public String call(Integer key, Iterator<String> values, KeyedState<Long> s) throws Exception {
+ StringBuilder sb = new StringBuilder(key.toString());
+ while (values.hasNext()) {
+ sb.append(values.next());
+ }
+ return sb.toString();
+ }
+ },
+ Encoders.LONG(),
+ Encoders.STRING());
+
+ Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped2.collectAsList()));
+
+ Dataset<String> flatMapped2 = grouped.flatMapGroupsWithState(
+ new FlatMapGroupsWithStateFunction<Integer, String, Long, String>() {
+ @Override
+ public Iterator<String> call(Integer key, Iterator<String> values, KeyedState<Long> s) {
+ StringBuilder sb = new StringBuilder(key.toString());
+ while (values.hasNext()) {
+ sb.append(values.next());
+ }
+ return Collections.singletonList(sb.toString()).iterator();
+ }
+ },
+ Encoders.LONG(),
+ Encoders.STRING());
+
+ Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList()));
+
Dataset<Tuple2<Integer, String>> reduced = grouped.reduceGroups(new ReduceFunction<String>() {
@Override
public String call(String v1, String v2) throws Exception {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
new file mode 100644
index 0000000000..0524898b15
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
@@ -0,0 +1,335 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.KeyedState
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.execution.streaming.{KeyedStateImpl, MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.StateStore
+
+/** Class to check custom state types */
+case class RunningCount(count: Long)
+
+class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll {
+
+ import testImplicits._
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ StateStore.stop()
+ }
+
+ test("KeyedState - get, exists, update, remove") {
+ var state: KeyedStateImpl[String] = null
+
+ def testState(
+ expectedData: Option[String],
+ shouldBeUpdated: Boolean = false,
+ shouldBeRemoved: Boolean = false): Unit = {
+ if (expectedData.isDefined) {
+ assert(state.exists)
+ assert(state.get === expectedData.get)
+ } else {
+ assert(!state.exists)
+ intercept[NoSuchElementException] {
+ state.get
+ }
+ }
+ assert(state.getOption === expectedData)
+ assert(state.isUpdated === shouldBeUpdated)
+ assert(state.isRemoved === shouldBeRemoved)
+ }
+
+ // Updating empty state
+ state = new KeyedStateImpl[String](None)
+ testState(None)
+ state.update("")
+ testState(Some(""), shouldBeUpdated = true)
+
+ // Updating exiting state
+ state = new KeyedStateImpl[String](Some("2"))
+ testState(Some("2"))
+ state.update("3")
+ testState(Some("3"), shouldBeUpdated = true)
+
+ // Removing state
+ state.remove()
+ testState(None, shouldBeRemoved = true, shouldBeUpdated = false)
+ state.remove() // should be still callable
+ state.update("4")
+ testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true)
+
+ // Updating by null throw exception
+ intercept[IllegalArgumentException] {
+ state.update(null)
+ }
+ }
+
+ test("KeyedState - primitive type") {
+ var intState = new KeyedStateImpl[Int](None)
+ intercept[NoSuchElementException] {
+ intState.get
+ }
+ assert(intState.getOption === None)
+
+ intState = new KeyedStateImpl[Int](Some(10))
+ assert(intState.get == 10)
+ intState.update(0)
+ assert(intState.get == 0)
+ intState.remove()
+ intercept[NoSuchElementException] {
+ intState.get
+ }
+ }
+
+ test("flatMapGroupsWithState - streaming") {
+ // Function to maintain running count up to 2, and then remove the count
+ // Returns the data and the count if state is defined, otherwise does not return anything
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+
+ val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+ if (count == 3) {
+ state.remove()
+ Iterator.empty
+ } else {
+ state.update(RunningCount(count))
+ Iterator((key, count.toString))
+ }
+ }
+
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str)
+
+ testStream(result, Append)(
+ AddData(inputData, "a"),
+ CheckLastBatch(("a", "1")),
+ assertNumStateRows(total = 1, updated = 1),
+ AddData(inputData, "a", "b"),
+ CheckLastBatch(("a", "2"), ("b", "1")),
+ assertNumStateRows(total = 2, updated = 2),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a
+ CheckLastBatch(("b", "2")),
+ assertNumStateRows(total = 1, updated = 2),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and
+ CheckLastBatch(("a", "1"), ("c", "1")),
+ assertNumStateRows(total = 3, updated = 2)
+ )
+ }
+
+ test("flatMapGroupsWithState - streaming + func returns iterator that updates state lazily") {
+ // Function to maintain running count up to 2, and then remove the count
+ // Returns the data and the count if state is defined, otherwise does not return anything
+ // Additionally, it updates state lazily as the returned iterator get consumed
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+ values.flatMap { _ =>
+ val count = state.getOption.map(_.count).getOrElse(0L) + 1
+ if (count == 3) {
+ state.remove()
+ None
+ } else {
+ state.update(RunningCount(count))
+ Some((key, count.toString))
+ }
+ }
+ }
+
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str)
+
+ testStream(result, Append)(
+ AddData(inputData, "a", "a", "b"),
+ CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a
+ CheckLastBatch(("b", "2")),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and
+ CheckLastBatch(("a", "1"), ("c", "1"))
+ )
+ }
+
+ test("flatMapGroupsWithState - batch") {
+ // Function that returns running count only if its even, otherwise does not return
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+ if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+ Iterator((key, values.size))
+ }
+ checkAnswer(
+ Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF,
+ Seq(("a", 2), ("b", 1)).toDF)
+ }
+
+ test("mapGroupsWithState - streaming") {
+ // Function to maintain running count up to 2, and then remove the count
+ // Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+
+ val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+ if (count == 3) {
+ state.remove()
+ (key, "-1")
+ } else {
+ state.update(RunningCount(count))
+ (key, count.toString)
+ }
+ }
+
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
+
+ testStream(result, Append)(
+ AddData(inputData, "a"),
+ CheckLastBatch(("a", "1")),
+ assertNumStateRows(total = 1, updated = 1),
+ AddData(inputData, "a", "b"),
+ CheckLastBatch(("a", "2"), ("b", "1")),
+ assertNumStateRows(total = 2, updated = 2),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1
+ CheckLastBatch(("a", "-1"), ("b", "2")),
+ assertNumStateRows(total = 1, updated = 2),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1
+ CheckLastBatch(("a", "1"), ("c", "1")),
+ assertNumStateRows(total = 3, updated = 2)
+ )
+ }
+
+ test("mapGroupsWithState - streaming + aggregation") {
+ // Function to maintain running count up to 2, and then remove the count
+ // Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+
+ val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+ if (count == 3) {
+ state.remove()
+ (key, "-1")
+ } else {
+ state.update(RunningCount(count))
+ (key, count.toString)
+ }
+ }
+
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
+ .groupByKey(_._1)
+ .count()
+
+ testStream(result, Complete)(
+ AddData(inputData, "a"),
+ CheckLastBatch(("a", 1)),
+ AddData(inputData, "a", "b"),
+ // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1
+ CheckLastBatch(("a", 2), ("b", 1)),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "b"),
+ // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ;
+ // so increment a and b by 1
+ CheckLastBatch(("a", 3), ("b", 2)),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "c"),
+ // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ;
+ // so increment a and c by 1
+ CheckLastBatch(("a", 4), ("b", 2), ("c", 1))
+ )
+ }
+
+ test("mapGroupsWithState - batch") {
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+ if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+ (key, values.size)
+ }
+
+ checkAnswer(
+ spark.createDataset(Seq("a", "a", "b"))
+ .groupByKey(x => x)
+ .mapGroupsWithState(stateFunc)
+ .toDF,
+ spark.createDataset(Seq(("a", 2), ("b", 1))).toDF)
+ }
+
+ testQuietly("StateStore.abort on task failure handling") {
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+ if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure")
+ val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+ state.update(RunningCount(count))
+ (key, count)
+ }
+
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
+
+ def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q =>
+ MapGroupsWithStateSuite.failInTask = value
+ true
+ }
+
+ testStream(result, Append)(
+ setFailInTask(false),
+ AddData(inputData, "a"),
+ CheckLastBatch(("a", 1L)),
+ AddData(inputData, "a"),
+ CheckLastBatch(("a", 2L)),
+ setFailInTask(true),
+ AddData(inputData, "a"),
+ ExpectFailure[SparkException](), // task should fail but should not increment count
+ setFailInTask(false),
+ StartStream(),
+ CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count
+ )
+ }
+
+ private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q =>
+ val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
+ assert(progressWithData.stateOperators(0).numRowsTotal === total, "incorrect total rows")
+ assert(progressWithData.stateOperators(0).numRowsUpdated === updated, "incorrect updates rows")
+ true
+ }
+}
+
+object MapGroupsWithStateSuite {
+ var failInTask = true
+}