diff options
3 files changed, 37 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index bf07f4557a..5e1672c779 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -476,11 +476,17 @@ object ScalaReflection extends ScalaReflection { // For non-primitives, we can just extract the object from the Option and then recurse. case other => val className = getClassNameFromType(optType) - val classObj = Utils.classForName(className) - val optionObjectType = ObjectType(classObj) val newPath = s"""- option value class: "$className"""" +: walkedTypePath + val optionObjectType: DataType = other match { + // Special handling is required for arrays, as getClassFromType(<Array>) will fail + // since Scala Arrays map to native Java constructs. E.g. "Array[Int]" will map to + // the Java type "[I". + case arr if arr <:< localTypeOf[Array[_]] => arrayClassFor(t) + case cls => ObjectType(getClassFromType(cls)) + } val unwrapped = UnwrapOption(optionObjectType, inputObject) + expressions.If( IsNull(unwrapped), expressions.Literal.create(null, silentSchemaFor(optType).dataType), @@ -626,6 +632,9 @@ object ScalaReflection extends ScalaReflection { constructParams(t).map(_.name.toString) } + /* + * Retrieves the runtime class corresponding to the provided type. + */ def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) } @@ -676,9 +685,12 @@ trait ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { val className = getClassNameFromType(tpe) + tpe match { + case t if Utils.classIsLoadable(className) && Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + // Note: We check for classIsLoadable above since Utils.classForName uses Java reflection, // whereas className is from Scala reflection. This can make it hard to find classes // in some cases, such as when a class is enclosed in an object (in which case @@ -748,7 +760,16 @@ trait ScalaReflection { case _: UnsupportedOperationException => Schema(NullType, nullable = true) } - /** Returns the full class name for a type. */ + /** + * Returns the full class name for a type. The returned name is the canonical + * Scala name, where each component is separated by a period. It is NOT the + * Java-equivalent runtime name (no dollar signs). + * + * In simple cases, both the Scala and Java names are the same, however when Scala + * generates constructs that do not map to a Java equivalent, such as singleton objects + * or nested classes in package objects, it uses the dollar sign ($) to create + * synthetic classes, emulating behaviour in Java bytecode. + */ def getClassNameFromType(tpe: `Type`): String = { tpe.erasure.typeSymbol.asClass.fullName } @@ -792,4 +813,5 @@ trait ScalaReflection { } params.flatten } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index cca320fae9..3024858b06 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -152,6 +152,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { productTest(InnerClass(1)) encodeDecodeTest(Array(InnerClass(1)), "array of inner class") + encodeDecodeTest(Array(Option(InnerClass(1))), "array of optional inner class") + productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) productTest( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 6e9840e4a7..ff022b2dc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -23,6 +23,10 @@ import org.apache.spark.sql.test.SharedSQLContext case class IntClass(value: Int) +package object packageobject { + case class PackageClass(value: Int) +} + class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -127,4 +131,10 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(Array("test")).toDS(), Array("test")) checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1))) } + + test("package objects") { + import packageobject._ + checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) + } + } |