From d202ad2fc24b54de38ad7bfb646bf7703069e9f7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 5 Jan 2016 12:33:21 -0800 Subject: [SPARK-12439][SQL] Fix toCatalystArray and MapObjects JIRA: https://issues.apache.org/jira/browse/SPARK-12439 In toCatalystArray, we should look at the data type returned by dataTypeFor instead of silentSchemaFor, to determine if the element is native type. An obvious problem is when the element is Option[Int] class, catalsilentSchemaFor will return Int, then we will wrongly recognize the element is native type. There is another problem when using Option as array element. When we encode data like Seq(Some(1), Some(2), None) with encoder, we will use MapObjects to construct an array for it later. But in MapObjects, we don't check if the return value of lambdaFunction is null or not. That causes a bug that the decoded data for Seq(Some(1), Some(2), None) would be Seq(1, 2, -1), instead of Seq(1, 2, null). Author: Liang-Chi Hsieh Closes #10391 from viirya/fix-catalystarray. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 2 +- .../org/apache/spark/sql/catalyst/encoders/RowEncoder.scala | 11 ++++++++--- .../org/apache/spark/sql/catalyst/expressions/objects.scala | 4 ++-- .../spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala | 3 +++ 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c6aa60b0b4..b0efdf3ef4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -405,7 +405,7 @@ object ScalaReflection extends ScalaReflection { def toCatalystArray(input: Expression, elementType: `Type`): Expression = { val externalDataType = dataTypeFor(elementType) val Schema(catalystType, nullable) = silentSchemaFor(elementType) - if (isNativeType(catalystType)) { + if (isNativeType(externalDataType)) { NewInstance( classOf[GenericArrayData], input :: Nil, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 6f3d5ba84c..3903086a4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -35,7 +35,8 @@ object RowEncoder { def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val extractExpressions = extractorsFor(inputObject, schema) + // We use an If expression to wrap extractorsFor result of StructType + val extractExpressions = extractorsFor(inputObject, schema).asInstanceOf[If].falseValue val constructExpression = constructorFor(schema) new ExpressionEncoder[Row]( schema, @@ -129,7 +130,9 @@ object RowEncoder { Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil), f.dataType)) } - CreateStruct(convertedFields) + If(IsNull(inputObject), + Literal.create(null, inputType), + CreateStruct(convertedFields)) } private def externalDataTypeFor(dt: DataType): DataType = dt match { @@ -220,6 +223,8 @@ object RowEncoder { Literal.create(null, externalDataTypeFor(f.dataType)), constructorFor(GetStructField(input, i))) } - CreateExternalRow(convertedFields) + If(IsNull(input), + Literal.create(null, externalDataTypeFor(input.dataType)), + CreateExternalRow(convertedFields)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index fb404c12d5..c0c3e6e891 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -456,10 +456,10 @@ case class MapObjects( ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; $loopNullCheck - if (${loopVar.isNull}) { + ${genFunction.code} + if (${genFunction.isNull}) { $convertedArray[$loopIndex] = null; } else { - ${genFunction.code} $convertedArray[$loopIndex] = ${genFunction.value}; } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 6453f1c191..98f29e53df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -160,6 +160,9 @@ class ExpressionEncoderSuite extends SparkFunSuite { productTest(OptionalData(None, None, None, None, None, None, None, None)) + encodeDecodeTest(Seq(Some(1), None), "Option in array") + encodeDecodeTest(Map(1 -> Some(10L), 2 -> Some(20L), 3 -> None), "Option in map") + productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) productTest(BoxedData(null, null, null, null, null, null, null)) -- cgit v1.2.3