aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2015-11-24 23:13:01 -0800
committerShixiong Zhu <shixiong@databricks.com>2015-11-24 23:13:01 -0800
commit2169886883d33b33acf378ac42a626576b342df1 (patch)
treec1b94aab923d9d5d605f940dc489a19518249534
parent151d7c2baf18403e6e59e97c80c8bcded6148038 (diff)
downloadspark-2169886883d33b33acf378ac42a626576b342df1.tar.gz
spark-2169886883d33b33acf378ac42a626576b342df1.tar.bz2
spark-2169886883d33b33acf378ac42a626576b342df1.zip
[SPARK-11979][STREAMING] Empty TrackStateRDD cannot be checkpointed and recovered from checkpoint file
This solves the following exception caused when empty state RDD is checkpointed and recovered. The root cause is that an empty OpenHashMapBasedStateMap cannot be deserialized as the initialCapacity is set to zero. ``` Job aborted due to stage failure: Task 0 in stage 6.0 failed 1 times, most recent failure: Lost task 0.0 in stage 6.0 (TID 20, localhost): java.lang.IllegalArgumentException: requirement failed: Invalid initial capacity at scala.Predef$.require(Predef.scala:233) at org.apache.spark.streaming.util.OpenHashMapBasedStateMap.<init>(StateMap.scala:96) at org.apache.spark.streaming.util.OpenHashMapBasedStateMap.<init>(StateMap.scala:86) at org.apache.spark.streaming.util.OpenHashMapBasedStateMap.readObject(StateMap.scala:291) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:606) at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1017) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1893) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1798) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1350) at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:1990) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1915) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1798) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1350) at java.io.ObjectInputStream.readObject(ObjectInputStream.java:370) at org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:76) at org.apache.spark.serializer.DeserializationStream$$anon$1.getNext(Serializer.scala:181) at org.apache.spark.util.NextIterator.hasNext(NextIterator.scala:73) at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:371) at scala.collection.Iterator$class.foreach(Iterator.scala:727) at scala.collection.AbstractIterator.foreach(Iterator.scala:1157) at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47) at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273) at scala.collection.AbstractIterator.to(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265) at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252) at scala.collection.AbstractIterator.toArray(Iterator.scala:1157) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:921) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:921) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1858) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1858) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66) at org.apache.spark.scheduler.Task.run(Task.scala:88) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:744) ``` Author: Tathagata Das <tathagata.das1565@gmail.com> Closes #9958 from tdas/SPARK-11979.
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala19
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala30
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala10
3 files changed, 42 insertions, 17 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
index 34287c3e00..3f139ad138 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
@@ -59,7 +59,7 @@ private[streaming] object StateMap {
def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = {
val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold",
DELTA_CHAIN_LENGTH_THRESHOLD)
- new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold)
+ new OpenHashMapBasedStateMap[K, S](deltaChainThreshold)
}
}
@@ -79,7 +79,7 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa
/** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */
private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
@transient @volatile var parentStateMap: StateMap[K, S],
- initialCapacity: Int = 64,
+ initialCapacity: Int = DEFAULT_INITIAL_CAPACITY,
deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD
) extends StateMap[K, S] { self =>
@@ -89,12 +89,14 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
deltaChainThreshold = deltaChainThreshold)
def this(deltaChainThreshold: Int) = this(
- initialCapacity = 64, deltaChainThreshold = deltaChainThreshold)
+ initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold)
def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD)
- @transient @volatile private var deltaMap =
- new OpenHashMap[K, StateInfo[S]](initialCapacity)
+ require(initialCapacity >= 1, "Invalid initial capacity")
+ require(deltaChainThreshold >= 1, "Invalid delta chain threshold")
+
+ @transient @volatile private var deltaMap = new OpenHashMap[K, StateInfo[S]](initialCapacity)
/** Get the session data if it exists */
override def get(key: K): Option[S] = {
@@ -284,9 +286,10 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
// Read the data of the parent map. Keep reading records, until the limiter is reached
// First read the approximate number of records to expect and allocate properly size
// OpenHashMap
- val parentSessionStoreSizeHint = inputStream.readInt()
+ val parentStateMapSizeHint = inputStream.readInt()
+ val newStateMapInitialCapacity = math.max(parentStateMapSizeHint, DEFAULT_INITIAL_CAPACITY)
val newParentSessionStore = new OpenHashMapBasedStateMap[K, S](
- initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold)
+ initialCapacity = newStateMapInitialCapacity, deltaChainThreshold)
// Read the records until the limit marking object has been reached
var parentSessionLoopDone = false
@@ -338,4 +341,6 @@ private[streaming] object OpenHashMapBasedStateMap {
class LimitMarker(val num: Int) extends Serializable
val DELTA_CHAIN_LENGTH_THRESHOLD = 20
+
+ val DEFAULT_INITIAL_CAPACITY = 64
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
index 48d3b41b66..c4a01eaea7 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
@@ -122,23 +122,27 @@ class StateMapSuite extends SparkFunSuite {
test("OpenHashMapBasedStateMap - serializing and deserializing") {
val map1 = new OpenHashMapBasedStateMap[Int, Int]()
+ testSerialization(map1, "error deserializing and serialized empty map")
+
map1.put(1, 100, 1)
map1.put(2, 200, 2)
+ testSerialization(map1, "error deserializing and serialized map with data + no delta")
val map2 = map1.copy()
+ // Do not test compaction
+ assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
+ testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data")
+
map2.put(3, 300, 3)
map2.put(4, 400, 4)
+ testSerialization(map2, "error deserializing and serialized map with 1 delta + new data")
val map3 = map2.copy()
+ assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
+ testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data")
map3.put(3, 600, 3)
map3.remove(2)
-
- // Do not test compaction
- assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
-
- val deser_map3 = Utils.deserialize[StateMap[Int, Int]](
- Utils.serialize(map3), Thread.currentThread().getContextClassLoader)
- assertMap(deser_map3, map3, 1, "Deserialized map not same as original map")
+ testSerialization(map3, "error deserializing and serialized map with 2 delta + new data")
}
test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") {
@@ -156,11 +160,9 @@ class StateMapSuite extends SparkFunSuite {
assert(map.deltaChainLength > deltaChainThreshold)
assert(map.shouldCompact === true)
- val deser_map = Utils.deserialize[OpenHashMapBasedStateMap[Int, Int]](
- Utils.serialize(map), Thread.currentThread().getContextClassLoader)
+ val deser_map = testSerialization(map, "Deserialized + compacted map not same as original map")
assert(deser_map.deltaChainLength < deltaChainThreshold)
assert(deser_map.shouldCompact === false)
- assertMap(deser_map, map, 1, "Deserialized + compacted map not same as original map")
}
test("OpenHashMapBasedStateMap - all possible sequences of operations with copies ") {
@@ -265,6 +267,14 @@ class StateMapSuite extends SparkFunSuite {
assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map")
}
+ private def testSerialization[MapType <: StateMap[Int, Int]](
+ map: MapType, msg: String): MapType = {
+ val deserMap = Utils.deserialize[MapType](
+ Utils.serialize(map), Thread.currentThread().getContextClassLoader)
+ assertMap(deserMap, map, 1, msg)
+ deserMap
+ }
+
// Assert whether all the data and operations on a state map matches that of a reference state map
private def assertMap(
mapToTest: StateMap[Int, Int],
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
index 0feb3af1ab..3b2d43f2ce 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
@@ -332,6 +332,16 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _)
}
+ test("checkpointing empty state RDD") {
+ val emptyStateRDD = TrackStateRDD.createFromPairRDD[Int, Int, Int, Int](
+ sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0))
+ emptyStateRDD.checkpoint()
+ assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
+ val cpRDD = sc.checkpointFile[TrackStateRDDRecord[Int, Int, Int]](
+ emptyStateRDD.getCheckpointFile.get)
+ assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
+ }
+
/** Assert whether the `trackStateByKey` operation generates expected results */
private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
testStateRDD: TrackStateRDD[K, V, S, T],