From 28e0e500a2062baeda8c887e17dc8ab2b7d7d4b4 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 7 Jan 2016 17:46:24 -0800 Subject: [SPARK-12591][STREAMING] Register OpenHashMapBasedStateMap for Kryo The default serializer in Kryo is FieldSerializer and it ignores transient fields and never calls `writeObject` or `readObject`. So we should register OpenHashMapBasedStateMap using `DefaultSerializer` to make it work with Kryo. Author: Shixiong Zhu Closes #10609 from zsxwing/SPARK-12591. --- .../org/apache/spark/streaming/util/StateMap.scala | 71 +++++++++++++++------- 1 file changed, 50 insertions(+), 21 deletions(-) (limited to 'streaming/src/main') diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 3f139ad138..4e5baebaae 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -17,16 +17,20 @@ package org.apache.spark.streaming.util -import java.io.{ObjectInputStream, ObjectOutputStream} +import java.io._ import scala.reflect.ClassTag +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.SparkConf +import org.apache.spark.serializer.{KryoOutputObjectOutputBridge, KryoInputObjectInputBridge} import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._ import org.apache.spark.util.collection.OpenHashMap /** Internal interface for defining the map that keeps track of sessions. */ -private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable { +private[streaming] abstract class StateMap[K, S] extends Serializable { /** Get the state for a key if it exists */ def get(key: K): Option[S] @@ -54,7 +58,7 @@ private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Ser /** Companion object for [[StateMap]], with utility methods */ private[streaming] object StateMap { - def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S] + def empty[K, S]: StateMap[K, S] = new EmptyStateMap[K, S] def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", @@ -64,7 +68,7 @@ private[streaming] object StateMap { } /** Implementation of StateMap interface representing an empty map */ -private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] { +private[streaming] class EmptyStateMap[K, S] extends StateMap[K, S] { override def put(key: K, session: S, updateTime: Long): Unit = { throw new NotImplementedError("put() should not be called on an EmptyStateMap") } @@ -77,21 +81,26 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa } /** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */ -private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( +private[streaming] class OpenHashMapBasedStateMap[K, S]( @transient @volatile var parentStateMap: StateMap[K, S], - initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, - deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD - ) extends StateMap[K, S] { self => + private var initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, + private var deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD + )(implicit private var keyClassTag: ClassTag[K], private var stateClassTag: ClassTag[S]) + extends StateMap[K, S] with KryoSerializable { self => - def this(initialCapacity: Int, deltaChainThreshold: Int) = this( + def this(initialCapacity: Int, deltaChainThreshold: Int) + (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this( new EmptyStateMap[K, S], initialCapacity = initialCapacity, deltaChainThreshold = deltaChainThreshold) - def this(deltaChainThreshold: Int) = this( + def this(deltaChainThreshold: Int) + (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this( initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold) - def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) + def this()(implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = { + this(DELTA_CHAIN_LENGTH_THRESHOLD) + } require(initialCapacity >= 1, "Invalid initial capacity") require(deltaChainThreshold >= 1, "Invalid delta chain threshold") @@ -206,11 +215,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( * Serialize the map data. Besides serialization, this method actually compact the deltas * (if needed) in a single pass over all the data in the map. */ - - private def writeObject(outputStream: ObjectOutputStream): Unit = { - // Write all the non-transient fields, especially class tags, etc. - outputStream.defaultWriteObject() - + private def writeObjectInternal(outputStream: ObjectOutput): Unit = { // Write the data in the delta of this state map outputStream.writeInt(deltaMap.size) val deltaMapIterator = deltaMap.iterator @@ -262,11 +267,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( } /** Deserialize the map data. */ - private def readObject(inputStream: ObjectInputStream): Unit = { - - // Read the non-transient fields, especially class tags, etc. - inputStream.defaultReadObject() - + private def readObjectInternal(inputStream: ObjectInput): Unit = { // Read the data of the delta val deltaMapSize = inputStream.readInt() deltaMap = if (deltaMapSize != 0) { @@ -309,6 +310,34 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( } parentStateMap = newParentSessionStore } + + private def writeObject(outputStream: ObjectOutputStream): Unit = { + // Write all the non-transient fields, especially class tags, etc. + outputStream.defaultWriteObject() + writeObjectInternal(outputStream) + } + + private def readObject(inputStream: ObjectInputStream): Unit = { + // Read the non-transient fields, especially class tags, etc. + inputStream.defaultReadObject() + readObjectInternal(inputStream) + } + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeInt(initialCapacity) + output.writeInt(deltaChainThreshold) + kryo.writeClassAndObject(output, keyClassTag) + kryo.writeClassAndObject(output, stateClassTag) + writeObjectInternal(new KryoOutputObjectOutputBridge(kryo, output)) + } + + override def read(kryo: Kryo, input: Input): Unit = { + initialCapacity = input.readInt() + deltaChainThreshold = input.readInt() + keyClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[K]] + stateClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[S]] + readObjectInternal(new KryoInputObjectInputBridge(kryo, input)) + } } /** -- cgit v1.2.3