aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala19
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala30
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala10
3 files changed, 42 insertions, 17 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
index 34287c3e00..3f139ad138 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
@@ -59,7 +59,7 @@ private[streaming] object StateMap {
def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = {
val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold",
DELTA_CHAIN_LENGTH_THRESHOLD)
- new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold)
+ new OpenHashMapBasedStateMap[K, S](deltaChainThreshold)
}
}
@@ -79,7 +79,7 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa
/** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */
private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
@transient @volatile var parentStateMap: StateMap[K, S],
- initialCapacity: Int = 64,
+ initialCapacity: Int = DEFAULT_INITIAL_CAPACITY,
deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD
) extends StateMap[K, S] { self =>
@@ -89,12 +89,14 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
deltaChainThreshold = deltaChainThreshold)
def this(deltaChainThreshold: Int) = this(
- initialCapacity = 64, deltaChainThreshold = deltaChainThreshold)
+ initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold)
def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD)
- @transient @volatile private var deltaMap =
- new OpenHashMap[K, StateInfo[S]](initialCapacity)
+ require(initialCapacity >= 1, "Invalid initial capacity")
+ require(deltaChainThreshold >= 1, "Invalid delta chain threshold")
+
+ @transient @volatile private var deltaMap = new OpenHashMap[K, StateInfo[S]](initialCapacity)
/** Get the session data if it exists */
override def get(key: K): Option[S] = {
@@ -284,9 +286,10 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
// Read the data of the parent map. Keep reading records, until the limiter is reached
// First read the approximate number of records to expect and allocate properly size
// OpenHashMap
- val parentSessionStoreSizeHint = inputStream.readInt()
+ val parentStateMapSizeHint = inputStream.readInt()
+ val newStateMapInitialCapacity = math.max(parentStateMapSizeHint, DEFAULT_INITIAL_CAPACITY)
val newParentSessionStore = new OpenHashMapBasedStateMap[K, S](
- initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold)
+ initialCapacity = newStateMapInitialCapacity, deltaChainThreshold)
// Read the records until the limit marking object has been reached
var parentSessionLoopDone = false
@@ -338,4 +341,6 @@ private[streaming] object OpenHashMapBasedStateMap {
class LimitMarker(val num: Int) extends Serializable
val DELTA_CHAIN_LENGTH_THRESHOLD = 20
+
+ val DEFAULT_INITIAL_CAPACITY = 64
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
index 48d3b41b66..c4a01eaea7 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
@@ -122,23 +122,27 @@ class StateMapSuite extends SparkFunSuite {
test("OpenHashMapBasedStateMap - serializing and deserializing") {
val map1 = new OpenHashMapBasedStateMap[Int, Int]()
+ testSerialization(map1, "error deserializing and serialized empty map")
+
map1.put(1, 100, 1)
map1.put(2, 200, 2)
+ testSerialization(map1, "error deserializing and serialized map with data + no delta")
val map2 = map1.copy()
+ // Do not test compaction
+ assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
+ testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data")
+
map2.put(3, 300, 3)
map2.put(4, 400, 4)
+ testSerialization(map2, "error deserializing and serialized map with 1 delta + new data")
val map3 = map2.copy()
+ assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
+ testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data")
map3.put(3, 600, 3)
map3.remove(2)
-
- // Do not test compaction
- assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
-
- val deser_map3 = Utils.deserialize[StateMap[Int, Int]](
- Utils.serialize(map3), Thread.currentThread().getContextClassLoader)
- assertMap(deser_map3, map3, 1, "Deserialized map not same as original map")
+ testSerialization(map3, "error deserializing and serialized map with 2 delta + new data")
}
test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") {
@@ -156,11 +160,9 @@ class StateMapSuite extends SparkFunSuite {
assert(map.deltaChainLength > deltaChainThreshold)
assert(map.shouldCompact === true)
- val deser_map = Utils.deserialize[OpenHashMapBasedStateMap[Int, Int]](
- Utils.serialize(map), Thread.currentThread().getContextClassLoader)
+ val deser_map = testSerialization(map, "Deserialized + compacted map not same as original map")
assert(deser_map.deltaChainLength < deltaChainThreshold)
assert(deser_map.shouldCompact === false)
- assertMap(deser_map, map, 1, "Deserialized + compacted map not same as original map")
}
test("OpenHashMapBasedStateMap - all possible sequences of operations with copies ") {
@@ -265,6 +267,14 @@ class StateMapSuite extends SparkFunSuite {
assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map")
}
+ private def testSerialization[MapType <: StateMap[Int, Int]](
+ map: MapType, msg: String): MapType = {
+ val deserMap = Utils.deserialize[MapType](
+ Utils.serialize(map), Thread.currentThread().getContextClassLoader)
+ assertMap(deserMap, map, 1, msg)
+ deserMap
+ }
+
// Assert whether all the data and operations on a state map matches that of a reference state map
private def assertMap(
mapToTest: StateMap[Int, Int],
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
index 0feb3af1ab..3b2d43f2ce 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
@@ -332,6 +332,16 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _)
}
+ test("checkpointing empty state RDD") {
+ val emptyStateRDD = TrackStateRDD.createFromPairRDD[Int, Int, Int, Int](
+ sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0))
+ emptyStateRDD.checkpoint()
+ assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
+ val cpRDD = sc.checkpointFile[TrackStateRDDRecord[Int, Int, Int]](
+ emptyStateRDD.getCheckpointFile.get)
+ assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
+ }
+
/** Assert whether the `trackStateByKey` operation generates expected results */
private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
testStateRDD: TrackStateRDD[K, V, S, T],