aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala11
2 files changed, 15 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index becd019cae..8a22b37d07 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -380,7 +380,7 @@ object ScalaReflection extends ScalaReflection {
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
extractorFor(inputObject, tpe, walkedTypePath) match {
- case s: CreateNamedStruct => s
+ case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
}
}
@@ -466,14 +466,14 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[Product] =>
val params = getConstructorParameters(t)
-
- CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
+ val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
val clsName = getClassNameFromType(fieldType)
val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
-
expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil
})
+ val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
+ expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 3337996309..7fe66e461c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -546,6 +546,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
"Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
))
}
+
+ test("SPARK-12478: top level null field") {
+ val ds0 = Seq(NestedStruct(null)).toDS()
+ checkAnswer(ds0, NestedStruct(null))
+ checkAnswer(ds0.toDF(), Row(null))
+
+ val ds1 = Seq(DeepNestedStruct(NestedStruct(null))).toDS()
+ checkAnswer(ds1, DeepNestedStruct(NestedStruct(null)))
+ checkAnswer(ds1.toDF(), Row(Row(null)))
+ }
}
case class ClassData(a: String, b: Int)
@@ -553,6 +563,7 @@ case class ClassData2(c: String, d: Int)
case class ClassNullableData(a: String, b: Integer)
case class NestedStruct(f: ClassData)
+case class DeepNestedStruct(f: NestedStruct)
/**
* A class used to test serialization using encoders. This class throws exceptions when using