aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-11-18 15:42:07 -0800
committerReynold Xin <rxin@databricks.com>2015-11-18 15:42:07 -0800
commit5df08949f5d9e5b4b0e9c2db50c1b4eb93383de3 (patch)
tree61e12a73b845f34afd2b781d0395ca0094521272
parent54db79702513e11335c33bcf3a03c59e965e6f16 (diff)
downloadspark-5df08949f5d9e5b4b0e9c2db50c1b4eb93383de3.tar.gz
spark-5df08949f5d9e5b4b0e9c2db50c1b4eb93383de3.tar.bz2
spark-5df08949f5d9e5b4b0e9c2db50c1b4eb93383de3.zip
[SPARK-11810][SQL] Java-based encoder for opaque types in Datasets.
This patch refactors the existing Kryo encoder expressions and adds support for Java serialization. Author: Reynold Xin <rxin@databricks.com> Closes #9802 from rxin/SPARK-11810.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala67
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala36
4 files changed, 130 insertions, 41 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index 79c2255641..1ed5111440 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.reflect.{ClassTag, classTag}
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
-import org.apache.spark.sql.catalyst.expressions.{DeserializeWithKryo, BoundReference, SerializeWithKryo}
+import org.apache.spark.sql.catalyst.expressions.{DecodeUsingSerializer, BoundReference, EncodeUsingSerializer}
import org.apache.spark.sql.types._
/**
@@ -43,28 +43,49 @@ trait Encoder[T] extends Serializable {
*/
object Encoders {
- /**
- * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
- * This encoder maps T into a single byte array (binary) field.
- */
- def kryo[T: ClassTag]: Encoder[T] = {
- val ser = SerializeWithKryo(BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true))
- val deser = DeserializeWithKryo[T](BoundReference(0, BinaryType, nullable = true), classTag[T])
+ /** A way to construct encoders using generic serializers. */
+ private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = {
ExpressionEncoder[T](
schema = new StructType().add("value", BinaryType),
flat = true,
- toRowExpressions = Seq(ser),
- fromRowExpression = deser,
+ toRowExpressions = Seq(
+ EncodeUsingSerializer(
+ BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)),
+ fromRowExpression =
+ DecodeUsingSerializer[T](
+ BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo),
clsTag = classTag[T]
)
}
/**
+ * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
+ * This encoder maps T into a single byte array (binary) field.
+ */
+ def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true)
+
+ /**
* Creates an encoder that serializes objects of type T using Kryo.
* This encoder maps T into a single byte array (binary) field.
*/
def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz))
+ /**
+ * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java
+ * serialization. This encoder maps T into a single byte array (binary) field.
+ *
+ * Note that this is extremely inefficient and should only be used as the last resort.
+ */
+ def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false)
+
+ /**
+ * Creates an encoder that serializes objects of type T using generic Java serialization.
+ * This encoder maps T into a single byte array (binary) field.
+ *
+ * Note that this is extremely inefficient and should only be used as the last resort.
+ */
+ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz))
+
def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true)
def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true)
def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 489c6126f8..acf0da2400 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -21,7 +21,7 @@ import scala.language.existentials
import scala.reflect.ClassTag
import org.apache.spark.SparkConf
-import org.apache.spark.serializer.{KryoSerializerInstance, KryoSerializer}
+import org.apache.spark.serializer._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
@@ -517,29 +517,39 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy
}
}
-/** Serializes an input object using Kryo serializer. */
-case class SerializeWithKryo(child: Expression) extends UnaryExpression {
+/**
+ * Serializes an input object using a generic serializer (Kryo or Java).
+ * @param kryo if true, use Kryo. Otherwise, use Java.
+ */
+case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends UnaryExpression {
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val input = child.gen(ctx)
- val kryo = ctx.freshName("kryoSerializer")
- val kryoClass = classOf[KryoSerializer].getName
- val kryoInstanceClass = classOf[KryoSerializerInstance].getName
- val sparkConfClass = classOf[SparkConf].getName
+ // Code to initialize the serializer.
+ val serializer = ctx.freshName("serializer")
+ val (serializerClass, serializerInstanceClass) = {
+ if (kryo) {
+ (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName)
+ } else {
+ (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName)
+ }
+ }
+ val sparkConf = s"new ${classOf[SparkConf].getName}()"
ctx.addMutableState(
- kryoInstanceClass,
- kryo,
- s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();")
+ serializerInstanceClass,
+ serializer,
+ s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();")
+ // Code to serialize.
+ val input = child.gen(ctx)
s"""
${input.code}
final boolean ${ev.isNull} = ${input.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
- ${ev.value} = $kryo.serialize(${input.value}, null).array();
+ ${ev.value} = $serializer.serialize(${input.value}, null).array();
}
"""
}
@@ -548,29 +558,38 @@ case class SerializeWithKryo(child: Expression) extends UnaryExpression {
}
/**
- * Deserializes an input object using Kryo serializer. Note that the ClassTag is not an implicit
- * parameter because TreeNode cannot copy implicit parameters.
+ * Serializes an input object using a generic serializer (Kryo or Java). Note that the ClassTag
+ * is not an implicit parameter because TreeNode cannot copy implicit parameters.
+ * @param kryo if true, use Kryo. Otherwise, use Java.
*/
-case class DeserializeWithKryo[T](child: Expression, tag: ClassTag[T]) extends UnaryExpression {
+case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean)
+ extends UnaryExpression {
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val input = child.gen(ctx)
- val kryo = ctx.freshName("kryoSerializer")
- val kryoClass = classOf[KryoSerializer].getName
- val kryoInstanceClass = classOf[KryoSerializerInstance].getName
- val sparkConfClass = classOf[SparkConf].getName
+ // Code to initialize the serializer.
+ val serializer = ctx.freshName("serializer")
+ val (serializerClass, serializerInstanceClass) = {
+ if (kryo) {
+ (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName)
+ } else {
+ (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName)
+ }
+ }
+ val sparkConf = s"new ${classOf[SparkConf].getName}()"
ctx.addMutableState(
- kryoInstanceClass,
- kryo,
- s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();")
+ serializerInstanceClass,
+ serializer,
+ s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();")
+ // Code to serialize.
+ val input = child.gen(ctx)
s"""
${input.code}
final boolean ${ev.isNull} = ${input.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = (${ctx.javaType(dataType)})
- $kryo.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null);
+ $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null);
}
"""
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
index 2729db8489..6e0322fb6e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
@@ -76,17 +76,34 @@ class FlatEncoderSuite extends ExpressionEncoderSuite {
// Kryo encoders
encodeDecodeTest(
"hello",
- Encoders.kryo[String].asInstanceOf[ExpressionEncoder[String]],
+ encoderFor(Encoders.kryo[String]),
"kryo string")
encodeDecodeTest(
- new NotJavaSerializable(15),
- Encoders.kryo[NotJavaSerializable].asInstanceOf[ExpressionEncoder[NotJavaSerializable]],
+ new KryoSerializable(15),
+ encoderFor(Encoders.kryo[KryoSerializable]),
"kryo object serialization")
+
+ // Java encoders
+ encodeDecodeTest(
+ "hello",
+ encoderFor(Encoders.javaSerialization[String]),
+ "java string")
+ encodeDecodeTest(
+ new JavaSerializable(15),
+ encoderFor(Encoders.javaSerialization[JavaSerializable]),
+ "java object serialization")
}
+/** For testing Kryo serialization based encoder. */
+class KryoSerializable(val value: Int) {
+ override def equals(other: Any): Boolean = {
+ this.value == other.asInstanceOf[KryoSerializable].value
+ }
+}
-class NotJavaSerializable(val value: Int) {
+/** For testing Java serialization based encoder. */
+class JavaSerializable(val value: Int) extends Serializable {
override def equals(other: Any): Boolean = {
- this.value == other.asInstanceOf[NotJavaSerializable].value
+ this.value == other.asInstanceOf[JavaSerializable].value
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index b6db583dfe..89d964aa3e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -357,7 +357,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.toString == "[_1: int, _2: int]")
}
- test("kryo encoder") {
+ test("Kryo encoder") {
implicit val kryoEncoder = Encoders.kryo[KryoData]
val ds = Seq(KryoData(1), KryoData(2)).toDS()
@@ -365,7 +365,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
Seq((KryoData(1), 1L), (KryoData(2), 1L)))
}
- test("kryo encoder self join") {
+ test("Kryo encoder self join") {
implicit val kryoEncoder = Encoders.kryo[KryoData]
val ds = Seq(KryoData(1), KryoData(2)).toDS()
assert(ds.joinWith(ds, lit(true)).collect().toSet ==
@@ -375,6 +375,25 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
(KryoData(2), KryoData(1)),
(KryoData(2), KryoData(2))))
}
+
+ test("Java encoder") {
+ implicit val kryoEncoder = Encoders.javaSerialization[JavaData]
+ val ds = Seq(JavaData(1), JavaData(2)).toDS()
+
+ assert(ds.groupBy(p => p).count().collect().toSeq ==
+ Seq((JavaData(1), 1L), (JavaData(2), 1L)))
+ }
+
+ ignore("Java encoder self join") {
+ implicit val kryoEncoder = Encoders.javaSerialization[JavaData]
+ val ds = Seq(JavaData(1), JavaData(2)).toDS()
+ assert(ds.joinWith(ds, lit(true)).collect().toSet ==
+ Set(
+ (JavaData(1), JavaData(1)),
+ (JavaData(1), JavaData(2)),
+ (JavaData(2), JavaData(1)),
+ (JavaData(2), JavaData(2))))
+ }
}
@@ -406,3 +425,16 @@ class KryoData(val a: Int) {
object KryoData {
def apply(a: Int): KryoData = new KryoData(a)
}
+
+/** Used to test Java encoder. */
+class JavaData(val a: Int) extends Serializable {
+ override def equals(other: Any): Boolean = {
+ a == other.asInstanceOf[JavaData].a
+ }
+ override def hashCode: Int = a
+ override def toString: String = s"JavaData($a)"
+}
+
+object JavaData {
+ def apply(a: Int): JavaData = new JavaData(a)
+}