aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala105
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala4
2 files changed, 21 insertions, 88 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 36989a20cb..bdd40f3402 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
@@ -72,6 +72,7 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< definitions.ByteTpe => ByteType
case t if t <:< definitions.BooleanTpe => BooleanType
case t if t <:< localTypeOf[Array[Byte]] => BinaryType
+ case t if t <:< localTypeOf[CalendarInterval] => CalendarIntervalType
case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT
case _ =>
val className = getClassNameFromType(tpe)
@@ -189,7 +190,6 @@ object ScalaReflection extends ScalaReflection {
case _ => UpCast(expr, expected, walkedTypePath)
}
- val className = getClassNameFromType(tpe)
tpe match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
@@ -239,16 +239,14 @@ object ScalaReflection extends ScalaReflection {
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Date]),
"toJavaDate",
- getPath :: Nil,
- propagateNull = true)
+ getPath :: Nil)
case t if t <:< localTypeOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Timestamp]),
"toJavaTimestamp",
- getPath :: Nil,
- propagateNull = true)
+ getPath :: Nil)
case t if t <:< localTypeOf[java.lang.String] =>
Invoke(getPath, "toString", ObjectType(classOf[String]))
@@ -437,17 +435,17 @@ object ScalaReflection extends ScalaReflection {
walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
- val externalDataType = dataTypeFor(elementType)
- val Schema(catalystType, nullable) = silentSchemaFor(elementType)
- if (isNativeType(externalDataType)) {
- NewInstance(
- classOf[GenericArrayData],
- input :: Nil,
- dataType = ArrayType(catalystType, nullable))
- } else {
- val clsName = getClassNameFromType(elementType)
- val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
- MapObjects(serializerFor(_, elementType, newPath), input, externalDataType)
+ dataTypeFor(elementType) match {
+ case dt: ObjectType =>
+ val clsName = getClassNameFromType(elementType)
+ val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
+ MapObjects(serializerFor(_, elementType, newPath), input, dt)
+
+ case dt =>
+ NewInstance(
+ classOf[GenericArrayData],
+ input :: Nil,
+ dataType = ArrayType(dt, schemaFor(elementType).nullable))
}
}
@@ -457,63 +455,10 @@ object ScalaReflection extends ScalaReflection {
tpe match {
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
- optType match {
- // For primitive types we must manually unbox the value of the object.
- case t if t <:< definitions.IntTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject),
- "intValue",
- IntegerType)
- case t if t <:< definitions.LongTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject),
- "longValue",
- LongType)
- case t if t <:< definitions.DoubleTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject),
- "doubleValue",
- DoubleType)
- case t if t <:< definitions.FloatTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject),
- "floatValue",
- FloatType)
- case t if t <:< definitions.ShortTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject),
- "shortValue",
- ShortType)
- case t if t <:< definitions.ByteTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject),
- "byteValue",
- ByteType)
- case t if t <:< definitions.BooleanTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject),
- "booleanValue",
- BooleanType)
-
- // For non-primitives, we can just extract the object from the Option and then recurse.
- case other =>
- val className = getClassNameFromType(optType)
- val newPath = s"""- option value class: "$className"""" +: walkedTypePath
-
- val optionObjectType: DataType = other match {
- // Special handling is required for arrays, as getClassFromType(<Array>) will fail
- // since Scala Arrays map to native Java constructs. E.g. "Array[Int]" will map to
- // the Java type "[I".
- case arr if arr <:< localTypeOf[Array[_]] => arrayClassFor(t)
- case cls => ObjectType(getClassFromType(cls))
- }
- val unwrapped = UnwrapOption(optionObjectType, inputObject)
-
- expressions.If(
- IsNull(unwrapped),
- expressions.Literal.create(null, silentSchemaFor(optType).dataType),
- serializerFor(unwrapped, optType, newPath))
- }
+ val className = getClassNameFromType(optType)
+ val newPath = s"""- option value class: "$className"""" +: walkedTypePath
+ val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject)
+ serializerFor(unwrapped, optType, newPath)
// Since List[_] also belongs to localTypeOf[Product], we put this case before
// "case t if definedByConstructorParams(t)" to make sure it will match to the
@@ -704,18 +649,6 @@ object ScalaReflection extends ScalaReflection {
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T])
- /**
- * Returns a catalyst DataType and its nullability for the given Scala Type using reflection.
- *
- * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return
- * `NullType` silently instead.
- */
- def silentSchemaFor(tpe: `Type`): Schema = try {
- schemaFor(tpe)
- } catch {
- case _: UnsupportedOperationException => Schema(NullType, nullable = true)
- }
-
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
tpe match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 5e17f89209..2f2323fa3a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -289,8 +289,8 @@ case class UnwrapOption(
${inputObject.code}
final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
- $javaType ${ev.value} =
- ${ev.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) ${inputObject.value}.get();
+ $javaType ${ev.value} = ${ev.isNull} ?
+ ${ctx.defaultValue(javaType)} : (${ctx.boxedType(javaType)}) ${inputObject.value}.get();
"""
ev.copy(code = code)
}