aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2016-03-23 12:48:05 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2016-03-23 12:48:05 -0700
commit8c826880f5eaa3221c4e9e7d3fece54e821a0b98 (patch)
treeb6dbe3670844bac231b787ccd9a97d2797f0a181 /sql/core
parent0a64294fcb4b64bfe095c63c3a494e0f40e22743 (diff)
downloadspark-8c826880f5eaa3221c4e9e7d3fece54e821a0b98.tar.gz
spark-8c826880f5eaa3221c4e9e7d3fece54e821a0b98.tar.bz2
spark-8c826880f5eaa3221c4e9e7d3fece54e821a0b98.zip
[SPARK-13809][SQL] State store for streaming aggregations
## What changes were proposed in this pull request? In this PR, I am implementing a new abstraction for management of streaming state data - State Store. It is a key-value store for persisting running aggregates for aggregate operations in streaming dataframes. The motivation and design is discussed here. https://docs.google.com/document/d/1-ncawFx8JS5Zyfq1HAEGBx56RDet9wfVp_hDM8ZL254/edit# ## How was this patch tested? - [x] Unit tests - [x] Cluster tests **Coverage from unit tests** <img width="952" alt="screen shot 2016-03-21 at 3 09 40 pm" src="https://cloud.githubusercontent.com/assets/663212/13935872/fdc8ba86-ef76-11e5-93e8-9fa310472c7b.png"> ## TODO - [x] Fix updates() iterator to avoid duplicate updates for same key - [x] Use Coordinator in ContinuousQueryManager - [x] Plugging in hadoop conf and other confs - [x] Unit tests - [x] StateStore object lifecycle and methods - [x] StateStoreCoordinator communication and logic - [x] StateStoreRDD fault-tolerance - [x] StateStoreRDD preferred location using StateStoreCoordinator - [ ] Cluster tests - [ ] Whether preferred locations are set correctly - [ ] Whether recovery works correctly with distributed storage - [x] Basic performance tests - [x] Docs Author: Tathagata Das <tathagata.das1565@gmail.com> Closes #11645 from tdas/state-store.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala584
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala247
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala37
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala146
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala70
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala75
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala123
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala192
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala562
11 files changed, 2052 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
index fa8219bbed..465feeb604 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution}
+import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef
import org.apache.spark.sql.util.ContinuousQueryListener
/**
@@ -33,6 +34,8 @@ import org.apache.spark.sql.util.ContinuousQueryListener
@Experimental
class ContinuousQueryManager(sqlContext: SQLContext) {
+ private[sql] val stateStoreCoordinator =
+ StateStoreCoordinatorRef.forDriver(sqlContext.sparkContext.env)
private val listenerBus = new ContinuousQueryListenerBus(sqlContext.sparkContext.listenerBus)
private val activeQueries = new mutable.HashMap[String, ContinuousQuery]
private val activeQueriesLock = new Object
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
new file mode 100644
index 0000000000..ee015baf3f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -0,0 +1,584 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io.{DataInputStream, DataOutputStream, IOException}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.util.Random
+import scala.util.control.NonFatal
+
+import com.google.common.io.ByteStreams
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+
+import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.internal.Logging
+import org.apache.spark.io.LZ4CompressionCodec
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
+
+
+/**
+ * An implementation of [[StateStoreProvider]] and [[StateStore]] in which all the data is backed
+ * by files in a HDFS-compatible file system. All updates to the store has to be done in sets
+ * transactionally, and each set of updates increments the store's version. These versions can
+ * be used to re-execute the updates (by retries in RDD operations) on the correct version of
+ * the store, and regenerate the store version.
+ *
+ * Usage:
+ * To update the data in the state store, the following order of operations are needed.
+ *
+ * - val store = StateStore.get(operatorId, partitionId, version) // to get the right store
+ * - store.update(...)
+ * - store.remove(...)
+ * - store.commit() // commits all the updates to made with version number
+ * - store.iterator() // key-value data after last commit as an iterator
+ * - store.updates() // updates made in the last as an iterator
+ *
+ * Fault-tolerance model:
+ * - Every set of updates is written to a delta file before committing.
+ * - The state store is responsible for managing, collapsing and cleaning up of delta files.
+ * - Multiple attempts to commit the same version of updates may overwrite each other.
+ * Consistency guarantees depend on whether multiple attempts have the same updates and
+ * the overwrite semantics of underlying file system.
+ * - Background maintenance of files ensures that last versions of the store is always recoverable
+ * to ensure re-executed RDD operations re-apply updates on the correct past version of the
+ * store.
+ */
+private[state] class HDFSBackedStateStoreProvider(
+ val id: StateStoreId,
+ keySchema: StructType,
+ valueSchema: StructType,
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration
+ ) extends StateStoreProvider with Logging {
+
+ type MapType = java.util.HashMap[UnsafeRow, UnsafeRow]
+
+ /** Implementation of [[StateStore]] API which is backed by a HDFS-compatible file system */
+ class HDFSBackedStateStore(val version: Long, mapToUpdate: MapType)
+ extends StateStore {
+
+ /** Trait and classes representing the internal state of the store */
+ trait STATE
+ case object UPDATING extends STATE
+ case object COMMITTED extends STATE
+ case object CANCELLED extends STATE
+
+ private val newVersion = version + 1
+ private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}")
+ private val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true))
+
+ private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]()
+
+ @volatile private var state: STATE = UPDATING
+ @volatile private var finalDeltaFile: Path = null
+
+ override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id
+
+ /**
+ * Update the value of a key using the value generated by the update function.
+ * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous
+ * versions of the store data.
+ */
+ override def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit = {
+ verify(state == UPDATING, "Cannot update after already committed or cancelled")
+ val oldValueOption = Option(mapToUpdate.get(key))
+ val value = updateFunc(oldValueOption)
+ mapToUpdate.put(key, value)
+
+ Option(allUpdates.get(key)) match {
+ case Some(ValueAdded(_, _)) =>
+ // Value did not exist in previous version and was added already, keep it marked as added
+ allUpdates.put(key, ValueAdded(key, value))
+ case Some(ValueUpdated(_, _)) | Some(KeyRemoved(_)) =>
+ // Value existed in prev version and updated/removed, mark it as updated
+ allUpdates.put(key, ValueUpdated(key, value))
+ case None =>
+ // There was no prior update, so mark this as added or updated according to its presence
+ // in previous version.
+ val update =
+ if (oldValueOption.nonEmpty) ValueUpdated(key, value) else ValueAdded(key, value)
+ allUpdates.put(key, update)
+ }
+ writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value))
+ }
+
+ /** Remove keys that match the following condition */
+ override def remove(condition: UnsafeRow => Boolean): Unit = {
+ verify(state == UPDATING, "Cannot remove after already committed or cancelled")
+ val keyIter = mapToUpdate.keySet().iterator()
+ while (keyIter.hasNext) {
+ val key = keyIter.next
+ if (condition(key)) {
+ keyIter.remove()
+
+ Option(allUpdates.get(key)) match {
+ case Some(ValueUpdated(_, _)) | None =>
+ // Value existed in previous version and maybe was updated, mark removed
+ allUpdates.put(key, KeyRemoved(key))
+ case Some(ValueAdded(_, _)) =>
+ // Value did not exist in previous version and was added, should not appear in updates
+ allUpdates.remove(key)
+ case Some(KeyRemoved(_)) =>
+ // Remove already in update map, no need to change
+ }
+ writeToDeltaFile(tempDeltaFileStream, KeyRemoved(key))
+ }
+ }
+ }
+
+ /** Commit all the updates that have been made to the store, and return the new version. */
+ override def commit(): Long = {
+ verify(state == UPDATING, "Cannot commit again after already committed or cancelled")
+
+ try {
+ finalizeDeltaFile(tempDeltaFileStream)
+ finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile)
+ state = COMMITTED
+ logInfo(s"Committed version $newVersion for $this")
+ newVersion
+ } catch {
+ case NonFatal(e) =>
+ throw new IllegalStateException(
+ s"Error committing version $newVersion into ${HDFSBackedStateStoreProvider.this}", e)
+ }
+ }
+
+ /** Cancel all the updates made on this store. This store will not be usable any more. */
+ override def cancel(): Unit = {
+ state = CANCELLED
+ if (tempDeltaFileStream != null) {
+ tempDeltaFileStream.close()
+ }
+ if (tempDeltaFile != null && fs.exists(tempDeltaFile)) {
+ fs.delete(tempDeltaFile, true)
+ }
+ logInfo("Canceled ")
+ }
+
+ /**
+ * Get an iterator of all the store data. This can be called only after committing the
+ * updates.
+ */
+ override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = {
+ verify(state == COMMITTED, "Cannot get iterator of store data before comitting")
+ HDFSBackedStateStoreProvider.this.iterator(newVersion)
+ }
+
+ /**
+ * Get an iterator of all the updates made to the store in the current version.
+ * This can be called only after committing the updates.
+ */
+ override def updates(): Iterator[StoreUpdate] = {
+ verify(state == COMMITTED, "Cannot get iterator of updates before committing")
+ allUpdates.values().asScala.toIterator
+ }
+
+ /**
+ * Whether all updates have been committed
+ */
+ override def hasCommitted: Boolean = {
+ state == COMMITTED
+ }
+ }
+
+ /** Get the state store for making updates to create a new `version` of the store. */
+ override def getStore(version: Long): StateStore = synchronized {
+ require(version >= 0, "Version cannot be less than 0")
+ val newMap = new MapType()
+ if (version > 0) {
+ newMap.putAll(loadMap(version))
+ }
+ val store = new HDFSBackedStateStore(version, newMap)
+ logInfo(s"Retrieved version $version of $this for update")
+ store
+ }
+
+ /** Do maintenance backing data files, including creating snapshots and cleaning up old files */
+ override def doMaintenance(): Unit = {
+ try {
+ doSnapshot()
+ cleanup()
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Error performing snapshot and cleaning up $this")
+ }
+ }
+
+ override def toString(): String = {
+ s"StateStore[id = (op=${id.operatorId},part=${id.partitionId}), dir = $baseDir]"
+ }
+
+ /* Internal classes and methods */
+
+ private val loadedMaps = new mutable.HashMap[Long, MapType]
+ private val baseDir =
+ new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}")
+ private val fs = baseDir.getFileSystem(hadoopConf)
+ private val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
+
+ initialize()
+
+ private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean)
+
+ /** Commit a set of updates to the store with the given new version */
+ private def commitUpdates(newVersion: Long, map: MapType, tempDeltaFile: Path): Path = {
+ synchronized {
+ val finalDeltaFile = deltaFile(newVersion)
+ fs.rename(tempDeltaFile, finalDeltaFile)
+ loadedMaps.put(newVersion, map)
+ finalDeltaFile
+ }
+ }
+
+ /**
+ * Get iterator of all the data of the latest version of the store.
+ * Note that this will look up the files to determined the latest known version.
+ */
+ private[state] def latestIterator(): Iterator[(UnsafeRow, UnsafeRow)] = synchronized {
+ val versionsInFiles = fetchFiles().map(_.version).toSet
+ val versionsLoaded = loadedMaps.keySet
+ val allKnownVersions = versionsInFiles ++ versionsLoaded
+ if (allKnownVersions.nonEmpty) {
+ loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { x =>
+ (x.getKey, x.getValue)
+ }
+ } else Iterator.empty
+ }
+
+ /** Get iterator of a specific version of the store */
+ private[state] def iterator(version: Long): Iterator[(UnsafeRow, UnsafeRow)] = synchronized {
+ loadMap(version).entrySet().iterator().asScala.map { x =>
+ (x.getKey, x.getValue)
+ }
+ }
+
+ /** Initialize the store provider */
+ private def initialize(): Unit = {
+ if (!fs.exists(baseDir)) {
+ fs.mkdirs(baseDir)
+ } else {
+ if (!fs.isDirectory(baseDir)) {
+ throw new IllegalStateException(
+ s"Cannot use ${id.checkpointLocation} for storing state data for $this as" +
+ s"$baseDir already exists and is not a directory")
+ }
+ }
+ }
+
+ /** Load the required version of the map data from the backing files */
+ private def loadMap(version: Long): MapType = {
+ if (version <= 0) return new MapType
+ synchronized { loadedMaps.get(version) }.getOrElse {
+ val mapFromFile = readSnapshotFile(version).getOrElse {
+ val prevMap = loadMap(version - 1)
+ val newMap = new MapType(prevMap)
+ newMap.putAll(prevMap)
+ updateFromDeltaFile(version, newMap)
+ newMap
+ }
+ loadedMaps.put(version, mapFromFile)
+ mapFromFile
+ }
+ }
+
+ private def writeToDeltaFile(output: DataOutputStream, update: StoreUpdate): Unit = {
+
+ def writeUpdate(key: UnsafeRow, value: UnsafeRow): Unit = {
+ val keyBytes = key.getBytes()
+ val valueBytes = value.getBytes()
+ output.writeInt(keyBytes.size)
+ output.write(keyBytes)
+ output.writeInt(valueBytes.size)
+ output.write(valueBytes)
+ }
+
+ def writeRemove(key: UnsafeRow): Unit = {
+ val keyBytes = key.getBytes()
+ output.writeInt(keyBytes.size)
+ output.write(keyBytes)
+ output.writeInt(-1)
+ }
+
+ update match {
+ case ValueAdded(key, value) =>
+ writeUpdate(key, value)
+ case ValueUpdated(key, value) =>
+ writeUpdate(key, value)
+ case KeyRemoved(key) =>
+ writeRemove(key)
+ }
+ }
+
+ private def finalizeDeltaFile(output: DataOutputStream): Unit = {
+ output.writeInt(-1) // Write this magic number to signify end of file
+ output.close()
+ }
+
+ private def updateFromDeltaFile(version: Long, map: MapType): Unit = {
+ val fileToRead = deltaFile(version)
+ if (!fs.exists(fileToRead)) {
+ throw new IllegalStateException(
+ s"Error reading delta file $fileToRead of $this: $fileToRead does not exist")
+ }
+ var input: DataInputStream = null
+ try {
+ input = decompressStream(fs.open(fileToRead))
+ var eof = false
+
+ while(!eof) {
+ val keySize = input.readInt()
+ if (keySize == -1) {
+ eof = true
+ } else if (keySize < 0) {
+ throw new IOException(
+ s"Error reading delta file $fileToRead of $this: key size cannot be $keySize")
+ } else {
+ val keyRowBuffer = new Array[Byte](keySize)
+ ByteStreams.readFully(input, keyRowBuffer, 0, keySize)
+
+ val keyRow = new UnsafeRow(keySchema.fields.length)
+ keyRow.pointTo(keyRowBuffer, keySize)
+
+ val valueSize = input.readInt()
+ if (valueSize < 0) {
+ map.remove(keyRow)
+ } else {
+ val valueRowBuffer = new Array[Byte](valueSize)
+ ByteStreams.readFully(input, valueRowBuffer, 0, valueSize)
+ val valueRow = new UnsafeRow(valueSchema.fields.length)
+ valueRow.pointTo(valueRowBuffer, valueSize)
+ map.put(keyRow, valueRow)
+ }
+ }
+ }
+ } finally {
+ if (input != null) input.close()
+ }
+ logInfo(s"Read delta file for version $version of $this from $fileToRead")
+ }
+
+ private def writeSnapshotFile(version: Long, map: MapType): Unit = {
+ val fileToWrite = snapshotFile(version)
+ var output: DataOutputStream = null
+ Utils.tryWithSafeFinally {
+ output = compressStream(fs.create(fileToWrite, false))
+ val iter = map.entrySet().iterator()
+ while(iter.hasNext) {
+ val entry = iter.next()
+ val keyBytes = entry.getKey.getBytes()
+ val valueBytes = entry.getValue.getBytes()
+ output.writeInt(keyBytes.size)
+ output.write(keyBytes)
+ output.writeInt(valueBytes.size)
+ output.write(valueBytes)
+ }
+ output.writeInt(-1)
+ } {
+ if (output != null) output.close()
+ }
+ logInfo(s"Written snapshot file for version $version of $this at $fileToWrite")
+ }
+
+ private def readSnapshotFile(version: Long): Option[MapType] = {
+ val fileToRead = snapshotFile(version)
+ if (!fs.exists(fileToRead)) return None
+
+ val map = new MapType()
+ var input: DataInputStream = null
+
+ try {
+ input = decompressStream(fs.open(fileToRead))
+ var eof = false
+
+ while (!eof) {
+ val keySize = input.readInt()
+ if (keySize == -1) {
+ eof = true
+ } else if (keySize < 0) {
+ throw new IOException(
+ s"Error reading snapshot file $fileToRead of $this: key size cannot be $keySize")
+ } else {
+ val keyRowBuffer = new Array[Byte](keySize)
+ ByteStreams.readFully(input, keyRowBuffer, 0, keySize)
+
+ val keyRow = new UnsafeRow(keySchema.fields.length)
+ keyRow.pointTo(keyRowBuffer, keySize)
+
+ val valueSize = input.readInt()
+ if (valueSize < 0) {
+ throw new IOException(
+ s"Error reading snapshot file $fileToRead of $this: value size cannot be $valueSize")
+ } else {
+ val valueRowBuffer = new Array[Byte](valueSize)
+ ByteStreams.readFully(input, valueRowBuffer, 0, valueSize)
+ val valueRow = new UnsafeRow(valueSchema.fields.length)
+ valueRow.pointTo(valueRowBuffer, valueSize)
+ map.put(keyRow, valueRow)
+ }
+ }
+ }
+ logInfo(s"Read snapshot file for version $version of $this from $fileToRead")
+ Some(map)
+ } finally {
+ if (input != null) input.close()
+ }
+ }
+
+
+ /** Perform a snapshot of the store to allow delta files to be consolidated */
+ private def doSnapshot(): Unit = {
+ try {
+ val files = fetchFiles()
+ if (files.nonEmpty) {
+ val lastVersion = files.last.version
+ val deltaFilesForLastVersion =
+ filesForVersion(files, lastVersion).filter(_.isSnapshot == false)
+ synchronized { loadedMaps.get(lastVersion) } match {
+ case Some(map) =>
+ if (deltaFilesForLastVersion.size > storeConf.maxDeltasForSnapshot) {
+ writeSnapshotFile(lastVersion, map)
+ }
+ case None =>
+ // The last map is not loaded, probably some other instance is incharge
+ }
+
+ }
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Error doing snapshots for $this", e)
+ }
+ }
+
+ /**
+ * Clean up old snapshots and delta files that are not needed any more. It ensures that last
+ * few versions of the store can be recovered from the files, so re-executed RDD operations
+ * can re-apply updates on the past versions of the store.
+ */
+ private[state] def cleanup(): Unit = {
+ try {
+ val files = fetchFiles()
+ if (files.nonEmpty) {
+ val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain
+ if (earliestVersionToRetain > 0) {
+ val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head
+ synchronized {
+ val mapsToRemove = loadedMaps.keys.filter(_ < earliestVersionToRetain).toSeq
+ mapsToRemove.foreach(loadedMaps.remove)
+ }
+ files.filter(_.version < earliestFileToRetain.version).foreach { f =>
+ fs.delete(f.path, true)
+ }
+ logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this")
+ }
+ }
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Error cleaning up files for $this", e)
+ }
+ }
+
+ /** Files needed to recover the given version of the store */
+ private def filesForVersion(allFiles: Seq[StoreFile], version: Long): Seq[StoreFile] = {
+ require(version >= 0)
+ require(allFiles.exists(_.version == version))
+
+ val latestSnapshotFileBeforeVersion = allFiles
+ .filter(_.isSnapshot == true)
+ .takeWhile(_.version <= version)
+ .lastOption
+ val deltaBatchFiles = latestSnapshotFileBeforeVersion match {
+ case Some(snapshotFile) =>
+ val deltaBatchIds = (snapshotFile.version + 1) to version
+
+ val deltaFiles = allFiles.filter { file =>
+ file.version > snapshotFile.version && file.version <= version
+ }
+ verify(
+ deltaFiles.size == version - snapshotFile.version,
+ s"Unexpected list of delta files for version $version for $this: $deltaFiles"
+ )
+ deltaFiles
+
+ case None =>
+ allFiles.takeWhile(_.version <= version)
+ }
+ latestSnapshotFileBeforeVersion.toSeq ++ deltaBatchFiles
+ }
+
+ /** Fetch all the files that back the store */
+ private def fetchFiles(): Seq[StoreFile] = {
+ val files: Seq[FileStatus] = try {
+ fs.listStatus(baseDir)
+ } catch {
+ case _: java.io.FileNotFoundException =>
+ Seq.empty
+ }
+ val versionToFiles = new mutable.HashMap[Long, StoreFile]
+ files.foreach { status =>
+ val path = status.getPath
+ val nameParts = path.getName.split("\\.")
+ if (nameParts.size == 2) {
+ val version = nameParts(0).toLong
+ nameParts(1).toLowerCase match {
+ case "delta" =>
+ // ignore the file otherwise, snapshot file already exists for that batch id
+ if (!versionToFiles.contains(version)) {
+ versionToFiles.put(version, StoreFile(version, path, isSnapshot = false))
+ }
+ case "snapshot" =>
+ versionToFiles.put(version, StoreFile(version, path, isSnapshot = true))
+ case _ =>
+ logWarning(s"Could not identify file $path for $this")
+ }
+ }
+ }
+ val storeFiles = versionToFiles.values.toSeq.sortBy(_.version)
+ logDebug(s"Current set of files for $this: $storeFiles")
+ storeFiles
+ }
+
+ private def compressStream(outputStream: DataOutputStream): DataOutputStream = {
+ val compressed = new LZ4CompressionCodec(sparkConf).compressedOutputStream(outputStream)
+ new DataOutputStream(compressed)
+ }
+
+ private def decompressStream(inputStream: DataInputStream): DataInputStream = {
+ val compressed = new LZ4CompressionCodec(sparkConf).compressedInputStream(inputStream)
+ new DataInputStream(compressed)
+ }
+
+ private def deltaFile(version: Long): Path = {
+ new Path(baseDir, s"$version.delta")
+ }
+
+ private def snapshotFile(version: Long): Path = {
+ new Path(baseDir, s"$version.snapshot")
+ }
+
+ private def verify(condition: => Boolean, msg: String): Unit = {
+ if (!condition) {
+ throw new IllegalStateException(msg)
+ }
+ }
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
new file mode 100644
index 0000000000..ca5c864d9e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.util.Timer
+import java.util.concurrent.{ScheduledFuture, TimeUnit}
+
+import scala.collection.mutable
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.ThreadUtils
+
+
+/** Unique identifier for a [[StateStore]] */
+case class StateStoreId(checkpointLocation: String, operatorId: Long, partitionId: Int)
+
+
+/**
+ * Base trait for a versioned key-value store used for streaming aggregations
+ */
+trait StateStore {
+
+ /** Unique identifier of the store */
+ def id: StateStoreId
+
+ /** Version of the data in this store before committing updates. */
+ def version: Long
+
+ /**
+ * Update the value of a key using the value generated by the update function.
+ * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous
+ * versions of the store data.
+ */
+ def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit
+
+ /**
+ * Remove keys that match the following condition.
+ */
+ def remove(condition: UnsafeRow => Boolean): Unit
+
+ /**
+ * Commit all the updates that have been made to the store, and return the new version.
+ */
+ def commit(): Long
+
+ /** Cancel all the updates that have been made to the store. */
+ def cancel(): Unit
+
+ /**
+ * Iterator of store data after a set of updates have been committed.
+ * This can be called only after commitUpdates() has been called in the current thread.
+ */
+ def iterator(): Iterator[(UnsafeRow, UnsafeRow)]
+
+ /**
+ * Iterator of the updates that have been committed.
+ * This can be called only after commitUpdates() has been called in the current thread.
+ */
+ def updates(): Iterator[StoreUpdate]
+
+ /**
+ * Whether all updates have been committed
+ */
+ def hasCommitted: Boolean
+}
+
+
+/** Trait representing a provider of a specific version of a [[StateStore]]. */
+trait StateStoreProvider {
+
+ /** Get the store with the existing version. */
+ def getStore(version: Long): StateStore
+
+ /** Optional method for providers to allow for background maintenance */
+ def doMaintenance(): Unit = { }
+}
+
+
+/** Trait representing updates made to a [[StateStore]]. */
+sealed trait StoreUpdate
+
+case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
+
+case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
+
+case class KeyRemoved(key: UnsafeRow) extends StoreUpdate
+
+
+/**
+ * Companion object to [[StateStore]] that provides helper methods to create and retrieve stores
+ * by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null),
+ * it also runs a periodic background tasks to do maintenance on the loaded stores. For each
+ * store, tt uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of
+ * the store is the active instance. Accordingly, it either keeps it loaded and performs
+ * maintenance, or unloads the store.
+ */
+private[state] object StateStore extends Logging {
+
+ val MAINTENANCE_INTERVAL_CONFIG = "spark.streaming.stateStore.maintenanceInterval"
+ val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60
+
+ private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]()
+ private val maintenanceTaskExecutor =
+ ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task")
+
+ @volatile private var maintenanceTask: ScheduledFuture[_] = null
+ @volatile private var _coordRef: StateStoreCoordinatorRef = null
+
+ /** Get or create a store associated with the id. */
+ def get(
+ storeId: StateStoreId,
+ keySchema: StructType,
+ valueSchema: StructType,
+ version: Long,
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration): StateStore = {
+ require(version >= 0)
+ val storeProvider = loadedProviders.synchronized {
+ startMaintenanceIfNeeded()
+ val provider = loadedProviders.getOrElseUpdate(
+ storeId,
+ new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, storeConf, hadoopConf))
+ reportActiveStoreInstance(storeId)
+ provider
+ }
+ storeProvider.getStore(version)
+ }
+
+ /** Unload a state store provider */
+ def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized {
+ loadedProviders.remove(storeId)
+ }
+
+ /** Whether a state store provider is loaded or not */
+ def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized {
+ loadedProviders.contains(storeId)
+ }
+
+ /** Unload and stop all state store providers */
+ def stop(): Unit = loadedProviders.synchronized {
+ loadedProviders.clear()
+ _coordRef = null
+ if (maintenanceTask != null) {
+ maintenanceTask.cancel(false)
+ maintenanceTask = null
+ }
+ logInfo("StateStore stopped")
+ }
+
+ /** Start the periodic maintenance task if not already started and if Spark active */
+ private def startMaintenanceIfNeeded(): Unit = loadedProviders.synchronized {
+ val env = SparkEnv.get
+ if (maintenanceTask == null && env != null) {
+ val periodMs = env.conf.getTimeAsMs(
+ MAINTENANCE_INTERVAL_CONFIG, s"${MAINTENANCE_INTERVAL_DEFAULT_SECS}s")
+ val runnable = new Runnable {
+ override def run(): Unit = { doMaintenance() }
+ }
+ maintenanceTask = maintenanceTaskExecutor.scheduleAtFixedRate(
+ runnable, periodMs, periodMs, TimeUnit.MILLISECONDS)
+ logInfo("State Store maintenance task started")
+ }
+ }
+
+ /**
+ * Execute background maintenance task in all the loaded store providers if they are still
+ * the active instances according to the coordinator.
+ */
+ private def doMaintenance(): Unit = {
+ logDebug("Doing maintenance")
+ loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) =>
+ try {
+ if (verifyIfStoreInstanceActive(id)) {
+ provider.doMaintenance()
+ } else {
+ unload(id)
+ logInfo(s"Unloaded $provider")
+ }
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Error managing $provider")
+ }
+ }
+ }
+
+ private def reportActiveStoreInstance(storeId: StateStoreId): Unit = {
+ try {
+ val host = SparkEnv.get.blockManager.blockManagerId.host
+ val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
+ coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId))
+ logDebug(s"Reported that the loaded instance $storeId is active")
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Error reporting active instance of $storeId")
+ }
+ }
+
+ private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = {
+ try {
+ val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
+ val verified =
+ coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false)
+ logDebug(s"Verifyied whether the loaded instance $storeId is active: $verified" )
+ verified
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Error verifying active instance of $storeId")
+ false
+ }
+ }
+
+ private def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized {
+ val env = SparkEnv.get
+ if (env != null) {
+ if (_coordRef == null) {
+ _coordRef = StateStoreCoordinatorRef.forExecutor(env)
+ }
+ logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}")
+ Some(_coordRef)
+ } else {
+ _coordRef = null
+ None
+ }
+ }
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
new file mode 100644
index 0000000000..cca22a0af8
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.internal.SQLConf
+
+/** A class that contains configuration parameters for [[StateStore]]s. */
+private[state] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable {
+
+ def this() = this(new SQLConf)
+
+ import SQLConf._
+
+ val maxDeltasForSnapshot = conf.getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
+
+ val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN)
+}
+
+private[state] object StateStoreConf {
+ val empty = new StateStoreConf()
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
new file mode 100644
index 0000000000..5aa0636850
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
+import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+import org.apache.spark.util.RpcUtils
+
+/** Trait representing all messages to [[StateStoreCoordinator]] */
+private sealed trait StateStoreCoordinatorMessage extends Serializable
+
+/** Classes representing messages */
+private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String)
+ extends StateStoreCoordinatorMessage
+
+private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: String)
+ extends StateStoreCoordinatorMessage
+
+private case class GetLocation(storeId: StateStoreId)
+ extends StateStoreCoordinatorMessage
+
+private case class DeactivateInstances(storeRootLocation: String)
+ extends StateStoreCoordinatorMessage
+
+private object StopCoordinator
+ extends StateStoreCoordinatorMessage
+
+/** Helper object used to create reference to [[StateStoreCoordinator]]. */
+private[sql] object StateStoreCoordinatorRef extends Logging {
+
+ private val endpointName = "StateStoreCoordinator"
+
+ /**
+ * Create a reference to a [[StateStoreCoordinator]], This can be called from driver as well as
+ * executors.
+ */
+ def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized {
+ try {
+ val coordinator = new StateStoreCoordinator(env.rpcEnv)
+ val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator)
+ logInfo("Registered StateStoreCoordinator endpoint")
+ new StateStoreCoordinatorRef(coordinatorRef)
+ } catch {
+ case e: IllegalArgumentException =>
+ val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv)
+ logDebug("Retrieved existing StateStoreCoordinator endpoint")
+ new StateStoreCoordinatorRef(rpcEndpointRef)
+ }
+ }
+
+ def forExecutor(env: SparkEnv): StateStoreCoordinatorRef = synchronized {
+ val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv)
+ logDebug("Retrieved existing StateStoreCoordinator endpoint")
+ new StateStoreCoordinatorRef(rpcEndpointRef)
+ }
+}
+
+/**
+ * Reference to a [[StateStoreCoordinator]] that can be used to coordinator instances of
+ * [[StateStore]]s across all the executors, and get their locations for job scheduling.
+ */
+private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) {
+
+ private[state] def reportActiveInstance(
+ storeId: StateStoreId,
+ host: String,
+ executorId: String): Unit = {
+ rpcEndpointRef.send(ReportActiveInstance(storeId, host, executorId))
+ }
+
+ /** Verify whether the given executor has the active instance of a state store */
+ private[state] def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = {
+ rpcEndpointRef.askWithRetry[Boolean](VerifyIfInstanceActive(storeId, executorId))
+ }
+
+ /** Get the location of the state store */
+ private[state] def getLocation(storeId: StateStoreId): Option[String] = {
+ rpcEndpointRef.askWithRetry[Option[String]](GetLocation(storeId))
+ }
+
+ /** Deactivate instances related to a set of operator */
+ private[state] def deactivateInstances(storeRootLocation: String): Unit = {
+ rpcEndpointRef.askWithRetry[Boolean](DeactivateInstances(storeRootLocation))
+ }
+
+ private[state] def stop(): Unit = {
+ rpcEndpointRef.askWithRetry[Boolean](StopCoordinator)
+ }
+}
+
+
+/**
+ * Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster,
+ * and get their locations for job scheduling.
+ */
+private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends RpcEndpoint {
+ private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation]
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case ReportActiveInstance(id, host, executorId) =>
+ instances.put(id, ExecutorCacheTaskLocation(host, executorId))
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case VerifyIfInstanceActive(id, execId) =>
+ val response = instances.get(id) match {
+ case Some(location) => location.executorId == execId
+ case None => false
+ }
+ context.reply(response)
+
+ case GetLocation(id) =>
+ context.reply(instances.get(id).map(_.toString))
+
+ case DeactivateInstances(loc) =>
+ val storeIdsToRemove =
+ instances.keys.filter(_.checkpointLocation == loc).toSeq
+ instances --= storeIdsToRemove
+ context.reply(true)
+
+ case StopCoordinator =>
+ stop() // Stop before replying to ensure that endpoint name has been deregistered
+ context.reply(true)
+ }
+}
+
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
new file mode 100644
index 0000000000..3318660895
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Partition, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.{SerializableConfiguration, Utils}
+
+/**
+ * An RDD that allows computations to be executed against [[StateStore]]s. It
+ * uses the [[StateStoreCoordinator]] to use the locations of loaded state stores as
+ * preferred locations.
+ */
+class StateStoreRDD[T: ClassTag, U: ClassTag](
+ dataRDD: RDD[T],
+ storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
+ checkpointLocation: String,
+ operatorId: Long,
+ storeVersion: Long,
+ keySchema: StructType,
+ valueSchema: StructType,
+ storeConf: StateStoreConf,
+ @transient private val storeCoordinator: Option[StateStoreCoordinatorRef])
+ extends RDD[U](dataRDD) {
+
+ // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
+ private val confBroadcast = dataRDD.context.broadcast(
+ new SerializableConfiguration(dataRDD.context.hadoopConfiguration))
+
+ override protected def getPartitions: Array[Partition] = dataRDD.partitions
+
+ override def getPreferredLocations(partition: Partition): Seq[String] = {
+ val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
+ storeCoordinator.flatMap(_.getLocation(storeId)).toSeq
+ }
+
+ override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
+ var store: StateStore = null
+
+ Utils.tryWithSafeFinally {
+ val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
+ store = StateStore.get(
+ storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value)
+ val inputIter = dataRDD.iterator(partition, ctxt)
+ val outputIter = storeUpdateFunction(store, inputIter)
+ assert(store.hasCommitted)
+ outputIter
+ } {
+ if (store != null) store.cancel()
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
new file mode 100644
index 0000000000..b249e37921
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.types.StructType
+
+package object state {
+
+ implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) {
+
+ /** Map each partition of a RDD along with data in a [[StateStore]]. */
+ def mapPartitionWithStateStore[U: ClassTag](
+ storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
+ checkpointLocation: String,
+ operatorId: Long,
+ storeVersion: Long,
+ keySchema: StructType,
+ valueSchema: StructType
+ )(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = {
+
+ mapPartitionWithStateStore(
+ storeUpdateFunction,
+ checkpointLocation,
+ operatorId,
+ storeVersion,
+ keySchema,
+ valueSchema,
+ new StateStoreConf(sqlContext.conf),
+ Some(sqlContext.streams.stateStoreCoordinator))
+ }
+
+ /** Map each partition of a RDD along with data in a [[StateStore]]. */
+ private[state] def mapPartitionWithStateStore[U: ClassTag](
+ storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
+ checkpointLocation: String,
+ operatorId: Long,
+ storeVersion: Long,
+ keySchema: StructType,
+ valueSchema: StructType,
+ storeConf: StateStoreConf,
+ storeCoordinator: Option[StateStoreCoordinatorRef]
+ ): StateStoreRDD[T, U] = {
+ val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
+ new StateStoreRDD(
+ dataRDD,
+ cleanedF,
+ checkpointLocation,
+ operatorId,
+ storeVersion,
+ keySchema,
+ valueSchema,
+ storeConf,
+ storeCoordinator)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index fd1d77f514..863a876afe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -524,6 +524,19 @@ object SQLConf {
doc = "When true, the planner will try to find out duplicated exchanges and re-use them.",
isPublic = false)
+ val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = intConf(
+ "spark.sql.streaming.stateStore.minDeltasForSnapshot",
+ defaultValue = Some(10),
+ doc = "Minimum number of state store delta files that needs to be generated before they " +
+ "consolidated into snapshots.",
+ isPublic = false)
+
+ val STATE_STORE_MIN_VERSIONS_TO_RETAIN = intConf(
+ "spark.sql.streaming.stateStore.minBatchesToRetain",
+ defaultValue = Some(2),
+ doc = "Minimum number of versions of a state store's data to retain after cleaning.",
+ isPublic = false)
+
val CHECKPOINT_LOCATION = stringConf("spark.sql.streaming.checkpointLocation",
defaultValue = None,
doc = "The default location for storing checkpoint data for continuously executing queries.",
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
new file mode 100644
index 0000000000..c99c2f505f
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite}
+import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+
+class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
+
+ import StateStoreCoordinatorSuite._
+
+ test("report, verify, getLocation") {
+ withCoordinatorRef(sc) { coordinatorRef =>
+ val id = StateStoreId("x", 0, 0)
+
+ assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false)
+ assert(coordinatorRef.getLocation(id) === None)
+
+ coordinatorRef.reportActiveInstance(id, "hostX", "exec1")
+ eventually(timeout(5 seconds)) {
+ assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === true)
+ assert(
+ coordinatorRef.getLocation(id) ===
+ Some(ExecutorCacheTaskLocation("hostX", "exec1").toString))
+ }
+
+ coordinatorRef.reportActiveInstance(id, "hostX", "exec2")
+
+ eventually(timeout(5 seconds)) {
+ assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false)
+ assert(coordinatorRef.verifyIfInstanceActive(id, "exec2") === true)
+
+ assert(
+ coordinatorRef.getLocation(id) ===
+ Some(ExecutorCacheTaskLocation("hostX", "exec2").toString))
+ }
+ }
+ }
+
+ test("make inactive") {
+ withCoordinatorRef(sc) { coordinatorRef =>
+ val id1 = StateStoreId("x", 0, 0)
+ val id2 = StateStoreId("y", 1, 0)
+ val id3 = StateStoreId("x", 0, 1)
+ val host = "hostX"
+ val exec = "exec1"
+
+ coordinatorRef.reportActiveInstance(id1, host, exec)
+ coordinatorRef.reportActiveInstance(id2, host, exec)
+ coordinatorRef.reportActiveInstance(id3, host, exec)
+
+ eventually(timeout(5 seconds)) {
+ assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === true)
+ assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true)
+ assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true)
+
+ }
+
+ coordinatorRef.deactivateInstances("x")
+
+ assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === false)
+ assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true)
+ assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === false)
+
+ assert(coordinatorRef.getLocation(id1) === None)
+ assert(
+ coordinatorRef.getLocation(id2) ===
+ Some(ExecutorCacheTaskLocation(host, exec).toString))
+ assert(coordinatorRef.getLocation(id3) === None)
+
+ coordinatorRef.deactivateInstances("y")
+ assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === false)
+ assert(coordinatorRef.getLocation(id2) === None)
+ }
+ }
+
+ test("multiple references have same underlying coordinator") {
+ withCoordinatorRef(sc) { coordRef1 =>
+ val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env)
+
+ val id = StateStoreId("x", 0, 0)
+
+ coordRef1.reportActiveInstance(id, "hostX", "exec1")
+
+ eventually(timeout(5 seconds)) {
+ assert(coordRef2.verifyIfInstanceActive(id, "exec1") === true)
+ assert(
+ coordRef2.getLocation(id) ===
+ Some(ExecutorCacheTaskLocation("hostX", "exec1").toString))
+ }
+ }
+ }
+}
+
+object StateStoreCoordinatorSuite {
+ def withCoordinatorRef(sc: SparkContext)(body: StateStoreCoordinatorRef => Unit): Unit = {
+ var coordinatorRef: StateStoreCoordinatorRef = null
+ try {
+ coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env)
+ body(coordinatorRef)
+ } finally {
+ if (coordinatorRef != null) coordinatorRef.stop()
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
new file mode 100644
index 0000000000..24cec30fa3
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io.File
+import java.nio.file.Files
+
+import scala.util.Random
+
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.LocalSparkContext._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.util.quietly
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+import org.apache.spark.util.Utils
+
+class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll {
+
+ private val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
+ private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString
+ private val keySchema = StructType(Seq(StructField("key", StringType, true)))
+ private val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+
+ import StateStoreSuite._
+
+ after {
+ StateStore.stop()
+ }
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ Utils.deleteRecursively(new File(tempDir))
+ }
+
+ test("versioning and immutability") {
+ quietly {
+ withSpark(new SparkContext(sparkConf)) { sc =>
+ implicit val sqlContet = new SQLContext(sc)
+ val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+ val increment = (store: StateStore, iter: Iterator[String]) => {
+ iter.foreach { s =>
+ store.update(
+ stringToRow(s), oldRow => {
+ val oldValue = oldRow.map(rowToInt).getOrElse(0)
+ intToRow(oldValue + 1)
+ })
+ }
+ store.commit()
+ store.iterator().map(rowsToStringInt)
+ }
+ val opId = 0
+ val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
+ increment, path, opId, storeVersion = 0, keySchema, valueSchema)
+ assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+
+ // Generate next version of stores
+ val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
+ increment, path, opId, storeVersion = 1, keySchema, valueSchema)
+ assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
+
+ // Make sure the previous RDD still has the same data.
+ assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+ }
+ }
+ }
+
+ test("recovering from files") {
+ quietly {
+ val opId = 0
+ val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+
+ def makeStoreRDD(
+ sc: SparkContext,
+ seq: Seq[String],
+ storeVersion: Int): RDD[(String, Int)] = {
+ implicit val sqlContext = new SQLContext(sc)
+ makeRDD(sc, Seq("a")).mapPartitionWithStateStore(
+ increment, path, opId, storeVersion, keySchema, valueSchema)
+ }
+
+ // Generate RDDs and state store data
+ withSpark(new SparkContext(sparkConf)) { sc =>
+ for (i <- 1 to 20) {
+ require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i))
+ }
+ }
+
+ // With a new context, try using the earlier state store data
+ withSpark(new SparkContext(sparkConf)) { sc =>
+ assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21))
+ }
+ }
+ }
+
+ test("preferred locations using StateStoreCoordinator") {
+ quietly {
+ val opId = 0
+ val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+
+ withSpark(new SparkContext(sparkConf)) { sc =>
+ implicit val sqlContext = new SQLContext(sc)
+ val coordinatorRef = sqlContext.streams.stateStoreCoordinator
+ coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1")
+ coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2")
+ assert(
+ coordinatorRef.getLocation(StateStoreId(path, opId, 0)) ===
+ Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
+
+ val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
+ increment, path, opId, storeVersion = 0, keySchema, valueSchema)
+ require(rdd.partitions.length === 2)
+
+ assert(
+ rdd.preferredLocations(rdd.partitions(0)) ===
+ Seq(ExecutorCacheTaskLocation("host1", "exec1").toString))
+
+ assert(
+ rdd.preferredLocations(rdd.partitions(1)) ===
+ Seq(ExecutorCacheTaskLocation("host2", "exec2").toString))
+
+ rdd.collect()
+ }
+ }
+ }
+
+ test("distributed test") {
+ quietly {
+ withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc =>
+ implicit val sqlContet = new SQLContext(sc)
+ val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+ val increment = (store: StateStore, iter: Iterator[String]) => {
+ iter.foreach { s =>
+ store.update(
+ stringToRow(s), oldRow => {
+ val oldValue = oldRow.map(rowToInt).getOrElse(0)
+ intToRow(oldValue + 1)
+ })
+ }
+ store.commit()
+ store.iterator().map(rowsToStringInt)
+ }
+ val opId = 0
+ val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
+ increment, path, opId, storeVersion = 0, keySchema, valueSchema)
+ assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+
+ // Generate next version of stores
+ val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
+ increment, path, opId, storeVersion = 1, keySchema, valueSchema)
+ assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
+
+ // Make sure the previous RDD still has the same data.
+ assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+ }
+ }
+ }
+
+ private def makeRDD(sc: SparkContext, seq: Seq[String]): RDD[String] = {
+ sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2)
+ }
+
+ private val increment = (store: StateStore, iter: Iterator[String]) => {
+ iter.foreach { s =>
+ store.update(
+ stringToRow(s), oldRow => {
+ val oldValue = oldRow.map(rowToInt).getOrElse(0)
+ intToRow(oldValue + 1)
+ })
+ }
+ store.commit()
+ store.iterator().map(rowsToStringInt)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
new file mode 100644
index 0000000000..22b2f4f75d
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io.File
+
+import scala.collection.mutable
+import scala.util.Random
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite}
+import org.apache.spark.LocalSparkContext._
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.util.quietly
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.Utils
+
+class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester {
+ type MapType = mutable.HashMap[UnsafeRow, UnsafeRow]
+
+ import StateStoreCoordinatorSuite._
+ import StateStoreSuite._
+
+ private val tempDir = Utils.createTempDir().toString
+ private val keySchema = StructType(Seq(StructField("key", StringType, true)))
+ private val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+
+ after {
+ StateStore.stop()
+ }
+
+ test("update, remove, commit, and all data iterator") {
+ val provider = newStoreProvider()
+
+ // Verify state before starting a new set of updates
+ assert(provider.latestIterator().isEmpty)
+
+ val store = provider.getStore(0)
+ assert(!store.hasCommitted)
+ intercept[IllegalStateException] {
+ store.iterator()
+ }
+ intercept[IllegalStateException] {
+ store.updates()
+ }
+
+ // Verify state after updating
+ update(store, "a", 1)
+ intercept[IllegalStateException] {
+ store.iterator()
+ }
+ intercept[IllegalStateException] {
+ store.updates()
+ }
+ assert(provider.latestIterator().isEmpty)
+
+ // Make updates, commit and then verify state
+ update(store, "b", 2)
+ update(store, "aa", 3)
+ remove(store, _.startsWith("a"))
+ assert(store.commit() === 1)
+
+ assert(store.hasCommitted)
+ assert(rowsToSet(store.iterator()) === Set("b" -> 2))
+ assert(rowsToSet(provider.latestIterator()) === Set("b" -> 2))
+ assert(fileExists(provider, version = 1, isSnapshot = false))
+
+ assert(getDataFromFiles(provider) === Set("b" -> 2))
+
+ // Trying to get newer versions should fail
+ intercept[Exception] {
+ provider.getStore(2)
+ }
+ intercept[Exception] {
+ getDataFromFiles(provider, 2)
+ }
+
+ // New updates to the reloaded store with new version, and does not change old version
+ val reloadedProvider = new HDFSBackedStateStoreProvider(
+ store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration)
+ val reloadedStore = reloadedProvider.getStore(1)
+ update(reloadedStore, "c", 4)
+ assert(reloadedStore.commit() === 2)
+ assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
+ assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4))
+ assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2))
+ assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4))
+ }
+
+ test("updates iterator with all combos of updates and removes") {
+ val provider = newStoreProvider()
+ var currentVersion: Int = 0
+ def withStore(body: StateStore => Unit): Unit = {
+ val store = provider.getStore(currentVersion)
+ body(store)
+ currentVersion += 1
+ }
+
+ // New data should be seen in updates as value added, even if they had multiple updates
+ withStore { store =>
+ update(store, "a", 1)
+ update(store, "aa", 1)
+ update(store, "aa", 2)
+ store.commit()
+ assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2)))
+ assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2))
+ }
+
+ // Multiple updates to same key should be collapsed in the updates as a single value update
+ // Keys that have not been updated should not appear in the updates
+ withStore { store =>
+ update(store, "a", 4)
+ update(store, "a", 6)
+ store.commit()
+ assert(updatesToSet(store.updates()) === Set(Updated("a", 6)))
+ assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2))
+ }
+
+ // Keys added, updated and finally removed before commit should not appear in updates
+ withStore { store =>
+ update(store, "b", 4) // Added, finally removed
+ update(store, "bb", 5) // Added, updated, finally removed
+ update(store, "bb", 6)
+ remove(store, _.startsWith("b"))
+ store.commit()
+ assert(updatesToSet(store.updates()) === Set.empty)
+ assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2))
+ }
+
+ // Removed data should be seen in updates as a key removed
+ // Removed, but re-added data should be seen in updates as a value update
+ withStore { store =>
+ remove(store, _.startsWith("a"))
+ update(store, "a", 10)
+ store.commit()
+ assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa")))
+ assert(rowsToSet(store.iterator()) === Set("a" -> 10))
+ }
+ }
+
+ test("cancel") {
+ val provider = newStoreProvider()
+ val store = provider.getStore(0)
+ update(store, "a", 1)
+ store.commit()
+ assert(rowsToSet(store.iterator()) === Set("a" -> 1))
+
+ // cancelUpdates should not change the data in the files
+ val store1 = provider.getStore(1)
+ update(store1, "b", 1)
+ store1.cancel()
+ assert(getDataFromFiles(provider) === Set("a" -> 1))
+ }
+
+ test("getStore with unexpected versions") {
+ val provider = newStoreProvider()
+
+ intercept[IllegalArgumentException] {
+ provider.getStore(-1)
+ }
+
+ // Prepare some data in the stoer
+ val store = provider.getStore(0)
+ update(store, "a", 1)
+ assert(store.commit() === 1)
+ assert(rowsToSet(store.iterator()) === Set("a" -> 1))
+
+ intercept[IllegalStateException] {
+ provider.getStore(2)
+ }
+
+ // Update store version with some data
+ val store1 = provider.getStore(1)
+ update(store1, "b", 1)
+ assert(store1.commit() === 2)
+ assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1))
+ assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1))
+
+ // Overwrite the version with other data
+ val store2 = provider.getStore(1)
+ update(store2, "c", 1)
+ assert(store2.commit() === 2)
+ assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1))
+ assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1))
+ }
+
+ test("snapshotting") {
+ val provider = newStoreProvider(minDeltasForSnapshot = 5)
+
+ var currentVersion = 0
+ def updateVersionTo(targetVersion: Int): Unit = {
+ for (i <- currentVersion + 1 to targetVersion) {
+ val store = provider.getStore(currentVersion)
+ update(store, "a", i)
+ store.commit()
+ currentVersion += 1
+ }
+ require(currentVersion === targetVersion)
+ }
+
+ updateVersionTo(2)
+ require(getDataFromFiles(provider) === Set("a" -> 2))
+ provider.doMaintenance() // should not generate snapshot files
+ assert(getDataFromFiles(provider) === Set("a" -> 2))
+
+ for (i <- 1 to currentVersion) {
+ assert(fileExists(provider, i, isSnapshot = false)) // all delta files present
+ assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present
+ }
+
+ // After version 6, snapshotting should generate one snapshot file
+ updateVersionTo(6)
+ require(getDataFromFiles(provider) === Set("a" -> 6), "store not updated correctly")
+ provider.doMaintenance() // should generate snapshot files
+
+ val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true))
+ assert(snapshotVersion.nonEmpty, "snapshot file not generated")
+ deleteFilesEarlierThanVersion(provider, snapshotVersion.get)
+ assert(
+ getDataFromFiles(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get),
+ "snapshotting messed up the data of the snapshotted version")
+ assert(
+ getDataFromFiles(provider) === Set("a" -> 6),
+ "snapshotting messed up the data of the final version")
+
+ // After version 20, snapshotting should generate newer snapshot files
+ updateVersionTo(20)
+ require(getDataFromFiles(provider) === Set("a" -> 20), "store not updated correctly")
+ provider.doMaintenance() // do snapshot
+
+ val latestSnapshotVersion = (0 to 20).filter(version =>
+ fileExists(provider, version, isSnapshot = true)).lastOption
+ assert(latestSnapshotVersion.nonEmpty, "no snapshot file found")
+ assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated")
+
+ deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get)
+ assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data")
+ }
+
+ test("cleaning") {
+ val provider = newStoreProvider(minDeltasForSnapshot = 5)
+
+ for (i <- 1 to 20) {
+ val store = provider.getStore(i - 1)
+ update(store, "a", i)
+ store.commit()
+ provider.doMaintenance() // do cleanup
+ }
+ require(
+ rowsToSet(provider.latestIterator()) === Set("a" -> 20),
+ "store not updated correctly")
+
+ assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted
+
+ // last couple of versions should be retrievable
+ assert(getDataFromFiles(provider, 20) === Set("a" -> 20))
+ assert(getDataFromFiles(provider, 19) === Set("a" -> 19))
+ }
+
+
+ test("corrupted file handling") {
+ val provider = newStoreProvider(minDeltasForSnapshot = 5)
+ for (i <- 1 to 6) {
+ val store = provider.getStore(i - 1)
+ update(store, "a", i)
+ store.commit()
+ provider.doMaintenance() // do cleanup
+ }
+ val snapshotVersion = (0 to 10).find( version =>
+ fileExists(provider, version, isSnapshot = true)).getOrElse(fail("snapshot file not found"))
+
+ // Corrupt snapshot file and verify that it throws error
+ assert(getDataFromFiles(provider, snapshotVersion) === Set("a" -> snapshotVersion))
+ corruptFile(provider, snapshotVersion, isSnapshot = true)
+ intercept[Exception] {
+ getDataFromFiles(provider, snapshotVersion)
+ }
+
+ // Corrupt delta file and verify that it throws error
+ assert(getDataFromFiles(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1)))
+ corruptFile(provider, snapshotVersion - 1, isSnapshot = false)
+ intercept[Exception] {
+ getDataFromFiles(provider, snapshotVersion - 1)
+ }
+
+ // Delete delta file and verify that it throws error
+ deleteFilesEarlierThanVersion(provider, snapshotVersion)
+ intercept[Exception] {
+ getDataFromFiles(provider, snapshotVersion - 1)
+ }
+ }
+
+ test("StateStore.get") {
+ quietly {
+ val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString
+ val storeId = StateStoreId(dir, 0, 0)
+ val storeConf = StateStoreConf.empty
+ val hadoopConf = new Configuration()
+
+
+ // Verify that trying to get incorrect versions throw errors
+ intercept[IllegalArgumentException] {
+ StateStore.get(storeId, keySchema, valueSchema, -1, storeConf, hadoopConf)
+ }
+ assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store
+
+ intercept[IllegalStateException] {
+ StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf)
+ }
+
+ // Increase version of the store
+ val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf)
+ assert(store0.version === 0)
+ update(store0, "a", 1)
+ store0.commit()
+
+ assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1)
+ assert(StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf).version == 0)
+
+ // Verify that you can remove the store and still reload and use it
+ StateStore.unload(storeId)
+ assert(!StateStore.isLoaded(storeId))
+
+ val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf)
+ assert(StateStore.isLoaded(storeId))
+ update(store1, "a", 2)
+ assert(store1.commit() === 2)
+ assert(rowsToSet(store1.iterator()) === Set("a" -> 2))
+ }
+ }
+
+ test("maintenance") {
+ val conf = new SparkConf()
+ .setMaster("local")
+ .setAppName("test")
+ .set(StateStore.MAINTENANCE_INTERVAL_CONFIG, "10ms")
+ .set("spark.rpc.numRetries", "1")
+ val opId = 0
+ val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString
+ val storeId = StateStoreId(dir, opId, 0)
+ val storeConf = StateStoreConf.empty
+ val hadoopConf = new Configuration()
+ val provider = new HDFSBackedStateStoreProvider(
+ storeId, keySchema, valueSchema, storeConf, hadoopConf)
+
+ quietly {
+ withSpark(new SparkContext(conf)) { sc =>
+ withCoordinatorRef(sc) { coordinatorRef =>
+ for (i <- 1 to 20) {
+ val store = StateStore.get(
+ storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf)
+ update(store, "a", i)
+ store.commit()
+ }
+ eventually(timeout(10 seconds)) {
+ assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported")
+ }
+
+ // Background maintenance should clean up and generate snapshots
+ eventually(timeout(10 seconds)) {
+ // Earliest delta file should get cleaned up
+ assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted")
+
+ // Some snapshots should have been generated
+ val snapshotVersions = (0 to 20).filter { version =>
+ fileExists(provider, version, isSnapshot = true)
+ }
+ assert(snapshotVersions.nonEmpty, "no snapshot file found")
+ }
+
+ // If driver decides to deactivate all instances of the store, then this instance
+ // should be unloaded
+ coordinatorRef.deactivateInstances(dir)
+ eventually(timeout(10 seconds)) {
+ assert(!StateStore.isLoaded(storeId))
+ }
+
+ // Reload the store and verify
+ StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf)
+ assert(StateStore.isLoaded(storeId))
+
+ // If some other executor loads the store, then this instance should be unloaded
+ coordinatorRef.reportActiveInstance(storeId, "other-host", "other-exec")
+ eventually(timeout(10 seconds)) {
+ assert(!StateStore.isLoaded(storeId))
+ }
+
+ // Reload the store and verify
+ StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf)
+ assert(StateStore.isLoaded(storeId))
+ }
+ }
+
+ // Verify if instance is unloaded if SparkContext is stopped
+ require(SparkEnv.get === null)
+ eventually(timeout(10 seconds)) {
+ assert(!StateStore.isLoaded(storeId))
+ }
+ }
+ }
+
+ def getDataFromFiles(
+ provider: HDFSBackedStateStoreProvider,
+ version: Int = -1): Set[(String, Int)] = {
+ val reloadedProvider = new HDFSBackedStateStoreProvider(
+ provider.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration)
+ if (version < 0) {
+ reloadedProvider.latestIterator().map(rowsToStringInt).toSet
+ } else {
+ reloadedProvider.iterator(version).map(rowsToStringInt).toSet
+ }
+ }
+
+ def assertMap(
+ testMapOption: Option[MapType],
+ expectedMap: Map[String, Int]): Unit = {
+ assert(testMapOption.nonEmpty, "no map present")
+ val convertedMap = testMapOption.get.map(rowsToStringInt)
+ assert(convertedMap === expectedMap)
+ }
+
+ def fileExists(
+ provider: HDFSBackedStateStoreProvider,
+ version: Long,
+ isSnapshot: Boolean): Boolean = {
+ val method = PrivateMethod[Path]('baseDir)
+ val basePath = provider invokePrivate method()
+ val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta"
+ val filePath = new File(basePath.toString, fileName)
+ filePath.exists
+ }
+
+ def deleteFilesEarlierThanVersion(provider: HDFSBackedStateStoreProvider, version: Long): Unit = {
+ val method = PrivateMethod[Path]('baseDir)
+ val basePath = provider invokePrivate method()
+ for (version <- 0 until version.toInt) {
+ for (isSnapshot <- Seq(false, true)) {
+ val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta"
+ val filePath = new File(basePath.toString, fileName)
+ if (filePath.exists) filePath.delete()
+ }
+ }
+ }
+
+ def corruptFile(
+ provider: HDFSBackedStateStoreProvider,
+ version: Long,
+ isSnapshot: Boolean): Unit = {
+ val method = PrivateMethod[Path]('baseDir)
+ val basePath = provider invokePrivate method()
+ val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta"
+ val filePath = new File(basePath.toString, fileName)
+ filePath.delete()
+ filePath.createNewFile()
+ }
+
+ def storeLoaded(storeId: StateStoreId): Boolean = {
+ val method = PrivateMethod[mutable.HashMap[StateStoreId, StateStore]]('loadedStores)
+ val loadedStores = StateStore invokePrivate method()
+ loadedStores.contains(storeId)
+ }
+
+ def unloadStore(storeId: StateStoreId): Boolean = {
+ val method = PrivateMethod('remove)
+ StateStore invokePrivate method(storeId)
+ }
+
+ def newStoreProvider(
+ opId: Long = Random.nextLong,
+ partition: Int = 0,
+ minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get
+ ): HDFSBackedStateStoreProvider = {
+ val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString
+ val sqlConf = new SQLConf()
+ sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot)
+ new HDFSBackedStateStoreProvider(
+ StateStoreId(dir, opId, partition),
+ keySchema,
+ valueSchema,
+ new StateStoreConf(sqlConf),
+ new Configuration())
+ }
+
+ def remove(store: StateStore, condition: String => Boolean): Unit = {
+ store.remove(row => condition(rowToString(row)))
+ }
+
+ private def update(store: StateStore, key: String, value: Int): Unit = {
+ store.update(stringToRow(key), _ => intToRow(value))
+ }
+}
+
+private[state] object StateStoreSuite {
+
+ /** Trait and classes mirroring [[StoreUpdate]] for testing store updates iterator */
+ trait TestUpdate
+ case class Added(key: String, value: Int) extends TestUpdate
+ case class Updated(key: String, value: Int) extends TestUpdate
+ case class Removed(key: String) extends TestUpdate
+
+ val strProj = UnsafeProjection.create(Array[DataType](StringType))
+ val intProj = UnsafeProjection.create(Array[DataType](IntegerType))
+
+ def stringToRow(s: String): UnsafeRow = {
+ strProj.apply(new GenericInternalRow(Array[Any](UTF8String.fromString(s)))).copy()
+ }
+
+ def intToRow(i: Int): UnsafeRow = {
+ intProj.apply(new GenericInternalRow(Array[Any](i))).copy()
+ }
+
+ def rowToString(row: UnsafeRow): String = {
+ row.getUTF8String(0).toString
+ }
+
+ def rowToInt(row: UnsafeRow): Int = {
+ row.getInt(0)
+ }
+
+ def rowsToIntInt(row: (UnsafeRow, UnsafeRow)): (Int, Int) = {
+ (rowToInt(row._1), rowToInt(row._2))
+ }
+
+
+ def rowsToStringInt(row: (UnsafeRow, UnsafeRow)): (String, Int) = {
+ (rowToString(row._1), rowToInt(row._2))
+ }
+
+ def rowsToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, Int)] = {
+ iterator.map(rowsToStringInt).toSet
+ }
+
+ def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = {
+ iterator.map { _ match {
+ case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value))
+ case ValueUpdated(key, value) => Updated(rowToString(key), rowToInt(value))
+ case KeyRemoved(key) => Removed(rowToString(key))
+ }}.toSet
+ }
+}