diff options
author | Jakob Odersky <jakob@odersky.com> | 2016-03-16 16:59:36 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-03-16 16:59:36 -0700 |
commit | d4d84936fb82bee91f4b04608de9f75c293ccc9e (patch) | |
tree | 9cbaeb2ad0de147a9febe516c2fd537fbbb9263c /sql/catalyst | |
parent | 77ba3021c12dc63cb7d831f964f901e0474acd96 (diff) | |
download | spark-d4d84936fb82bee91f4b04608de9f75c293ccc9e.tar.gz spark-d4d84936fb82bee91f4b04608de9f75c293ccc9e.tar.bz2 spark-d4d84936fb82bee91f4b04608de9f75c293ccc9e.zip |
[SPARK-11011][SQL] Narrow type of UDT serialization
## What changes were proposed in this pull request?
Narrow down the parameter type of `UserDefinedType#serialize()`. Currently, the parameter type is `Any`, however it would logically make more sense to narrow it down to the type of the actual user defined type.
## How was this patch tested?
Existing tests were successfully run on local machine.
Author: Jakob Odersky <jakob@odersky.com>
Closes #11379 from jodersky/SPARK-11011-udt-types.
Diffstat (limited to 'sql/catalyst')
4 files changed, 17 insertions, 30 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 2ec0ff53c8..9bfc381639 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -136,16 +136,16 @@ object CatalystTypeConverters { override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column, dataType) } - private case class UDTConverter( - udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { + private case class UDTConverter[A >: Null]( + udt: UserDefinedType[A]) extends CatalystTypeConverter[A, A, Any] { // toCatalyst (it calls toCatalystImpl) will do null check. - override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) + override def toCatalystImpl(scalaValue: A): Any = udt.serialize(scalaValue) - override def toScala(catalystValue: Any): Any = { + override def toScala(catalystValue: Any): A = { if (catalystValue == null) null else udt.deserialize(catalystValue) } - override def toScalaImpl(row: InternalRow, column: Int): Any = + override def toScalaImpl(row: InternalRow, column: Int): A = toScala(row.get(column, udt.sqlType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 9d2449f3b7..dabf9a2fc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -37,7 +37,7 @@ import org.apache.spark.annotation.DeveloperApi * The conversion via `deserialize` occurs when reading from a `DataFrame`. */ @DeveloperApi -abstract class UserDefinedType[UserType] extends DataType with Serializable { +abstract class UserDefinedType[UserType >: Null] extends DataType with Serializable { /** Underlying storage type for this UDT */ def sqlType: DataType @@ -50,11 +50,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** * Convert the user type to a SQL datum - * - * TODO: Can we make this take obj: UserType? The issue is in - * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType. */ - def serialize(obj: Any): Any + def serialize(obj: UserType): Any /** Convert a SQL datum to the user type */ def deserialize(datum: Any): UserType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 4e7bbc38d6..1b297525bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -36,11 +36,7 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] { override def sqlType: DataType = IntegerType - override def serialize(obj: Any): Int = { - obj match { - case groupableData: GroupableData => groupableData.data - } - } + override def serialize(groupableData: GroupableData): Int = groupableData.data override def deserialize(datum: Any): GroupableData = { datum match { @@ -60,13 +56,10 @@ private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] { override def sqlType: DataType = MapType(IntegerType, IntegerType) - override def serialize(obj: Any): MapData = { - obj match { - case groupableData: UngroupableData => - val keyArray = new GenericArrayData(groupableData.data.keys.toSeq) - val valueArray = new GenericArrayData(groupableData.data.values.toSeq) - new ArrayBasedMapData(keyArray, valueArray) - } + override def serialize(ungroupableData: UngroupableData): MapData = { + val keyArray = new GenericArrayData(ungroupableData.data.keys.toSeq) + val valueArray = new GenericArrayData(ungroupableData.data.values.toSeq) + new ArrayBasedMapData(keyArray, valueArray) } override def deserialize(datum: Any): UngroupableData = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index f119c6f4f7..bf0360c5e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -47,14 +47,11 @@ class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" - override def serialize(obj: Any): GenericArrayData = { - obj match { - case p: ExamplePoint => - val output = new Array[Any](2) - output(0) = p.x - output(1) = p.y - new GenericArrayData(output) - } + override def serialize(p: ExamplePoint): GenericArrayData = { + val output = new Array[Any](2) + output(0) = p.x + output(1) = p.y + new GenericArrayData(output) } override def deserialize(datum: Any): ExamplePoint = { |