aboutsummaryrefslogtreecommitdiff
path: root/streaming/src/main
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2016-01-07 17:46:24 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2016-01-07 17:46:24 -0800
commit28e0e500a2062baeda8c887e17dc8ab2b7d7d4b4 (patch)
treeffe4125cab4f5520a2b1c4159c84c1d26cfc59a1 /streaming/src/main
parentc94199e977279d9b4658297e8108b46bdf30157b (diff)
downloadspark-28e0e500a2062baeda8c887e17dc8ab2b7d7d4b4.tar.gz
spark-28e0e500a2062baeda8c887e17dc8ab2b7d7d4b4.tar.bz2
spark-28e0e500a2062baeda8c887e17dc8ab2b7d7d4b4.zip
[SPARK-12591][STREAMING] Register OpenHashMapBasedStateMap for Kryo
The default serializer in Kryo is FieldSerializer and it ignores transient fields and never calls `writeObject` or `readObject`. So we should register OpenHashMapBasedStateMap using `DefaultSerializer` to make it work with Kryo. Author: Shixiong Zhu <shixiong@databricks.com> Closes #10609 from zsxwing/SPARK-12591.
Diffstat (limited to 'streaming/src/main')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala71
1 files changed, 50 insertions, 21 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
index 3f139ad138..4e5baebaae 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
@@ -17,16 +17,20 @@
package org.apache.spark.streaming.util
-import java.io.{ObjectInputStream, ObjectOutputStream}
+import java.io._
import scala.reflect.ClassTag
+import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
+import com.esotericsoftware.kryo.io.{Input, Output}
+
import org.apache.spark.SparkConf
+import org.apache.spark.serializer.{KryoOutputObjectOutputBridge, KryoInputObjectInputBridge}
import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._
import org.apache.spark.util.collection.OpenHashMap
/** Internal interface for defining the map that keeps track of sessions. */
-private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable {
+private[streaming] abstract class StateMap[K, S] extends Serializable {
/** Get the state for a key if it exists */
def get(key: K): Option[S]
@@ -54,7 +58,7 @@ private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Ser
/** Companion object for [[StateMap]], with utility methods */
private[streaming] object StateMap {
- def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S]
+ def empty[K, S]: StateMap[K, S] = new EmptyStateMap[K, S]
def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = {
val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold",
@@ -64,7 +68,7 @@ private[streaming] object StateMap {
}
/** Implementation of StateMap interface representing an empty map */
-private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] {
+private[streaming] class EmptyStateMap[K, S] extends StateMap[K, S] {
override def put(key: K, session: S, updateTime: Long): Unit = {
throw new NotImplementedError("put() should not be called on an EmptyStateMap")
}
@@ -77,21 +81,26 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa
}
/** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */
-private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
+private[streaming] class OpenHashMapBasedStateMap[K, S](
@transient @volatile var parentStateMap: StateMap[K, S],
- initialCapacity: Int = DEFAULT_INITIAL_CAPACITY,
- deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD
- ) extends StateMap[K, S] { self =>
+ private var initialCapacity: Int = DEFAULT_INITIAL_CAPACITY,
+ private var deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD
+ )(implicit private var keyClassTag: ClassTag[K], private var stateClassTag: ClassTag[S])
+ extends StateMap[K, S] with KryoSerializable { self =>
- def this(initialCapacity: Int, deltaChainThreshold: Int) = this(
+ def this(initialCapacity: Int, deltaChainThreshold: Int)
+ (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this(
new EmptyStateMap[K, S],
initialCapacity = initialCapacity,
deltaChainThreshold = deltaChainThreshold)
- def this(deltaChainThreshold: Int) = this(
+ def this(deltaChainThreshold: Int)
+ (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this(
initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold)
- def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD)
+ def this()(implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = {
+ this(DELTA_CHAIN_LENGTH_THRESHOLD)
+ }
require(initialCapacity >= 1, "Invalid initial capacity")
require(deltaChainThreshold >= 1, "Invalid delta chain threshold")
@@ -206,11 +215,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
* Serialize the map data. Besides serialization, this method actually compact the deltas
* (if needed) in a single pass over all the data in the map.
*/
-
- private def writeObject(outputStream: ObjectOutputStream): Unit = {
- // Write all the non-transient fields, especially class tags, etc.
- outputStream.defaultWriteObject()
-
+ private def writeObjectInternal(outputStream: ObjectOutput): Unit = {
// Write the data in the delta of this state map
outputStream.writeInt(deltaMap.size)
val deltaMapIterator = deltaMap.iterator
@@ -262,11 +267,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
}
/** Deserialize the map data. */
- private def readObject(inputStream: ObjectInputStream): Unit = {
-
- // Read the non-transient fields, especially class tags, etc.
- inputStream.defaultReadObject()
-
+ private def readObjectInternal(inputStream: ObjectInput): Unit = {
// Read the data of the delta
val deltaMapSize = inputStream.readInt()
deltaMap = if (deltaMapSize != 0) {
@@ -309,6 +310,34 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
}
parentStateMap = newParentSessionStore
}
+
+ private def writeObject(outputStream: ObjectOutputStream): Unit = {
+ // Write all the non-transient fields, especially class tags, etc.
+ outputStream.defaultWriteObject()
+ writeObjectInternal(outputStream)
+ }
+
+ private def readObject(inputStream: ObjectInputStream): Unit = {
+ // Read the non-transient fields, especially class tags, etc.
+ inputStream.defaultReadObject()
+ readObjectInternal(inputStream)
+ }
+
+ override def write(kryo: Kryo, output: Output): Unit = {
+ output.writeInt(initialCapacity)
+ output.writeInt(deltaChainThreshold)
+ kryo.writeClassAndObject(output, keyClassTag)
+ kryo.writeClassAndObject(output, stateClassTag)
+ writeObjectInternal(new KryoOutputObjectOutputBridge(kryo, output))
+ }
+
+ override def read(kryo: Kryo, input: Input): Unit = {
+ initialCapacity = input.readInt()
+ deltaChainThreshold = input.readInt()
+ keyClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[K]]
+ stateClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[S]]
+ readObjectInternal(new KryoInputObjectInputBridge(kryo, input))
+ }
}
/**