aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2017-03-08 13:18:07 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2017-03-08 13:18:07 -0800
commit1bf9012380de2aa7bdf39220b55748defde8b700 (patch)
tree3efc5be6f6506eef72a98c08132d95c0ce9f8fcd /sql/core/src
parente9e2c612d58a19ddcb4b6abfb7389a4b0f7ef6f8 (diff)
downloadspark-1bf9012380de2aa7bdf39220b55748defde8b700.tar.gz
spark-1bf9012380de2aa7bdf39220b55748defde8b700.tar.bz2
spark-1bf9012380de2aa7bdf39220b55748defde8b700.zip
[SPARK-19858][SS] Add output mode to flatMapGroupsWithState and disallow invalid cases
## What changes were proposed in this pull request? Add a output mode parameter to `flatMapGroupsWithState` and just define `mapGroupsWithState` as `flatMapGroupsWithState(Update)`. `UnsupportedOperationChecker` is modified to disallow unsupported cases. - Batch mapGroupsWithState or flatMapGroupsWithState is always allowed. - For streaming (map/flatMap)GroupsWithState, see the following table: | Operators | Supported Query Output Mode | | ------------- | ------------- | | flatMapGroupsWithState(Update) without aggregation | Update | | flatMapGroupsWithState(Update) with aggregation | None | | flatMapGroupsWithState(Append) without aggregation | Append | | flatMapGroupsWithState(Append) before aggregation | Append, Update, Complete | | flatMapGroupsWithState(Append) after aggregation | None | | Multiple flatMapGroupsWithState(Append)s | Append | | Multiple mapGroupsWithStates | None | | Mxing mapGroupsWithStates and flatMapGroupsWithStates | None | | Other cases of multiple flatMapGroupsWithState | None | ## How was this patch tested? The added unit tests. Here are the tests related to (map/flatMap)GroupsWithState: ``` [info] - batch plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation: supported (1 millisecond) [info] - batch plan - flatMapGroupsWithState - multiple flatMapGroupsWithState(Append)s on batch relation: supported (0 milliseconds) [info] - batch plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation: supported (0 milliseconds) [info] - batch plan - flatMapGroupsWithState - multiple flatMapGroupsWithState(Update)s on batch relation: supported (0 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in update mode: supported (2 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in append mode: not supported (7 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in complete mode: not supported (5 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Append mode: not supported (11 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Update mode: not supported (5 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Complete mode: not supported (5 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation without aggregation in append mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation without aggregation in update mode: not supported (6 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Append mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Update mode: supported (0 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Complete mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation after aggregation in Append mode: not supported (6 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation after aggregation in Update mode: not supported (4 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation in complete mode: not supported (2 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation inside streaming relation in Append output mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation inside streaming relation in Update output mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation inside streaming relation in Append output mode: supported (0 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation inside streaming relation in Update output mode: supported (0 milliseconds) [info] - streaming plan - flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are in append mode: supported (2 milliseconds) [info] - streaming plan - flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some are not in append mode: not supported (7 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation without aggregation in append mode: not supported (3 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation without aggregation in complete mode: not supported (3 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Append mode: not supported (6 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Update mode: not supported (3 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Complete mode: not supported (4 milliseconds) [info] - streaming plan - mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are in append mode: not supported (4 milliseconds) [info] - streaming plan - mapGroupsWithState - mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation: not supported (4 milliseconds) ``` Author: Shixiong Zhu <shixiong@databricks.com> Closes #17197 from zsxwing/mapgroups-check.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala91
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala16
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala (renamed from sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala)46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala41
8 files changed, 141 insertions, 84 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 3a548c251f..ab956ffd64 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -24,8 +24,10 @@ import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct}
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.ReduceAggregator
+import org.apache.spark.sql.streaming.OutputMode
/**
* :: Experimental ::
@@ -238,8 +240,16 @@ class KeyValueGroupedDataset[K, V] private[sql](
@InterfaceStability.Evolving
def mapGroupsWithState[S: Encoder, U: Encoder](
func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = {
- flatMapGroupsWithState[S, U](
- (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)))
+ val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))
+ Dataset[U](
+ sparkSession,
+ FlatMapGroupsWithState[K, V, S, U](
+ flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
+ groupingAttributes,
+ dataAttributes,
+ OutputMode.Update,
+ isMapGroupsWithState = true,
+ child = logicalPlan))
}
/**
@@ -267,8 +277,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
func: MapGroupsWithStateFunction[K, V, S, U],
stateEncoder: Encoder[S],
outputEncoder: Encoder[U]): Dataset[U] = {
- flatMapGroupsWithState[S, U](
- (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s))
+ mapGroupsWithState[S, U](
+ (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s)
)(stateEncoder, outputEncoder)
}
@@ -284,6 +294,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func Function to be called on every group.
+ * @param outputMode The output mode of the function.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 2.1.1
@@ -291,14 +303,44 @@ class KeyValueGroupedDataset[K, V] private[sql](
@Experimental
@InterfaceStability.Evolving
def flatMapGroupsWithState[S: Encoder, U: Encoder](
- func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = {
+ func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: OutputMode): Dataset[U] = {
+ if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) {
+ throw new IllegalArgumentException("The output mode of function should be append or update")
+ }
Dataset[U](
sparkSession,
- MapGroupsWithState[K, V, S, U](
+ FlatMapGroupsWithState[K, V, S, U](
func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
groupingAttributes,
dataAttributes,
- logicalPlan))
+ outputMode,
+ isMapGroupsWithState = false,
+ child = logicalPlan))
+ }
+
+ /**
+ * ::Experimental::
+ * (Scala-specific)
+ * Applies the given function to each group of data, while maintaining a user-defined per-group
+ * state. The result Dataset will represent the objects returned by the function.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+ * updates to each group's state will be saved across invocations.
+ * See [[KeyedState]] for more details.
+ *
+ * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func Function to be called on every group.
+ * @param outputMode The output mode of the function.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 2.1.1
+ */
+ @Experimental
+ @InterfaceStability.Evolving
+ def flatMapGroupsWithState[S: Encoder, U: Encoder](
+ func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: String): Dataset[U] = {
+ flatMapGroupsWithState(func, InternalOutputModes(outputMode))
}
/**
@@ -314,6 +356,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
* @param func Function to be called on every group.
+ * @param outputMode The output mode of the function.
* @param stateEncoder Encoder for the state type.
* @param outputEncoder Encoder for the output type.
*
@@ -324,14 +367,46 @@ class KeyValueGroupedDataset[K, V] private[sql](
@InterfaceStability.Evolving
def flatMapGroupsWithState[S, U](
func: FlatMapGroupsWithStateFunction[K, V, S, U],
+ outputMode: OutputMode,
stateEncoder: Encoder[S],
outputEncoder: Encoder[U]): Dataset[U] = {
flatMapGroupsWithState[S, U](
- (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala
+ (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala,
+ outputMode
)(stateEncoder, outputEncoder)
}
/**
+ * ::Experimental::
+ * (Java-specific)
+ * Applies the given function to each group of data, while maintaining a user-defined per-group
+ * state. The result Dataset will represent the objects returned by the function.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+ * updates to each group's state will be saved across invocations.
+ * See [[KeyedState]] for more details.
+ *
+ * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func Function to be called on every group.
+ * @param outputMode The output mode of the function.
+ * @param stateEncoder Encoder for the state type.
+ * @param outputEncoder Encoder for the output type.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 2.1.1
+ */
+ @Experimental
+ @InterfaceStability.Evolving
+ def flatMapGroupsWithState[S, U](
+ func: FlatMapGroupsWithStateFunction[K, V, S, U],
+ outputMode: String,
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U]): Dataset[U] = {
+ flatMapGroupsWithState(func, InternalOutputModes(outputMode), stateEncoder, outputEncoder)
+ }
+
+ /**
* (Scala-specific)
* Reduces the elements of each group of data using the specified binary function.
* The given function must be commutative and associative or the result may be non-deterministic.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 20bf4925db..0f7aa3709c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -326,14 +326,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
/**
- * Strategy to convert MapGroupsWithState logical operator to physical operator
+ * Strategy to convert [[FlatMapGroupsWithState]] logical operator to physical operator
* in streaming plans. Conversion for batch plans is handled by [[BasicOperators]].
*/
object MapGroupsWithStateStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case MapGroupsWithState(
- f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) =>
- val execPlan = MapGroupsWithStateExec(
+ case FlatMapGroupsWithState(
+ f,
+ keyDeser,
+ valueDeser,
+ groupAttr,
+ dataAttr,
+ outputAttr,
+ stateDeser,
+ stateSer,
+ outputMode,
+ child,
+ _) =>
+ val execPlan = FlatMapGroupsWithStateExec(
f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer,
planLater(child))
execPlan :: Nil
@@ -381,7 +391,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil
case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
- case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) =>
+ case logical.FlatMapGroupsWithState(
+ f, key, value, grouping, data, output, _, _, _, child, _) =>
execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
execution.CoGroupExec(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index ffdcd9b19d..610ce5e1eb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -103,11 +103,11 @@ class IncrementalExecution(
child,
Some(stateId),
Some(currentEventTimeWatermark))
- case MapGroupsWithStateExec(
+ case FlatMapGroupsWithStateExec(
f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) =>
val stateId =
OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
- MapGroupsWithStateExec(
+ FlatMapGroupsWithStateExec(
f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index cbf656a204..c3075a3eac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -257,8 +257,8 @@ case class StateStoreSaveExec(
}
-/** Physical operator for executing streaming mapGroupsWithState. */
-case class MapGroupsWithStateExec(
+/** Physical operator for executing streaming flatMapGroupsWithState. */
+case class FlatMapGroupsWithStateExec(
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 0f7a33723c..c8fda8cd83 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.streaming
import scala.collection.JavaConverters._
import org.apache.spark.annotation.{Experimental, InterfaceStability}
-import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, ForeachWriter}
-import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter}
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink}
@@ -69,17 +69,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
* @since 2.0.0
*/
def outputMode(outputMode: String): DataStreamWriter[T] = {
- this.outputMode = outputMode.toLowerCase match {
- case "append" =>
- OutputMode.Append
- case "complete" =>
- OutputMode.Complete
- case "update" =>
- OutputMode.Update
- case _ =>
- throw new IllegalArgumentException(s"Unknown output mode $outputMode. " +
- "Accepted output modes are 'append', 'complete', 'update'")
- }
+ this.outputMode = InternalOutputModes(outputMode)
this
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index e3b0e37cca..d06e35bb44 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -23,6 +23,7 @@ import java.sql.Date;
import java.sql.Timestamp;
import java.util.*;
+import org.apache.spark.sql.streaming.OutputMode;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
@@ -205,6 +206,7 @@ public class JavaDatasetSuite implements Serializable {
}
return Collections.singletonList(sb.toString()).iterator();
},
+ OutputMode.Append(),
Encoders.LONG(),
Encoders.STRING());
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 6cf4d51f99..902b842e97 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore
/** Class to check custom state types */
case class RunningCount(count: Long)
-class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
+class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
import testImplicits._
@@ -119,9 +119,9 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
val result =
inputData.toDS()
.groupByKey(x => x)
- .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str)
+ .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str)
- testStream(result, Append)(
+ testStream(result, Update)(
AddData(inputData, "a"),
CheckLastBatch(("a", "1")),
assertNumStateRows(total = 1, updated = 1),
@@ -162,9 +162,9 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
val result =
inputData.toDS()
.groupByKey(x => x)
- .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str)
+ .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str)
- testStream(result, Append)(
+ testStream(result, Update)(
AddData(inputData, "a", "a", "b"),
CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")),
StopStream,
@@ -185,7 +185,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
Iterator((key, values.size))
}
checkAnswer(
- Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF,
+ Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc, Update).toDF,
Seq(("a", 2), ("b", 1)).toDF)
}
@@ -210,7 +210,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
.groupByKey(x => x)
.mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
- testStream(result, Append)(
+ testStream(result, Update)(
AddData(inputData, "a"),
CheckLastBatch(("a", "1")),
assertNumStateRows(total = 1, updated = 1),
@@ -230,7 +230,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
)
}
- test("mapGroupsWithState - streaming + aggregation") {
+ test("flatMapGroupsWithState - streaming + aggregation") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
@@ -238,10 +238,10 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
val count = state.getOption.map(_.count).getOrElse(0L) + values.size
if (count == 3) {
state.remove()
- (key, "-1")
+ Iterator(key -> "-1")
} else {
state.update(RunningCount(count))
- (key, count.toString)
+ Iterator(key -> count.toString)
}
}
@@ -249,7 +249,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
val result =
inputData.toDS()
.groupByKey(x => x)
- .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
+ .flatMapGroupsWithState(stateFunc, Append) // Types = State: MyState, Out: (Str, Str)
.groupByKey(_._1)
.count()
@@ -290,7 +290,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
testQuietly("StateStore.abort on task failure handling") {
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
- if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure")
+ if (FlatMapGroupsWithStateSuite.failInTask) throw new Exception("expected failure")
val count = state.getOption.map(_.count).getOrElse(0L) + values.size
state.update(RunningCount(count))
(key, count)
@@ -303,11 +303,11 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
.mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q =>
- MapGroupsWithStateSuite.failInTask = value
+ FlatMapGroupsWithStateSuite.failInTask = value
true
}
- testStream(result, Append)(
+ testStream(result, Update)(
setFailInTask(false),
AddData(inputData, "a"),
CheckLastBatch(("a", 1L)),
@@ -321,8 +321,24 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count
)
}
+
+ test("disallow complete mode") {
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+ Iterator[String]()
+ }
+
+ var e = intercept[IllegalArgumentException] {
+ MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, Complete)
+ }
+ assert(e.getMessage === "The output mode of function should be append or update")
+
+ e = intercept[IllegalArgumentException] {
+ MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, "complete")
+ }
+ assert(e.getMessage === "The output mode of function should be append or update")
+ }
}
-object MapGroupsWithStateSuite {
+object FlatMapGroupsWithStateSuite {
var failInTask = true
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
index 0470411a0f..f61dcdcbcf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
@@ -24,8 +24,7 @@ import scala.concurrent.duration._
import org.apache.hadoop.fs.Path
import org.mockito.Mockito._
-import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
-import org.scalatest.PrivateMethodTester.PrivateMethod
+import org.scalatest.BeforeAndAfter
import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming._
@@ -107,7 +106,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
}
}
-class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with PrivateMethodTester {
+class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
private def newMetadataDir =
Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
@@ -390,42 +389,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with Pr
private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath
- test("supported strings in outputMode(string)") {
- val outputModeMethod = PrivateMethod[OutputMode]('outputMode)
-
- def testMode(outputMode: String, expected: OutputMode): Unit = {
- val df = spark.readStream
- .format("org.apache.spark.sql.streaming.test")
- .load()
- val w = df.writeStream
- w.outputMode(outputMode)
- val setOutputMode = w invokePrivate outputModeMethod()
- assert(setOutputMode === expected)
- }
-
- testMode("append", OutputMode.Append)
- testMode("Append", OutputMode.Append)
- testMode("complete", OutputMode.Complete)
- testMode("Complete", OutputMode.Complete)
- testMode("update", OutputMode.Update)
- testMode("Update", OutputMode.Update)
- }
-
- test("unsupported strings in outputMode(string)") {
- def testMode(outputMode: String): Unit = {
- val acceptedModes = Seq("append", "update", "complete")
- val df = spark.readStream
- .format("org.apache.spark.sql.streaming.test")
- .load()
- val w = df.writeStream
- val e = intercept[IllegalArgumentException](w.outputMode(outputMode))
- (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s =>
- assert(e.getMessage.toLowerCase.contains(s.toLowerCase))
- }
- }
- testMode("Xyz")
- }
-
test("check foreach() catches null writers") {
val df = spark.readStream
.format("org.apache.spark.sql.streaming.test")