aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorJakob Odersky <jakob@odersky.com>2016-03-16 16:59:36 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-16 16:59:36 -0700
commitd4d84936fb82bee91f4b04608de9f75c293ccc9e (patch)
tree9cbaeb2ad0de147a9febe516c2fd537fbbb9263c /sql/core
parent77ba3021c12dc63cb7d831f964f901e0474acd96 (diff)
downloadspark-d4d84936fb82bee91f4b04608de9f75c293ccc9e.tar.gz
spark-d4d84936fb82bee91f4b04608de9f75c293ccc9e.tar.bz2
spark-d4d84936fb82bee91f4b04608de9f75c293ccc9e.zip
[SPARK-11011][SQL] Narrow type of UDT serialization
## What changes were proposed in this pull request? Narrow down the parameter type of `UserDefinedType#serialize()`. Currently, the parameter type is `Any`, however it would logically make more sense to narrow it down to the type of the actual user defined type. ## How was this patch tested? Existing tests were successfully run on local machine. Author: Jakob Odersky <jakob@odersky.com> Closes #11379 from jodersky/SPARK-11011-udt-types.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala11
3 files changed, 11 insertions, 20 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
index e2c9fc421b..695a5ad78a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -42,14 +42,11 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
- override def serialize(obj: Any): GenericArrayData = {
- obj match {
- case p: ExamplePoint =>
- val output = new Array[Any](2)
- output(0) = p.x
- output(1) = p.y
- new GenericArrayData(output)
- }
+ override def serialize(p: ExamplePoint): GenericArrayData = {
+ val output = new Array[Any](2)
+ output(0) = p.x
+ output(1) = p.y
+ new GenericArrayData(output)
}
override def deserialize(datum: Any): ExamplePoint = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 9081bc722a..8c4afb605b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -45,11 +45,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)
- override def serialize(obj: Any): ArrayData = {
- obj match {
- case features: MyDenseVector =>
- new GenericArrayData(features.data.map(_.asInstanceOf[Any]))
- }
+ override def serialize(features: MyDenseVector): ArrayData = {
+ new GenericArrayData(features.data.map(_.asInstanceOf[Any]))
}
override def deserialize(datum: Any): MyDenseVector = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index fb99b0c7e2..f8166c7ddc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -590,14 +590,11 @@ object TestingUDT {
.add("b", LongType, nullable = false)
.add("c", DoubleType, nullable = false)
- override def serialize(obj: Any): Any = {
+ override def serialize(n: NestedStruct): Any = {
val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType))
- obj match {
- case n: NestedStruct =>
- row.setInt(0, n.a)
- row.setLong(1, n.b)
- row.setDouble(2, n.c)
- }
+ row.setInt(0, n.a)
+ row.setLong(1, n.b)
+ row.setDouble(2, n.c)
}
override def userClass: Class[NestedStruct] = classOf[NestedStruct]