aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2015-12-23 10:21:00 +0800
committerCheng Lian <lian@databricks.com>2015-12-23 10:21:00 +0800
commit86761e10e145b6867cbe86b1e924ec237ba408af (patch)
tree715614c3e8df0fe537f64bf278868746ec00003f
parent20591afd790799327f99485c5a969ed7412eca45 (diff)
downloadspark-86761e10e145b6867cbe86b1e924ec237ba408af.tar.gz
spark-86761e10e145b6867cbe86b1e924ec237ba408af.tar.bz2
spark-86761e10e145b6867cbe86b1e924ec237ba408af.zip
[SPARK-12478][SQL] Bugfix: Dataset fields of product types can't be null
When creating extractors for product types (i.e. case classes and tuples), a null check is missing, thus we always assume input product values are non-null. This PR adds a null check in the extractor expression for product types. The null check is stripped off for top level product fields, which are mapped to the outermost `Row`s, since they can't be null. Thanks cloud-fan for helping investigating this issue! Author: Cheng Lian <lian@databricks.com> Closes #10431 from liancheng/spark-12478.top-level-null-field.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala11
2 files changed, 15 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index becd019cae..8a22b37d07 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -380,7 +380,7 @@ object ScalaReflection extends ScalaReflection {
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
extractorFor(inputObject, tpe, walkedTypePath) match {
- case s: CreateNamedStruct => s
+ case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
}
}
@@ -466,14 +466,14 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[Product] =>
val params = getConstructorParameters(t)
-
- CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
+ val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
val clsName = getClassNameFromType(fieldType)
val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
-
expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil
})
+ val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
+ expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 3337996309..7fe66e461c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -546,6 +546,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
"Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
))
}
+
+ test("SPARK-12478: top level null field") {
+ val ds0 = Seq(NestedStruct(null)).toDS()
+ checkAnswer(ds0, NestedStruct(null))
+ checkAnswer(ds0.toDF(), Row(null))
+
+ val ds1 = Seq(DeepNestedStruct(NestedStruct(null))).toDS()
+ checkAnswer(ds1, DeepNestedStruct(NestedStruct(null)))
+ checkAnswer(ds1.toDF(), Row(Row(null)))
+ }
}
case class ClassData(a: String, b: Int)
@@ -553,6 +563,7 @@ case class ClassData2(c: String, d: Int)
case class ClassNullableData(a: String, b: Integer)
case class NestedStruct(f: ClassData)
+case class DeepNestedStruct(f: NestedStruct)
/**
* A class used to test serialization using encoders. This class throws exceptions when using