diff options
author | Joan <joan@goyeau.com> | 2016-04-19 17:36:31 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2016-04-19 17:36:31 -0700 |
commit | 3ae25f244bd471ef77002c703f2cc7ed6b524f11 (patch) | |
tree | 6eb824fa2ac7b58c855f606a7a96936c91abbd32 /sql/catalyst | |
parent | 10f273d8db999cdc2e6c73bdbe98757de5d11676 (diff) | |
download | spark-3ae25f244bd471ef77002c703f2cc7ed6b524f11.tar.gz spark-3ae25f244bd471ef77002c703f2cc7ed6b524f11.tar.bz2 spark-3ae25f244bd471ef77002c703f2cc7ed6b524f11.zip |
[SPARK-13929] Use Scala reflection for UDTs
## What changes were proposed in this pull request?
Enable ScalaReflection and User Defined Types for plain Scala classes.
This involves the move of `schemaFor` from `ScalaReflection` trait (which is Runtime and Compile time (macros) reflection) to the `ScalaReflection` object (runtime reflection only) as I believe this code wouldn't work at compile time anyway as it manipulates `Class`'s that are not compiled yet.
## How was this patch tested?
Unit test
Author: Joan <joan@goyeau.com>
Closes #12149 from joan38/SPARK-13929-Scala-reflection.
Diffstat (limited to 'sql/catalyst')
3 files changed, 78 insertions, 62 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java index 1e4e5ede8c..110ed460cc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -24,11 +24,6 @@ import org.apache.spark.annotation.DeveloperApi; /** * ::DeveloperApi:: * A user-defined type which can be automatically recognized by a SQLContext and registered. - * <p> - * WARNING: This annotation will only work if both Java and Scala reflection return the same class - * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class - * is enclosed in an object (a singleton). - * <p> * WARNING: UDTs are currently only supported from Scala. */ // TODO: Should I used @Documented ? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4795fc2557..bd723135b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -374,10 +374,8 @@ object ScalaReflection extends ScalaReflection { newInstance } - case t if Utils.classIsLoadable(className) && - Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => - val udt = Utils.classForName(className) - .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, @@ -432,7 +430,6 @@ object ScalaReflection extends ScalaReflection { if (!inputObject.dataType.isInstanceOf[ObjectType]) { inputObject } else { - val className = getClassNameFromType(tpe) tpe match { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -589,9 +586,8 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) - case t if Utils.classIsLoadable(className) && - Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => - val udt = Utils.classForName(className) + case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => + val udt = getClassFromType(t) .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), @@ -637,24 +633,6 @@ object ScalaReflection extends ScalaReflection { * Retrieves the runtime class corresponding to the provided type. */ def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) -} - -/** - * Support for generating catalyst schemas for scala objects. Note that unlike its companion - * object, this trait able to work in both the runtime and the compile time (macro) universe. - */ -trait ScalaReflection { - /** The universe we work in (runtime or macro) */ - val universe: scala.reflect.api.Universe - - /** The mirror used to access types in the universe */ - def mirror: universe.Mirror - - import universe._ - - // The Predef.Map is scala.collection.immutable.Map. - // Since the map values can be mutable, we explicitly import scala.collection.Map at here. - import scala.collection.Map case class Schema(dataType: DataType, nullable: Boolean) @@ -668,36 +646,22 @@ trait ScalaReflection { def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) /** - * Return the Scala Type for `T` in the current classloader mirror. - * - * Use this method instead of the convenience method `universe.typeOf`, which - * assumes that all types can be found in the classloader that loaded scala-reflect classes. - * That's not necessarily the case when running using Eclipse launchers or even - * Sbt console or test (without `fork := true`). + * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. * - * @see SPARK-5281 + * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return + * `NullType` silently instead. */ - // SPARK-13640: Synchronize this because TypeTag.tpe is not thread-safe in Scala 2.10. - def localTypeOf[T: TypeTag]: `Type` = ScalaReflectionLock.synchronized { - val tag = implicitly[TypeTag[T]] - tag.in(mirror).tpe.normalize + def silentSchemaFor(tpe: `Type`): Schema = try { + schemaFor(tpe) + } catch { + case _: UnsupportedOperationException => Schema(NullType, nullable = true) } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { - val className = getClassNameFromType(tpe) - tpe match { - - case t if Utils.classIsLoadable(className) && - Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => - - // Note: We check for classIsLoadable above since Utils.classForName uses Java reflection, - // whereas className is from Scala reflection. This can make it hard to find classes - // in some cases, such as when a class is enclosed in an object (in which case - // Java appends a '$' to the object name but Scala does not). - val udt = Utils.classForName(className) - .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() Schema(udt, nullable = true) case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -748,17 +712,39 @@ trait ScalaReflection { throw new UnsupportedOperationException(s"Schema for type $other is not supported") } } +} + +/** + * Support for generating catalyst schemas for scala objects. Note that unlike its companion + * object, this trait able to work in both the runtime and the compile time (macro) universe. + */ +trait ScalaReflection { + /** The universe we work in (runtime or macro) */ + val universe: scala.reflect.api.Universe + + /** The mirror used to access types in the universe */ + def mirror: universe.Mirror + + import universe._ + + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map /** - * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. + * Return the Scala Type for `T` in the current classloader mirror. * - * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return - * `NullType` silently instead. + * Use this method instead of the convenience method `universe.typeOf`, which + * assumes that all types can be found in the classloader that loaded scala-reflect classes. + * That's not necessarily the case when running using Eclipse launchers or even + * Sbt console or test (without `fork := true`). + * + * @see SPARK-5281 */ - def silentSchemaFor(tpe: `Type`): Schema = try { - schemaFor(tpe) - } catch { - case _: UnsupportedOperationException => Schema(NullType, nullable = true) + // SPARK-13640: Synchronize this because TypeTag.tpe is not thread-safe in Scala 2.10. + def localTypeOf[T: TypeTag]: `Type` = ScalaReflectionLock.synchronized { + val tag = implicitly[TypeTag[T]] + tag.in(mirror).tpe.normalize } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 5ca5a72512..0672551b29 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -23,7 +23,7 @@ import java.sql.{Date, Timestamp} import scala.reflect.runtime.universe.typeOf import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.expressions.{BoundReference, SpecificMutableRow} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -81,9 +81,44 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) { def this(b: String, a: Int) = this(a, b, c = 1.0) } +object TestingUDT { + @SQLUserDefinedType(udt = classOf[NestedStructUDT]) + class NestedStruct(val a: Integer, val b: Long, val c: Double) + + class NestedStructUDT extends UserDefinedType[NestedStruct] { + override def sqlType: DataType = new StructType() + .add("a", IntegerType, nullable = true) + .add("b", LongType, nullable = false) + .add("c", DoubleType, nullable = false) + + override def serialize(n: NestedStruct): Any = { + val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType)) + row.setInt(0, n.a) + row.setLong(1, n.b) + row.setDouble(2, n.c) + } + + override def userClass: Class[NestedStruct] = classOf[NestedStruct] + + override def deserialize(datum: Any): NestedStruct = datum match { + case row: InternalRow => + new NestedStruct(row.getInt(0), row.getLong(1), row.getDouble(2)) + } + } +} + + class ScalaReflectionSuite extends SparkFunSuite { import org.apache.spark.sql.catalyst.ScalaReflection._ + test("SQLUserDefinedType annotation on Scala structure") { + val schema = schemaFor[TestingUDT.NestedStruct] + assert(schema === Schema( + new TestingUDT.NestedStructUDT, + nullable = true + )) + } + test("primitive data") { val schema = schemaFor[PrimitiveData] assert(schema === Schema( |