aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2016-01-07 17:46:24 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2016-01-07 17:46:24 -0800
commit28e0e500a2062baeda8c887e17dc8ab2b7d7d4b4 (patch)
treeffe4125cab4f5520a2b1c4159c84c1d26cfc59a1
parentc94199e977279d9b4658297e8108b46bdf30157b (diff)
downloadspark-28e0e500a2062baeda8c887e17dc8ab2b7d7d4b4.tar.gz
spark-28e0e500a2062baeda8c887e17dc8ab2b7d7d4b4.tar.bz2
spark-28e0e500a2062baeda8c887e17dc8ab2b7d7d4b4.zip
[SPARK-12591][STREAMING] Register OpenHashMapBasedStateMap for Kryo
The default serializer in Kryo is FieldSerializer and it ignores transient fields and never calls `writeObject` or `readObject`. So we should register OpenHashMapBasedStateMap using `DefaultSerializer` to make it work with Kryo. Author: Shixiong Zhu <shixiong@databricks.com> Closes #10609 from zsxwing/SPARK-12591.
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala24
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala20
-rw-r--r--project/MimaExcludes.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala71
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala96
5 files changed, 174 insertions, 41 deletions
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index bc9fd50c2c..150ddc12e0 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -17,7 +17,7 @@
package org.apache.spark.serializer
-import java.io.{DataInput, DataOutput, EOFException, InputStream, IOException, OutputStream}
+import java.io._
import java.nio.ByteBuffer
import javax.annotation.Nullable
@@ -378,18 +378,24 @@ private[serializer] object KryoSerializer {
private val toRegisterSerializer = Map[Class[_], KryoClassSerializer[_]](
classOf[RoaringBitmap] -> new KryoClassSerializer[RoaringBitmap]() {
override def write(kryo: Kryo, output: KryoOutput, bitmap: RoaringBitmap): Unit = {
- bitmap.serialize(new KryoOutputDataOutputBridge(output))
+ bitmap.serialize(new KryoOutputObjectOutputBridge(kryo, output))
}
override def read(kryo: Kryo, input: KryoInput, cls: Class[RoaringBitmap]): RoaringBitmap = {
val ret = new RoaringBitmap
- ret.deserialize(new KryoInputDataInputBridge(input))
+ ret.deserialize(new KryoInputObjectInputBridge(kryo, input))
ret
}
}
)
}
-private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends DataInput {
+/**
+ * This is a bridge class to wrap KryoInput as an InputStream and ObjectInput. It forwards all
+ * methods of InputStream and ObjectInput to KryoInput. It's usually helpful when an API expects
+ * an InputStream or ObjectInput but you want to use Kryo.
+ */
+private[spark] class KryoInputObjectInputBridge(
+ kryo: Kryo, input: KryoInput) extends FilterInputStream(input) with ObjectInput {
override def readLong(): Long = input.readLong()
override def readChar(): Char = input.readChar()
override def readFloat(): Float = input.readFloat()
@@ -408,9 +414,16 @@ private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends Dat
override def readBoolean(): Boolean = input.readBoolean()
override def readUnsignedByte(): Int = input.readByteUnsigned()
override def readDouble(): Double = input.readDouble()
+ override def readObject(): AnyRef = kryo.readClassAndObject(input)
}
-private[serializer] class KryoOutputDataOutputBridge(output: KryoOutput) extends DataOutput {
+/**
+ * This is a bridge class to wrap KryoOutput as an OutputStream and ObjectOutput. It forwards all
+ * methods of OutputStream and ObjectOutput to KryoOutput. It's usually helpful when an API expects
+ * an OutputStream or ObjectOutput but you want to use Kryo.
+ */
+private[spark] class KryoOutputObjectOutputBridge(
+ kryo: Kryo, output: KryoOutput) extends FilterOutputStream(output) with ObjectOutput {
override def writeFloat(v: Float): Unit = output.writeFloat(v)
// There is no "readChars" counterpart, except maybe "readLine", which is not supported
override def writeChars(s: String): Unit = throw new UnsupportedOperationException("writeChars")
@@ -426,6 +439,7 @@ private[serializer] class KryoOutputDataOutputBridge(output: KryoOutput) extends
override def writeChar(v: Int): Unit = output.writeChar(v.toChar)
override def writeLong(v: Long): Unit = output.writeLong(v)
override def writeByte(v: Int): Unit = output.writeByte(v)
+ override def writeObject(obj: AnyRef): Unit = kryo.writeClassAndObject(output, obj)
}
/**
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index 8f9b453a6e..f869bcd708 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -362,19 +362,35 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
bitmap.add(1)
bitmap.add(3)
bitmap.add(5)
- bitmap.serialize(new KryoOutputDataOutputBridge(output))
+ // Ignore Kryo because it doesn't use writeObject
+ bitmap.serialize(new KryoOutputObjectOutputBridge(null, output))
output.flush()
output.close()
val inStream = new FileInputStream(tmpfile)
val input = new KryoInput(inStream)
val ret = new RoaringBitmap
- ret.deserialize(new KryoInputDataInputBridge(input))
+ // Ignore Kryo because it doesn't use readObject
+ ret.deserialize(new KryoInputObjectInputBridge(null, input))
input.close()
assert(ret == bitmap)
Utils.deleteRecursively(dir)
}
+ test("KryoOutputObjectOutputBridge.writeObject and KryoInputObjectInputBridge.readObject") {
+ val kryo = new KryoSerializer(conf).newKryo()
+
+ val bytesOutput = new ByteArrayOutputStream()
+ val objectOutput = new KryoOutputObjectOutputBridge(kryo, new KryoOutput(bytesOutput))
+ objectOutput.writeObject("test")
+ objectOutput.close()
+
+ val bytesInput = new ByteArrayInputStream(bytesOutput.toByteArray)
+ val objectInput = new KryoInputObjectInputBridge(kryo, new KryoInput(bytesInput))
+ assert(objectInput.readObject() === "test")
+ objectInput.close()
+ }
+
test("getAutoReset") {
val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance]
assert(ser.getAutoReset)
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 69e5bc881b..40559a0910 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -120,6 +120,10 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$Multiplier"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$")
) ++ Seq(
+ // SPARK-12591 Register OpenHashMapBasedStateMap for Kryo
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoInputDataInputBridge"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoOutputDataOutputBridge")
+ ) ++ Seq(
// SPARK-12510 Refactor ActorReceiver to support Java
ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver")
)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
index 3f139ad138..4e5baebaae 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
@@ -17,16 +17,20 @@
package org.apache.spark.streaming.util
-import java.io.{ObjectInputStream, ObjectOutputStream}
+import java.io._
import scala.reflect.ClassTag
+import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
+import com.esotericsoftware.kryo.io.{Input, Output}
+
import org.apache.spark.SparkConf
+import org.apache.spark.serializer.{KryoOutputObjectOutputBridge, KryoInputObjectInputBridge}
import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._
import org.apache.spark.util.collection.OpenHashMap
/** Internal interface for defining the map that keeps track of sessions. */
-private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable {
+private[streaming] abstract class StateMap[K, S] extends Serializable {
/** Get the state for a key if it exists */
def get(key: K): Option[S]
@@ -54,7 +58,7 @@ private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Ser
/** Companion object for [[StateMap]], with utility methods */
private[streaming] object StateMap {
- def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S]
+ def empty[K, S]: StateMap[K, S] = new EmptyStateMap[K, S]
def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = {
val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold",
@@ -64,7 +68,7 @@ private[streaming] object StateMap {
}
/** Implementation of StateMap interface representing an empty map */
-private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] {
+private[streaming] class EmptyStateMap[K, S] extends StateMap[K, S] {
override def put(key: K, session: S, updateTime: Long): Unit = {
throw new NotImplementedError("put() should not be called on an EmptyStateMap")
}
@@ -77,21 +81,26 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa
}
/** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */
-private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
+private[streaming] class OpenHashMapBasedStateMap[K, S](
@transient @volatile var parentStateMap: StateMap[K, S],
- initialCapacity: Int = DEFAULT_INITIAL_CAPACITY,
- deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD
- ) extends StateMap[K, S] { self =>
+ private var initialCapacity: Int = DEFAULT_INITIAL_CAPACITY,
+ private var deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD
+ )(implicit private var keyClassTag: ClassTag[K], private var stateClassTag: ClassTag[S])
+ extends StateMap[K, S] with KryoSerializable { self =>
- def this(initialCapacity: Int, deltaChainThreshold: Int) = this(
+ def this(initialCapacity: Int, deltaChainThreshold: Int)
+ (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this(
new EmptyStateMap[K, S],
initialCapacity = initialCapacity,
deltaChainThreshold = deltaChainThreshold)
- def this(deltaChainThreshold: Int) = this(
+ def this(deltaChainThreshold: Int)
+ (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this(
initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold)
- def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD)
+ def this()(implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = {
+ this(DELTA_CHAIN_LENGTH_THRESHOLD)
+ }
require(initialCapacity >= 1, "Invalid initial capacity")
require(deltaChainThreshold >= 1, "Invalid delta chain threshold")
@@ -206,11 +215,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
* Serialize the map data. Besides serialization, this method actually compact the deltas
* (if needed) in a single pass over all the data in the map.
*/
-
- private def writeObject(outputStream: ObjectOutputStream): Unit = {
- // Write all the non-transient fields, especially class tags, etc.
- outputStream.defaultWriteObject()
-
+ private def writeObjectInternal(outputStream: ObjectOutput): Unit = {
// Write the data in the delta of this state map
outputStream.writeInt(deltaMap.size)
val deltaMapIterator = deltaMap.iterator
@@ -262,11 +267,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
}
/** Deserialize the map data. */
- private def readObject(inputStream: ObjectInputStream): Unit = {
-
- // Read the non-transient fields, especially class tags, etc.
- inputStream.defaultReadObject()
-
+ private def readObjectInternal(inputStream: ObjectInput): Unit = {
// Read the data of the delta
val deltaMapSize = inputStream.readInt()
deltaMap = if (deltaMapSize != 0) {
@@ -309,6 +310,34 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
}
parentStateMap = newParentSessionStore
}
+
+ private def writeObject(outputStream: ObjectOutputStream): Unit = {
+ // Write all the non-transient fields, especially class tags, etc.
+ outputStream.defaultWriteObject()
+ writeObjectInternal(outputStream)
+ }
+
+ private def readObject(inputStream: ObjectInputStream): Unit = {
+ // Read the non-transient fields, especially class tags, etc.
+ inputStream.defaultReadObject()
+ readObjectInternal(inputStream)
+ }
+
+ override def write(kryo: Kryo, output: Output): Unit = {
+ output.writeInt(initialCapacity)
+ output.writeInt(deltaChainThreshold)
+ kryo.writeClassAndObject(output, keyClassTag)
+ kryo.writeClassAndObject(output, stateClassTag)
+ writeObjectInternal(new KryoOutputObjectOutputBridge(kryo, output))
+ }
+
+ override def read(kryo: Kryo, input: Input): Unit = {
+ initialCapacity = input.readInt()
+ deltaChainThreshold = input.readInt()
+ keyClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[K]]
+ stateClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[S]]
+ readObjectInternal(new KryoInputObjectInputBridge(kryo, input))
+ }
}
/**
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
index c4a01eaea7..ea32bbf95c 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
@@ -17,15 +17,23 @@
package org.apache.spark.streaming
+import org.apache.spark.streaming.rdd.MapWithStateRDDRecord
+
import scala.collection.{immutable, mutable, Map}
+import scala.reflect.ClassTag
import scala.util.Random
-import org.apache.spark.SparkFunSuite
+import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
+import com.esotericsoftware.kryo.io.{Output, Input}
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.serializer._
import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap}
-import org.apache.spark.util.Utils
class StateMapSuite extends SparkFunSuite {
+ private val conf = new SparkConf()
+
test("EmptyStateMap") {
val map = new EmptyStateMap[Int, Int]
intercept[scala.NotImplementedError] {
@@ -128,17 +136,17 @@ class StateMapSuite extends SparkFunSuite {
map1.put(2, 200, 2)
testSerialization(map1, "error deserializing and serialized map with data + no delta")
- val map2 = map1.copy()
+ val map2 = map1.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]]
// Do not test compaction
- assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
+ assert(map2.shouldCompact === false)
testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data")
map2.put(3, 300, 3)
map2.put(4, 400, 4)
testSerialization(map2, "error deserializing and serialized map with 1 delta + new data")
- val map3 = map2.copy()
- assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
+ val map3 = map2.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]]
+ assert(map3.shouldCompact === false)
testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data")
map3.put(3, 600, 3)
map3.remove(2)
@@ -267,18 +275,25 @@ class StateMapSuite extends SparkFunSuite {
assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map")
}
- private def testSerialization[MapType <: StateMap[Int, Int]](
- map: MapType, msg: String): MapType = {
- val deserMap = Utils.deserialize[MapType](
- Utils.serialize(map), Thread.currentThread().getContextClassLoader)
+ private def testSerialization[T: ClassTag](
+ map: OpenHashMapBasedStateMap[T, T], msg: String): OpenHashMapBasedStateMap[T, T] = {
+ testSerialization(new JavaSerializer(conf), map, msg)
+ testSerialization(new KryoSerializer(conf), map, msg)
+ }
+
+ private def testSerialization[T : ClassTag](
+ serializer: Serializer,
+ map: OpenHashMapBasedStateMap[T, T],
+ msg: String): OpenHashMapBasedStateMap[T, T] = {
+ val deserMap = serializeAndDeserialize(serializer, map)
assertMap(deserMap, map, 1, msg)
deserMap
}
// Assert whether all the data and operations on a state map matches that of a reference state map
- private def assertMap(
- mapToTest: StateMap[Int, Int],
- refMapToTestWith: StateMap[Int, Int],
+ private def assertMap[T](
+ mapToTest: StateMap[T, T],
+ refMapToTestWith: StateMap[T, T],
time: Long,
msg: String): Unit = {
withClue(msg) {
@@ -321,4 +336,59 @@ class StateMapSuite extends SparkFunSuite {
}
}
}
+
+ test("OpenHashMapBasedStateMap - serializing and deserializing with KryoSerializable states") {
+ val map = new OpenHashMapBasedStateMap[KryoState, KryoState]()
+ map.put(new KryoState("a"), new KryoState("b"), 1)
+ testSerialization(
+ new KryoSerializer(conf), map, "error deserializing and serialized KryoSerializable states")
+ }
+
+ test("EmptyStateMap - serializing and deserializing") {
+ val map = StateMap.empty[KryoState, KryoState]
+ // Since EmptyStateMap doesn't contains any date, KryoState won't break JavaSerializer.
+ assert(serializeAndDeserialize(new JavaSerializer(conf), map).
+ isInstanceOf[EmptyStateMap[KryoState, KryoState]])
+ assert(serializeAndDeserialize(new KryoSerializer(conf), map).
+ isInstanceOf[EmptyStateMap[KryoState, KryoState]])
+ }
+
+ test("MapWithStateRDDRecord - serializing and deserializing with KryoSerializable states") {
+ val map = new OpenHashMapBasedStateMap[KryoState, KryoState]()
+ map.put(new KryoState("a"), new KryoState("b"), 1)
+
+ val record =
+ MapWithStateRDDRecord[KryoState, KryoState, KryoState](map, Seq(new KryoState("c")))
+ val deserRecord = serializeAndDeserialize(new KryoSerializer(conf), record)
+ assert(!(record eq deserRecord))
+ assert(record.stateMap.getAll().toSeq === deserRecord.stateMap.getAll().toSeq)
+ assert(record.mappedData === deserRecord.mappedData)
+ }
+
+ private def serializeAndDeserialize[T: ClassTag](serializer: Serializer, t: T): T = {
+ val serializerInstance = serializer.newInstance()
+ serializerInstance.deserialize[T](
+ serializerInstance.serialize(t), Thread.currentThread().getContextClassLoader)
+ }
+}
+
+/** A class that only supports Kryo serialization. */
+private[streaming] final class KryoState(var state: String) extends KryoSerializable {
+
+ override def write(kryo: Kryo, output: Output): Unit = {
+ kryo.writeClassAndObject(output, state)
+ }
+
+ override def read(kryo: Kryo, input: Input): Unit = {
+ state = kryo.readClassAndObject(input).asInstanceOf[String]
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case that: KryoState => state == that.state
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ if (state == null) 0 else state.hashCode()
+ }
}