aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2016-01-05 10:19:56 -0800
committerMichael Armbrust <michael@databricks.com>2016-01-05 10:19:56 -0800
commitb3c48e39f4a0a42a0b6b433511b2cce0d1e3f03d (patch)
tree488e56d4c16f379ceec604904b4ee824ce3a474a
parent1cdc42d2b99edfec01066699a7620cca02b61f0e (diff)
downloadspark-b3c48e39f4a0a42a0b6b433511b2cce0d1e3f03d.tar.gz
spark-b3c48e39f4a0a42a0b6b433511b2cce0d1e3f03d.tar.bz2
spark-b3c48e39f4a0a42a0b6b433511b2cce0d1e3f03d.zip
[SPARK-12438][SQL] Add SQLUserDefinedType support for encoder
JIRA: https://issues.apache.org/jira/browse/SPARK-12438 ScalaReflection lacks the support of SQLUserDefinedType. We should add it. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #10390 from viirya/encoder-udt.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala22
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala2
3 files changed, 38 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 9784c96966..c6aa60b0b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -177,6 +177,7 @@ object ScalaReflection extends ScalaReflection {
case _ => UpCast(expr, expected, walkedTypePath)
}
+ val className = getClassNameFromType(tpe)
tpe match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
@@ -360,6 +361,16 @@ object ScalaReflection extends ScalaReflection {
} else {
newInstance
}
+
+ case t if Utils.classIsLoadable(className) &&
+ Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+ val udt = Utils.classForName(className)
+ .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+ val obj = NewInstance(
+ udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+ Nil,
+ dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+ Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
}
}
@@ -409,6 +420,7 @@ object ScalaReflection extends ScalaReflection {
if (!inputObject.dataType.isInstanceOf[ObjectType]) {
inputObject
} else {
+ val className = getClassNameFromType(tpe)
tpe match {
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
@@ -559,6 +571,16 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[java.lang.Boolean] =>
Invoke(inputObject, "booleanValue", BooleanType)
+ case t if Utils.classIsLoadable(className) &&
+ Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+ val udt = Utils.classForName(className)
+ .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+ val obj = NewInstance(
+ udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+ Nil,
+ dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+ Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
+
case other =>
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index b18f49f320..d82d3edae4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.math.{BigDecimal => JavaBigDecimal}
+import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -81,6 +82,9 @@ object Cast {
toField.nullable)
}
+ case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt1.userClass == udt2.userClass =>
+ true
+
case _ => false
}
@@ -431,6 +435,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
case map: MapType => castMap(from.asInstanceOf[MapType], map)
case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
+ case udt: UserDefinedType[_]
+ if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
+ identity[Any]
+ case _: UserDefinedType[_] =>
+ throw new SparkException(s"Cannot cast $from to $to.")
}
private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
@@ -473,6 +482,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx)
case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx)
case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx)
+ case udt: UserDefinedType[_]
+ if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
+ (c, evPrim, evNull) => s"$evPrim = $c;"
+ case _: UserDefinedType[_] =>
+ throw new SparkException(s"Cannot cast $from to $to.")
}
// Since we need to cast child expressions recursively inside ComplexTypes, such as Map's
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 3740dea8aa..6453f1c191 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -244,6 +244,8 @@ class ExpressionEncoderSuite extends SparkFunSuite {
ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
}
+ productTest(("UDT", new ExamplePoint(0.1, 0.2)))
+
test("nullable of encoder schema") {
def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = {
assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq)