aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/scala/org/apache
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2016-04-01 15:15:16 -0700
committerMichael Armbrust <michael@databricks.com>2016-04-01 15:15:16 -0700
commit0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c (patch)
tree5c72eb22fb2ef033a6d08f989dcf4fa18d66a84f /sql/core/src/test/scala/org/apache
parent0b7d4966ca7e02f351c4b92a74789cef4799fcb1 (diff)
downloadspark-0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c.tar.gz
spark-0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c.tar.bz2
spark-0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c.zip
[SPARK-14255][SQL] Streaming Aggregation
This PR adds the ability to perform aggregations inside of a `ContinuousQuery`. In order to implement this feature, the planning of aggregation has augmented with a new `StatefulAggregationStrategy`. Unlike batch aggregation, stateful-aggregation uses the `StateStore` (introduced in #11645) to persist the results of partial aggregation across different invocations. The resulting physical plan performs the aggregation using the following progression: - Partial Aggregation - Shuffle - Partial Merge (now there is at most 1 tuple per group) - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) - Partial Merge (now there is at most 1 tuple per group) - StateStoreSave (saves the tuple for the next batch) - Complete (output the current result of the aggregation) The following refactoring was also performed to allow us to plug into existing code: - The get/put implementation is taken from #12013 - The logic for breaking down and de-duping the physical execution of aggregation has been move into a new pattern `PhysicalAggregation` - The `AttributeReference` used to identify the result of an `AggregateFunction` as been moved into the `AggregateExpression` container. This change moves the reference into the same object as the other intermediate references used in aggregation and eliminates the need to pass around a `Map[(AggregateFunction, Boolean), Attribute]`. Further clean up (using a different aggregation container for logical/physical plans) is deferred to a followup. - Some planning logic is moved from the `SessionState` into the `QueryExecution` to make it easier to override in the streaming case. - The ability to write a `StreamTest` that checks only the output of the last batch has been added to simulate the future addition of output modes. Author: Michael Armbrust <michael@databricks.com> Closes #12048 from marmbrus/statefulAgg.
Diffstat (limited to 'sql/core/src/test/scala/org/apache')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala36
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala152
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala61
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala132
5 files changed, 278 insertions, 113 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
index b5be7ef47e..550c3c6f9c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
@@ -116,15 +116,30 @@ trait StreamTest extends QueryTest with Timeouts {
def apply[A : Encoder](data: A*): CheckAnswerRows = {
val encoder = encoderFor[A]
val toExternalRow = RowEncoder(encoder.schema)
- CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))))
+ CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), false)
}
- def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows)
+ def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false)
}
- case class CheckAnswerRows(expectedAnswer: Seq[Row])
+ /**
+ * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`.
+ * This operation automatically blocks until all added data has been processed.
+ */
+ object CheckLastBatch {
+ def apply[A : Encoder](data: A*): CheckAnswerRows = {
+ val encoder = encoderFor[A]
+ val toExternalRow = RowEncoder(encoder.schema)
+ CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), true)
+ }
+
+ def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true)
+ }
+
+ case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean)
extends StreamAction with StreamMustBeRunning {
- override def toString: String = s"CheckAnswer: ${expectedAnswer.mkString(",")}"
+ override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}"
+ private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer"
}
/** Stops the stream. It must currently be running. */
@@ -224,11 +239,8 @@ trait StreamTest extends QueryTest with Timeouts {
""".stripMargin
def verify(condition: => Boolean, message: String): Unit = {
- try {
- Assertions.assert(condition)
- } catch {
- case NonFatal(e) =>
- failTest(message, e)
+ if (!condition) {
+ failTest(message)
}
}
@@ -351,7 +363,7 @@ trait StreamTest extends QueryTest with Timeouts {
case a: AddData =>
awaiting.put(a.source, a.addData())
- case CheckAnswerRows(expectedAnswer) =>
+ case CheckAnswerRows(expectedAnswer, lastOnly) =>
verify(currentStream != null, "stream not running")
// Block until all data added has been processed
@@ -361,12 +373,12 @@ trait StreamTest extends QueryTest with Timeouts {
}
}
- val allData = try sink.allData catch {
+ val sparkAnswer = try if (lastOnly) sink.lastBatch else sink.allData catch {
case e: Exception =>
failTest("Exception while getting data from sink", e)
}
- QueryTest.sameRows(expectedAnswer, allData).foreach {
+ QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach {
error => failTest(error)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index ed0d3f56e5..38318740a5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -231,10 +231,8 @@ object SparkPlanTest {
}
private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
- // A very simple resolver to make writing tests easier. In contrast to the real resolver
- // this is always case sensitive and does not try to handle scoping or complex type resolution.
- val resolvedPlan = sqlContext.sessionState.prepareForExecution.execute(
- outputPlan transform {
+ val execution = new QueryExecution(sqlContext, null) {
+ override lazy val sparkPlan: SparkPlan = outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
plan transformExpressions {
@@ -243,8 +241,8 @@ object SparkPlanTest {
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
}
}
- )
- resolvedPlan.executeCollectPublic().toSeq
+ }
+ execution.executedPlan.executeCollectPublic().toSeq
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index 85db05157c..6be94eb24f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -33,7 +33,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{CompletionIterator, Utils}
class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll {
@@ -54,62 +54,93 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
}
test("versioning and immutability") {
- quietly {
- withSpark(new SparkContext(sparkConf)) { sc =>
- implicit val sqlContet = new SQLContext(sc)
- val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
- val increment = (store: StateStore, iter: Iterator[String]) => {
- iter.foreach { s =>
- store.update(
- stringToRow(s), oldRow => {
- val oldValue = oldRow.map(rowToInt).getOrElse(0)
- intToRow(oldValue + 1)
- })
- }
- store.commit()
- store.iterator().map(rowsToStringInt)
- }
- val opId = 0
- val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion = 0, keySchema, valueSchema)
- assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+ withSpark(new SparkContext(sparkConf)) { sc =>
+ val sqlContext = new SQLContext(sc)
+ val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+ val opId = 0
+ val rdd1 =
+ makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(
+ increment)
+ assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+
+ // Generate next version of stores
+ val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
+ assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
+
+ // Make sure the previous RDD still has the same data.
+ assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+ }
+ }
- // Generate next version of stores
- val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion = 1, keySchema, valueSchema)
- assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
+ test("recovering from files") {
+ val opId = 0
+ val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+
+ def makeStoreRDD(
+ sc: SparkContext,
+ seq: Seq[String],
+ storeVersion: Int): RDD[(String, Int)] = {
+ implicit val sqlContext = new SQLContext(sc)
+ makeRDD(sc, Seq("a")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment)
+ }
- // Make sure the previous RDD still has the same data.
- assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+ // Generate RDDs and state store data
+ withSpark(new SparkContext(sparkConf)) { sc =>
+ for (i <- 1 to 20) {
+ require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i))
}
}
+
+ // With a new context, try using the earlier state store data
+ withSpark(new SparkContext(sparkConf)) { sc =>
+ assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21))
+ }
}
- test("recovering from files") {
- quietly {
- val opId = 0
+ test("usage with iterators - only gets and only puts") {
+ withSpark(new SparkContext(sparkConf)) { sc =>
+ implicit val sqlContext = new SQLContext(sc)
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+ val opId = 0
- def makeStoreRDD(
- sc: SparkContext,
- seq: Seq[String],
- storeVersion: Int): RDD[(String, Int)] = {
- implicit val sqlContext = new SQLContext(sc)
- makeRDD(sc, Seq("a")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion, keySchema, valueSchema)
+ // Returns an iterator of the incremented value made into the store
+ def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = {
+ val resIterator = iter.map { s =>
+ val key = stringToRow(s)
+ val oldValue = store.get(key).map(rowToInt).getOrElse(0)
+ val newValue = oldValue + 1
+ store.put(key, intToRow(newValue))
+ (s, newValue)
+ }
+ CompletionIterator[(String, Int), Iterator[(String, Int)]](resIterator, {
+ store.commit()
+ })
}
- // Generate RDDs and state store data
- withSpark(new SparkContext(sparkConf)) { sc =>
- for (i <- 1 to 20) {
- require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i))
+ def iteratorOfGets(
+ store: StateStore,
+ iter: Iterator[String]): Iterator[(String, Option[Int])] = {
+ iter.map { s =>
+ val key = stringToRow(s)
+ val value = store.get(key).map(rowToInt)
+ (s, value)
}
}
- // With a new context, try using the earlier state store data
- withSpark(new SparkContext(sparkConf)) { sc =>
- assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21))
- }
+ val rddOfGets1 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets)
+ assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None))
+
+ val rddOfPuts = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts)
+ assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1))
+
+ val rddOfGets2 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets)
+ assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None))
}
}
@@ -128,8 +159,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
coordinatorRef.getLocation(StateStoreId(path, opId, 0)) ===
Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
- val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion = 0, keySchema, valueSchema)
+ val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment)
require(rdd.partitions.length === 2)
assert(
@@ -148,27 +179,16 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
test("distributed test") {
quietly {
withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc =>
- implicit val sqlContet = new SQLContext(sc)
+ implicit val sqlContext = new SQLContext(sc)
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
- val increment = (store: StateStore, iter: Iterator[String]) => {
- iter.foreach { s =>
- store.update(
- stringToRow(s), oldRow => {
- val oldValue = oldRow.map(rowToInt).getOrElse(0)
- intToRow(oldValue + 1)
- })
- }
- store.commit()
- store.iterator().map(rowsToStringInt)
- }
val opId = 0
- val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion = 0, keySchema, valueSchema)
+ val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment)
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
// Generate next version of stores
- val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion = 1, keySchema, valueSchema)
+ val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
// Make sure the previous RDD still has the same data.
@@ -183,11 +203,9 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
private val increment = (store: StateStore, iter: Iterator[String]) => {
iter.foreach { s =>
- store.update(
- stringToRow(s), oldRow => {
- val oldValue = oldRow.map(rowToInt).getOrElse(0)
- intToRow(oldValue + 1)
- })
+ val key = stringToRow(s)
+ val oldValue = store.get(key).map(rowToInt).getOrElse(0)
+ store.put(key, intToRow(oldValue + 1))
}
store.commit()
store.iterator().map(rowsToStringInt)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 22b2f4f75d..0e5936d53f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -51,7 +51,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
StateStore.stop()
}
- test("update, remove, commit, and all data iterator") {
+ test("get, put, remove, commit, and all data iterator") {
val provider = newStoreProvider()
// Verify state before starting a new set of updates
@@ -67,7 +67,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
}
// Verify state after updating
- update(store, "a", 1)
+ put(store, "a", 1)
intercept[IllegalStateException] {
store.iterator()
}
@@ -77,8 +77,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
assert(provider.latestIterator().isEmpty)
// Make updates, commit and then verify state
- update(store, "b", 2)
- update(store, "aa", 3)
+ put(store, "b", 2)
+ put(store, "aa", 3)
remove(store, _.startsWith("a"))
assert(store.commit() === 1)
@@ -101,7 +101,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
val reloadedProvider = new HDFSBackedStateStoreProvider(
store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration)
val reloadedStore = reloadedProvider.getStore(1)
- update(reloadedStore, "c", 4)
+ put(reloadedStore, "c", 4)
assert(reloadedStore.commit() === 2)
assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4))
@@ -112,6 +112,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
test("updates iterator with all combos of updates and removes") {
val provider = newStoreProvider()
var currentVersion: Int = 0
+
def withStore(body: StateStore => Unit): Unit = {
val store = provider.getStore(currentVersion)
body(store)
@@ -120,9 +121,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// New data should be seen in updates as value added, even if they had multiple updates
withStore { store =>
- update(store, "a", 1)
- update(store, "aa", 1)
- update(store, "aa", 2)
+ put(store, "a", 1)
+ put(store, "aa", 1)
+ put(store, "aa", 2)
store.commit()
assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2)))
assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2))
@@ -131,8 +132,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Multiple updates to same key should be collapsed in the updates as a single value update
// Keys that have not been updated should not appear in the updates
withStore { store =>
- update(store, "a", 4)
- update(store, "a", 6)
+ put(store, "a", 4)
+ put(store, "a", 6)
store.commit()
assert(updatesToSet(store.updates()) === Set(Updated("a", 6)))
assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2))
@@ -140,9 +141,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Keys added, updated and finally removed before commit should not appear in updates
withStore { store =>
- update(store, "b", 4) // Added, finally removed
- update(store, "bb", 5) // Added, updated, finally removed
- update(store, "bb", 6)
+ put(store, "b", 4) // Added, finally removed
+ put(store, "bb", 5) // Added, updated, finally removed
+ put(store, "bb", 6)
remove(store, _.startsWith("b"))
store.commit()
assert(updatesToSet(store.updates()) === Set.empty)
@@ -153,7 +154,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Removed, but re-added data should be seen in updates as a value update
withStore { store =>
remove(store, _.startsWith("a"))
- update(store, "a", 10)
+ put(store, "a", 10)
store.commit()
assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa")))
assert(rowsToSet(store.iterator()) === Set("a" -> 10))
@@ -163,14 +164,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
test("cancel") {
val provider = newStoreProvider()
val store = provider.getStore(0)
- update(store, "a", 1)
+ put(store, "a", 1)
store.commit()
assert(rowsToSet(store.iterator()) === Set("a" -> 1))
// cancelUpdates should not change the data in the files
val store1 = provider.getStore(1)
- update(store1, "b", 1)
- store1.cancel()
+ put(store1, "b", 1)
+ store1.abort()
assert(getDataFromFiles(provider) === Set("a" -> 1))
}
@@ -183,7 +184,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Prepare some data in the stoer
val store = provider.getStore(0)
- update(store, "a", 1)
+ put(store, "a", 1)
assert(store.commit() === 1)
assert(rowsToSet(store.iterator()) === Set("a" -> 1))
@@ -193,14 +194,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Update store version with some data
val store1 = provider.getStore(1)
- update(store1, "b", 1)
+ put(store1, "b", 1)
assert(store1.commit() === 2)
assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1))
assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1))
// Overwrite the version with other data
val store2 = provider.getStore(1)
- update(store2, "c", 1)
+ put(store2, "c", 1)
assert(store2.commit() === 2)
assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1))
assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1))
@@ -213,7 +214,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
def updateVersionTo(targetVersion: Int): Unit = {
for (i <- currentVersion + 1 to targetVersion) {
val store = provider.getStore(currentVersion)
- update(store, "a", i)
+ put(store, "a", i)
store.commit()
currentVersion += 1
}
@@ -264,7 +265,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
for (i <- 1 to 20) {
val store = provider.getStore(i - 1)
- update(store, "a", i)
+ put(store, "a", i)
store.commit()
provider.doMaintenance() // do cleanup
}
@@ -284,7 +285,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
val provider = newStoreProvider(minDeltasForSnapshot = 5)
for (i <- 1 to 6) {
val store = provider.getStore(i - 1)
- update(store, "a", i)
+ put(store, "a", i)
store.commit()
provider.doMaintenance() // do cleanup
}
@@ -333,7 +334,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Increase version of the store
val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf)
assert(store0.version === 0)
- update(store0, "a", 1)
+ put(store0, "a", 1)
store0.commit()
assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1)
@@ -345,7 +346,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf)
assert(StateStore.isLoaded(storeId))
- update(store1, "a", 2)
+ put(store1, "a", 2)
assert(store1.commit() === 2)
assert(rowsToSet(store1.iterator()) === Set("a" -> 2))
}
@@ -371,7 +372,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
for (i <- 1 to 20) {
val store = StateStore.get(
storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf)
- update(store, "a", i)
+ put(store, "a", i)
store.commit()
}
eventually(timeout(10 seconds)) {
@@ -507,8 +508,12 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
store.remove(row => condition(rowToString(row)))
}
- private def update(store: StateStore, key: String, value: Int): Unit = {
- store.update(stringToRow(key), _ => intToRow(value))
+ private def put(store: StateStore, key: String, value: Int): Unit = {
+ store.put(stringToRow(key), intToRow(value))
+ }
+
+ private def get(store: StateStore, key: String): Option[Int] = {
+ store.get(stringToRow(key)).map(rowToInt)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
new file mode 100644
index 0000000000..b63ce89d18
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{Encoder, StreamTest, SumOf, TypedColumn}
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+object FailureSinglton {
+ var firstTime = true
+}
+
+class StreamingAggregationSuite extends StreamTest with SharedSQLContext {
+
+ import testImplicits._
+
+ test("simple count") {
+ val inputData = MemoryStream[Int]
+
+ val aggregated =
+ inputData.toDF()
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated)(
+ AddData(inputData, 3),
+ CheckLastBatch((3, 1)),
+ AddData(inputData, 3, 2),
+ CheckLastBatch((3, 2), (2, 1)),
+ StopStream,
+ StartStream,
+ AddData(inputData, 3, 2, 1),
+ CheckLastBatch((3, 3), (2, 2), (1, 1)),
+ // By default we run in new tuple mode.
+ AddData(inputData, 4, 4, 4, 4),
+ CheckLastBatch((4, 4))
+ )
+ }
+
+ test("multiple keys") {
+ val inputData = MemoryStream[Int]
+
+ val aggregated =
+ inputData.toDF()
+ .groupBy($"value", $"value" + 1)
+ .agg(count("*"))
+ .as[(Int, Int, Long)]
+
+ testStream(aggregated)(
+ AddData(inputData, 1, 2),
+ CheckLastBatch((1, 2, 1), (2, 3, 1)),
+ AddData(inputData, 1, 2),
+ CheckLastBatch((1, 2, 2), (2, 3, 2))
+ )
+ }
+
+ test("multiple aggregations") {
+ val inputData = MemoryStream[Int]
+
+ val aggregated =
+ inputData.toDF()
+ .groupBy($"value")
+ .agg(count("*") as 'count)
+ .groupBy($"value" % 2)
+ .agg(sum($"count"))
+ .as[(Int, Long)]
+
+ testStream(aggregated)(
+ AddData(inputData, 1, 2, 3, 4),
+ CheckLastBatch((0, 2), (1, 2)),
+ AddData(inputData, 1, 3, 5),
+ CheckLastBatch((1, 5))
+ )
+ }
+
+ testQuietly("midbatch failure") {
+ val inputData = MemoryStream[Int]
+ FailureSinglton.firstTime = true
+ val aggregated =
+ inputData.toDS()
+ .map { i =>
+ if (i == 4 && FailureSinglton.firstTime) {
+ FailureSinglton.firstTime = false
+ sys.error("injected failure")
+ }
+
+ i
+ }
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated)(
+ StartStream,
+ AddData(inputData, 1, 2, 3, 4),
+ ExpectFailure[SparkException](),
+ StartStream,
+ CheckLastBatch((1, 1), (2, 1), (3, 1), (4, 1))
+ )
+ }
+
+ test("typed aggregators") {
+ def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] =
+ new SumOf(f).toColumn
+
+ val inputData = MemoryStream[(String, Int)]
+ val aggregated = inputData.toDS().groupByKey(_._1).agg(sum(_._2))
+
+ testStream(aggregated)(
+ AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)),
+ CheckLastBatch(("a", 30), ("b", 3), ("c", 1))
+ )
+ }
+}