diff options
author | Shixiong Zhu <shixiong@databricks.com> | 2016-05-20 12:38:46 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2016-05-20 12:38:46 -0700 |
commit | dfa61f7b136ae060bbe04e3c0da1148da41018c7 (patch) | |
tree | 5a55105b5c3a499572b18b52685c0b7b7731f3a0 /sql/catalyst | |
parent | 22947cd0213856442025baf653be588c6c707e36 (diff) | |
download | spark-dfa61f7b136ae060bbe04e3c0da1148da41018c7.tar.gz spark-dfa61f7b136ae060bbe04e3c0da1148da41018c7.tar.bz2 spark-dfa61f7b136ae060bbe04e3c0da1148da41018c7.zip |
[SPARK-15190][SQL] Support using SQLUserDefinedType for case classes
## What changes were proposed in this pull request?
Right now inferring the schema for case classes happens before searching the SQLUserDefinedType annotation, so the SQLUserDefinedType annotation for case classes doesn't work.
This PR simply changes the inferring order to resolve it. I also reenabled the java.math.BigDecimal test and added two tests for `List`.
## How was this patch tested?
`encodeDecodeTest(UDTCaseClass(new java.net.URI("http://spark.apache.org/")), "udt with case class")`
Author: Shixiong Zhu <shixiong@databricks.com>
Closes #12965 from zsxwing/SPARK-15190.
Diffstat (limited to 'sql/catalyst')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 70 | ||||
-rw-r--r-- | sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala | 28 |
2 files changed, 62 insertions, 36 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 58df651da2..36989a20cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -348,6 +348,23 @@ object ScalaReflection extends ScalaReflection { "toScalaMap", keyData :: valueData :: Nil) + case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) + + case t if UDTRegistration.exists(getClassNameFromType(t)) => + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() + .asInstanceOf[UserDefinedType[_]] + val obj = NewInstance( + udt.getClass, + Nil, + dataType = ObjectType(udt.getClass)) + Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) + case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) @@ -388,23 +405,6 @@ object ScalaReflection extends ScalaReflection { } else { newInstance } - - case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() - val obj = NewInstance( - udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), - Nil, - dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) - - case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() - .asInstanceOf[UserDefinedType[_]] - val obj = NewInstance( - udt.getClass, - Nil, - dataType = ObjectType(udt.getClass)) - Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) } } @@ -522,17 +522,6 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t toCatalystArray(inputObject, elementType) - case t if definedByConstructorParams(t) => - val params = getConstructorParameters(t) - val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => - val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - val clsName = getClassNameFromType(fieldType) - val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil - }) - val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) - expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) - case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t toCatalystArray(inputObject, elementType) @@ -645,6 +634,17 @@ object ScalaReflection extends ScalaReflection { dataType = ObjectType(udt.getClass)) Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + case t if definedByConstructorParams(t) => + val params = getConstructorParameters(t) + val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => + val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) + val clsName = getClassNameFromType(fieldType) + val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil + }) + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) + case other => throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) @@ -743,13 +743,6 @@ object ScalaReflection extends ScalaReflection { val Schema(valueDataType, valueNullable) = schemaFor(valueType) Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if definedByConstructorParams(t) => - val params = getConstructorParameters(t) - Schema(StructType( - params.map { case (fieldName, fieldType) => - val Schema(dataType, nullable) = schemaFor(fieldType) - StructField(fieldName, dataType, nullable) - }), nullable = true) case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true) @@ -775,6 +768,13 @@ object ScalaReflection extends ScalaReflection { case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case t if definedByConstructorParams(t) => + val params = getConstructorParameters(t) + Schema(StructType( + params.map { case (fieldName, fieldType) => + val Schema(dataType, nullable) = schemaFor(fieldType) + StructField(fieldName, dataType, nullable) + }), nullable = true) case other => throw new UnsupportedOperationException(s"Schema for type $other is not supported") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index d4387890b4..3d97113b52 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -31,7 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.types.{ArrayType, Decimal, ObjectType, StructType} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -86,6 +87,25 @@ class JavaSerializable(val value: Int) extends Serializable { } } +/** For testing UDT for a case class */ +@SQLUserDefinedType(udt = classOf[UDTForCaseClass]) +case class UDTCaseClass(uri: java.net.URI) + +class UDTForCaseClass extends UserDefinedType[UDTCaseClass] { + + override def sqlType: DataType = StringType + + override def serialize(obj: UDTCaseClass): UTF8String = { + UTF8String.fromString(obj.uri.toString) + } + + override def userClass: Class[UDTCaseClass] = classOf[UDTCaseClass] + + override def deserialize(datum: Any): UDTCaseClass = datum match { + case uri: UTF8String => UDTCaseClass(new java.net.URI(uri.toString)) + } +} + class ExpressionEncoderSuite extends PlanTest with AnalysisTest { OuterScopes.addOuterScope(this) @@ -147,6 +167,12 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple") encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple") + encodeDecodeTest(List(1, 2), "list of int") + encodeDecodeTest(List("a", null), "list with String and null") + + encodeDecodeTest( + UDTCaseClass(new java.net.URI("http://spark.apache.org/")), "udt with case class") + // Kryo encoders encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String])) encodeDecodeTest(new KryoSerializable(15), "kryo object")( |