diff options
Diffstat (limited to 'streaming')
-rw-r--r-- | streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala | 71 | ||||
-rw-r--r-- | streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala | 96 |
2 files changed, 133 insertions, 34 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 3f139ad138..4e5baebaae 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -17,16 +17,20 @@ package org.apache.spark.streaming.util -import java.io.{ObjectInputStream, ObjectOutputStream} +import java.io._ import scala.reflect.ClassTag +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.SparkConf +import org.apache.spark.serializer.{KryoOutputObjectOutputBridge, KryoInputObjectInputBridge} import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._ import org.apache.spark.util.collection.OpenHashMap /** Internal interface for defining the map that keeps track of sessions. */ -private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable { +private[streaming] abstract class StateMap[K, S] extends Serializable { /** Get the state for a key if it exists */ def get(key: K): Option[S] @@ -54,7 +58,7 @@ private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Ser /** Companion object for [[StateMap]], with utility methods */ private[streaming] object StateMap { - def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S] + def empty[K, S]: StateMap[K, S] = new EmptyStateMap[K, S] def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", @@ -64,7 +68,7 @@ private[streaming] object StateMap { } /** Implementation of StateMap interface representing an empty map */ -private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] { +private[streaming] class EmptyStateMap[K, S] extends StateMap[K, S] { override def put(key: K, session: S, updateTime: Long): Unit = { throw new NotImplementedError("put() should not be called on an EmptyStateMap") } @@ -77,21 +81,26 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa } /** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */ -private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( +private[streaming] class OpenHashMapBasedStateMap[K, S]( @transient @volatile var parentStateMap: StateMap[K, S], - initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, - deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD - ) extends StateMap[K, S] { self => + private var initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, + private var deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD + )(implicit private var keyClassTag: ClassTag[K], private var stateClassTag: ClassTag[S]) + extends StateMap[K, S] with KryoSerializable { self => - def this(initialCapacity: Int, deltaChainThreshold: Int) = this( + def this(initialCapacity: Int, deltaChainThreshold: Int) + (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this( new EmptyStateMap[K, S], initialCapacity = initialCapacity, deltaChainThreshold = deltaChainThreshold) - def this(deltaChainThreshold: Int) = this( + def this(deltaChainThreshold: Int) + (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this( initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold) - def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) + def this()(implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = { + this(DELTA_CHAIN_LENGTH_THRESHOLD) + } require(initialCapacity >= 1, "Invalid initial capacity") require(deltaChainThreshold >= 1, "Invalid delta chain threshold") @@ -206,11 +215,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( * Serialize the map data. Besides serialization, this method actually compact the deltas * (if needed) in a single pass over all the data in the map. */ - - private def writeObject(outputStream: ObjectOutputStream): Unit = { - // Write all the non-transient fields, especially class tags, etc. - outputStream.defaultWriteObject() - + private def writeObjectInternal(outputStream: ObjectOutput): Unit = { // Write the data in the delta of this state map outputStream.writeInt(deltaMap.size) val deltaMapIterator = deltaMap.iterator @@ -262,11 +267,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( } /** Deserialize the map data. */ - private def readObject(inputStream: ObjectInputStream): Unit = { - - // Read the non-transient fields, especially class tags, etc. - inputStream.defaultReadObject() - + private def readObjectInternal(inputStream: ObjectInput): Unit = { // Read the data of the delta val deltaMapSize = inputStream.readInt() deltaMap = if (deltaMapSize != 0) { @@ -309,6 +310,34 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( } parentStateMap = newParentSessionStore } + + private def writeObject(outputStream: ObjectOutputStream): Unit = { + // Write all the non-transient fields, especially class tags, etc. + outputStream.defaultWriteObject() + writeObjectInternal(outputStream) + } + + private def readObject(inputStream: ObjectInputStream): Unit = { + // Read the non-transient fields, especially class tags, etc. + inputStream.defaultReadObject() + readObjectInternal(inputStream) + } + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeInt(initialCapacity) + output.writeInt(deltaChainThreshold) + kryo.writeClassAndObject(output, keyClassTag) + kryo.writeClassAndObject(output, stateClassTag) + writeObjectInternal(new KryoOutputObjectOutputBridge(kryo, output)) + } + + override def read(kryo: Kryo, input: Input): Unit = { + initialCapacity = input.readInt() + deltaChainThreshold = input.readInt() + keyClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[K]] + stateClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[S]] + readObjectInternal(new KryoInputObjectInputBridge(kryo, input)) + } } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index c4a01eaea7..ea32bbf95c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -17,15 +17,23 @@ package org.apache.spark.streaming +import org.apache.spark.streaming.rdd.MapWithStateRDDRecord + import scala.collection.{immutable, mutable, Map} +import scala.reflect.ClassTag import scala.util.Random -import org.apache.spark.SparkFunSuite +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Output, Input} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer._ import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap} -import org.apache.spark.util.Utils class StateMapSuite extends SparkFunSuite { + private val conf = new SparkConf() + test("EmptyStateMap") { val map = new EmptyStateMap[Int, Int] intercept[scala.NotImplementedError] { @@ -128,17 +136,17 @@ class StateMapSuite extends SparkFunSuite { map1.put(2, 200, 2) testSerialization(map1, "error deserializing and serialized map with data + no delta") - val map2 = map1.copy() + val map2 = map1.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] // Do not test compaction - assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + assert(map2.shouldCompact === false) testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data") map2.put(3, 300, 3) map2.put(4, 400, 4) testSerialization(map2, "error deserializing and serialized map with 1 delta + new data") - val map3 = map2.copy() - assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + val map3 = map2.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] + assert(map3.shouldCompact === false) testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data") map3.put(3, 600, 3) map3.remove(2) @@ -267,18 +275,25 @@ class StateMapSuite extends SparkFunSuite { assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map") } - private def testSerialization[MapType <: StateMap[Int, Int]]( - map: MapType, msg: String): MapType = { - val deserMap = Utils.deserialize[MapType]( - Utils.serialize(map), Thread.currentThread().getContextClassLoader) + private def testSerialization[T: ClassTag]( + map: OpenHashMapBasedStateMap[T, T], msg: String): OpenHashMapBasedStateMap[T, T] = { + testSerialization(new JavaSerializer(conf), map, msg) + testSerialization(new KryoSerializer(conf), map, msg) + } + + private def testSerialization[T : ClassTag]( + serializer: Serializer, + map: OpenHashMapBasedStateMap[T, T], + msg: String): OpenHashMapBasedStateMap[T, T] = { + val deserMap = serializeAndDeserialize(serializer, map) assertMap(deserMap, map, 1, msg) deserMap } // Assert whether all the data and operations on a state map matches that of a reference state map - private def assertMap( - mapToTest: StateMap[Int, Int], - refMapToTestWith: StateMap[Int, Int], + private def assertMap[T]( + mapToTest: StateMap[T, T], + refMapToTestWith: StateMap[T, T], time: Long, msg: String): Unit = { withClue(msg) { @@ -321,4 +336,59 @@ class StateMapSuite extends SparkFunSuite { } } } + + test("OpenHashMapBasedStateMap - serializing and deserializing with KryoSerializable states") { + val map = new OpenHashMapBasedStateMap[KryoState, KryoState]() + map.put(new KryoState("a"), new KryoState("b"), 1) + testSerialization( + new KryoSerializer(conf), map, "error deserializing and serialized KryoSerializable states") + } + + test("EmptyStateMap - serializing and deserializing") { + val map = StateMap.empty[KryoState, KryoState] + // Since EmptyStateMap doesn't contains any date, KryoState won't break JavaSerializer. + assert(serializeAndDeserialize(new JavaSerializer(conf), map). + isInstanceOf[EmptyStateMap[KryoState, KryoState]]) + assert(serializeAndDeserialize(new KryoSerializer(conf), map). + isInstanceOf[EmptyStateMap[KryoState, KryoState]]) + } + + test("MapWithStateRDDRecord - serializing and deserializing with KryoSerializable states") { + val map = new OpenHashMapBasedStateMap[KryoState, KryoState]() + map.put(new KryoState("a"), new KryoState("b"), 1) + + val record = + MapWithStateRDDRecord[KryoState, KryoState, KryoState](map, Seq(new KryoState("c"))) + val deserRecord = serializeAndDeserialize(new KryoSerializer(conf), record) + assert(!(record eq deserRecord)) + assert(record.stateMap.getAll().toSeq === deserRecord.stateMap.getAll().toSeq) + assert(record.mappedData === deserRecord.mappedData) + } + + private def serializeAndDeserialize[T: ClassTag](serializer: Serializer, t: T): T = { + val serializerInstance = serializer.newInstance() + serializerInstance.deserialize[T]( + serializerInstance.serialize(t), Thread.currentThread().getContextClassLoader) + } +} + +/** A class that only supports Kryo serialization. */ +private[streaming] final class KryoState(var state: String) extends KryoSerializable { + + override def write(kryo: Kryo, output: Output): Unit = { + kryo.writeClassAndObject(output, state) + } + + override def read(kryo: Kryo, input: Input): Unit = { + state = kryo.readClassAndObject(input).asInstanceOf[String] + } + + override def equals(other: Any): Boolean = other match { + case that: KryoState => state == that.state + case _ => false + } + + override def hashCode(): Int = { + if (state == null) 0 else state.hashCode() + } } |