aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJakob Odersky <jakob@odersky.com>2016-03-16 21:53:16 -0700
committerReynold Xin <rxin@databricks.com>2016-03-16 21:53:16 -0700
commit7eef2463ade693f8e87ecd913a5adf03b69ac14e (patch)
treeee4aa4e324e41d49f2f580783cd751fa83a17b18
parentc100d31ddc6db9c03b7a65a20a7dd56dcdc18baf (diff)
downloadspark-7eef2463ade693f8e87ecd913a5adf03b69ac14e.tar.gz
spark-7eef2463ade693f8e87ecd913a5adf03b69ac14e.tar.bz2
spark-7eef2463ade693f8e87ecd913a5adf03b69ac14e.zip
[SPARK-13118][SQL] Expression encoding for optional synthetic classes
## What changes were proposed in this pull request? Fix expression generation for optional types. Standard Java reflection causes issues when dealing with synthetic Scala objects (things that do not map to Java and thus contain a dollar sign in their name). This patch introduces Scala reflection in such cases. This patch also adds a regression test for Dataset's handling of classes defined in package objects (which was the initial purpose of this PR). ## How was this patch tested? A new test in ExpressionEncoderSuite that tests optional inner classes and a regression test for Dataset's handling of package objects. Author: Jakob Odersky <jakob@odersky.com> Closes #11708 from jodersky/SPARK-13118-package-objects.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala28
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala10
3 files changed, 37 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index bf07f4557a..5e1672c779 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -476,11 +476,17 @@ object ScalaReflection extends ScalaReflection {
// For non-primitives, we can just extract the object from the Option and then recurse.
case other =>
val className = getClassNameFromType(optType)
- val classObj = Utils.classForName(className)
- val optionObjectType = ObjectType(classObj)
val newPath = s"""- option value class: "$className"""" +: walkedTypePath
+ val optionObjectType: DataType = other match {
+ // Special handling is required for arrays, as getClassFromType(<Array>) will fail
+ // since Scala Arrays map to native Java constructs. E.g. "Array[Int]" will map to
+ // the Java type "[I".
+ case arr if arr <:< localTypeOf[Array[_]] => arrayClassFor(t)
+ case cls => ObjectType(getClassFromType(cls))
+ }
val unwrapped = UnwrapOption(optionObjectType, inputObject)
+
expressions.If(
IsNull(unwrapped),
expressions.Literal.create(null, silentSchemaFor(optType).dataType),
@@ -626,6 +632,9 @@ object ScalaReflection extends ScalaReflection {
constructParams(t).map(_.name.toString)
}
+ /*
+ * Retrieves the runtime class corresponding to the provided type.
+ */
def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
}
@@ -676,9 +685,12 @@ trait ScalaReflection {
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
val className = getClassNameFromType(tpe)
+
tpe match {
+
case t if Utils.classIsLoadable(className) &&
Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+
// Note: We check for classIsLoadable above since Utils.classForName uses Java reflection,
// whereas className is from Scala reflection. This can make it hard to find classes
// in some cases, such as when a class is enclosed in an object (in which case
@@ -748,7 +760,16 @@ trait ScalaReflection {
case _: UnsupportedOperationException => Schema(NullType, nullable = true)
}
- /** Returns the full class name for a type. */
+ /**
+ * Returns the full class name for a type. The returned name is the canonical
+ * Scala name, where each component is separated by a period. It is NOT the
+ * Java-equivalent runtime name (no dollar signs).
+ *
+ * In simple cases, both the Scala and Java names are the same, however when Scala
+ * generates constructs that do not map to a Java equivalent, such as singleton objects
+ * or nested classes in package objects, it uses the dollar sign ($) to create
+ * synthetic classes, emulating behaviour in Java bytecode.
+ */
def getClassNameFromType(tpe: `Type`): String = {
tpe.erasure.typeSymbol.asClass.fullName
}
@@ -792,4 +813,5 @@ trait ScalaReflection {
}
params.flatten
}
+
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index cca320fae9..3024858b06 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -152,6 +152,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
productTest(InnerClass(1))
encodeDecodeTest(Array(InnerClass(1)), "array of inner class")
+ encodeDecodeTest(Array(Option(InnerClass(1))), "array of optional inner class")
+
productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
productTest(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index 6e9840e4a7..ff022b2dc4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -23,6 +23,10 @@ import org.apache.spark.sql.test.SharedSQLContext
case class IntClass(value: Int)
+package object packageobject {
+ case class PackageClass(value: Int)
+}
+
class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -127,4 +131,10 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
checkDataset(Seq(Array("test")).toDS(), Array("test"))
checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
}
+
+ test("package objects") {
+ import packageobject._
+ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))
+ }
+
}