diff options
author | Jakob Odersky <jakob@odersky.com> | 2016-09-21 19:09:33 -0700 |
---|---|---|
committer | Jakob Odersky <jakob@odersky.com> | 2016-09-21 19:09:33 -0700 |
commit | 8fe58dbd9dbd70f2160acd79754e2d3729243c9e (patch) | |
tree | 43df71e7f1a6e1fb61bc82542577e8ab998cf984 | |
parent | 8f5a3081ee2751b8fae00ddc30eaab4d21a0aca4 (diff) | |
download | spark-8fe58dbd9dbd70f2160acd79754e2d3729243c9e.tar.gz spark-8fe58dbd9dbd70f2160acd79754e2d3729243c9e.tar.bz2 spark-8fe58dbd9dbd70f2160acd79754e2d3729243c9e.zip |
It compiles!
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaMacros.scala | 62 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 18 |
2 files changed, 72 insertions, 8 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaMacros.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaMacros.scala new file mode 100644 index 0000000000..343ce1ace1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaMacros.scala @@ -0,0 +1,62 @@ +package org.apache.spark.sql.catalyst + +import scala.reflect.macros.blackbox.Context +import scala.reflect.api.Universe + + +class ScalaMacros(val context: Context) extends ScalaReflection { + + val universe: context.universe.type = context.universe + + def mirror: universe.Mirror = context.mirror + +} + + +import scala.language.experimental.macros + +import org.apache.spark.sql.Encoder +//import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance} +import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts +import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} +import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType} +import org.apache.spark.util.Utils + +object Macros { + + implicit def encoder[T]: Encoder[T] = macro encoderImpl[T] + + def encoderImpl[T: c.WeakTypeTag](c: Context): c.Expr[Encoder[T]] = { + val helper = new ScalaMacros(c) + import helper.universe._ + + val tag = implicitly[WeakTypeTag[T]] + + val tpe = tag.tpe + val flat = !helper.definedByConstructorParams(tpe) + + val inputObject = BoundReference(0, helper.dataTypeFor[T](tag), nullable = true) + val nullSafeInput = if (flat) { + inputObject + } else { + // For input object of non-flat type, we can't encode it to row if it's null, as Spark SQL + // doesn't allow top-level row to be null, only its columns can be null. + AssertNotNull(inputObject, Seq("top level non-flat input object")) + } + val serializer = helper.serializerFor[T](nullSafeInput)(tag) + val deserializer = helper.deserializerFor[T](tag) + + val schema = helper.schemaFor[T](tag) match { + case helper.Schema(s: StructType, _) => s + case helper.Schema(dt, nullable) => new StructType().add("value", dt, nullable) + } + + ??? + } + + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 64569b3e60..53848641eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -37,6 +37,7 @@ trait DefinedByConstructorParams * A default version of ScalaReflection that uses the runtime universe. */ object ScalaReflection extends ScalaReflection { + val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe // Since we are creating a runtime mirror using the class loader of current thread, // we need to use def at here. So, every time we call mirror, it is using the @@ -60,6 +61,7 @@ object ScalaReflection extends ScalaReflection { * object, this trait able to work in both the runtime and the compile time (macro) universe. */ trait ScalaReflection { + /** The universe we work in (runtime or macro) */ val universe: scala.reflect.api.Universe @@ -109,9 +111,9 @@ trait ScalaReflection { * * @see SPARK-5281 */ - // SPARK-13640: Synchronize this because TypeTag.tpe is not thread-safe in Scala 2.10. - def localTypeOf[T: TypeTag]: `Type` = ScalaReflectionLock.synchronized { - val tag = implicitly[TypeTag[T]] + // SPARK-13640: Synchronize this because WeakTypeTag.tpe is not thread-safe in Scala 2.10. + def localTypeOf[T: WeakTypeTag]: `Type` = ScalaReflectionLock.synchronized { + val tag = implicitly[WeakTypeTag[T]] tag.in(mirror).tpe.normalize } @@ -181,7 +183,7 @@ trait ScalaReflection { * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type * system. As a result, ObjectType will be returned for things like boxed Integers */ - def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) + def dataTypeFor[T : WeakTypeTag]: DataType = dataTypeFor(localTypeOf[T]) private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { tpe match { @@ -250,7 +252,7 @@ trait ScalaReflection { * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ - def deserializerFor[T : TypeTag]: Expression = { + def deserializerFor[T : WeakTypeTag]: Expression = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil @@ -539,7 +541,7 @@ trait ScalaReflection { * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` */ - def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { + def serializerFor[T : WeakTypeTag](inputObject: Expression): CreateNamedStruct = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil @@ -723,13 +725,13 @@ trait ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) /** Returns a Sequence of attributes for the given case class type. */ - def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { + def attributesFor[T: WeakTypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => s.toAttributes } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) + def schemaFor[T: WeakTypeTag]: Schema = schemaFor(localTypeOf[T]) /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { |