aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSela <ansela@paypal.com>2016-06-10 14:36:51 -0700
committerMichael Armbrust <michael@databricks.com>2016-06-10 14:36:51 -0700
commit127a6678d7af6b5164a115be7c64525bb80001fe (patch)
treef0c5fe53afa3d8f7388ec225da3a1f843be299e3
parentaec502d9114ad8e18bfbbd63f38780e076d326d1 (diff)
downloadspark-127a6678d7af6b5164a115be7c64525bb80001fe.tar.gz
spark-127a6678d7af6b5164a115be7c64525bb80001fe.tar.bz2
spark-127a6678d7af6b5164a115be7c64525bb80001fe.zip
[SPARK-15489][SQL] Dataset kryo encoder won't load custom user settings
## What changes were proposed in this pull request? Serializer instantiation will consider existing SparkConf ## How was this patch tested? manual test with `ImmutableList` (Guava) and `kryo-serializers`'s `Immutable*Serializer` implementations. Added Test Suite. (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: Sela <ansela@paypal.com> Closes #13424 from amitsela/SPARK-15489.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala30
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala68
2 files changed, 89 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 87c8a2e54a..c597a2a709 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -22,7 +22,7 @@ import java.lang.reflect.Modifier
import scala.language.existentials
import scala.reflect.ClassTag
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
@@ -547,11 +547,17 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
(classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName)
}
}
+ // try conf from env, otherwise create a new one
+ val env = s"${classOf[SparkEnv].getName}.get()"
val sparkConf = s"new ${classOf[SparkConf].getName}()"
- ctx.addMutableState(
- serializerInstanceClass,
- serializer,
- s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();")
+ val serializerInit = s"""
+ if ($env == null) {
+ $serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();
+ } else {
+ $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance();
+ }
+ """
+ ctx.addMutableState(serializerInstanceClass, serializer, serializerInit)
// Code to serialize.
val input = child.genCode(ctx)
@@ -587,11 +593,17 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
(classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName)
}
}
+ // try conf from env, otherwise create a new one
+ val env = s"${classOf[SparkEnv].getName}.get()"
val sparkConf = s"new ${classOf[SparkConf].getName}()"
- ctx.addMutableState(
- serializerInstanceClass,
- serializer,
- s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();")
+ val serializerInit = s"""
+ if ($env == null) {
+ $serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();
+ } else {
+ $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance();
+ }
+ """
+ ctx.addMutableState(serializerInstanceClass, serializer, serializerInit)
// Code to deserialize.
val input = child.genCode(ctx)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala
new file mode 100644
index 0000000000..0f3d0cefe3
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import com.esotericsoftware.kryo.{Kryo, Serializer}
+import com.esotericsoftware.kryo.io.{Input, Output}
+
+import org.apache.spark.serializer.KryoRegistrator
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.TestSparkSession
+
+/**
+ * Test suite to test Kryo custom registrators.
+ */
+class DatasetSerializerRegistratorSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
+
+ /**
+ * Initialize the [[TestSparkSession]] with a [[KryoRegistrator]].
+ */
+ protected override def beforeAll(): Unit = {
+ sparkConf.set("spark.kryo.registrator", TestRegistrator().getClass.getCanonicalName)
+ super.beforeAll()
+ }
+
+ test("Kryo registrator") {
+ implicit val kryoEncoder = Encoders.kryo[KryoData]
+ val ds = Seq(KryoData(1), KryoData(2)).toDS()
+ assert(ds.collect().toSet == Set(KryoData(0), KryoData(0)))
+ }
+
+}
+
+/** Used to test user provided registrator. */
+class TestRegistrator extends KryoRegistrator {
+ override def registerClasses(kryo: Kryo): Unit =
+ kryo.register(classOf[KryoData], new ZeroKryoDataSerializer())
+}
+
+object TestRegistrator {
+ def apply(): TestRegistrator = new TestRegistrator()
+}
+
+/** A [[Serializer]] that takes a [[KryoData]] and serializes it as KryoData(0). */
+class ZeroKryoDataSerializer extends Serializer[KryoData] {
+ override def write(kryo: Kryo, output: Output, t: KryoData): Unit = {
+ output.writeInt(0)
+ }
+
+ override def read(kryo: Kryo, input: Input, aClass: Class[KryoData]): KryoData = {
+ KryoData(input.readInt())
+ }
+}