diff options
author | Tathagata Das <tathagata.das1565@gmail.com> | 2016-03-23 12:48:05 -0700 |
---|---|---|
committer | Tathagata Das <tathagata.das1565@gmail.com> | 2016-03-23 12:48:05 -0700 |
commit | 8c826880f5eaa3221c4e9e7d3fece54e821a0b98 (patch) | |
tree | b6dbe3670844bac231b787ccd9a97d2797f0a181 /sql/core/src/test/scala | |
parent | 0a64294fcb4b64bfe095c63c3a494e0f40e22743 (diff) | |
download | spark-8c826880f5eaa3221c4e9e7d3fece54e821a0b98.tar.gz spark-8c826880f5eaa3221c4e9e7d3fece54e821a0b98.tar.bz2 spark-8c826880f5eaa3221c4e9e7d3fece54e821a0b98.zip |
[SPARK-13809][SQL] State store for streaming aggregations
## What changes were proposed in this pull request?
In this PR, I am implementing a new abstraction for management of streaming state data - State Store. It is a key-value store for persisting running aggregates for aggregate operations in streaming dataframes. The motivation and design is discussed here.
https://docs.google.com/document/d/1-ncawFx8JS5Zyfq1HAEGBx56RDet9wfVp_hDM8ZL254/edit#
## How was this patch tested?
- [x] Unit tests
- [x] Cluster tests
**Coverage from unit tests**
<img width="952" alt="screen shot 2016-03-21 at 3 09 40 pm" src="https://cloud.githubusercontent.com/assets/663212/13935872/fdc8ba86-ef76-11e5-93e8-9fa310472c7b.png">
## TODO
- [x] Fix updates() iterator to avoid duplicate updates for same key
- [x] Use Coordinator in ContinuousQueryManager
- [x] Plugging in hadoop conf and other confs
- [x] Unit tests
- [x] StateStore object lifecycle and methods
- [x] StateStoreCoordinator communication and logic
- [x] StateStoreRDD fault-tolerance
- [x] StateStoreRDD preferred location using StateStoreCoordinator
- [ ] Cluster tests
- [ ] Whether preferred locations are set correctly
- [ ] Whether recovery works correctly with distributed storage
- [x] Basic performance tests
- [x] Docs
Author: Tathagata Das <tathagata.das1565@gmail.com>
Closes #11645 from tdas/state-store.
Diffstat (limited to 'sql/core/src/test/scala')
3 files changed, 877 insertions, 0 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala new file mode 100644 index 0000000000..c99c2f505f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.ExecutorCacheTaskLocation + +class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { + + import StateStoreCoordinatorSuite._ + + test("report, verify, getLocation") { + withCoordinatorRef(sc) { coordinatorRef => + val id = StateStoreId("x", 0, 0) + + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) + assert(coordinatorRef.getLocation(id) === None) + + coordinatorRef.reportActiveInstance(id, "hostX", "exec1") + eventually(timeout(5 seconds)) { + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === true) + assert( + coordinatorRef.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) + } + + coordinatorRef.reportActiveInstance(id, "hostX", "exec2") + + eventually(timeout(5 seconds)) { + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) + assert(coordinatorRef.verifyIfInstanceActive(id, "exec2") === true) + + assert( + coordinatorRef.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec2").toString)) + } + } + } + + test("make inactive") { + withCoordinatorRef(sc) { coordinatorRef => + val id1 = StateStoreId("x", 0, 0) + val id2 = StateStoreId("y", 1, 0) + val id3 = StateStoreId("x", 0, 1) + val host = "hostX" + val exec = "exec1" + + coordinatorRef.reportActiveInstance(id1, host, exec) + coordinatorRef.reportActiveInstance(id2, host, exec) + coordinatorRef.reportActiveInstance(id3, host, exec) + + eventually(timeout(5 seconds)) { + assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true) + + } + + coordinatorRef.deactivateInstances("x") + + assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === false) + assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === false) + + assert(coordinatorRef.getLocation(id1) === None) + assert( + coordinatorRef.getLocation(id2) === + Some(ExecutorCacheTaskLocation(host, exec).toString)) + assert(coordinatorRef.getLocation(id3) === None) + + coordinatorRef.deactivateInstances("y") + assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === false) + assert(coordinatorRef.getLocation(id2) === None) + } + } + + test("multiple references have same underlying coordinator") { + withCoordinatorRef(sc) { coordRef1 => + val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env) + + val id = StateStoreId("x", 0, 0) + + coordRef1.reportActiveInstance(id, "hostX", "exec1") + + eventually(timeout(5 seconds)) { + assert(coordRef2.verifyIfInstanceActive(id, "exec1") === true) + assert( + coordRef2.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) + } + } + } +} + +object StateStoreCoordinatorSuite { + def withCoordinatorRef(sc: SparkContext)(body: StateStoreCoordinatorRef => Unit): Unit = { + var coordinatorRef: StateStoreCoordinatorRef = null + try { + coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env) + body(coordinatorRef) + } finally { + if (coordinatorRef != null) coordinatorRef.stop() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala new file mode 100644 index 0000000000..24cec30fa3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.File +import java.nio.file.Files + +import scala.util.Random + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.LocalSparkContext._ +import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.util.Utils + +class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { + + private val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) + private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString + private val keySchema = StructType(Seq(StructField("key", StringType, true))) + private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + + import StateStoreSuite._ + + after { + StateStore.stop() + } + + override def afterAll(): Unit = { + super.afterAll() + Utils.deleteRecursively(new File(tempDir)) + } + + test("versioning and immutability") { + quietly { + withSpark(new SparkContext(sparkConf)) { sc => + implicit val sqlContet = new SQLContext(sc) + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + store.update( + stringToRow(s), oldRow => { + val oldValue = oldRow.map(rowToInt).getOrElse(0) + intToRow(oldValue + 1) + }) + } + store.commit() + store.iterator().map(rowsToStringInt) + } + val opId = 0 + val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 0, keySchema, valueSchema) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + + // Generate next version of stores + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 1, keySchema, valueSchema) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + + // Make sure the previous RDD still has the same data. + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + } + } + } + + test("recovering from files") { + quietly { + val opId = 0 + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + + def makeStoreRDD( + sc: SparkContext, + seq: Seq[String], + storeVersion: Int): RDD[(String, Int)] = { + implicit val sqlContext = new SQLContext(sc) + makeRDD(sc, Seq("a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion, keySchema, valueSchema) + } + + // Generate RDDs and state store data + withSpark(new SparkContext(sparkConf)) { sc => + for (i <- 1 to 20) { + require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) + } + } + + // With a new context, try using the earlier state store data + withSpark(new SparkContext(sparkConf)) { sc => + assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) + } + } + } + + test("preferred locations using StateStoreCoordinator") { + quietly { + val opId = 0 + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + + withSpark(new SparkContext(sparkConf)) { sc => + implicit val sqlContext = new SQLContext(sc) + val coordinatorRef = sqlContext.streams.stateStoreCoordinator + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") + assert( + coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === + Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) + + val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 0, keySchema, valueSchema) + require(rdd.partitions.length === 2) + + assert( + rdd.preferredLocations(rdd.partitions(0)) === + Seq(ExecutorCacheTaskLocation("host1", "exec1").toString)) + + assert( + rdd.preferredLocations(rdd.partitions(1)) === + Seq(ExecutorCacheTaskLocation("host2", "exec2").toString)) + + rdd.collect() + } + } + } + + test("distributed test") { + quietly { + withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc => + implicit val sqlContet = new SQLContext(sc) + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + store.update( + stringToRow(s), oldRow => { + val oldValue = oldRow.map(rowToInt).getOrElse(0) + intToRow(oldValue + 1) + }) + } + store.commit() + store.iterator().map(rowsToStringInt) + } + val opId = 0 + val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 0, keySchema, valueSchema) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + + // Generate next version of stores + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 1, keySchema, valueSchema) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + + // Make sure the previous RDD still has the same data. + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + } + } + } + + private def makeRDD(sc: SparkContext, seq: Seq[String]): RDD[String] = { + sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2) + } + + private val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + store.update( + stringToRow(s), oldRow => { + val oldValue = oldRow.map(rowToInt).getOrElse(0) + intToRow(oldValue + 1) + }) + } + store.commit() + store.iterator().map(rowsToStringInt) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala new file mode 100644 index 0000000000..22b2f4f75d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -0,0 +1,562 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.File + +import scala.collection.mutable +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark.LocalSparkContext._ +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester { + type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] + + import StateStoreCoordinatorSuite._ + import StateStoreSuite._ + + private val tempDir = Utils.createTempDir().toString + private val keySchema = StructType(Seq(StructField("key", StringType, true))) + private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + + after { + StateStore.stop() + } + + test("update, remove, commit, and all data iterator") { + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(provider.latestIterator().isEmpty) + + val store = provider.getStore(0) + assert(!store.hasCommitted) + intercept[IllegalStateException] { + store.iterator() + } + intercept[IllegalStateException] { + store.updates() + } + + // Verify state after updating + update(store, "a", 1) + intercept[IllegalStateException] { + store.iterator() + } + intercept[IllegalStateException] { + store.updates() + } + assert(provider.latestIterator().isEmpty) + + // Make updates, commit and then verify state + update(store, "b", 2) + update(store, "aa", 3) + remove(store, _.startsWith("a")) + assert(store.commit() === 1) + + assert(store.hasCommitted) + assert(rowsToSet(store.iterator()) === Set("b" -> 2)) + assert(rowsToSet(provider.latestIterator()) === Set("b" -> 2)) + assert(fileExists(provider, version = 1, isSnapshot = false)) + + assert(getDataFromFiles(provider) === Set("b" -> 2)) + + // Trying to get newer versions should fail + intercept[Exception] { + provider.getStore(2) + } + intercept[Exception] { + getDataFromFiles(provider, 2) + } + + // New updates to the reloaded store with new version, and does not change old version + val reloadedProvider = new HDFSBackedStateStoreProvider( + store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) + val reloadedStore = reloadedProvider.getStore(1) + update(reloadedStore, "c", 4) + assert(reloadedStore.commit() === 2) + assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) + assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) + assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2)) + assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4)) + } + + test("updates iterator with all combos of updates and removes") { + val provider = newStoreProvider() + var currentVersion: Int = 0 + def withStore(body: StateStore => Unit): Unit = { + val store = provider.getStore(currentVersion) + body(store) + currentVersion += 1 + } + + // New data should be seen in updates as value added, even if they had multiple updates + withStore { store => + update(store, "a", 1) + update(store, "aa", 1) + update(store, "aa", 2) + store.commit() + assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2))) + assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2)) + } + + // Multiple updates to same key should be collapsed in the updates as a single value update + // Keys that have not been updated should not appear in the updates + withStore { store => + update(store, "a", 4) + update(store, "a", 6) + store.commit() + assert(updatesToSet(store.updates()) === Set(Updated("a", 6))) + assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) + } + + // Keys added, updated and finally removed before commit should not appear in updates + withStore { store => + update(store, "b", 4) // Added, finally removed + update(store, "bb", 5) // Added, updated, finally removed + update(store, "bb", 6) + remove(store, _.startsWith("b")) + store.commit() + assert(updatesToSet(store.updates()) === Set.empty) + assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) + } + + // Removed data should be seen in updates as a key removed + // Removed, but re-added data should be seen in updates as a value update + withStore { store => + remove(store, _.startsWith("a")) + update(store, "a", 10) + store.commit() + assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa"))) + assert(rowsToSet(store.iterator()) === Set("a" -> 10)) + } + } + + test("cancel") { + val provider = newStoreProvider() + val store = provider.getStore(0) + update(store, "a", 1) + store.commit() + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + // cancelUpdates should not change the data in the files + val store1 = provider.getStore(1) + update(store1, "b", 1) + store1.cancel() + assert(getDataFromFiles(provider) === Set("a" -> 1)) + } + + test("getStore with unexpected versions") { + val provider = newStoreProvider() + + intercept[IllegalArgumentException] { + provider.getStore(-1) + } + + // Prepare some data in the stoer + val store = provider.getStore(0) + update(store, "a", 1) + assert(store.commit() === 1) + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + intercept[IllegalStateException] { + provider.getStore(2) + } + + // Update store version with some data + val store1 = provider.getStore(1) + update(store1, "b", 1) + assert(store1.commit() === 2) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) + assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) + + // Overwrite the version with other data + val store2 = provider.getStore(1) + update(store2, "c", 1) + assert(store2.commit() === 2) + assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1)) + assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1)) + } + + test("snapshotting") { + val provider = newStoreProvider(minDeltasForSnapshot = 5) + + var currentVersion = 0 + def updateVersionTo(targetVersion: Int): Unit = { + for (i <- currentVersion + 1 to targetVersion) { + val store = provider.getStore(currentVersion) + update(store, "a", i) + store.commit() + currentVersion += 1 + } + require(currentVersion === targetVersion) + } + + updateVersionTo(2) + require(getDataFromFiles(provider) === Set("a" -> 2)) + provider.doMaintenance() // should not generate snapshot files + assert(getDataFromFiles(provider) === Set("a" -> 2)) + + for (i <- 1 to currentVersion) { + assert(fileExists(provider, i, isSnapshot = false)) // all delta files present + assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present + } + + // After version 6, snapshotting should generate one snapshot file + updateVersionTo(6) + require(getDataFromFiles(provider) === Set("a" -> 6), "store not updated correctly") + provider.doMaintenance() // should generate snapshot files + + val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true)) + assert(snapshotVersion.nonEmpty, "snapshot file not generated") + deleteFilesEarlierThanVersion(provider, snapshotVersion.get) + assert( + getDataFromFiles(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get), + "snapshotting messed up the data of the snapshotted version") + assert( + getDataFromFiles(provider) === Set("a" -> 6), + "snapshotting messed up the data of the final version") + + // After version 20, snapshotting should generate newer snapshot files + updateVersionTo(20) + require(getDataFromFiles(provider) === Set("a" -> 20), "store not updated correctly") + provider.doMaintenance() // do snapshot + + val latestSnapshotVersion = (0 to 20).filter(version => + fileExists(provider, version, isSnapshot = true)).lastOption + assert(latestSnapshotVersion.nonEmpty, "no snapshot file found") + assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") + + deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get) + assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data") + } + + test("cleaning") { + val provider = newStoreProvider(minDeltasForSnapshot = 5) + + for (i <- 1 to 20) { + val store = provider.getStore(i - 1) + update(store, "a", i) + store.commit() + provider.doMaintenance() // do cleanup + } + require( + rowsToSet(provider.latestIterator()) === Set("a" -> 20), + "store not updated correctly") + + assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted + + // last couple of versions should be retrievable + assert(getDataFromFiles(provider, 20) === Set("a" -> 20)) + assert(getDataFromFiles(provider, 19) === Set("a" -> 19)) + } + + + test("corrupted file handling") { + val provider = newStoreProvider(minDeltasForSnapshot = 5) + for (i <- 1 to 6) { + val store = provider.getStore(i - 1) + update(store, "a", i) + store.commit() + provider.doMaintenance() // do cleanup + } + val snapshotVersion = (0 to 10).find( version => + fileExists(provider, version, isSnapshot = true)).getOrElse(fail("snapshot file not found")) + + // Corrupt snapshot file and verify that it throws error + assert(getDataFromFiles(provider, snapshotVersion) === Set("a" -> snapshotVersion)) + corruptFile(provider, snapshotVersion, isSnapshot = true) + intercept[Exception] { + getDataFromFiles(provider, snapshotVersion) + } + + // Corrupt delta file and verify that it throws error + assert(getDataFromFiles(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1))) + corruptFile(provider, snapshotVersion - 1, isSnapshot = false) + intercept[Exception] { + getDataFromFiles(provider, snapshotVersion - 1) + } + + // Delete delta file and verify that it throws error + deleteFilesEarlierThanVersion(provider, snapshotVersion) + intercept[Exception] { + getDataFromFiles(provider, snapshotVersion - 1) + } + } + + test("StateStore.get") { + quietly { + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val storeId = StateStoreId(dir, 0, 0) + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() + + + // Verify that trying to get incorrect versions throw errors + intercept[IllegalArgumentException] { + StateStore.get(storeId, keySchema, valueSchema, -1, storeConf, hadoopConf) + } + assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store + + intercept[IllegalStateException] { + StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + } + + // Increase version of the store + val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) + assert(store0.version === 0) + update(store0, "a", 1) + store0.commit() + + assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1) + assert(StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf).version == 0) + + // Verify that you can remove the store and still reload and use it + StateStore.unload(storeId) + assert(!StateStore.isLoaded(storeId)) + + val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + update(store1, "a", 2) + assert(store1.commit() === 2) + assert(rowsToSet(store1.iterator()) === Set("a" -> 2)) + } + } + + test("maintenance") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set(StateStore.MAINTENANCE_INTERVAL_CONFIG, "10ms") + .set("spark.rpc.numRetries", "1") + val opId = 0 + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val storeId = StateStoreId(dir, opId, 0) + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() + val provider = new HDFSBackedStateStoreProvider( + storeId, keySchema, valueSchema, storeConf, hadoopConf) + + quietly { + withSpark(new SparkContext(conf)) { sc => + withCoordinatorRef(sc) { coordinatorRef => + for (i <- 1 to 20) { + val store = StateStore.get( + storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf) + update(store, "a", i) + store.commit() + } + eventually(timeout(10 seconds)) { + assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported") + } + + // Background maintenance should clean up and generate snapshots + eventually(timeout(10 seconds)) { + // Earliest delta file should get cleaned up + assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + + // Some snapshots should have been generated + val snapshotVersions = (0 to 20).filter { version => + fileExists(provider, version, isSnapshot = true) + } + assert(snapshotVersions.nonEmpty, "no snapshot file found") + } + + // If driver decides to deactivate all instances of the store, then this instance + // should be unloaded + coordinatorRef.deactivateInstances(dir) + eventually(timeout(10 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } + + // Reload the store and verify + StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + + // If some other executor loads the store, then this instance should be unloaded + coordinatorRef.reportActiveInstance(storeId, "other-host", "other-exec") + eventually(timeout(10 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } + + // Reload the store and verify + StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + } + } + + // Verify if instance is unloaded if SparkContext is stopped + require(SparkEnv.get === null) + eventually(timeout(10 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } + } + } + + def getDataFromFiles( + provider: HDFSBackedStateStoreProvider, + version: Int = -1): Set[(String, Int)] = { + val reloadedProvider = new HDFSBackedStateStoreProvider( + provider.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) + if (version < 0) { + reloadedProvider.latestIterator().map(rowsToStringInt).toSet + } else { + reloadedProvider.iterator(version).map(rowsToStringInt).toSet + } + } + + def assertMap( + testMapOption: Option[MapType], + expectedMap: Map[String, Int]): Unit = { + assert(testMapOption.nonEmpty, "no map present") + val convertedMap = testMapOption.get.map(rowsToStringInt) + assert(convertedMap === expectedMap) + } + + def fileExists( + provider: HDFSBackedStateStoreProvider, + version: Long, + isSnapshot: Boolean): Boolean = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + filePath.exists + } + + def deleteFilesEarlierThanVersion(provider: HDFSBackedStateStoreProvider, version: Long): Unit = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + for (version <- 0 until version.toInt) { + for (isSnapshot <- Seq(false, true)) { + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + if (filePath.exists) filePath.delete() + } + } + } + + def corruptFile( + provider: HDFSBackedStateStoreProvider, + version: Long, + isSnapshot: Boolean): Unit = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + filePath.delete() + filePath.createNewFile() + } + + def storeLoaded(storeId: StateStoreId): Boolean = { + val method = PrivateMethod[mutable.HashMap[StateStoreId, StateStore]]('loadedStores) + val loadedStores = StateStore invokePrivate method() + loadedStores.contains(storeId) + } + + def unloadStore(storeId: StateStoreId): Boolean = { + val method = PrivateMethod('remove) + StateStore invokePrivate method(storeId) + } + + def newStoreProvider( + opId: Long = Random.nextLong, + partition: Int = 0, + minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get + ): HDFSBackedStateStoreProvider = { + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val sqlConf = new SQLConf() + sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) + new HDFSBackedStateStoreProvider( + StateStoreId(dir, opId, partition), + keySchema, + valueSchema, + new StateStoreConf(sqlConf), + new Configuration()) + } + + def remove(store: StateStore, condition: String => Boolean): Unit = { + store.remove(row => condition(rowToString(row))) + } + + private def update(store: StateStore, key: String, value: Int): Unit = { + store.update(stringToRow(key), _ => intToRow(value)) + } +} + +private[state] object StateStoreSuite { + + /** Trait and classes mirroring [[StoreUpdate]] for testing store updates iterator */ + trait TestUpdate + case class Added(key: String, value: Int) extends TestUpdate + case class Updated(key: String, value: Int) extends TestUpdate + case class Removed(key: String) extends TestUpdate + + val strProj = UnsafeProjection.create(Array[DataType](StringType)) + val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) + + def stringToRow(s: String): UnsafeRow = { + strProj.apply(new GenericInternalRow(Array[Any](UTF8String.fromString(s)))).copy() + } + + def intToRow(i: Int): UnsafeRow = { + intProj.apply(new GenericInternalRow(Array[Any](i))).copy() + } + + def rowToString(row: UnsafeRow): String = { + row.getUTF8String(0).toString + } + + def rowToInt(row: UnsafeRow): Int = { + row.getInt(0) + } + + def rowsToIntInt(row: (UnsafeRow, UnsafeRow)): (Int, Int) = { + (rowToInt(row._1), rowToInt(row._2)) + } + + + def rowsToStringInt(row: (UnsafeRow, UnsafeRow)): (String, Int) = { + (rowToString(row._1), rowToInt(row._2)) + } + + def rowsToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, Int)] = { + iterator.map(rowsToStringInt).toSet + } + + def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = { + iterator.map { _ match { + case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value)) + case ValueUpdated(key, value) => Updated(rowToString(key), rowToInt(value)) + case KeyRemoved(key) => Removed(rowToString(key)) + }}.toSet + } +} |