aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala28
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala10
3 files changed, 37 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index bf07f4557a..5e1672c779 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -476,11 +476,17 @@ object ScalaReflection extends ScalaReflection {
// For non-primitives, we can just extract the object from the Option and then recurse.
case other =>
val className = getClassNameFromType(optType)
- val classObj = Utils.classForName(className)
- val optionObjectType = ObjectType(classObj)
val newPath = s"""- option value class: "$className"""" +: walkedTypePath
+ val optionObjectType: DataType = other match {
+ // Special handling is required for arrays, as getClassFromType(<Array>) will fail
+ // since Scala Arrays map to native Java constructs. E.g. "Array[Int]" will map to
+ // the Java type "[I".
+ case arr if arr <:< localTypeOf[Array[_]] => arrayClassFor(t)
+ case cls => ObjectType(getClassFromType(cls))
+ }
val unwrapped = UnwrapOption(optionObjectType, inputObject)
+
expressions.If(
IsNull(unwrapped),
expressions.Literal.create(null, silentSchemaFor(optType).dataType),
@@ -626,6 +632,9 @@ object ScalaReflection extends ScalaReflection {
constructParams(t).map(_.name.toString)
}
+ /*
+ * Retrieves the runtime class corresponding to the provided type.
+ */
def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
}
@@ -676,9 +685,12 @@ trait ScalaReflection {
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
val className = getClassNameFromType(tpe)
+
tpe match {
+
case t if Utils.classIsLoadable(className) &&
Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+
// Note: We check for classIsLoadable above since Utils.classForName uses Java reflection,
// whereas className is from Scala reflection. This can make it hard to find classes
// in some cases, such as when a class is enclosed in an object (in which case
@@ -748,7 +760,16 @@ trait ScalaReflection {
case _: UnsupportedOperationException => Schema(NullType, nullable = true)
}
- /** Returns the full class name for a type. */
+ /**
+ * Returns the full class name for a type. The returned name is the canonical
+ * Scala name, where each component is separated by a period. It is NOT the
+ * Java-equivalent runtime name (no dollar signs).
+ *
+ * In simple cases, both the Scala and Java names are the same, however when Scala
+ * generates constructs that do not map to a Java equivalent, such as singleton objects
+ * or nested classes in package objects, it uses the dollar sign ($) to create
+ * synthetic classes, emulating behaviour in Java bytecode.
+ */
def getClassNameFromType(tpe: `Type`): String = {
tpe.erasure.typeSymbol.asClass.fullName
}
@@ -792,4 +813,5 @@ trait ScalaReflection {
}
params.flatten
}
+
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index cca320fae9..3024858b06 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -152,6 +152,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
productTest(InnerClass(1))
encodeDecodeTest(Array(InnerClass(1)), "array of inner class")
+ encodeDecodeTest(Array(Option(InnerClass(1))), "array of optional inner class")
+
productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
productTest(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index 6e9840e4a7..ff022b2dc4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -23,6 +23,10 @@ import org.apache.spark.sql.test.SharedSQLContext
case class IntClass(value: Int)
+package object packageobject {
+ case class PackageClass(value: Int)
+}
+
class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -127,4 +131,10 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
checkDataset(Seq(Array("test")).toDS(), Array("test"))
checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
}
+
+ test("package objects") {
+ import packageobject._
+ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))
+ }
+
}