diff options
3 files changed, 48 insertions, 23 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 9521506325..9948292470 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -113,7 +113,7 @@ case class KeyRemoved(key: UnsafeRow) extends StoreUpdate * the store is the active instance. Accordingly, it either keeps it loaded and performs * maintenance, or unloads the store. */ -private[state] object StateStore extends Logging { +private[sql] object StateStore extends Logging { val MAINTENANCE_INTERVAL_CONFIG = "spark.streaming.stateStore.maintenanceInterval" val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 @@ -155,6 +155,10 @@ private[state] object StateStore extends Logging { loadedProviders.contains(storeId) } + def isMaintenanceRunning: Boolean = loadedProviders.synchronized { + maintenanceTask != null + } + /** Unload and stop all state store providers */ def stop(): Unit = loadedProviders.synchronized { loadedProviders.clear() @@ -187,44 +191,44 @@ private[state] object StateStore extends Logging { */ private def doMaintenance(): Unit = { logDebug("Doing maintenance") - loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => - try { - if (verifyIfStoreInstanceActive(id)) { - provider.doMaintenance() - } else { - unload(id) - logInfo(s"Unloaded $provider") + if (SparkEnv.get == null) { + stop() + } else { + loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => + try { + if (verifyIfStoreInstanceActive(id)) { + provider.doMaintenance() + } else { + unload(id) + logInfo(s"Unloaded $provider") + } + } catch { + case NonFatal(e) => + logWarning(s"Error managing $provider, stopping management thread") + stop() } - } catch { - case NonFatal(e) => - logWarning(s"Error managing $provider") } } } private def reportActiveStoreInstance(storeId: StateStoreId): Unit = { - try { + if (SparkEnv.get != null) { val host = SparkEnv.get.blockManager.blockManagerId.host val executorId = SparkEnv.get.blockManager.blockManagerId.executorId coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId)) logDebug(s"Reported that the loaded instance $storeId is active") - } catch { - case NonFatal(e) => - logWarning(s"Error reporting active instance of $storeId") } } private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = { - try { + if (SparkEnv.get != null) { val executorId = SparkEnv.get.blockManager.blockManagerId.executorId val verified = coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false) - logDebug(s"Verified whether the loaded instance $storeId is active: $verified" ) + logDebug(s"Verified whether the loaded instance $storeId is active: $verified") verified - } catch { - case NonFatal(e) => - logWarning(s"Error verifying active instance of $storeId") - false + } else { + false } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index dd23925716..f8f8bc7d6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -47,8 +47,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth private val keySchema = StructType(Seq(StructField("key", StringType, true))) private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + before { + StateStore.stop() + require(!StateStore.isMaintenanceRunning) + } + after { StateStore.stop() + require(!StateStore.isMaintenanceRunning) } test("get, put, remove, commit, and all data iterator") { @@ -352,7 +358,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } } - ignore("maintenance") { + test("maintenance") { val conf = new SparkConf() .setMaster("local") .setAppName("test") @@ -366,20 +372,26 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val provider = new HDFSBackedStateStoreProvider( storeId, keySchema, valueSchema, storeConf, hadoopConf) + quietly { withSpark(new SparkContext(conf)) { sc => withCoordinatorRef(sc) { coordinatorRef => + require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running") + for (i <- 1 to 20) { val store = StateStore.get( storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf) put(store, "a", i) store.commit() } + eventually(timeout(10 seconds)) { assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported") } // Background maintenance should clean up and generate snapshots + assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") + eventually(timeout(10 seconds)) { // Earliest delta file should get cleaned up assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") @@ -418,6 +430,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth require(SparkEnv.get === null) eventually(timeout(10 seconds)) { assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isMaintenanceRunning) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index bdf40f5cd4..8da7742ffe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.streaming +import org.scalatest.BeforeAndAfterAll + import org.apache.spark.SparkException import org.apache.spark.sql.StreamTest import org.apache.spark.sql.catalyst.analysis.Update import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -29,7 +32,12 @@ object FailureSinglton { var firstTime = true } -class StreamingAggregationSuite extends StreamTest with SharedSQLContext { +class StreamingAggregationSuite extends StreamTest with SharedSQLContext with BeforeAndAfterAll { + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } import testImplicits._ |