aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2016-01-05 12:33:21 -0800
committerMichael Armbrust <michael@databricks.com>2016-01-05 12:33:21 -0800
commitd202ad2fc24b54de38ad7bfb646bf7703069e9f7 (patch)
treedd8768195bd3d5a699c597b6ab0d29c0c41dea66
parent8ce645d4eeda203cf5e100c4bdba2d71edd44e6a (diff)
downloadspark-d202ad2fc24b54de38ad7bfb646bf7703069e9f7.tar.gz
spark-d202ad2fc24b54de38ad7bfb646bf7703069e9f7.tar.bz2
spark-d202ad2fc24b54de38ad7bfb646bf7703069e9f7.zip
[SPARK-12439][SQL] Fix toCatalystArray and MapObjects
JIRA: https://issues.apache.org/jira/browse/SPARK-12439 In toCatalystArray, we should look at the data type returned by dataTypeFor instead of silentSchemaFor, to determine if the element is native type. An obvious problem is when the element is Option[Int] class, catalsilentSchemaFor will return Int, then we will wrongly recognize the element is native type. There is another problem when using Option as array element. When we encode data like Seq(Some(1), Some(2), None) with encoder, we will use MapObjects to construct an array for it later. But in MapObjects, we don't check if the return value of lambdaFunction is null or not. That causes a bug that the decoded data for Seq(Some(1), Some(2), None) would be Seq(1, 2, -1), instead of Seq(1, 2, null). Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #10391 from viirya/fix-catalystarray.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala3
4 files changed, 14 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index c6aa60b0b4..b0efdf3ef4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -405,7 +405,7 @@ object ScalaReflection extends ScalaReflection {
def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
val externalDataType = dataTypeFor(elementType)
val Schema(catalystType, nullable) = silentSchemaFor(elementType)
- if (isNativeType(catalystType)) {
+ if (isNativeType(externalDataType)) {
NewInstance(
classOf[GenericArrayData],
input :: Nil,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 6f3d5ba84c..3903086a4c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -35,7 +35,8 @@ object RowEncoder {
def apply(schema: StructType): ExpressionEncoder[Row] = {
val cls = classOf[Row]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
- val extractExpressions = extractorsFor(inputObject, schema)
+ // We use an If expression to wrap extractorsFor result of StructType
+ val extractExpressions = extractorsFor(inputObject, schema).asInstanceOf[If].falseValue
val constructExpression = constructorFor(schema)
new ExpressionEncoder[Row](
schema,
@@ -129,7 +130,9 @@ object RowEncoder {
Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil),
f.dataType))
}
- CreateStruct(convertedFields)
+ If(IsNull(inputObject),
+ Literal.create(null, inputType),
+ CreateStruct(convertedFields))
}
private def externalDataTypeFor(dt: DataType): DataType = dt match {
@@ -220,6 +223,8 @@ object RowEncoder {
Literal.create(null, externalDataTypeFor(f.dataType)),
constructorFor(GetStructField(input, i)))
}
- CreateExternalRow(convertedFields)
+ If(IsNull(input),
+ Literal.create(null, externalDataTypeFor(input.dataType)),
+ CreateExternalRow(convertedFields))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index fb404c12d5..c0c3e6e891 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -456,10 +456,10 @@ case class MapObjects(
($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
$loopNullCheck
- if (${loopVar.isNull}) {
+ ${genFunction.code}
+ if (${genFunction.isNull}) {
$convertedArray[$loopIndex] = null;
} else {
- ${genFunction.code}
$convertedArray[$loopIndex] = ${genFunction.value};
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 6453f1c191..98f29e53df 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -160,6 +160,9 @@ class ExpressionEncoderSuite extends SparkFunSuite {
productTest(OptionalData(None, None, None, None, None, None, None, None))
+ encodeDecodeTest(Seq(Some(1), None), "Option in array")
+ encodeDecodeTest(Map(1 -> Some(10L), 2 -> Some(20L), 3 -> None), "Option in map")
+
productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
productTest(BoxedData(null, null, null, null, null, null, null))