aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala28
3 files changed, 41 insertions, 33 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 206ae2f0e5..198122759e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -251,19 +251,22 @@ object ScalaReflection extends ScalaReflection {
getPath :: Nil)
case t if t <:< localTypeOf[java.lang.String] =>
- Invoke(getPath, "toString", ObjectType(classOf[String]))
+ Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false)
case t if t <:< localTypeOf[java.math.BigDecimal] =>
- Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+ Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
+ returnNullable = false)
case t if t <:< localTypeOf[BigDecimal] =>
- Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]))
+ Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false)
case t if t <:< localTypeOf[java.math.BigInteger] =>
- Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]))
+ Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]),
+ returnNullable = false)
case t if t <:< localTypeOf[scala.math.BigInt] =>
- Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]))
+ Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]),
+ returnNullable = false)
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
@@ -284,7 +287,7 @@ object ScalaReflection extends ScalaReflection {
val arrayCls = arrayClassFor(elementType)
if (elementNullable) {
- Invoke(arrayData, "array", arrayCls)
+ Invoke(arrayData, "array", arrayCls, returnNullable = false)
} else {
val primitiveMethod = elementType match {
case t if t <:< definitions.IntTpe => "toIntArray"
@@ -297,7 +300,7 @@ object ScalaReflection extends ScalaReflection {
case other => throw new IllegalStateException("expect primitive array element type " +
"but got " + other)
}
- Invoke(arrayData, primitiveMethod, arrayCls)
+ Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false)
}
case t if t <:< localTypeOf[Seq[_]] =>
@@ -330,19 +333,21 @@ object ScalaReflection extends ScalaReflection {
Invoke(
MapObjects(
p => deserializerFor(keyType, Some(p), walkedTypePath),
- Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
+ Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType),
+ returnNullable = false),
schemaFor(keyType).dataType),
"array",
- ObjectType(classOf[Array[Any]]))
+ ObjectType(classOf[Array[Any]]), returnNullable = false)
val valueData =
Invoke(
MapObjects(
p => deserializerFor(valueType, Some(p), walkedTypePath),
- Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
+ Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType),
+ returnNullable = false),
schemaFor(valueType).dataType),
"array",
- ObjectType(classOf[Array[Any]]))
+ ObjectType(classOf[Array[Any]]), returnNullable = false)
StaticInvoke(
ArrayBasedMapData.getClass,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index e95e97b9dc..0f8282d3b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -89,7 +89,7 @@ object RowEncoder {
udtClass,
Nil,
dataType = ObjectType(udtClass), false)
- Invoke(obj, "serialize", udt, inputObject :: Nil)
+ Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false)
case TimestampType =>
StaticInvoke(
@@ -136,16 +136,18 @@ object RowEncoder {
case t @ MapType(kt, vt, valueNullable) =>
val keys =
Invoke(
- Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
+ Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]),
+ returnNullable = false),
"toSeq",
- ObjectType(classOf[scala.collection.Seq[_]]))
+ ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
val convertedKeys = serializerFor(keys, ArrayType(kt, false))
val values =
Invoke(
- Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
+ Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]]),
+ returnNullable = false),
"toSeq",
- ObjectType(classOf[scala.collection.Seq[_]]))
+ ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
val convertedValues = serializerFor(values, ArrayType(vt, valueNullable))
NewInstance(
@@ -262,17 +264,18 @@ object RowEncoder {
input :: Nil)
case _: DecimalType =>
- Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+ Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
+ returnNullable = false)
case StringType =>
- Invoke(input, "toString", ObjectType(classOf[String]))
+ Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false)
case ArrayType(et, nullable) =>
val arrayData =
Invoke(
MapObjects(deserializerFor(_), input, et),
"array",
- ObjectType(classOf[Array[_]]))
+ ObjectType(classOf[Array[_]]), returnNullable = false)
StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 53842ef348..6d94764f1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -225,25 +225,26 @@ case class Invoke(
getFuncResult(ev.value, s"${obj.value}.$functionName($argString)")
} else {
val funcResult = ctx.freshName("funcResult")
+ // If the function can return null, we do an extra check to make sure our null bit is still
+ // set correctly.
+ val assignResult = if (!returnNullable) {
+ s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;"
+ } else {
+ s"""
+ if ($funcResult != null) {
+ ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;
+ } else {
+ ${ev.isNull} = true;
+ }
+ """
+ }
s"""
Object $funcResult = null;
${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")}
- if ($funcResult == null) {
- ${ev.isNull} = true;
- } else {
- ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;
- }
+ $assignResult
"""
}
- // If the function can return null, we do an extra check to make sure our null bit is still set
- // correctly.
- val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
- s"${ev.isNull} = ${ev.value} == null;"
- } else {
- ""
- }
-
val code = s"""
${obj.code}
boolean ${ev.isNull} = true;
@@ -254,7 +255,6 @@ case class Invoke(
if (!${ev.isNull}) {
$evaluate
}
- $postNullCheck
}
"""
ev.copy(code = code)