aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-05-23 11:13:27 -0700
committerMichael Armbrust <michael@databricks.com>2016-05-23 11:13:27 -0700
commit07c36a2f07fcf5da6fb395f830ebbfc10eb27dcc (patch)
treede8e594c84274d6918f6078e3fea760c903d2e13 /sql
parent80091b8a6840b562cf76341926e5b828d4def7e2 (diff)
downloadspark-07c36a2f07fcf5da6fb395f830ebbfc10eb27dcc.tar.gz
spark-07c36a2f07fcf5da6fb395f830ebbfc10eb27dcc.tar.bz2
spark-07c36a2f07fcf5da6fb395f830ebbfc10eb27dcc.zip
[SPARK-15471][SQL] ScalaReflection cleanup
## What changes were proposed in this pull request? 1. simplify the logic of deserializing option type. 2. simplify the logic of serializing array type, and remove silentSchemaFor 3. remove some unnecessary code. ## How was this patch tested? existing tests Author: Wenchen Fan <wenchen@databricks.com> Closes #13250 from cloud-fan/encoder.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala105
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala4
2 files changed, 21 insertions, 88 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 36989a20cb..bdd40f3402 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
@@ -72,6 +72,7 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< definitions.ByteTpe => ByteType
case t if t <:< definitions.BooleanTpe => BooleanType
case t if t <:< localTypeOf[Array[Byte]] => BinaryType
+ case t if t <:< localTypeOf[CalendarInterval] => CalendarIntervalType
case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT
case _ =>
val className = getClassNameFromType(tpe)
@@ -189,7 +190,6 @@ object ScalaReflection extends ScalaReflection {
case _ => UpCast(expr, expected, walkedTypePath)
}
- val className = getClassNameFromType(tpe)
tpe match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
@@ -239,16 +239,14 @@ object ScalaReflection extends ScalaReflection {
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Date]),
"toJavaDate",
- getPath :: Nil,
- propagateNull = true)
+ getPath :: Nil)
case t if t <:< localTypeOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Timestamp]),
"toJavaTimestamp",
- getPath :: Nil,
- propagateNull = true)
+ getPath :: Nil)
case t if t <:< localTypeOf[java.lang.String] =>
Invoke(getPath, "toString", ObjectType(classOf[String]))
@@ -437,17 +435,17 @@ object ScalaReflection extends ScalaReflection {
walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
- val externalDataType = dataTypeFor(elementType)
- val Schema(catalystType, nullable) = silentSchemaFor(elementType)
- if (isNativeType(externalDataType)) {
- NewInstance(
- classOf[GenericArrayData],
- input :: Nil,
- dataType = ArrayType(catalystType, nullable))
- } else {
- val clsName = getClassNameFromType(elementType)
- val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
- MapObjects(serializerFor(_, elementType, newPath), input, externalDataType)
+ dataTypeFor(elementType) match {
+ case dt: ObjectType =>
+ val clsName = getClassNameFromType(elementType)
+ val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
+ MapObjects(serializerFor(_, elementType, newPath), input, dt)
+
+ case dt =>
+ NewInstance(
+ classOf[GenericArrayData],
+ input :: Nil,
+ dataType = ArrayType(dt, schemaFor(elementType).nullable))
}
}
@@ -457,63 +455,10 @@ object ScalaReflection extends ScalaReflection {
tpe match {
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
- optType match {
- // For primitive types we must manually unbox the value of the object.
- case t if t <:< definitions.IntTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject),
- "intValue",
- IntegerType)
- case t if t <:< definitions.LongTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject),
- "longValue",
- LongType)
- case t if t <:< definitions.DoubleTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject),
- "doubleValue",
- DoubleType)
- case t if t <:< definitions.FloatTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject),
- "floatValue",
- FloatType)
- case t if t <:< definitions.ShortTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject),
- "shortValue",
- ShortType)
- case t if t <:< definitions.ByteTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject),
- "byteValue",
- ByteType)
- case t if t <:< definitions.BooleanTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject),
- "booleanValue",
- BooleanType)
-
- // For non-primitives, we can just extract the object from the Option and then recurse.
- case other =>
- val className = getClassNameFromType(optType)
- val newPath = s"""- option value class: "$className"""" +: walkedTypePath
-
- val optionObjectType: DataType = other match {
- // Special handling is required for arrays, as getClassFromType(<Array>) will fail
- // since Scala Arrays map to native Java constructs. E.g. "Array[Int]" will map to
- // the Java type "[I".
- case arr if arr <:< localTypeOf[Array[_]] => arrayClassFor(t)
- case cls => ObjectType(getClassFromType(cls))
- }
- val unwrapped = UnwrapOption(optionObjectType, inputObject)
-
- expressions.If(
- IsNull(unwrapped),
- expressions.Literal.create(null, silentSchemaFor(optType).dataType),
- serializerFor(unwrapped, optType, newPath))
- }
+ val className = getClassNameFromType(optType)
+ val newPath = s"""- option value class: "$className"""" +: walkedTypePath
+ val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject)
+ serializerFor(unwrapped, optType, newPath)
// Since List[_] also belongs to localTypeOf[Product], we put this case before
// "case t if definedByConstructorParams(t)" to make sure it will match to the
@@ -704,18 +649,6 @@ object ScalaReflection extends ScalaReflection {
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T])
- /**
- * Returns a catalyst DataType and its nullability for the given Scala Type using reflection.
- *
- * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return
- * `NullType` silently instead.
- */
- def silentSchemaFor(tpe: `Type`): Schema = try {
- schemaFor(tpe)
- } catch {
- case _: UnsupportedOperationException => Schema(NullType, nullable = true)
- }
-
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
tpe match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 5e17f89209..2f2323fa3a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -289,8 +289,8 @@ case class UnwrapOption(
${inputObject.code}
final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
- $javaType ${ev.value} =
- ${ev.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) ${inputObject.value}.get();
+ $javaType ${ev.value} = ${ev.isNull} ?
+ ${ctx.defaultValue(javaType)} : (${ctx.boxedType(javaType)}) ${inputObject.value}.get();
"""
ev.copy(code = code)
}