aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJakob Odersky <jakob@odersky.com>2016-09-21 19:09:33 -0700
committerJakob Odersky <jakob@odersky.com>2016-09-21 19:09:33 -0700
commit8fe58dbd9dbd70f2160acd79754e2d3729243c9e (patch)
tree43df71e7f1a6e1fb61bc82542577e8ab998cf984
parent8f5a3081ee2751b8fae00ddc30eaab4d21a0aca4 (diff)
downloadspark-8fe58dbd9dbd70f2160acd79754e2d3729243c9e.tar.gz
spark-8fe58dbd9dbd70f2160acd79754e2d3729243c9e.tar.bz2
spark-8fe58dbd9dbd70f2160acd79754e2d3729243c9e.zip
It compiles!
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaMacros.scala62
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala18
2 files changed, 72 insertions, 8 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaMacros.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaMacros.scala
new file mode 100644
index 0000000000..343ce1ace1
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaMacros.scala
@@ -0,0 +1,62 @@
+package org.apache.spark.sql.catalyst
+
+import scala.reflect.macros.blackbox.Context
+import scala.reflect.api.Universe
+
+
+class ScalaMacros(val context: Context) extends ScalaReflection {
+
+ val universe: context.universe.type = context.universe
+
+ def mirror: universe.Mirror = context.mirror
+
+}
+
+
+import scala.language.experimental.macros
+
+import org.apache.spark.sql.Encoder
+//import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance}
+import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
+import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation}
+import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType}
+import org.apache.spark.util.Utils
+
+object Macros {
+
+ implicit def encoder[T]: Encoder[T] = macro encoderImpl[T]
+
+ def encoderImpl[T: c.WeakTypeTag](c: Context): c.Expr[Encoder[T]] = {
+ val helper = new ScalaMacros(c)
+ import helper.universe._
+
+ val tag = implicitly[WeakTypeTag[T]]
+
+ val tpe = tag.tpe
+ val flat = !helper.definedByConstructorParams(tpe)
+
+ val inputObject = BoundReference(0, helper.dataTypeFor[T](tag), nullable = true)
+ val nullSafeInput = if (flat) {
+ inputObject
+ } else {
+ // For input object of non-flat type, we can't encode it to row if it's null, as Spark SQL
+ // doesn't allow top-level row to be null, only its columns can be null.
+ AssertNotNull(inputObject, Seq("top level non-flat input object"))
+ }
+ val serializer = helper.serializerFor[T](nullSafeInput)(tag)
+ val deserializer = helper.deserializerFor[T](tag)
+
+ val schema = helper.schemaFor[T](tag) match {
+ case helper.Schema(s: StructType, _) => s
+ case helper.Schema(dt, nullable) => new StructType().add("value", dt, nullable)
+ }
+
+ ???
+ }
+
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 64569b3e60..53848641eb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -37,6 +37,7 @@ trait DefinedByConstructorParams
* A default version of ScalaReflection that uses the runtime universe.
*/
object ScalaReflection extends ScalaReflection {
+
val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe
// Since we are creating a runtime mirror using the class loader of current thread,
// we need to use def at here. So, every time we call mirror, it is using the
@@ -60,6 +61,7 @@ object ScalaReflection extends ScalaReflection {
* object, this trait able to work in both the runtime and the compile time (macro) universe.
*/
trait ScalaReflection {
+
/** The universe we work in (runtime or macro) */
val universe: scala.reflect.api.Universe
@@ -109,9 +111,9 @@ trait ScalaReflection {
*
* @see SPARK-5281
*/
- // SPARK-13640: Synchronize this because TypeTag.tpe is not thread-safe in Scala 2.10.
- def localTypeOf[T: TypeTag]: `Type` = ScalaReflectionLock.synchronized {
- val tag = implicitly[TypeTag[T]]
+ // SPARK-13640: Synchronize this because WeakTypeTag.tpe is not thread-safe in Scala 2.10.
+ def localTypeOf[T: WeakTypeTag]: `Type` = ScalaReflectionLock.synchronized {
+ val tag = implicitly[WeakTypeTag[T]]
tag.in(mirror).tpe.normalize
}
@@ -181,7 +183,7 @@ trait ScalaReflection {
* Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type
* system. As a result, ObjectType will be returned for things like boxed Integers
*/
- def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T])
+ def dataTypeFor[T : WeakTypeTag]: DataType = dataTypeFor(localTypeOf[T])
private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized {
tpe match {
@@ -250,7 +252,7 @@ trait ScalaReflection {
* from ordinal 0 (since there are no names to map to). The actual location can be moved by
* calling resolve/bind with a new schema.
*/
- def deserializerFor[T : TypeTag]: Expression = {
+ def deserializerFor[T : WeakTypeTag]: Expression = {
val tpe = localTypeOf[T]
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
@@ -539,7 +541,7 @@ trait ScalaReflection {
* * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"`
* * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")`
*/
- def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
+ def serializerFor[T : WeakTypeTag](inputObject: Expression): CreateNamedStruct = {
val tpe = localTypeOf[T]
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
@@ -723,13 +725,13 @@ trait ScalaReflection {
case class Schema(dataType: DataType, nullable: Boolean)
/** Returns a Sequence of attributes for the given case class type. */
- def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
+ def attributesFor[T: WeakTypeTag]: Seq[Attribute] = schemaFor[T] match {
case Schema(s: StructType, _) =>
s.toAttributes
}
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
- def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T])
+ def schemaFor[T: WeakTypeTag]: Schema = schemaFor(localTypeOf[T])
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {