aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala70
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala28
2 files changed, 62 insertions, 36 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 58df651da2..36989a20cb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -348,6 +348,23 @@ object ScalaReflection extends ScalaReflection {
"toScalaMap",
keyData :: valueData :: Nil)
+ case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
+ val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+ val obj = NewInstance(
+ udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+ Nil,
+ dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+ Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
+
+ case t if UDTRegistration.exists(getClassNameFromType(t)) =>
+ val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
+ .asInstanceOf[UserDefinedType[_]]
+ val obj = NewInstance(
+ udt.getClass,
+ Nil,
+ dataType = ObjectType(udt.getClass))
+ Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
+
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
@@ -388,23 +405,6 @@ object ScalaReflection extends ScalaReflection {
} else {
newInstance
}
-
- case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
- val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
- val obj = NewInstance(
- udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
- Nil,
- dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
- Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
-
- case t if UDTRegistration.exists(getClassNameFromType(t)) =>
- val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
- .asInstanceOf[UserDefinedType[_]]
- val obj = NewInstance(
- udt.getClass,
- Nil,
- dataType = ObjectType(udt.getClass))
- Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
}
}
@@ -522,17 +522,6 @@ object ScalaReflection extends ScalaReflection {
val TypeRef(_, _, Seq(elementType)) = t
toCatalystArray(inputObject, elementType)
- case t if definedByConstructorParams(t) =>
- val params = getConstructorParameters(t)
- val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
- val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
- val clsName = getClassNameFromType(fieldType)
- val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
- expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
- })
- val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
- expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
-
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
toCatalystArray(inputObject, elementType)
@@ -645,6 +634,17 @@ object ScalaReflection extends ScalaReflection {
dataType = ObjectType(udt.getClass))
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
+ case t if definedByConstructorParams(t) =>
+ val params = getConstructorParameters(t)
+ val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
+ val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
+ val clsName = getClassNameFromType(fieldType)
+ val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
+ expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
+ })
+ val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
+ expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
+
case other =>
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
@@ -743,13 +743,6 @@ object ScalaReflection extends ScalaReflection {
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
- case t if definedByConstructorParams(t) =>
- val params = getConstructorParameters(t)
- Schema(StructType(
- params.map { case (fieldName, fieldType) =>
- val Schema(dataType, nullable) = schemaFor(fieldType)
- StructField(fieldName, dataType, nullable)
- }), nullable = true)
case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true)
@@ -775,6 +768,13 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
+ case t if definedByConstructorParams(t) =>
+ val params = getConstructorParameters(t)
+ Schema(StructType(
+ params.map { case (fieldName, fieldType) =>
+ val Schema(dataType, nullable) = schemaFor(fieldType)
+ StructField(fieldName, dataType, nullable)
+ }), nullable = true)
case other =>
throw new UnsupportedOperationException(s"Schema for type $other is not supported")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index d4387890b4..3d97113b52 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -31,7 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.types.{ArrayType, Decimal, ObjectType, StructType}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
case class RepeatedStruct(s: Seq[PrimitiveData])
@@ -86,6 +87,25 @@ class JavaSerializable(val value: Int) extends Serializable {
}
}
+/** For testing UDT for a case class */
+@SQLUserDefinedType(udt = classOf[UDTForCaseClass])
+case class UDTCaseClass(uri: java.net.URI)
+
+class UDTForCaseClass extends UserDefinedType[UDTCaseClass] {
+
+ override def sqlType: DataType = StringType
+
+ override def serialize(obj: UDTCaseClass): UTF8String = {
+ UTF8String.fromString(obj.uri.toString)
+ }
+
+ override def userClass: Class[UDTCaseClass] = classOf[UDTCaseClass]
+
+ override def deserialize(datum: Any): UDTCaseClass = datum match {
+ case uri: UTF8String => UDTCaseClass(new java.net.URI(uri.toString))
+ }
+}
+
class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
OuterScopes.addOuterScope(this)
@@ -147,6 +167,12 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple")
encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple")
+ encodeDecodeTest(List(1, 2), "list of int")
+ encodeDecodeTest(List("a", null), "list with String and null")
+
+ encodeDecodeTest(
+ UDTCaseClass(new java.net.URI("http://spark.apache.org/")), "udt with case class")
+
// Kryo encoders
encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String]))
encodeDecodeTest(new KryoSerializable(15), "kryo object")(