From 86761e10e145b6867cbe86b1e924ec237ba408af Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 23 Dec 2015 10:21:00 +0800 Subject: [SPARK-12478][SQL] Bugfix: Dataset fields of product types can't be null When creating extractors for product types (i.e. case classes and tuples), a null check is missing, thus we always assume input product values are non-null. This PR adds a null check in the extractor expression for product types. The null check is stripped off for top level product fields, which are mapped to the outermost `Row`s, since they can't be null. Thanks cloud-fan for helping investigating this issue! Author: Cheng Lian Closes #10431 from liancheng/spark-12478.top-level-null-field. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 8 ++++---- .../src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 11 +++++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index becd019cae..8a22b37d07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -380,7 +380,7 @@ object ScalaReflection extends ScalaReflection { val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil extractorFor(inputObject, tpe, walkedTypePath) match { - case s: CreateNamedStruct => s + case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } } @@ -466,14 +466,14 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Product] => val params = getConstructorParameters(t) - - CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => + val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) val clsName = getClassNameFromType(fieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil }) + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3337996309..7fe66e461c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -546,6 +546,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int." )) } + + test("SPARK-12478: top level null field") { + val ds0 = Seq(NestedStruct(null)).toDS() + checkAnswer(ds0, NestedStruct(null)) + checkAnswer(ds0.toDF(), Row(null)) + + val ds1 = Seq(DeepNestedStruct(NestedStruct(null))).toDS() + checkAnswer(ds1, DeepNestedStruct(NestedStruct(null))) + checkAnswer(ds1.toDF(), Row(Row(null))) + } } case class ClassData(a: String, b: Int) @@ -553,6 +563,7 @@ case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) case class NestedStruct(f: ClassData) +case class DeepNestedStruct(f: NestedStruct) /** * A class used to test serialization using encoders. This class throws exceptions when using -- cgit v1.2.3