diff options
author | Michael Armbrust <michael@databricks.com> | 2016-04-01 15:15:16 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2016-04-01 15:15:16 -0700 |
commit | 0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c (patch) | |
tree | 5c72eb22fb2ef033a6d08f989dcf4fa18d66a84f /sql/core/src/test/scala/org/apache | |
parent | 0b7d4966ca7e02f351c4b92a74789cef4799fcb1 (diff) | |
download | spark-0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c.tar.gz spark-0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c.tar.bz2 spark-0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c.zip |
[SPARK-14255][SQL] Streaming Aggregation
This PR adds the ability to perform aggregations inside of a `ContinuousQuery`. In order to implement this feature, the planning of aggregation has augmented with a new `StatefulAggregationStrategy`. Unlike batch aggregation, stateful-aggregation uses the `StateStore` (introduced in #11645) to persist the results of partial aggregation across different invocations. The resulting physical plan performs the aggregation using the following progression:
- Partial Aggregation
- Shuffle
- Partial Merge (now there is at most 1 tuple per group)
- StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous)
- Partial Merge (now there is at most 1 tuple per group)
- StateStoreSave (saves the tuple for the next batch)
- Complete (output the current result of the aggregation)
The following refactoring was also performed to allow us to plug into existing code:
- The get/put implementation is taken from #12013
- The logic for breaking down and de-duping the physical execution of aggregation has been move into a new pattern `PhysicalAggregation`
- The `AttributeReference` used to identify the result of an `AggregateFunction` as been moved into the `AggregateExpression` container. This change moves the reference into the same object as the other intermediate references used in aggregation and eliminates the need to pass around a `Map[(AggregateFunction, Boolean), Attribute]`. Further clean up (using a different aggregation container for logical/physical plans) is deferred to a followup.
- Some planning logic is moved from the `SessionState` into the `QueryExecution` to make it easier to override in the streaming case.
- The ability to write a `StreamTest` that checks only the output of the last batch has been added to simulate the future addition of output modes.
Author: Michael Armbrust <michael@databricks.com>
Closes #12048 from marmbrus/statefulAgg.
Diffstat (limited to 'sql/core/src/test/scala/org/apache')
5 files changed, 278 insertions, 113 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index b5be7ef47e..550c3c6f9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -116,15 +116,30 @@ trait StreamTest extends QueryTest with Timeouts { def apply[A : Encoder](data: A*): CheckAnswerRows = { val encoder = encoderFor[A] val toExternalRow = RowEncoder(encoder.schema) - CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d)))) + CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), false) } - def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows) + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false) } - case class CheckAnswerRows(expectedAnswer: Seq[Row]) + /** + * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`. + * This operation automatically blocks until all added data has been processed. + */ + object CheckLastBatch { + def apply[A : Encoder](data: A*): CheckAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema) + CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), true) + } + + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true) + } + + case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean) extends StreamAction with StreamMustBeRunning { - override def toString: String = s"CheckAnswer: ${expectedAnswer.mkString(",")}" + override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" + private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" } /** Stops the stream. It must currently be running. */ @@ -224,11 +239,8 @@ trait StreamTest extends QueryTest with Timeouts { """.stripMargin def verify(condition: => Boolean, message: String): Unit = { - try { - Assertions.assert(condition) - } catch { - case NonFatal(e) => - failTest(message, e) + if (!condition) { + failTest(message) } } @@ -351,7 +363,7 @@ trait StreamTest extends QueryTest with Timeouts { case a: AddData => awaiting.put(a.source, a.addData()) - case CheckAnswerRows(expectedAnswer) => + case CheckAnswerRows(expectedAnswer, lastOnly) => verify(currentStream != null, "stream not running") // Block until all data added has been processed @@ -361,12 +373,12 @@ trait StreamTest extends QueryTest with Timeouts { } } - val allData = try sink.allData catch { + val sparkAnswer = try if (lastOnly) sink.lastBatch else sink.allData catch { case e: Exception => failTest("Exception while getting data from sink", e) } - QueryTest.sameRows(expectedAnswer, allData).foreach { + QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { error => failTest(error) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index ed0d3f56e5..38318740a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -231,10 +231,8 @@ object SparkPlanTest { } private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { - // A very simple resolver to make writing tests easier. In contrast to the real resolver - // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = sqlContext.sessionState.prepareForExecution.execute( - outputPlan transform { + val execution = new QueryExecution(sqlContext, null) { + override lazy val sparkPlan: SparkPlan = outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap plan transformExpressions { @@ -243,8 +241,8 @@ object SparkPlanTest { sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) } } - ) - resolvedPlan.executeCollectPublic().toSeq + } + execution.executedPlan.executeCollectPublic().toSeq } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 85db05157c..6be94eb24f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.util.Utils +import org.apache.spark.util.{CompletionIterator, Utils} class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { @@ -54,62 +54,93 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("versioning and immutability") { - quietly { - withSpark(new SparkContext(sparkConf)) { sc => - implicit val sqlContet = new SQLContext(sc) - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - val increment = (store: StateStore, iter: Iterator[String]) => { - iter.foreach { s => - store.update( - stringToRow(s), oldRow => { - val oldValue = oldRow.map(rowToInt).getOrElse(0) - intToRow(oldValue + 1) - }) - } - store.commit() - store.iterator().map(rowsToStringInt) - } - val opId = 0 - val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 0, keySchema, valueSchema) - assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + withSpark(new SparkContext(sparkConf)) { sc => + val sqlContext = new SQLContext(sc) + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val opId = 0 + val rdd1 = + makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)( + increment) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + + // Generate next version of stores + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + + // Make sure the previous RDD still has the same data. + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + } + } - // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 1, keySchema, valueSchema) - assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + test("recovering from files") { + val opId = 0 + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + + def makeStoreRDD( + sc: SparkContext, + seq: Seq[String], + storeVersion: Int): RDD[(String, Int)] = { + implicit val sqlContext = new SQLContext(sc) + makeRDD(sc, Seq("a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment) + } - // Make sure the previous RDD still has the same data. - assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + // Generate RDDs and state store data + withSpark(new SparkContext(sparkConf)) { sc => + for (i <- 1 to 20) { + require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) } } + + // With a new context, try using the earlier state store data + withSpark(new SparkContext(sparkConf)) { sc => + assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) + } } - test("recovering from files") { - quietly { - val opId = 0 + test("usage with iterators - only gets and only puts") { + withSpark(new SparkContext(sparkConf)) { sc => + implicit val sqlContext = new SQLContext(sc) val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val opId = 0 - def makeStoreRDD( - sc: SparkContext, - seq: Seq[String], - storeVersion: Int): RDD[(String, Int)] = { - implicit val sqlContext = new SQLContext(sc) - makeRDD(sc, Seq("a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion, keySchema, valueSchema) + // Returns an iterator of the incremented value made into the store + def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = { + val resIterator = iter.map { s => + val key = stringToRow(s) + val oldValue = store.get(key).map(rowToInt).getOrElse(0) + val newValue = oldValue + 1 + store.put(key, intToRow(newValue)) + (s, newValue) + } + CompletionIterator[(String, Int), Iterator[(String, Int)]](resIterator, { + store.commit() + }) } - // Generate RDDs and state store data - withSpark(new SparkContext(sparkConf)) { sc => - for (i <- 1 to 20) { - require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) + def iteratorOfGets( + store: StateStore, + iter: Iterator[String]): Iterator[(String, Option[Int])] = { + iter.map { s => + val key = stringToRow(s) + val value = store.get(key).map(rowToInt) + (s, value) } } - // With a new context, try using the earlier state store data - withSpark(new SparkContext(sparkConf)) { sc => - assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) - } + val rddOfGets1 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) + assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) + + val rddOfPuts = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts) + assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) + + val rddOfGets2 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets) + assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } } @@ -128,8 +159,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) - val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 0, keySchema, valueSchema) + val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) require(rdd.partitions.length === 2) assert( @@ -148,27 +179,16 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("distributed test") { quietly { withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc => - implicit val sqlContet = new SQLContext(sc) + implicit val sqlContext = new SQLContext(sc) val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - val increment = (store: StateStore, iter: Iterator[String]) => { - iter.foreach { s => - store.update( - stringToRow(s), oldRow => { - val oldValue = oldRow.map(rowToInt).getOrElse(0) - intToRow(oldValue + 1) - }) - } - store.commit() - store.iterator().map(rowsToStringInt) - } val opId = 0 - val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 0, keySchema, valueSchema) + val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 1, keySchema, valueSchema) + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -183,11 +203,9 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn private val increment = (store: StateStore, iter: Iterator[String]) => { iter.foreach { s => - store.update( - stringToRow(s), oldRow => { - val oldValue = oldRow.map(rowToInt).getOrElse(0) - intToRow(oldValue + 1) - }) + val key = stringToRow(s) + val oldValue = store.get(key).map(rowToInt).getOrElse(0) + store.put(key, intToRow(oldValue + 1)) } store.commit() store.iterator().map(rowsToStringInt) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 22b2f4f75d..0e5936d53f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -51,7 +51,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth StateStore.stop() } - test("update, remove, commit, and all data iterator") { + test("get, put, remove, commit, and all data iterator") { val provider = newStoreProvider() // Verify state before starting a new set of updates @@ -67,7 +67,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Verify state after updating - update(store, "a", 1) + put(store, "a", 1) intercept[IllegalStateException] { store.iterator() } @@ -77,8 +77,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(provider.latestIterator().isEmpty) // Make updates, commit and then verify state - update(store, "b", 2) - update(store, "aa", 3) + put(store, "b", 2) + put(store, "aa", 3) remove(store, _.startsWith("a")) assert(store.commit() === 1) @@ -101,7 +101,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val reloadedProvider = new HDFSBackedStateStoreProvider( store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) val reloadedStore = reloadedProvider.getStore(1) - update(reloadedStore, "c", 4) + put(reloadedStore, "c", 4) assert(reloadedStore.commit() === 2) assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) @@ -112,6 +112,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("updates iterator with all combos of updates and removes") { val provider = newStoreProvider() var currentVersion: Int = 0 + def withStore(body: StateStore => Unit): Unit = { val store = provider.getStore(currentVersion) body(store) @@ -120,9 +121,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // New data should be seen in updates as value added, even if they had multiple updates withStore { store => - update(store, "a", 1) - update(store, "aa", 1) - update(store, "aa", 2) + put(store, "a", 1) + put(store, "aa", 1) + put(store, "aa", 2) store.commit() assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2))) assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2)) @@ -131,8 +132,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Multiple updates to same key should be collapsed in the updates as a single value update // Keys that have not been updated should not appear in the updates withStore { store => - update(store, "a", 4) - update(store, "a", 6) + put(store, "a", 4) + put(store, "a", 6) store.commit() assert(updatesToSet(store.updates()) === Set(Updated("a", 6))) assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) @@ -140,9 +141,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Keys added, updated and finally removed before commit should not appear in updates withStore { store => - update(store, "b", 4) // Added, finally removed - update(store, "bb", 5) // Added, updated, finally removed - update(store, "bb", 6) + put(store, "b", 4) // Added, finally removed + put(store, "bb", 5) // Added, updated, finally removed + put(store, "bb", 6) remove(store, _.startsWith("b")) store.commit() assert(updatesToSet(store.updates()) === Set.empty) @@ -153,7 +154,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Removed, but re-added data should be seen in updates as a value update withStore { store => remove(store, _.startsWith("a")) - update(store, "a", 10) + put(store, "a", 10) store.commit() assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa"))) assert(rowsToSet(store.iterator()) === Set("a" -> 10)) @@ -163,14 +164,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("cancel") { val provider = newStoreProvider() val store = provider.getStore(0) - update(store, "a", 1) + put(store, "a", 1) store.commit() assert(rowsToSet(store.iterator()) === Set("a" -> 1)) // cancelUpdates should not change the data in the files val store1 = provider.getStore(1) - update(store1, "b", 1) - store1.cancel() + put(store1, "b", 1) + store1.abort() assert(getDataFromFiles(provider) === Set("a" -> 1)) } @@ -183,7 +184,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Prepare some data in the stoer val store = provider.getStore(0) - update(store, "a", 1) + put(store, "a", 1) assert(store.commit() === 1) assert(rowsToSet(store.iterator()) === Set("a" -> 1)) @@ -193,14 +194,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Update store version with some data val store1 = provider.getStore(1) - update(store1, "b", 1) + put(store1, "b", 1) assert(store1.commit() === 2) assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) // Overwrite the version with other data val store2 = provider.getStore(1) - update(store2, "c", 1) + put(store2, "c", 1) assert(store2.commit() === 2) assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1)) assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1)) @@ -213,7 +214,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth def updateVersionTo(targetVersion: Int): Unit = { for (i <- currentVersion + 1 to targetVersion) { val store = provider.getStore(currentVersion) - update(store, "a", i) + put(store, "a", i) store.commit() currentVersion += 1 } @@ -264,7 +265,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth for (i <- 1 to 20) { val store = provider.getStore(i - 1) - update(store, "a", i) + put(store, "a", i) store.commit() provider.doMaintenance() // do cleanup } @@ -284,7 +285,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val provider = newStoreProvider(minDeltasForSnapshot = 5) for (i <- 1 to 6) { val store = provider.getStore(i - 1) - update(store, "a", i) + put(store, "a", i) store.commit() provider.doMaintenance() // do cleanup } @@ -333,7 +334,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Increase version of the store val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) assert(store0.version === 0) - update(store0, "a", 1) + put(store0, "a", 1) store0.commit() assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1) @@ -345,7 +346,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) - update(store1, "a", 2) + put(store1, "a", 2) assert(store1.commit() === 2) assert(rowsToSet(store1.iterator()) === Set("a" -> 2)) } @@ -371,7 +372,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth for (i <- 1 to 20) { val store = StateStore.get( storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf) - update(store, "a", i) + put(store, "a", i) store.commit() } eventually(timeout(10 seconds)) { @@ -507,8 +508,12 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth store.remove(row => condition(rowToString(row))) } - private def update(store: StateStore, key: String, value: Int): Unit = { - store.update(stringToRow(key), _ => intToRow(value)) + private def put(store: StateStore, key: String, value: Int): Unit = { + store.put(stringToRow(key), intToRow(value)) + } + + private def get(store: StateStore, key: String): Option[Int] = { + store.get(stringToRow(key)).map(rowToInt) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala new file mode 100644 index 0000000000..b63ce89d18 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkException +import org.apache.spark.sql.{Encoder, StreamTest, SumOf, TypedColumn} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +object FailureSinglton { + var firstTime = true +} + +class StreamingAggregationSuite extends StreamTest with SharedSQLContext { + + import testImplicits._ + + test("simple count") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated)( + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream, + StartStream, + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)) + ) + } + + test("multiple keys") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value", $"value" + 1) + .agg(count("*")) + .as[(Int, Int, Long)] + + testStream(aggregated)( + AddData(inputData, 1, 2), + CheckLastBatch((1, 2, 1), (2, 3, 1)), + AddData(inputData, 1, 2), + CheckLastBatch((1, 2, 2), (2, 3, 2)) + ) + } + + test("multiple aggregations") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*") as 'count) + .groupBy($"value" % 2) + .agg(sum($"count")) + .as[(Int, Long)] + + testStream(aggregated)( + AddData(inputData, 1, 2, 3, 4), + CheckLastBatch((0, 2), (1, 2)), + AddData(inputData, 1, 3, 5), + CheckLastBatch((1, 5)) + ) + } + + testQuietly("midbatch failure") { + val inputData = MemoryStream[Int] + FailureSinglton.firstTime = true + val aggregated = + inputData.toDS() + .map { i => + if (i == 4 && FailureSinglton.firstTime) { + FailureSinglton.firstTime = false + sys.error("injected failure") + } + + i + } + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated)( + StartStream, + AddData(inputData, 1, 2, 3, 4), + ExpectFailure[SparkException](), + StartStream, + CheckLastBatch((1, 1), (2, 1), (3, 1), (4, 1)) + ) + } + + test("typed aggregators") { + def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = + new SumOf(f).toColumn + + val inputData = MemoryStream[(String, Int)] + val aggregated = inputData.toDS().groupByKey(_._1).agg(sum(_._2)) + + testStream(aggregated)( + AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)), + CheckLastBatch(("a", 30), ("b", 3), ("c", 1)) + ) + } +} |