aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2016-05-20 12:38:46 -0700
committerMichael Armbrust <michael@databricks.com>2016-05-20 12:38:46 -0700
commitdfa61f7b136ae060bbe04e3c0da1148da41018c7 (patch)
tree5a55105b5c3a499572b18b52685c0b7b7731f3a0 /sql
parent22947cd0213856442025baf653be588c6c707e36 (diff)
downloadspark-dfa61f7b136ae060bbe04e3c0da1148da41018c7.tar.gz
spark-dfa61f7b136ae060bbe04e3c0da1148da41018c7.tar.bz2
spark-dfa61f7b136ae060bbe04e3c0da1148da41018c7.zip
[SPARK-15190][SQL] Support using SQLUserDefinedType for case classes
## What changes were proposed in this pull request? Right now inferring the schema for case classes happens before searching the SQLUserDefinedType annotation, so the SQLUserDefinedType annotation for case classes doesn't work. This PR simply changes the inferring order to resolve it. I also reenabled the java.math.BigDecimal test and added two tests for `List`. ## How was this patch tested? `encodeDecodeTest(UDTCaseClass(new java.net.URI("http://spark.apache.org/")), "udt with case class")` Author: Shixiong Zhu <shixiong@databricks.com> Closes #12965 from zsxwing/SPARK-15190.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala70
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala28
2 files changed, 62 insertions, 36 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 58df651da2..36989a20cb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -348,6 +348,23 @@ object ScalaReflection extends ScalaReflection {
"toScalaMap",
keyData :: valueData :: Nil)
+ case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
+ val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+ val obj = NewInstance(
+ udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+ Nil,
+ dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+ Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
+
+ case t if UDTRegistration.exists(getClassNameFromType(t)) =>
+ val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
+ .asInstanceOf[UserDefinedType[_]]
+ val obj = NewInstance(
+ udt.getClass,
+ Nil,
+ dataType = ObjectType(udt.getClass))
+ Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
+
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
@@ -388,23 +405,6 @@ object ScalaReflection extends ScalaReflection {
} else {
newInstance
}
-
- case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
- val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
- val obj = NewInstance(
- udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
- Nil,
- dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
- Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
-
- case t if UDTRegistration.exists(getClassNameFromType(t)) =>
- val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
- .asInstanceOf[UserDefinedType[_]]
- val obj = NewInstance(
- udt.getClass,
- Nil,
- dataType = ObjectType(udt.getClass))
- Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
}
}
@@ -522,17 +522,6 @@ object ScalaReflection extends ScalaReflection {
val TypeRef(_, _, Seq(elementType)) = t
toCatalystArray(inputObject, elementType)
- case t if definedByConstructorParams(t) =>
- val params = getConstructorParameters(t)
- val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
- val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
- val clsName = getClassNameFromType(fieldType)
- val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
- expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
- })
- val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
- expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
-
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
toCatalystArray(inputObject, elementType)
@@ -645,6 +634,17 @@ object ScalaReflection extends ScalaReflection {
dataType = ObjectType(udt.getClass))
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
+ case t if definedByConstructorParams(t) =>
+ val params = getConstructorParameters(t)
+ val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
+ val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
+ val clsName = getClassNameFromType(fieldType)
+ val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
+ expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
+ })
+ val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
+ expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
+
case other =>
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
@@ -743,13 +743,6 @@ object ScalaReflection extends ScalaReflection {
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
- case t if definedByConstructorParams(t) =>
- val params = getConstructorParameters(t)
- Schema(StructType(
- params.map { case (fieldName, fieldType) =>
- val Schema(dataType, nullable) = schemaFor(fieldType)
- StructField(fieldName, dataType, nullable)
- }), nullable = true)
case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true)
@@ -775,6 +768,13 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
+ case t if definedByConstructorParams(t) =>
+ val params = getConstructorParameters(t)
+ Schema(StructType(
+ params.map { case (fieldName, fieldType) =>
+ val Schema(dataType, nullable) = schemaFor(fieldType)
+ StructField(fieldName, dataType, nullable)
+ }), nullable = true)
case other =>
throw new UnsupportedOperationException(s"Schema for type $other is not supported")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index d4387890b4..3d97113b52 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -31,7 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.types.{ArrayType, Decimal, ObjectType, StructType}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
case class RepeatedStruct(s: Seq[PrimitiveData])
@@ -86,6 +87,25 @@ class JavaSerializable(val value: Int) extends Serializable {
}
}
+/** For testing UDT for a case class */
+@SQLUserDefinedType(udt = classOf[UDTForCaseClass])
+case class UDTCaseClass(uri: java.net.URI)
+
+class UDTForCaseClass extends UserDefinedType[UDTCaseClass] {
+
+ override def sqlType: DataType = StringType
+
+ override def serialize(obj: UDTCaseClass): UTF8String = {
+ UTF8String.fromString(obj.uri.toString)
+ }
+
+ override def userClass: Class[UDTCaseClass] = classOf[UDTCaseClass]
+
+ override def deserialize(datum: Any): UDTCaseClass = datum match {
+ case uri: UTF8String => UDTCaseClass(new java.net.URI(uri.toString))
+ }
+}
+
class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
OuterScopes.addOuterScope(this)
@@ -147,6 +167,12 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple")
encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple")
+ encodeDecodeTest(List(1, 2), "list of int")
+ encodeDecodeTest(List("a", null), "list with String and null")
+
+ encodeDecodeTest(
+ UDTCaseClass(new java.net.URI("http://spark.apache.org/")), "udt with case class")
+
// Kryo encoders
encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String]))
encodeDecodeTest(new KryoSerializable(15), "kryo object")(