aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMax Seiden <max@platfora.com>2015-04-15 16:15:11 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-15 16:15:11 -0700
commit8a53de16fc8208358b76d0f3d45538e0304bcc8e (patch)
tree1d5da52bbb0c2aa015a743f25f68554a635b183c /sql
parentd5f1b9650b6e46cf6a9d61f01cda0df0cda5b1c9 (diff)
downloadspark-8a53de16fc8208358b76d0f3d45538e0304bcc8e.tar.gz
spark-8a53de16fc8208358b76d0f3d45538e0304bcc8e.tar.bz2
spark-8a53de16fc8208358b76d0f3d45538e0304bcc8e.zip
[SPARK-5277][SQL] - SparkSqlSerializer doesn't always register user specified KryoRegistrators
[SPARK-5277][SQL] - SparkSqlSerializer doesn't always register user specified KryoRegistrators There were a few places where new SparkSqlSerializer instances were created with new, empty SparkConfs resulting in user specified registrators sometimes not getting initialized. The fix is to try and pull a conf from the SparkEnv, and construct a new conf (that loads defaults) if one cannot be found. The changes touched: 1) SparkSqlSerializer's resource pool (this appears to fix the issue in the comment) 2) execution.Exchange (for all of the partitioners) 3) execution.Limit (for the HashPartitioner) A few tests were added to ColumnTypeSuite, ensuring that a custom registrator and serde is initialized and used when in-memory columns are written. Author: Max Seiden <max@platfora.com> This patch had conflicts when merged, resolved by Committer: Michael Armbrust <michael@databricks.com> Closes #5237 from mhseiden/sql_udt_kryo and squashes the following commits: 3175c2f [Max Seiden] [SPARK-5277][SQL] - address code review comments e5011fb [Max Seiden] [SPARK-5277][SQL] - SparkSqlSerializer does not register user specified KryoRegistrators
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala62
4 files changed, 68 insertions, 12 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 518fc9e57c..69a620e1ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -78,6 +78,8 @@ case class Exchange(
}
override def execute(): RDD[Row] = attachTree(this , "execute") {
+ lazy val sparkConf = child.sqlContext.sparkContext.getConf
+
newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
// TODO: Eliminate redundant expressions in grouping key and value.
@@ -109,7 +111,7 @@ case class Exchange(
} else {
new ShuffledRDD[Row, Row, Row](rdd, part)
}
- shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
+ shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
shuffled.map(_._2)
case RangePartitioning(sortingExpressions, numPartitions) =>
@@ -132,8 +134,7 @@ case class Exchange(
} else {
new ShuffledRDD[Row, Null, Null](rdd, part)
}
- shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
-
+ shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
shuffled.map(_._1)
case SinglePartition =>
@@ -151,7 +152,7 @@ case class Exchange(
}
val partitioner = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
- shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
+ shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
shuffled.map(_._2)
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 914f387dec..eea15aff5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -65,12 +65,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
private[execution] class KryoResourcePool(size: Int)
extends ResourcePool[SerializerInstance](size) {
- val ser: KryoSerializer = {
+ val ser: SparkSqlSerializer = {
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
- // TODO (lian) Using KryoSerializer here is workaround, needs further investigation
- // Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization
- // related error.
- new KryoSerializer(sparkConf)
+ new SparkSqlSerializer(sparkConf)
}
def newInstance(): SerializerInstance = ser.newInstance()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 308dae236a..d286fe81be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -121,7 +121,7 @@ case class Limit(limit: Int, child: SparkPlan)
}
val part = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part)
- shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
+ shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf))
shuffled.mapPartitions(_.take(limit).map(_._2))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index c86ef338fc..b48bed1871 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -20,9 +20,12 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
import java.sql.Timestamp
+import com.esotericsoftware.kryo.{Serializer, Kryo}
+import com.esotericsoftware.kryo.io.{Input, Output}
+import org.apache.spark.serializer.KryoRegistrator
import org.scalatest.FunSuite
-import org.apache.spark.Logging
+import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.SparkSqlSerializer
@@ -73,7 +76,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
checkActualSize(BINARY, binary, 4 + 4)
val generic = Map(1 -> "a")
- checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11)
+ checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8)
}
testNativeColumnType[BooleanType.type](
@@ -158,6 +161,41 @@ class ColumnTypeSuite extends FunSuite with Logging {
}
}
+ test("CUSTOM") {
+ val conf = new SparkConf()
+ conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator")
+ val serializer = new SparkSqlSerializer(conf).newInstance()
+
+ val buffer = ByteBuffer.allocate(512)
+ val obj = CustomClass(Int.MaxValue,Long.MaxValue)
+ val serializedObj = serializer.serialize(obj).array()
+
+ GENERIC.append(serializer.serialize(obj).array(), buffer)
+ buffer.rewind()
+
+ val length = buffer.getInt
+ assert(length === serializedObj.length)
+ assert(13 == length) // id (1) + int (4) + long (8)
+
+ val genericSerializedObj = SparkSqlSerializer.serialize(obj)
+ assert(length != genericSerializedObj.length)
+ assert(length < genericSerializedObj.length)
+
+ assertResult(obj, "Custom deserialized object didn't equal the original object") {
+ val bytes = new Array[Byte](length)
+ buffer.get(bytes, 0, length)
+ serializer.deserialize(ByteBuffer.wrap(bytes))
+ }
+
+ buffer.rewind()
+ buffer.putInt(serializedObj.length).put(serializedObj)
+
+ assertResult(obj, "Custom deserialized object didn't equal the original object") {
+ buffer.rewind()
+ serializer.deserialize(ByteBuffer.wrap(GENERIC.extract(buffer)))
+ }
+ }
+
def testNativeColumnType[T <: NativeType](
columnType: NativeColumnType[T],
putter: (ByteBuffer, T#JvmType) => Unit,
@@ -229,3 +267,23 @@ class ColumnTypeSuite extends FunSuite with Logging {
}
}
}
+
+private[columnar] final case class CustomClass(a: Int, b: Long)
+
+private[columnar] object CustomerSerializer extends Serializer[CustomClass] {
+ override def write(kryo: Kryo, output: Output, t: CustomClass) {
+ output.writeInt(t.a)
+ output.writeLong(t.b)
+ }
+ override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = {
+ val a = input.readInt()
+ val b = input.readLong()
+ CustomClass(a,b)
+ }
+}
+
+private[columnar] final class Registrator extends KryoRegistrator {
+ override def registerClasses(kryo: Kryo) {
+ kryo.register(classOf[CustomClass], CustomerSerializer)
+ }
+}