aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala62
4 files changed, 68 insertions, 12 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 518fc9e57c..69a620e1ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -78,6 +78,8 @@ case class Exchange(
}
override def execute(): RDD[Row] = attachTree(this , "execute") {
+ lazy val sparkConf = child.sqlContext.sparkContext.getConf
+
newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
// TODO: Eliminate redundant expressions in grouping key and value.
@@ -109,7 +111,7 @@ case class Exchange(
} else {
new ShuffledRDD[Row, Row, Row](rdd, part)
}
- shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
+ shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
shuffled.map(_._2)
case RangePartitioning(sortingExpressions, numPartitions) =>
@@ -132,8 +134,7 @@ case class Exchange(
} else {
new ShuffledRDD[Row, Null, Null](rdd, part)
}
- shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
-
+ shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
shuffled.map(_._1)
case SinglePartition =>
@@ -151,7 +152,7 @@ case class Exchange(
}
val partitioner = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
- shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
+ shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
shuffled.map(_._2)
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 914f387dec..eea15aff5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -65,12 +65,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
private[execution] class KryoResourcePool(size: Int)
extends ResourcePool[SerializerInstance](size) {
- val ser: KryoSerializer = {
+ val ser: SparkSqlSerializer = {
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
- // TODO (lian) Using KryoSerializer here is workaround, needs further investigation
- // Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization
- // related error.
- new KryoSerializer(sparkConf)
+ new SparkSqlSerializer(sparkConf)
}
def newInstance(): SerializerInstance = ser.newInstance()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 308dae236a..d286fe81be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -121,7 +121,7 @@ case class Limit(limit: Int, child: SparkPlan)
}
val part = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part)
- shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
+ shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf))
shuffled.mapPartitions(_.take(limit).map(_._2))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index c86ef338fc..b48bed1871 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -20,9 +20,12 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
import java.sql.Timestamp
+import com.esotericsoftware.kryo.{Serializer, Kryo}
+import com.esotericsoftware.kryo.io.{Input, Output}
+import org.apache.spark.serializer.KryoRegistrator
import org.scalatest.FunSuite
-import org.apache.spark.Logging
+import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.SparkSqlSerializer
@@ -73,7 +76,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
checkActualSize(BINARY, binary, 4 + 4)
val generic = Map(1 -> "a")
- checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11)
+ checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8)
}
testNativeColumnType[BooleanType.type](
@@ -158,6 +161,41 @@ class ColumnTypeSuite extends FunSuite with Logging {
}
}
+ test("CUSTOM") {
+ val conf = new SparkConf()
+ conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator")
+ val serializer = new SparkSqlSerializer(conf).newInstance()
+
+ val buffer = ByteBuffer.allocate(512)
+ val obj = CustomClass(Int.MaxValue,Long.MaxValue)
+ val serializedObj = serializer.serialize(obj).array()
+
+ GENERIC.append(serializer.serialize(obj).array(), buffer)
+ buffer.rewind()
+
+ val length = buffer.getInt
+ assert(length === serializedObj.length)
+ assert(13 == length) // id (1) + int (4) + long (8)
+
+ val genericSerializedObj = SparkSqlSerializer.serialize(obj)
+ assert(length != genericSerializedObj.length)
+ assert(length < genericSerializedObj.length)
+
+ assertResult(obj, "Custom deserialized object didn't equal the original object") {
+ val bytes = new Array[Byte](length)
+ buffer.get(bytes, 0, length)
+ serializer.deserialize(ByteBuffer.wrap(bytes))
+ }
+
+ buffer.rewind()
+ buffer.putInt(serializedObj.length).put(serializedObj)
+
+ assertResult(obj, "Custom deserialized object didn't equal the original object") {
+ buffer.rewind()
+ serializer.deserialize(ByteBuffer.wrap(GENERIC.extract(buffer)))
+ }
+ }
+
def testNativeColumnType[T <: NativeType](
columnType: NativeColumnType[T],
putter: (ByteBuffer, T#JvmType) => Unit,
@@ -229,3 +267,23 @@ class ColumnTypeSuite extends FunSuite with Logging {
}
}
}
+
+private[columnar] final case class CustomClass(a: Int, b: Long)
+
+private[columnar] object CustomerSerializer extends Serializer[CustomClass] {
+ override def write(kryo: Kryo, output: Output, t: CustomClass) {
+ output.writeInt(t.a)
+ output.writeLong(t.b)
+ }
+ override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = {
+ val a = input.readInt()
+ val b = input.readLong()
+ CustomClass(a,b)
+ }
+}
+
+private[columnar] final class Registrator extends KryoRegistrator {
+ override def registerClasses(kryo: Kryo) {
+ kryo.register(classOf[CustomClass], CustomerSerializer)
+ }
+}