aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-11-18 00:09:29 -0800
committerReynold Xin <rxin@databricks.com>2015-11-18 00:09:29 -0800
commit5e2b44474c2b838bebeffe5ba5cd72961b0cd31e (patch)
tree9e9763e7f8503897b9e69ab84ec6696a5370282c /sql
parent8019f66df5c65e21d6e4e7e8fbfb7d0471ba3e37 (diff)
downloadspark-5e2b44474c2b838bebeffe5ba5cd72961b0cd31e.tar.gz
spark-5e2b44474c2b838bebeffe5ba5cd72961b0cd31e.tar.bz2
spark-5e2b44474c2b838bebeffe5ba5cd72961b0cd31e.zip
[SPARK-11802][SQL] Kryo-based encoder for opaque types in Datasets
I also found a bug with self-joins returning incorrect results in the Dataset API. Two test cases attached and filed SPARK-11803. Author: Reynold Xin <rxin@databricks.com> Closes #9789 from rxin/SPARK-11802.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala69
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala70
8 files changed, 178 insertions, 23 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index c8b017e251..79c2255641 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql
-import scala.reflect.ClassTag
+import scala.reflect.{ClassTag, classTag}
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.catalyst.expressions.{DeserializeWithKryo, BoundReference, SerializeWithKryo}
+import org.apache.spark.sql.types._
/**
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
@@ -37,7 +38,33 @@ trait Encoder[T] extends Serializable {
def clsTag: ClassTag[T]
}
+/**
+ * Methods for creating encoders.
+ */
object Encoders {
+
+ /**
+ * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
+ * This encoder maps T into a single byte array (binary) field.
+ */
+ def kryo[T: ClassTag]: Encoder[T] = {
+ val ser = SerializeWithKryo(BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true))
+ val deser = DeserializeWithKryo[T](BoundReference(0, BinaryType, nullable = true), classTag[T])
+ ExpressionEncoder[T](
+ schema = new StructType().add("value", BinaryType),
+ flat = true,
+ toRowExpressions = Seq(ser),
+ fromRowExpression = deser,
+ clsTag = classTag[T]
+ )
+ }
+
+ /**
+ * Creates an encoder that serializes objects of type T using Kryo.
+ * This encoder maps T into a single byte array (binary) field.
+ */
+ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz))
+
def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true)
def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true)
def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 9a1a8f5cbb..b977f278c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -161,7 +161,9 @@ case class ExpressionEncoder[T](
@transient
private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions)
- private val inputRow = new GenericMutableRow(1)
+
+ @transient
+ private lazy val inputRow = new GenericMutableRow(1)
@transient
private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
index 414adb2116..55c4ee11b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
@@ -230,7 +230,7 @@ object ProductEncoder {
Invoke(inputObject, "booleanValue", BooleanType)
case other =>
- throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
+ throw new UnsupportedOperationException(s"Encoder for type $other is not supported")
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 5cd19de683..489c6126f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -17,13 +17,15 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.language.existentials
+import scala.reflect.ClassTag
+
+import org.apache.spark.SparkConf
+import org.apache.spark.serializer.{KryoSerializerInstance, KryoSerializer}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
import org.apache.spark.sql.catalyst.util.GenericArrayData
-
-import scala.language.existentials
-
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types._
@@ -514,3 +516,64 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy
"""
}
}
+
+/** Serializes an input object using Kryo serializer. */
+case class SerializeWithKryo(child: Expression) extends UnaryExpression {
+
+ override def eval(input: InternalRow): Any =
+ throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val input = child.gen(ctx)
+ val kryo = ctx.freshName("kryoSerializer")
+ val kryoClass = classOf[KryoSerializer].getName
+ val kryoInstanceClass = classOf[KryoSerializerInstance].getName
+ val sparkConfClass = classOf[SparkConf].getName
+ ctx.addMutableState(
+ kryoInstanceClass,
+ kryo,
+ s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();")
+
+ s"""
+ ${input.code}
+ final boolean ${ev.isNull} = ${input.isNull};
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${ev.value} = $kryo.serialize(${input.value}, null).array();
+ }
+ """
+ }
+
+ override def dataType: DataType = BinaryType
+}
+
+/**
+ * Deserializes an input object using Kryo serializer. Note that the ClassTag is not an implicit
+ * parameter because TreeNode cannot copy implicit parameters.
+ */
+case class DeserializeWithKryo[T](child: Expression, tag: ClassTag[T]) extends UnaryExpression {
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val input = child.gen(ctx)
+ val kryo = ctx.freshName("kryoSerializer")
+ val kryoClass = classOf[KryoSerializer].getName
+ val kryoInstanceClass = classOf[KryoSerializerInstance].getName
+ val sparkConfClass = classOf[SparkConf].getName
+ ctx.addMutableState(
+ kryoInstanceClass,
+ kryo,
+ s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();")
+
+ s"""
+ ${input.code}
+ final boolean ${ev.isNull} = ${input.isNull};
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${ev.value} = (${ctx.javaType(dataType)})
+ $kryo.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null);
+ }
+ """
+ }
+
+ override def dataType: DataType = ObjectType(tag.runtimeClass)
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
index 55821c4370..2729db8489 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.encoders
import java.sql.{Date, Timestamp}
+import org.apache.spark.sql.Encoders
class FlatEncoderSuite extends ExpressionEncoderSuite {
encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean")
@@ -71,4 +72,21 @@ class FlatEncoderSuite extends ExpressionEncoderSuite {
encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null")
encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)),
FlatEncoder[Map[Int, Map[String, Int]]], "map of map")
+
+ // Kryo encoders
+ encodeDecodeTest(
+ "hello",
+ Encoders.kryo[String].asInstanceOf[ExpressionEncoder[String]],
+ "kryo string")
+ encodeDecodeTest(
+ new NotJavaSerializable(15),
+ Encoders.kryo[NotJavaSerializable].asInstanceOf[ExpressionEncoder[NotJavaSerializable]],
+ "kryo object serialization")
+}
+
+
+class NotJavaSerializable(val value: Int) {
+ override def equals(other: Any): Boolean = {
+ this.value == other.asInstanceOf[NotJavaSerializable].value
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 718ed812dd..817c20fdbb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -147,6 +147,12 @@ class Dataset[T] private[sql](
}
}
+ /**
+ * Returns the number of elements in the [[Dataset]].
+ * @since 1.6.0
+ */
+ def count(): Long = toDF().count()
+
/* *********************** *
* Functional Operations *
* *********************** */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 467cd42b9b..c66162ee21 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql
-
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index ea29428c55..a522894c37 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -24,21 +24,6 @@ import scala.language.postfixOps
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
-case class ClassData(a: String, b: Int)
-
-/**
- * A class used to test serialization using encoders. This class throws exceptions when using
- * Java serialization -- so the only way it can be "serialized" is through our encoders.
- */
-case class NonSerializableCaseClass(value: String) extends Externalizable {
- override def readExternal(in: ObjectInput): Unit = {
- throw new UnsupportedOperationException
- }
-
- override def writeExternal(out: ObjectOutput): Unit = {
- throw new UnsupportedOperationException
- }
-}
class DatasetSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -362,8 +347,63 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(joined, ("2", 2))
}
+ ignore("self join") {
+ val ds = Seq("1", "2").toDS().as("a")
+ val joined = ds.joinWith(ds, lit(true))
+ checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2"))
+ }
+
test("toString") {
val ds = Seq((1, 2)).toDS()
assert(ds.toString == "[_1: int, _2: int]")
}
+
+ test("kryo encoder") {
+ implicit val kryoEncoder = Encoders.kryo[KryoData]
+ val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2)))
+
+ assert(ds.groupBy(p => p).count().collect().toSeq ==
+ Seq((KryoData(1), 1L), (KryoData(2), 1L)))
+ }
+
+ ignore("kryo encoder self join") {
+ implicit val kryoEncoder = Encoders.kryo[KryoData]
+ val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2)))
+ assert(ds.joinWith(ds, lit(true)).collect().toSet ==
+ Set(
+ (KryoData(1), KryoData(1)),
+ (KryoData(1), KryoData(2)),
+ (KryoData(2), KryoData(1)),
+ (KryoData(2), KryoData(2))))
+ }
+}
+
+
+case class ClassData(a: String, b: Int)
+
+/**
+ * A class used to test serialization using encoders. This class throws exceptions when using
+ * Java serialization -- so the only way it can be "serialized" is through our encoders.
+ */
+case class NonSerializableCaseClass(value: String) extends Externalizable {
+ override def readExternal(in: ObjectInput): Unit = {
+ throw new UnsupportedOperationException
+ }
+
+ override def writeExternal(out: ObjectOutput): Unit = {
+ throw new UnsupportedOperationException
+ }
+}
+
+/** Used to test Kryo encoder. */
+class KryoData(val a: Int) {
+ override def equals(other: Any): Boolean = {
+ a == other.asInstanceOf[KryoData].a
+ }
+ override def hashCode: Int = a
+ override def toString: String = s"KryoData($a)"
+}
+
+object KryoData {
+ def apply(a: Int): KryoData = new KryoData(a)
}