aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala2
-rw-r--r--project/MimaExcludes.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala7
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala17
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala11
10 files changed, 32 insertions, 52 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index fdede2ad39..157f2dbf5d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -177,7 +177,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
))
}
- override def serialize(obj: Any): InternalRow = {
+ override def serialize(obj: Matrix): InternalRow = {
val row = new GenericMutableRow(7)
obj match {
case sm: SparseMatrix =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index cecfd067bd..0f0c3a2df5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -203,7 +203,7 @@ class VectorUDT extends UserDefinedType[Vector] {
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
}
- override def serialize(obj: Any): InternalRow = {
+ override def serialize(obj: Vector): InternalRow = {
obj match {
case SparseVector(size, indices, values) =>
val row = new GenericMutableRow(4)
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 59c7e7db2e..ffc6fa0599 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -292,6 +292,8 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$")
) ++ Seq(
+ //SPARK-11011 UserDefinedType serialization should be strongly typed
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"),
// SPARK-12073: backpressure rate controller consumes events preferentially from lagging partitions
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.KafkaTestUtils.createTopic"),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 2ec0ff53c8..9bfc381639 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -136,16 +136,16 @@ object CatalystTypeConverters {
override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column, dataType)
}
- private case class UDTConverter(
- udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] {
+ private case class UDTConverter[A >: Null](
+ udt: UserDefinedType[A]) extends CatalystTypeConverter[A, A, Any] {
// toCatalyst (it calls toCatalystImpl) will do null check.
- override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue)
+ override def toCatalystImpl(scalaValue: A): Any = udt.serialize(scalaValue)
- override def toScala(catalystValue: Any): Any = {
+ override def toScala(catalystValue: Any): A = {
if (catalystValue == null) null else udt.deserialize(catalystValue)
}
- override def toScalaImpl(row: InternalRow, column: Int): Any =
+ override def toScalaImpl(row: InternalRow, column: Int): A =
toScala(row.get(column, udt.sqlType))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index 9d2449f3b7..dabf9a2fc0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -37,7 +37,7 @@ import org.apache.spark.annotation.DeveloperApi
* The conversion via `deserialize` occurs when reading from a `DataFrame`.
*/
@DeveloperApi
-abstract class UserDefinedType[UserType] extends DataType with Serializable {
+abstract class UserDefinedType[UserType >: Null] extends DataType with Serializable {
/** Underlying storage type for this UDT */
def sqlType: DataType
@@ -50,11 +50,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
/**
* Convert the user type to a SQL datum
- *
- * TODO: Can we make this take obj: UserType? The issue is in
- * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType.
*/
- def serialize(obj: Any): Any
+ def serialize(obj: UserType): Any
/** Convert a SQL datum to the user type */
def deserialize(datum: Any): UserType
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 4e7bbc38d6..1b297525bd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -36,11 +36,7 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] {
override def sqlType: DataType = IntegerType
- override def serialize(obj: Any): Int = {
- obj match {
- case groupableData: GroupableData => groupableData.data
- }
- }
+ override def serialize(groupableData: GroupableData): Int = groupableData.data
override def deserialize(datum: Any): GroupableData = {
datum match {
@@ -60,13 +56,10 @@ private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] {
override def sqlType: DataType = MapType(IntegerType, IntegerType)
- override def serialize(obj: Any): MapData = {
- obj match {
- case groupableData: UngroupableData =>
- val keyArray = new GenericArrayData(groupableData.data.keys.toSeq)
- val valueArray = new GenericArrayData(groupableData.data.values.toSeq)
- new ArrayBasedMapData(keyArray, valueArray)
- }
+ override def serialize(ungroupableData: UngroupableData): MapData = {
+ val keyArray = new GenericArrayData(ungroupableData.data.keys.toSeq)
+ val valueArray = new GenericArrayData(ungroupableData.data.values.toSeq)
+ new ArrayBasedMapData(keyArray, valueArray)
}
override def deserialize(datum: Any): UngroupableData = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index f119c6f4f7..bf0360c5e1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -47,14 +47,11 @@ class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
- override def serialize(obj: Any): GenericArrayData = {
- obj match {
- case p: ExamplePoint =>
- val output = new Array[Any](2)
- output(0) = p.x
- output(1) = p.y
- new GenericArrayData(output)
- }
+ override def serialize(p: ExamplePoint): GenericArrayData = {
+ val output = new Array[Any](2)
+ output(0) = p.x
+ output(1) = p.y
+ new GenericArrayData(output)
}
override def deserialize(datum: Any): ExamplePoint = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
index e2c9fc421b..695a5ad78a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -42,14 +42,11 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
- override def serialize(obj: Any): GenericArrayData = {
- obj match {
- case p: ExamplePoint =>
- val output = new Array[Any](2)
- output(0) = p.x
- output(1) = p.y
- new GenericArrayData(output)
- }
+ override def serialize(p: ExamplePoint): GenericArrayData = {
+ val output = new Array[Any](2)
+ output(0) = p.x
+ output(1) = p.y
+ new GenericArrayData(output)
}
override def deserialize(datum: Any): ExamplePoint = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 9081bc722a..8c4afb605b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -45,11 +45,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)
- override def serialize(obj: Any): ArrayData = {
- obj match {
- case features: MyDenseVector =>
- new GenericArrayData(features.data.map(_.asInstanceOf[Any]))
- }
+ override def serialize(features: MyDenseVector): ArrayData = {
+ new GenericArrayData(features.data.map(_.asInstanceOf[Any]))
}
override def deserialize(datum: Any): MyDenseVector = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index fb99b0c7e2..f8166c7ddc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -590,14 +590,11 @@ object TestingUDT {
.add("b", LongType, nullable = false)
.add("c", DoubleType, nullable = false)
- override def serialize(obj: Any): Any = {
+ override def serialize(n: NestedStruct): Any = {
val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType))
- obj match {
- case n: NestedStruct =>
- row.setInt(0, n.a)
- row.setLong(1, n.b)
- row.setDouble(2, n.c)
- }
+ row.setInt(0, n.a)
+ row.setLong(1, n.b)
+ row.setDouble(2, n.c)
}
override def userClass: Class[NestedStruct] = classOf[NestedStruct]