From 8c826880f5eaa3221c4e9e7d3fece54e821a0b98 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 23 Mar 2016 12:48:05 -0700 Subject: [SPARK-13809][SQL] State store for streaming aggregations ## What changes were proposed in this pull request? In this PR, I am implementing a new abstraction for management of streaming state data - State Store. It is a key-value store for persisting running aggregates for aggregate operations in streaming dataframes. The motivation and design is discussed here. https://docs.google.com/document/d/1-ncawFx8JS5Zyfq1HAEGBx56RDet9wfVp_hDM8ZL254/edit# ## How was this patch tested? - [x] Unit tests - [x] Cluster tests **Coverage from unit tests** screen shot 2016-03-21 at 3 09 40 pm ## TODO - [x] Fix updates() iterator to avoid duplicate updates for same key - [x] Use Coordinator in ContinuousQueryManager - [x] Plugging in hadoop conf and other confs - [x] Unit tests - [x] StateStore object lifecycle and methods - [x] StateStoreCoordinator communication and logic - [x] StateStoreRDD fault-tolerance - [x] StateStoreRDD preferred location using StateStoreCoordinator - [ ] Cluster tests - [ ] Whether preferred locations are set correctly - [ ] Whether recovery works correctly with distributed storage - [x] Basic performance tests - [x] Docs Author: Tathagata Das Closes #11645 from tdas/state-store. --- .../apache/spark/sql/ContinuousQueryManager.scala | 3 + .../state/HDFSBackedStateStoreProvider.scala | 584 +++++++++++++++++++++ .../sql/execution/streaming/state/StateStore.scala | 247 +++++++++ .../execution/streaming/state/StateStoreConf.scala | 37 ++ .../streaming/state/StateStoreCoordinator.scala | 146 ++++++ .../execution/streaming/state/StateStoreRDD.scala | 70 +++ .../sql/execution/streaming/state/package.scala | 75 +++ .../org/apache/spark/sql/internal/SQLConf.scala | 13 + .../state/StateStoreCoordinatorSuite.scala | 123 +++++ .../streaming/state/StateStoreRDDSuite.scala | 192 +++++++ .../streaming/state/StateStoreSuite.scala | 562 ++++++++++++++++++++ 11 files changed, 2052 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala index fa8219bbed..465feeb604 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution} +import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.util.ContinuousQueryListener /** @@ -33,6 +34,8 @@ import org.apache.spark.sql.util.ContinuousQueryListener @Experimental class ContinuousQueryManager(sqlContext: SQLContext) { + private[sql] val stateStoreCoordinator = + StateStoreCoordinatorRef.forDriver(sqlContext.sparkContext.env) private val listenerBus = new ContinuousQueryListenerBus(sqlContext.sparkContext.listenerBus) private val activeQueries = new mutable.HashMap[String, ContinuousQuery] private val activeQueriesLock = new Object diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala new file mode 100644 index 0000000000..ee015baf3f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -0,0 +1,584 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.{DataInputStream, DataOutputStream, IOException} + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.Random +import scala.util.control.NonFatal + +import com.google.common.io.ByteStreams +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.io.LZ4CompressionCodec +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + + +/** + * An implementation of [[StateStoreProvider]] and [[StateStore]] in which all the data is backed + * by files in a HDFS-compatible file system. All updates to the store has to be done in sets + * transactionally, and each set of updates increments the store's version. These versions can + * be used to re-execute the updates (by retries in RDD operations) on the correct version of + * the store, and regenerate the store version. + * + * Usage: + * To update the data in the state store, the following order of operations are needed. + * + * - val store = StateStore.get(operatorId, partitionId, version) // to get the right store + * - store.update(...) + * - store.remove(...) + * - store.commit() // commits all the updates to made with version number + * - store.iterator() // key-value data after last commit as an iterator + * - store.updates() // updates made in the last as an iterator + * + * Fault-tolerance model: + * - Every set of updates is written to a delta file before committing. + * - The state store is responsible for managing, collapsing and cleaning up of delta files. + * - Multiple attempts to commit the same version of updates may overwrite each other. + * Consistency guarantees depend on whether multiple attempts have the same updates and + * the overwrite semantics of underlying file system. + * - Background maintenance of files ensures that last versions of the store is always recoverable + * to ensure re-executed RDD operations re-apply updates on the correct past version of the + * store. + */ +private[state] class HDFSBackedStateStoreProvider( + val id: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + storeConf: StateStoreConf, + hadoopConf: Configuration + ) extends StateStoreProvider with Logging { + + type MapType = java.util.HashMap[UnsafeRow, UnsafeRow] + + /** Implementation of [[StateStore]] API which is backed by a HDFS-compatible file system */ + class HDFSBackedStateStore(val version: Long, mapToUpdate: MapType) + extends StateStore { + + /** Trait and classes representing the internal state of the store */ + trait STATE + case object UPDATING extends STATE + case object COMMITTED extends STATE + case object CANCELLED extends STATE + + private val newVersion = version + 1 + private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") + private val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) + + private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]() + + @volatile private var state: STATE = UPDATING + @volatile private var finalDeltaFile: Path = null + + override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id + + /** + * Update the value of a key using the value generated by the update function. + * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous + * versions of the store data. + */ + override def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit = { + verify(state == UPDATING, "Cannot update after already committed or cancelled") + val oldValueOption = Option(mapToUpdate.get(key)) + val value = updateFunc(oldValueOption) + mapToUpdate.put(key, value) + + Option(allUpdates.get(key)) match { + case Some(ValueAdded(_, _)) => + // Value did not exist in previous version and was added already, keep it marked as added + allUpdates.put(key, ValueAdded(key, value)) + case Some(ValueUpdated(_, _)) | Some(KeyRemoved(_)) => + // Value existed in prev version and updated/removed, mark it as updated + allUpdates.put(key, ValueUpdated(key, value)) + case None => + // There was no prior update, so mark this as added or updated according to its presence + // in previous version. + val update = + if (oldValueOption.nonEmpty) ValueUpdated(key, value) else ValueAdded(key, value) + allUpdates.put(key, update) + } + writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value)) + } + + /** Remove keys that match the following condition */ + override def remove(condition: UnsafeRow => Boolean): Unit = { + verify(state == UPDATING, "Cannot remove after already committed or cancelled") + val keyIter = mapToUpdate.keySet().iterator() + while (keyIter.hasNext) { + val key = keyIter.next + if (condition(key)) { + keyIter.remove() + + Option(allUpdates.get(key)) match { + case Some(ValueUpdated(_, _)) | None => + // Value existed in previous version and maybe was updated, mark removed + allUpdates.put(key, KeyRemoved(key)) + case Some(ValueAdded(_, _)) => + // Value did not exist in previous version and was added, should not appear in updates + allUpdates.remove(key) + case Some(KeyRemoved(_)) => + // Remove already in update map, no need to change + } + writeToDeltaFile(tempDeltaFileStream, KeyRemoved(key)) + } + } + } + + /** Commit all the updates that have been made to the store, and return the new version. */ + override def commit(): Long = { + verify(state == UPDATING, "Cannot commit again after already committed or cancelled") + + try { + finalizeDeltaFile(tempDeltaFileStream) + finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile) + state = COMMITTED + logInfo(s"Committed version $newVersion for $this") + newVersion + } catch { + case NonFatal(e) => + throw new IllegalStateException( + s"Error committing version $newVersion into ${HDFSBackedStateStoreProvider.this}", e) + } + } + + /** Cancel all the updates made on this store. This store will not be usable any more. */ + override def cancel(): Unit = { + state = CANCELLED + if (tempDeltaFileStream != null) { + tempDeltaFileStream.close() + } + if (tempDeltaFile != null && fs.exists(tempDeltaFile)) { + fs.delete(tempDeltaFile, true) + } + logInfo("Canceled ") + } + + /** + * Get an iterator of all the store data. This can be called only after committing the + * updates. + */ + override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { + verify(state == COMMITTED, "Cannot get iterator of store data before comitting") + HDFSBackedStateStoreProvider.this.iterator(newVersion) + } + + /** + * Get an iterator of all the updates made to the store in the current version. + * This can be called only after committing the updates. + */ + override def updates(): Iterator[StoreUpdate] = { + verify(state == COMMITTED, "Cannot get iterator of updates before committing") + allUpdates.values().asScala.toIterator + } + + /** + * Whether all updates have been committed + */ + override def hasCommitted: Boolean = { + state == COMMITTED + } + } + + /** Get the state store for making updates to create a new `version` of the store. */ + override def getStore(version: Long): StateStore = synchronized { + require(version >= 0, "Version cannot be less than 0") + val newMap = new MapType() + if (version > 0) { + newMap.putAll(loadMap(version)) + } + val store = new HDFSBackedStateStore(version, newMap) + logInfo(s"Retrieved version $version of $this for update") + store + } + + /** Do maintenance backing data files, including creating snapshots and cleaning up old files */ + override def doMaintenance(): Unit = { + try { + doSnapshot() + cleanup() + } catch { + case NonFatal(e) => + logWarning(s"Error performing snapshot and cleaning up $this") + } + } + + override def toString(): String = { + s"StateStore[id = (op=${id.operatorId},part=${id.partitionId}), dir = $baseDir]" + } + + /* Internal classes and methods */ + + private val loadedMaps = new mutable.HashMap[Long, MapType] + private val baseDir = + new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}") + private val fs = baseDir.getFileSystem(hadoopConf) + private val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + + initialize() + + private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) + + /** Commit a set of updates to the store with the given new version */ + private def commitUpdates(newVersion: Long, map: MapType, tempDeltaFile: Path): Path = { + synchronized { + val finalDeltaFile = deltaFile(newVersion) + fs.rename(tempDeltaFile, finalDeltaFile) + loadedMaps.put(newVersion, map) + finalDeltaFile + } + } + + /** + * Get iterator of all the data of the latest version of the store. + * Note that this will look up the files to determined the latest known version. + */ + private[state] def latestIterator(): Iterator[(UnsafeRow, UnsafeRow)] = synchronized { + val versionsInFiles = fetchFiles().map(_.version).toSet + val versionsLoaded = loadedMaps.keySet + val allKnownVersions = versionsInFiles ++ versionsLoaded + if (allKnownVersions.nonEmpty) { + loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { x => + (x.getKey, x.getValue) + } + } else Iterator.empty + } + + /** Get iterator of a specific version of the store */ + private[state] def iterator(version: Long): Iterator[(UnsafeRow, UnsafeRow)] = synchronized { + loadMap(version).entrySet().iterator().asScala.map { x => + (x.getKey, x.getValue) + } + } + + /** Initialize the store provider */ + private def initialize(): Unit = { + if (!fs.exists(baseDir)) { + fs.mkdirs(baseDir) + } else { + if (!fs.isDirectory(baseDir)) { + throw new IllegalStateException( + s"Cannot use ${id.checkpointLocation} for storing state data for $this as" + + s"$baseDir already exists and is not a directory") + } + } + } + + /** Load the required version of the map data from the backing files */ + private def loadMap(version: Long): MapType = { + if (version <= 0) return new MapType + synchronized { loadedMaps.get(version) }.getOrElse { + val mapFromFile = readSnapshotFile(version).getOrElse { + val prevMap = loadMap(version - 1) + val newMap = new MapType(prevMap) + newMap.putAll(prevMap) + updateFromDeltaFile(version, newMap) + newMap + } + loadedMaps.put(version, mapFromFile) + mapFromFile + } + } + + private def writeToDeltaFile(output: DataOutputStream, update: StoreUpdate): Unit = { + + def writeUpdate(key: UnsafeRow, value: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + val valueBytes = value.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(valueBytes.size) + output.write(valueBytes) + } + + def writeRemove(key: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(-1) + } + + update match { + case ValueAdded(key, value) => + writeUpdate(key, value) + case ValueUpdated(key, value) => + writeUpdate(key, value) + case KeyRemoved(key) => + writeRemove(key) + } + } + + private def finalizeDeltaFile(output: DataOutputStream): Unit = { + output.writeInt(-1) // Write this magic number to signify end of file + output.close() + } + + private def updateFromDeltaFile(version: Long, map: MapType): Unit = { + val fileToRead = deltaFile(version) + if (!fs.exists(fileToRead)) { + throw new IllegalStateException( + s"Error reading delta file $fileToRead of $this: $fileToRead does not exist") + } + var input: DataInputStream = null + try { + input = decompressStream(fs.open(fileToRead)) + var eof = false + + while(!eof) { + val keySize = input.readInt() + if (keySize == -1) { + eof = true + } else if (keySize < 0) { + throw new IOException( + s"Error reading delta file $fileToRead of $this: key size cannot be $keySize") + } else { + val keyRowBuffer = new Array[Byte](keySize) + ByteStreams.readFully(input, keyRowBuffer, 0, keySize) + + val keyRow = new UnsafeRow(keySchema.fields.length) + keyRow.pointTo(keyRowBuffer, keySize) + + val valueSize = input.readInt() + if (valueSize < 0) { + map.remove(keyRow) + } else { + val valueRowBuffer = new Array[Byte](valueSize) + ByteStreams.readFully(input, valueRowBuffer, 0, valueSize) + val valueRow = new UnsafeRow(valueSchema.fields.length) + valueRow.pointTo(valueRowBuffer, valueSize) + map.put(keyRow, valueRow) + } + } + } + } finally { + if (input != null) input.close() + } + logInfo(s"Read delta file for version $version of $this from $fileToRead") + } + + private def writeSnapshotFile(version: Long, map: MapType): Unit = { + val fileToWrite = snapshotFile(version) + var output: DataOutputStream = null + Utils.tryWithSafeFinally { + output = compressStream(fs.create(fileToWrite, false)) + val iter = map.entrySet().iterator() + while(iter.hasNext) { + val entry = iter.next() + val keyBytes = entry.getKey.getBytes() + val valueBytes = entry.getValue.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(valueBytes.size) + output.write(valueBytes) + } + output.writeInt(-1) + } { + if (output != null) output.close() + } + logInfo(s"Written snapshot file for version $version of $this at $fileToWrite") + } + + private def readSnapshotFile(version: Long): Option[MapType] = { + val fileToRead = snapshotFile(version) + if (!fs.exists(fileToRead)) return None + + val map = new MapType() + var input: DataInputStream = null + + try { + input = decompressStream(fs.open(fileToRead)) + var eof = false + + while (!eof) { + val keySize = input.readInt() + if (keySize == -1) { + eof = true + } else if (keySize < 0) { + throw new IOException( + s"Error reading snapshot file $fileToRead of $this: key size cannot be $keySize") + } else { + val keyRowBuffer = new Array[Byte](keySize) + ByteStreams.readFully(input, keyRowBuffer, 0, keySize) + + val keyRow = new UnsafeRow(keySchema.fields.length) + keyRow.pointTo(keyRowBuffer, keySize) + + val valueSize = input.readInt() + if (valueSize < 0) { + throw new IOException( + s"Error reading snapshot file $fileToRead of $this: value size cannot be $valueSize") + } else { + val valueRowBuffer = new Array[Byte](valueSize) + ByteStreams.readFully(input, valueRowBuffer, 0, valueSize) + val valueRow = new UnsafeRow(valueSchema.fields.length) + valueRow.pointTo(valueRowBuffer, valueSize) + map.put(keyRow, valueRow) + } + } + } + logInfo(s"Read snapshot file for version $version of $this from $fileToRead") + Some(map) + } finally { + if (input != null) input.close() + } + } + + + /** Perform a snapshot of the store to allow delta files to be consolidated */ + private def doSnapshot(): Unit = { + try { + val files = fetchFiles() + if (files.nonEmpty) { + val lastVersion = files.last.version + val deltaFilesForLastVersion = + filesForVersion(files, lastVersion).filter(_.isSnapshot == false) + synchronized { loadedMaps.get(lastVersion) } match { + case Some(map) => + if (deltaFilesForLastVersion.size > storeConf.maxDeltasForSnapshot) { + writeSnapshotFile(lastVersion, map) + } + case None => + // The last map is not loaded, probably some other instance is incharge + } + + } + } catch { + case NonFatal(e) => + logWarning(s"Error doing snapshots for $this", e) + } + } + + /** + * Clean up old snapshots and delta files that are not needed any more. It ensures that last + * few versions of the store can be recovered from the files, so re-executed RDD operations + * can re-apply updates on the past versions of the store. + */ + private[state] def cleanup(): Unit = { + try { + val files = fetchFiles() + if (files.nonEmpty) { + val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain + if (earliestVersionToRetain > 0) { + val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head + synchronized { + val mapsToRemove = loadedMaps.keys.filter(_ < earliestVersionToRetain).toSeq + mapsToRemove.foreach(loadedMaps.remove) + } + files.filter(_.version < earliestFileToRetain.version).foreach { f => + fs.delete(f.path, true) + } + logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this") + } + } + } catch { + case NonFatal(e) => + logWarning(s"Error cleaning up files for $this", e) + } + } + + /** Files needed to recover the given version of the store */ + private def filesForVersion(allFiles: Seq[StoreFile], version: Long): Seq[StoreFile] = { + require(version >= 0) + require(allFiles.exists(_.version == version)) + + val latestSnapshotFileBeforeVersion = allFiles + .filter(_.isSnapshot == true) + .takeWhile(_.version <= version) + .lastOption + val deltaBatchFiles = latestSnapshotFileBeforeVersion match { + case Some(snapshotFile) => + val deltaBatchIds = (snapshotFile.version + 1) to version + + val deltaFiles = allFiles.filter { file => + file.version > snapshotFile.version && file.version <= version + } + verify( + deltaFiles.size == version - snapshotFile.version, + s"Unexpected list of delta files for version $version for $this: $deltaFiles" + ) + deltaFiles + + case None => + allFiles.takeWhile(_.version <= version) + } + latestSnapshotFileBeforeVersion.toSeq ++ deltaBatchFiles + } + + /** Fetch all the files that back the store */ + private def fetchFiles(): Seq[StoreFile] = { + val files: Seq[FileStatus] = try { + fs.listStatus(baseDir) + } catch { + case _: java.io.FileNotFoundException => + Seq.empty + } + val versionToFiles = new mutable.HashMap[Long, StoreFile] + files.foreach { status => + val path = status.getPath + val nameParts = path.getName.split("\\.") + if (nameParts.size == 2) { + val version = nameParts(0).toLong + nameParts(1).toLowerCase match { + case "delta" => + // ignore the file otherwise, snapshot file already exists for that batch id + if (!versionToFiles.contains(version)) { + versionToFiles.put(version, StoreFile(version, path, isSnapshot = false)) + } + case "snapshot" => + versionToFiles.put(version, StoreFile(version, path, isSnapshot = true)) + case _ => + logWarning(s"Could not identify file $path for $this") + } + } + } + val storeFiles = versionToFiles.values.toSeq.sortBy(_.version) + logDebug(s"Current set of files for $this: $storeFiles") + storeFiles + } + + private def compressStream(outputStream: DataOutputStream): DataOutputStream = { + val compressed = new LZ4CompressionCodec(sparkConf).compressedOutputStream(outputStream) + new DataOutputStream(compressed) + } + + private def decompressStream(inputStream: DataInputStream): DataInputStream = { + val compressed = new LZ4CompressionCodec(sparkConf).compressedInputStream(inputStream) + new DataInputStream(compressed) + } + + private def deltaFile(version: Long): Path = { + new Path(baseDir, s"$version.delta") + } + + private def snapshotFile(version: Long): Path = { + new Path(baseDir, s"$version.snapshot") + } + + private def verify(condition: => Boolean, msg: String): Unit = { + if (!condition) { + throw new IllegalStateException(msg) + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala new file mode 100644 index 0000000000..ca5c864d9e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.Timer +import java.util.concurrent.{ScheduledFuture, TimeUnit} + +import scala.collection.mutable +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ThreadUtils + + +/** Unique identifier for a [[StateStore]] */ +case class StateStoreId(checkpointLocation: String, operatorId: Long, partitionId: Int) + + +/** + * Base trait for a versioned key-value store used for streaming aggregations + */ +trait StateStore { + + /** Unique identifier of the store */ + def id: StateStoreId + + /** Version of the data in this store before committing updates. */ + def version: Long + + /** + * Update the value of a key using the value generated by the update function. + * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous + * versions of the store data. + */ + def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit + + /** + * Remove keys that match the following condition. + */ + def remove(condition: UnsafeRow => Boolean): Unit + + /** + * Commit all the updates that have been made to the store, and return the new version. + */ + def commit(): Long + + /** Cancel all the updates that have been made to the store. */ + def cancel(): Unit + + /** + * Iterator of store data after a set of updates have been committed. + * This can be called only after commitUpdates() has been called in the current thread. + */ + def iterator(): Iterator[(UnsafeRow, UnsafeRow)] + + /** + * Iterator of the updates that have been committed. + * This can be called only after commitUpdates() has been called in the current thread. + */ + def updates(): Iterator[StoreUpdate] + + /** + * Whether all updates have been committed + */ + def hasCommitted: Boolean +} + + +/** Trait representing a provider of a specific version of a [[StateStore]]. */ +trait StateStoreProvider { + + /** Get the store with the existing version. */ + def getStore(version: Long): StateStore + + /** Optional method for providers to allow for background maintenance */ + def doMaintenance(): Unit = { } +} + + +/** Trait representing updates made to a [[StateStore]]. */ +sealed trait StoreUpdate + +case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate + +case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate + +case class KeyRemoved(key: UnsafeRow) extends StoreUpdate + + +/** + * Companion object to [[StateStore]] that provides helper methods to create and retrieve stores + * by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null), + * it also runs a periodic background tasks to do maintenance on the loaded stores. For each + * store, tt uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of + * the store is the active instance. Accordingly, it either keeps it loaded and performs + * maintenance, or unloads the store. + */ +private[state] object StateStore extends Logging { + + val MAINTENANCE_INTERVAL_CONFIG = "spark.streaming.stateStore.maintenanceInterval" + val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 + + private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() + private val maintenanceTaskExecutor = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task") + + @volatile private var maintenanceTask: ScheduledFuture[_] = null + @volatile private var _coordRef: StateStoreCoordinatorRef = null + + /** Get or create a store associated with the id. */ + def get( + storeId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + version: Long, + storeConf: StateStoreConf, + hadoopConf: Configuration): StateStore = { + require(version >= 0) + val storeProvider = loadedProviders.synchronized { + startMaintenanceIfNeeded() + val provider = loadedProviders.getOrElseUpdate( + storeId, + new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, storeConf, hadoopConf)) + reportActiveStoreInstance(storeId) + provider + } + storeProvider.getStore(version) + } + + /** Unload a state store provider */ + def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized { + loadedProviders.remove(storeId) + } + + /** Whether a state store provider is loaded or not */ + def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized { + loadedProviders.contains(storeId) + } + + /** Unload and stop all state store providers */ + def stop(): Unit = loadedProviders.synchronized { + loadedProviders.clear() + _coordRef = null + if (maintenanceTask != null) { + maintenanceTask.cancel(false) + maintenanceTask = null + } + logInfo("StateStore stopped") + } + + /** Start the periodic maintenance task if not already started and if Spark active */ + private def startMaintenanceIfNeeded(): Unit = loadedProviders.synchronized { + val env = SparkEnv.get + if (maintenanceTask == null && env != null) { + val periodMs = env.conf.getTimeAsMs( + MAINTENANCE_INTERVAL_CONFIG, s"${MAINTENANCE_INTERVAL_DEFAULT_SECS}s") + val runnable = new Runnable { + override def run(): Unit = { doMaintenance() } + } + maintenanceTask = maintenanceTaskExecutor.scheduleAtFixedRate( + runnable, periodMs, periodMs, TimeUnit.MILLISECONDS) + logInfo("State Store maintenance task started") + } + } + + /** + * Execute background maintenance task in all the loaded store providers if they are still + * the active instances according to the coordinator. + */ + private def doMaintenance(): Unit = { + logDebug("Doing maintenance") + loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => + try { + if (verifyIfStoreInstanceActive(id)) { + provider.doMaintenance() + } else { + unload(id) + logInfo(s"Unloaded $provider") + } + } catch { + case NonFatal(e) => + logWarning(s"Error managing $provider") + } + } + } + + private def reportActiveStoreInstance(storeId: StateStoreId): Unit = { + try { + val host = SparkEnv.get.blockManager.blockManagerId.host + val executorId = SparkEnv.get.blockManager.blockManagerId.executorId + coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId)) + logDebug(s"Reported that the loaded instance $storeId is active") + } catch { + case NonFatal(e) => + logWarning(s"Error reporting active instance of $storeId") + } + } + + private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = { + try { + val executorId = SparkEnv.get.blockManager.blockManagerId.executorId + val verified = + coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false) + logDebug(s"Verifyied whether the loaded instance $storeId is active: $verified" ) + verified + } catch { + case NonFatal(e) => + logWarning(s"Error verifying active instance of $storeId") + false + } + } + + private def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized { + val env = SparkEnv.get + if (env != null) { + if (_coordRef == null) { + _coordRef = StateStoreCoordinatorRef.forExecutor(env) + } + logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") + Some(_coordRef) + } else { + _coordRef = null + None + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala new file mode 100644 index 0000000000..cca22a0af8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.internal.SQLConf + +/** A class that contains configuration parameters for [[StateStore]]s. */ +private[state] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { + + def this() = this(new SQLConf) + + import SQLConf._ + + val maxDeltasForSnapshot = conf.getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) + + val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN) +} + +private[state] object StateStoreConf { + val empty = new StateStoreConf() +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala new file mode 100644 index 0000000000..5aa0636850 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.collection.mutable + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.util.RpcUtils + +/** Trait representing all messages to [[StateStoreCoordinator]] */ +private sealed trait StateStoreCoordinatorMessage extends Serializable + +/** Classes representing messages */ +private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String) + extends StateStoreCoordinatorMessage + +private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: String) + extends StateStoreCoordinatorMessage + +private case class GetLocation(storeId: StateStoreId) + extends StateStoreCoordinatorMessage + +private case class DeactivateInstances(storeRootLocation: String) + extends StateStoreCoordinatorMessage + +private object StopCoordinator + extends StateStoreCoordinatorMessage + +/** Helper object used to create reference to [[StateStoreCoordinator]]. */ +private[sql] object StateStoreCoordinatorRef extends Logging { + + private val endpointName = "StateStoreCoordinator" + + /** + * Create a reference to a [[StateStoreCoordinator]], This can be called from driver as well as + * executors. + */ + def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized { + try { + val coordinator = new StateStoreCoordinator(env.rpcEnv) + val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator) + logInfo("Registered StateStoreCoordinator endpoint") + new StateStoreCoordinatorRef(coordinatorRef) + } catch { + case e: IllegalArgumentException => + val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv) + logDebug("Retrieved existing StateStoreCoordinator endpoint") + new StateStoreCoordinatorRef(rpcEndpointRef) + } + } + + def forExecutor(env: SparkEnv): StateStoreCoordinatorRef = synchronized { + val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv) + logDebug("Retrieved existing StateStoreCoordinator endpoint") + new StateStoreCoordinatorRef(rpcEndpointRef) + } +} + +/** + * Reference to a [[StateStoreCoordinator]] that can be used to coordinator instances of + * [[StateStore]]s across all the executors, and get their locations for job scheduling. + */ +private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { + + private[state] def reportActiveInstance( + storeId: StateStoreId, + host: String, + executorId: String): Unit = { + rpcEndpointRef.send(ReportActiveInstance(storeId, host, executorId)) + } + + /** Verify whether the given executor has the active instance of a state store */ + private[state] def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = { + rpcEndpointRef.askWithRetry[Boolean](VerifyIfInstanceActive(storeId, executorId)) + } + + /** Get the location of the state store */ + private[state] def getLocation(storeId: StateStoreId): Option[String] = { + rpcEndpointRef.askWithRetry[Option[String]](GetLocation(storeId)) + } + + /** Deactivate instances related to a set of operator */ + private[state] def deactivateInstances(storeRootLocation: String): Unit = { + rpcEndpointRef.askWithRetry[Boolean](DeactivateInstances(storeRootLocation)) + } + + private[state] def stop(): Unit = { + rpcEndpointRef.askWithRetry[Boolean](StopCoordinator) + } +} + + +/** + * Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster, + * and get their locations for job scheduling. + */ +private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends RpcEndpoint { + private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] + + override def receive: PartialFunction[Any, Unit] = { + case ReportActiveInstance(id, host, executorId) => + instances.put(id, ExecutorCacheTaskLocation(host, executorId)) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case VerifyIfInstanceActive(id, execId) => + val response = instances.get(id) match { + case Some(location) => location.executorId == execId + case None => false + } + context.reply(response) + + case GetLocation(id) => + context.reply(instances.get(id).map(_.toString)) + + case DeactivateInstances(loc) => + val storeIdsToRemove = + instances.keys.filter(_.checkpointLocation == loc).toSeq + instances --= storeIdsToRemove + context.reply(true) + + case StopCoordinator => + stop() // Stop before replying to ensure that endpoint name has been deregistered + context.reply(true) + } +} + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala new file mode 100644 index 0000000000..3318660895 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{SerializableConfiguration, Utils} + +/** + * An RDD that allows computations to be executed against [[StateStore]]s. It + * uses the [[StateStoreCoordinator]] to use the locations of loaded state stores as + * preferred locations. + */ +class StateStoreRDD[T: ClassTag, U: ClassTag]( + dataRDD: RDD[T], + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + checkpointLocation: String, + operatorId: Long, + storeVersion: Long, + keySchema: StructType, + valueSchema: StructType, + storeConf: StateStoreConf, + @transient private val storeCoordinator: Option[StateStoreCoordinatorRef]) + extends RDD[U](dataRDD) { + + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it + private val confBroadcast = dataRDD.context.broadcast( + new SerializableConfiguration(dataRDD.context.hadoopConfiguration)) + + override protected def getPartitions: Array[Partition] = dataRDD.partitions + + override def getPreferredLocations(partition: Partition): Seq[String] = { + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) + storeCoordinator.flatMap(_.getLocation(storeId)).toSeq + } + + override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { + var store: StateStore = null + + Utils.tryWithSafeFinally { + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) + store = StateStore.get( + storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value) + val inputIter = dataRDD.iterator(partition, ctxt) + val outputIter = storeUpdateFunction(store, inputIter) + assert(store.hasCommitted) + outputIter + } { + if (store != null) store.cancel() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala new file mode 100644 index 0000000000..b249e37921 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.reflect.ClassTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.types.StructType + +package object state { + + implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) { + + /** Map each partition of a RDD along with data in a [[StateStore]]. */ + def mapPartitionWithStateStore[U: ClassTag]( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + checkpointLocation: String, + operatorId: Long, + storeVersion: Long, + keySchema: StructType, + valueSchema: StructType + )(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = { + + mapPartitionWithStateStore( + storeUpdateFunction, + checkpointLocation, + operatorId, + storeVersion, + keySchema, + valueSchema, + new StateStoreConf(sqlContext.conf), + Some(sqlContext.streams.stateStoreCoordinator)) + } + + /** Map each partition of a RDD along with data in a [[StateStore]]. */ + private[state] def mapPartitionWithStateStore[U: ClassTag]( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + checkpointLocation: String, + operatorId: Long, + storeVersion: Long, + keySchema: StructType, + valueSchema: StructType, + storeConf: StateStoreConf, + storeCoordinator: Option[StateStoreCoordinatorRef] + ): StateStoreRDD[T, U] = { + val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) + new StateStoreRDD( + dataRDD, + cleanedF, + checkpointLocation, + operatorId, + storeVersion, + keySchema, + valueSchema, + storeConf, + storeCoordinator) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fd1d77f514..863a876afe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -524,6 +524,19 @@ object SQLConf { doc = "When true, the planner will try to find out duplicated exchanges and re-use them.", isPublic = false) + val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = intConf( + "spark.sql.streaming.stateStore.minDeltasForSnapshot", + defaultValue = Some(10), + doc = "Minimum number of state store delta files that needs to be generated before they " + + "consolidated into snapshots.", + isPublic = false) + + val STATE_STORE_MIN_VERSIONS_TO_RETAIN = intConf( + "spark.sql.streaming.stateStore.minBatchesToRetain", + defaultValue = Some(2), + doc = "Minimum number of versions of a state store's data to retain after cleaning.", + isPublic = false) + val CHECKPOINT_LOCATION = stringConf("spark.sql.streaming.checkpointLocation", defaultValue = None, doc = "The default location for storing checkpoint data for continuously executing queries.", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala new file mode 100644 index 0000000000..c99c2f505f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.ExecutorCacheTaskLocation + +class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { + + import StateStoreCoordinatorSuite._ + + test("report, verify, getLocation") { + withCoordinatorRef(sc) { coordinatorRef => + val id = StateStoreId("x", 0, 0) + + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) + assert(coordinatorRef.getLocation(id) === None) + + coordinatorRef.reportActiveInstance(id, "hostX", "exec1") + eventually(timeout(5 seconds)) { + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === true) + assert( + coordinatorRef.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) + } + + coordinatorRef.reportActiveInstance(id, "hostX", "exec2") + + eventually(timeout(5 seconds)) { + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) + assert(coordinatorRef.verifyIfInstanceActive(id, "exec2") === true) + + assert( + coordinatorRef.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec2").toString)) + } + } + } + + test("make inactive") { + withCoordinatorRef(sc) { coordinatorRef => + val id1 = StateStoreId("x", 0, 0) + val id2 = StateStoreId("y", 1, 0) + val id3 = StateStoreId("x", 0, 1) + val host = "hostX" + val exec = "exec1" + + coordinatorRef.reportActiveInstance(id1, host, exec) + coordinatorRef.reportActiveInstance(id2, host, exec) + coordinatorRef.reportActiveInstance(id3, host, exec) + + eventually(timeout(5 seconds)) { + assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true) + + } + + coordinatorRef.deactivateInstances("x") + + assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === false) + assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === false) + + assert(coordinatorRef.getLocation(id1) === None) + assert( + coordinatorRef.getLocation(id2) === + Some(ExecutorCacheTaskLocation(host, exec).toString)) + assert(coordinatorRef.getLocation(id3) === None) + + coordinatorRef.deactivateInstances("y") + assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === false) + assert(coordinatorRef.getLocation(id2) === None) + } + } + + test("multiple references have same underlying coordinator") { + withCoordinatorRef(sc) { coordRef1 => + val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env) + + val id = StateStoreId("x", 0, 0) + + coordRef1.reportActiveInstance(id, "hostX", "exec1") + + eventually(timeout(5 seconds)) { + assert(coordRef2.verifyIfInstanceActive(id, "exec1") === true) + assert( + coordRef2.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) + } + } + } +} + +object StateStoreCoordinatorSuite { + def withCoordinatorRef(sc: SparkContext)(body: StateStoreCoordinatorRef => Unit): Unit = { + var coordinatorRef: StateStoreCoordinatorRef = null + try { + coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env) + body(coordinatorRef) + } finally { + if (coordinatorRef != null) coordinatorRef.stop() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala new file mode 100644 index 0000000000..24cec30fa3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.File +import java.nio.file.Files + +import scala.util.Random + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.LocalSparkContext._ +import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.util.Utils + +class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { + + private val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) + private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString + private val keySchema = StructType(Seq(StructField("key", StringType, true))) + private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + + import StateStoreSuite._ + + after { + StateStore.stop() + } + + override def afterAll(): Unit = { + super.afterAll() + Utils.deleteRecursively(new File(tempDir)) + } + + test("versioning and immutability") { + quietly { + withSpark(new SparkContext(sparkConf)) { sc => + implicit val sqlContet = new SQLContext(sc) + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + store.update( + stringToRow(s), oldRow => { + val oldValue = oldRow.map(rowToInt).getOrElse(0) + intToRow(oldValue + 1) + }) + } + store.commit() + store.iterator().map(rowsToStringInt) + } + val opId = 0 + val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 0, keySchema, valueSchema) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + + // Generate next version of stores + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 1, keySchema, valueSchema) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + + // Make sure the previous RDD still has the same data. + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + } + } + } + + test("recovering from files") { + quietly { + val opId = 0 + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + + def makeStoreRDD( + sc: SparkContext, + seq: Seq[String], + storeVersion: Int): RDD[(String, Int)] = { + implicit val sqlContext = new SQLContext(sc) + makeRDD(sc, Seq("a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion, keySchema, valueSchema) + } + + // Generate RDDs and state store data + withSpark(new SparkContext(sparkConf)) { sc => + for (i <- 1 to 20) { + require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) + } + } + + // With a new context, try using the earlier state store data + withSpark(new SparkContext(sparkConf)) { sc => + assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) + } + } + } + + test("preferred locations using StateStoreCoordinator") { + quietly { + val opId = 0 + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + + withSpark(new SparkContext(sparkConf)) { sc => + implicit val sqlContext = new SQLContext(sc) + val coordinatorRef = sqlContext.streams.stateStoreCoordinator + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") + assert( + coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === + Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) + + val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 0, keySchema, valueSchema) + require(rdd.partitions.length === 2) + + assert( + rdd.preferredLocations(rdd.partitions(0)) === + Seq(ExecutorCacheTaskLocation("host1", "exec1").toString)) + + assert( + rdd.preferredLocations(rdd.partitions(1)) === + Seq(ExecutorCacheTaskLocation("host2", "exec2").toString)) + + rdd.collect() + } + } + } + + test("distributed test") { + quietly { + withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc => + implicit val sqlContet = new SQLContext(sc) + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + store.update( + stringToRow(s), oldRow => { + val oldValue = oldRow.map(rowToInt).getOrElse(0) + intToRow(oldValue + 1) + }) + } + store.commit() + store.iterator().map(rowsToStringInt) + } + val opId = 0 + val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 0, keySchema, valueSchema) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + + // Generate next version of stores + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 1, keySchema, valueSchema) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + + // Make sure the previous RDD still has the same data. + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + } + } + } + + private def makeRDD(sc: SparkContext, seq: Seq[String]): RDD[String] = { + sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2) + } + + private val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + store.update( + stringToRow(s), oldRow => { + val oldValue = oldRow.map(rowToInt).getOrElse(0) + intToRow(oldValue + 1) + }) + } + store.commit() + store.iterator().map(rowsToStringInt) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala new file mode 100644 index 0000000000..22b2f4f75d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -0,0 +1,562 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.File + +import scala.collection.mutable +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark.LocalSparkContext._ +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester { + type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] + + import StateStoreCoordinatorSuite._ + import StateStoreSuite._ + + private val tempDir = Utils.createTempDir().toString + private val keySchema = StructType(Seq(StructField("key", StringType, true))) + private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + + after { + StateStore.stop() + } + + test("update, remove, commit, and all data iterator") { + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(provider.latestIterator().isEmpty) + + val store = provider.getStore(0) + assert(!store.hasCommitted) + intercept[IllegalStateException] { + store.iterator() + } + intercept[IllegalStateException] { + store.updates() + } + + // Verify state after updating + update(store, "a", 1) + intercept[IllegalStateException] { + store.iterator() + } + intercept[IllegalStateException] { + store.updates() + } + assert(provider.latestIterator().isEmpty) + + // Make updates, commit and then verify state + update(store, "b", 2) + update(store, "aa", 3) + remove(store, _.startsWith("a")) + assert(store.commit() === 1) + + assert(store.hasCommitted) + assert(rowsToSet(store.iterator()) === Set("b" -> 2)) + assert(rowsToSet(provider.latestIterator()) === Set("b" -> 2)) + assert(fileExists(provider, version = 1, isSnapshot = false)) + + assert(getDataFromFiles(provider) === Set("b" -> 2)) + + // Trying to get newer versions should fail + intercept[Exception] { + provider.getStore(2) + } + intercept[Exception] { + getDataFromFiles(provider, 2) + } + + // New updates to the reloaded store with new version, and does not change old version + val reloadedProvider = new HDFSBackedStateStoreProvider( + store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) + val reloadedStore = reloadedProvider.getStore(1) + update(reloadedStore, "c", 4) + assert(reloadedStore.commit() === 2) + assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) + assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) + assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2)) + assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4)) + } + + test("updates iterator with all combos of updates and removes") { + val provider = newStoreProvider() + var currentVersion: Int = 0 + def withStore(body: StateStore => Unit): Unit = { + val store = provider.getStore(currentVersion) + body(store) + currentVersion += 1 + } + + // New data should be seen in updates as value added, even if they had multiple updates + withStore { store => + update(store, "a", 1) + update(store, "aa", 1) + update(store, "aa", 2) + store.commit() + assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2))) + assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2)) + } + + // Multiple updates to same key should be collapsed in the updates as a single value update + // Keys that have not been updated should not appear in the updates + withStore { store => + update(store, "a", 4) + update(store, "a", 6) + store.commit() + assert(updatesToSet(store.updates()) === Set(Updated("a", 6))) + assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) + } + + // Keys added, updated and finally removed before commit should not appear in updates + withStore { store => + update(store, "b", 4) // Added, finally removed + update(store, "bb", 5) // Added, updated, finally removed + update(store, "bb", 6) + remove(store, _.startsWith("b")) + store.commit() + assert(updatesToSet(store.updates()) === Set.empty) + assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) + } + + // Removed data should be seen in updates as a key removed + // Removed, but re-added data should be seen in updates as a value update + withStore { store => + remove(store, _.startsWith("a")) + update(store, "a", 10) + store.commit() + assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa"))) + assert(rowsToSet(store.iterator()) === Set("a" -> 10)) + } + } + + test("cancel") { + val provider = newStoreProvider() + val store = provider.getStore(0) + update(store, "a", 1) + store.commit() + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + // cancelUpdates should not change the data in the files + val store1 = provider.getStore(1) + update(store1, "b", 1) + store1.cancel() + assert(getDataFromFiles(provider) === Set("a" -> 1)) + } + + test("getStore with unexpected versions") { + val provider = newStoreProvider() + + intercept[IllegalArgumentException] { + provider.getStore(-1) + } + + // Prepare some data in the stoer + val store = provider.getStore(0) + update(store, "a", 1) + assert(store.commit() === 1) + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + intercept[IllegalStateException] { + provider.getStore(2) + } + + // Update store version with some data + val store1 = provider.getStore(1) + update(store1, "b", 1) + assert(store1.commit() === 2) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) + assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) + + // Overwrite the version with other data + val store2 = provider.getStore(1) + update(store2, "c", 1) + assert(store2.commit() === 2) + assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1)) + assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1)) + } + + test("snapshotting") { + val provider = newStoreProvider(minDeltasForSnapshot = 5) + + var currentVersion = 0 + def updateVersionTo(targetVersion: Int): Unit = { + for (i <- currentVersion + 1 to targetVersion) { + val store = provider.getStore(currentVersion) + update(store, "a", i) + store.commit() + currentVersion += 1 + } + require(currentVersion === targetVersion) + } + + updateVersionTo(2) + require(getDataFromFiles(provider) === Set("a" -> 2)) + provider.doMaintenance() // should not generate snapshot files + assert(getDataFromFiles(provider) === Set("a" -> 2)) + + for (i <- 1 to currentVersion) { + assert(fileExists(provider, i, isSnapshot = false)) // all delta files present + assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present + } + + // After version 6, snapshotting should generate one snapshot file + updateVersionTo(6) + require(getDataFromFiles(provider) === Set("a" -> 6), "store not updated correctly") + provider.doMaintenance() // should generate snapshot files + + val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true)) + assert(snapshotVersion.nonEmpty, "snapshot file not generated") + deleteFilesEarlierThanVersion(provider, snapshotVersion.get) + assert( + getDataFromFiles(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get), + "snapshotting messed up the data of the snapshotted version") + assert( + getDataFromFiles(provider) === Set("a" -> 6), + "snapshotting messed up the data of the final version") + + // After version 20, snapshotting should generate newer snapshot files + updateVersionTo(20) + require(getDataFromFiles(provider) === Set("a" -> 20), "store not updated correctly") + provider.doMaintenance() // do snapshot + + val latestSnapshotVersion = (0 to 20).filter(version => + fileExists(provider, version, isSnapshot = true)).lastOption + assert(latestSnapshotVersion.nonEmpty, "no snapshot file found") + assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") + + deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get) + assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data") + } + + test("cleaning") { + val provider = newStoreProvider(minDeltasForSnapshot = 5) + + for (i <- 1 to 20) { + val store = provider.getStore(i - 1) + update(store, "a", i) + store.commit() + provider.doMaintenance() // do cleanup + } + require( + rowsToSet(provider.latestIterator()) === Set("a" -> 20), + "store not updated correctly") + + assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted + + // last couple of versions should be retrievable + assert(getDataFromFiles(provider, 20) === Set("a" -> 20)) + assert(getDataFromFiles(provider, 19) === Set("a" -> 19)) + } + + + test("corrupted file handling") { + val provider = newStoreProvider(minDeltasForSnapshot = 5) + for (i <- 1 to 6) { + val store = provider.getStore(i - 1) + update(store, "a", i) + store.commit() + provider.doMaintenance() // do cleanup + } + val snapshotVersion = (0 to 10).find( version => + fileExists(provider, version, isSnapshot = true)).getOrElse(fail("snapshot file not found")) + + // Corrupt snapshot file and verify that it throws error + assert(getDataFromFiles(provider, snapshotVersion) === Set("a" -> snapshotVersion)) + corruptFile(provider, snapshotVersion, isSnapshot = true) + intercept[Exception] { + getDataFromFiles(provider, snapshotVersion) + } + + // Corrupt delta file and verify that it throws error + assert(getDataFromFiles(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1))) + corruptFile(provider, snapshotVersion - 1, isSnapshot = false) + intercept[Exception] { + getDataFromFiles(provider, snapshotVersion - 1) + } + + // Delete delta file and verify that it throws error + deleteFilesEarlierThanVersion(provider, snapshotVersion) + intercept[Exception] { + getDataFromFiles(provider, snapshotVersion - 1) + } + } + + test("StateStore.get") { + quietly { + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val storeId = StateStoreId(dir, 0, 0) + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() + + + // Verify that trying to get incorrect versions throw errors + intercept[IllegalArgumentException] { + StateStore.get(storeId, keySchema, valueSchema, -1, storeConf, hadoopConf) + } + assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store + + intercept[IllegalStateException] { + StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + } + + // Increase version of the store + val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) + assert(store0.version === 0) + update(store0, "a", 1) + store0.commit() + + assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1) + assert(StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf).version == 0) + + // Verify that you can remove the store and still reload and use it + StateStore.unload(storeId) + assert(!StateStore.isLoaded(storeId)) + + val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + update(store1, "a", 2) + assert(store1.commit() === 2) + assert(rowsToSet(store1.iterator()) === Set("a" -> 2)) + } + } + + test("maintenance") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set(StateStore.MAINTENANCE_INTERVAL_CONFIG, "10ms") + .set("spark.rpc.numRetries", "1") + val opId = 0 + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val storeId = StateStoreId(dir, opId, 0) + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() + val provider = new HDFSBackedStateStoreProvider( + storeId, keySchema, valueSchema, storeConf, hadoopConf) + + quietly { + withSpark(new SparkContext(conf)) { sc => + withCoordinatorRef(sc) { coordinatorRef => + for (i <- 1 to 20) { + val store = StateStore.get( + storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf) + update(store, "a", i) + store.commit() + } + eventually(timeout(10 seconds)) { + assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported") + } + + // Background maintenance should clean up and generate snapshots + eventually(timeout(10 seconds)) { + // Earliest delta file should get cleaned up + assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + + // Some snapshots should have been generated + val snapshotVersions = (0 to 20).filter { version => + fileExists(provider, version, isSnapshot = true) + } + assert(snapshotVersions.nonEmpty, "no snapshot file found") + } + + // If driver decides to deactivate all instances of the store, then this instance + // should be unloaded + coordinatorRef.deactivateInstances(dir) + eventually(timeout(10 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } + + // Reload the store and verify + StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + + // If some other executor loads the store, then this instance should be unloaded + coordinatorRef.reportActiveInstance(storeId, "other-host", "other-exec") + eventually(timeout(10 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } + + // Reload the store and verify + StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + } + } + + // Verify if instance is unloaded if SparkContext is stopped + require(SparkEnv.get === null) + eventually(timeout(10 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } + } + } + + def getDataFromFiles( + provider: HDFSBackedStateStoreProvider, + version: Int = -1): Set[(String, Int)] = { + val reloadedProvider = new HDFSBackedStateStoreProvider( + provider.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) + if (version < 0) { + reloadedProvider.latestIterator().map(rowsToStringInt).toSet + } else { + reloadedProvider.iterator(version).map(rowsToStringInt).toSet + } + } + + def assertMap( + testMapOption: Option[MapType], + expectedMap: Map[String, Int]): Unit = { + assert(testMapOption.nonEmpty, "no map present") + val convertedMap = testMapOption.get.map(rowsToStringInt) + assert(convertedMap === expectedMap) + } + + def fileExists( + provider: HDFSBackedStateStoreProvider, + version: Long, + isSnapshot: Boolean): Boolean = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + filePath.exists + } + + def deleteFilesEarlierThanVersion(provider: HDFSBackedStateStoreProvider, version: Long): Unit = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + for (version <- 0 until version.toInt) { + for (isSnapshot <- Seq(false, true)) { + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + if (filePath.exists) filePath.delete() + } + } + } + + def corruptFile( + provider: HDFSBackedStateStoreProvider, + version: Long, + isSnapshot: Boolean): Unit = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + filePath.delete() + filePath.createNewFile() + } + + def storeLoaded(storeId: StateStoreId): Boolean = { + val method = PrivateMethod[mutable.HashMap[StateStoreId, StateStore]]('loadedStores) + val loadedStores = StateStore invokePrivate method() + loadedStores.contains(storeId) + } + + def unloadStore(storeId: StateStoreId): Boolean = { + val method = PrivateMethod('remove) + StateStore invokePrivate method(storeId) + } + + def newStoreProvider( + opId: Long = Random.nextLong, + partition: Int = 0, + minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get + ): HDFSBackedStateStoreProvider = { + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val sqlConf = new SQLConf() + sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) + new HDFSBackedStateStoreProvider( + StateStoreId(dir, opId, partition), + keySchema, + valueSchema, + new StateStoreConf(sqlConf), + new Configuration()) + } + + def remove(store: StateStore, condition: String => Boolean): Unit = { + store.remove(row => condition(rowToString(row))) + } + + private def update(store: StateStore, key: String, value: Int): Unit = { + store.update(stringToRow(key), _ => intToRow(value)) + } +} + +private[state] object StateStoreSuite { + + /** Trait and classes mirroring [[StoreUpdate]] for testing store updates iterator */ + trait TestUpdate + case class Added(key: String, value: Int) extends TestUpdate + case class Updated(key: String, value: Int) extends TestUpdate + case class Removed(key: String) extends TestUpdate + + val strProj = UnsafeProjection.create(Array[DataType](StringType)) + val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) + + def stringToRow(s: String): UnsafeRow = { + strProj.apply(new GenericInternalRow(Array[Any](UTF8String.fromString(s)))).copy() + } + + def intToRow(i: Int): UnsafeRow = { + intProj.apply(new GenericInternalRow(Array[Any](i))).copy() + } + + def rowToString(row: UnsafeRow): String = { + row.getUTF8String(0).toString + } + + def rowToInt(row: UnsafeRow): Int = { + row.getInt(0) + } + + def rowsToIntInt(row: (UnsafeRow, UnsafeRow)): (Int, Int) = { + (rowToInt(row._1), rowToInt(row._2)) + } + + + def rowsToStringInt(row: (UnsafeRow, UnsafeRow)): (String, Int) = { + (rowToString(row._1), rowToInt(row._2)) + } + + def rowsToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, Int)] = { + iterator.map(rowsToStringInt).toSet + } + + def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = { + iterator.map { _ match { + case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value)) + case ValueUpdated(key, value) => Updated(rowToString(key), rowToInt(value)) + case KeyRemoved(key) => Removed(rowToString(key)) + }}.toSet + } +} -- cgit v1.2.3