aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-06-03 14:28:19 -0700
committerCheng Lian <lian@databricks.com>2016-06-03 14:28:19 -0700
commit11c83f83d5172167cb64513d5311b4178797d40e (patch)
treefa4a41d3d6d244f4887bb0cd14205b15af8d16f3
parent28ad0f7b0dc7bf24fac251c4f131aca74ba1c1d2 (diff)
downloadspark-11c83f83d5172167cb64513d5311b4178797d40e.tar.gz
spark-11c83f83d5172167cb64513d5311b4178797d40e.tar.bz2
spark-11c83f83d5172167cb64513d5311b4178797d40e.zip
[SPARK-15140][SQL] make the semantics of null input object for encoder clear
## What changes were proposed in this pull request? For input object of non-flat type, we can't encode it to row if it's null, as Spark SQL doesn't allow row to be null, only its columns can be null. This PR explicitly add this constraint and throw exception if users break it. ## How was this patch tested? several new tests Author: Wenchen Fan <wenchen@databricks.com> Closes #13469 from cloud-fan/null-object.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala10
5 files changed, 33 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index cc59d06fa3..688082dcce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaRefle
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
-import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance}
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation}
import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
@@ -50,8 +50,15 @@ object ExpressionEncoder {
val cls = mirror.runtimeClass(tpe)
val flat = !ScalaReflection.definedByConstructorParams(tpe)
- val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false)
- val serializer = ScalaReflection.serializerFor[T](inputObject)
+ val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true)
+ val nullSafeInput = if (flat) {
+ inputObject
+ } else {
+ // For input object of non-flat type, we can't encode it to row if it's null, as Spark SQL
+ // doesn't allow top-level row to be null, only its columns can be null.
+ AssertNotNull(inputObject, Seq("top level non-flat input object"))
+ }
+ val serializer = ScalaReflection.serializerFor[T](nullSafeInput)
val deserializer = ScalaReflection.deserializerFor[T]
val schema = ScalaReflection.schemaFor[T] match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 3c6ae1c5cc..6cd7b34ceb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -57,8 +57,8 @@ import org.apache.spark.unsafe.types.UTF8String
object RowEncoder {
def apply(schema: StructType): ExpressionEncoder[Row] = {
val cls = classOf[Row]
- val inputObject = BoundReference(0, ObjectType(cls), nullable = false)
- val serializer = serializerFor(inputObject, schema)
+ val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
+ val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level row object")), schema)
val deserializer = deserializerFor(schema)
new ExpressionEncoder[Row](
schema,
@@ -153,8 +153,7 @@ object RowEncoder {
val fieldValue = serializerFor(
GetExternalRowField(
inputObject, index, field.name, externalDataTypeForInput(field.dataType)),
- field.dataType
- )
+ field.dataType)
val convertedField = if (field.nullable) {
If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index c2e3ab82ff..d4c71bffe8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -519,7 +519,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
val code = s"""
$values = new Object[${children.size}];
$childrenCode
- final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);
+ final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField);
"""
ev.copy(code = code, isNull = "false")
}
@@ -675,7 +675,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
${childGen.code}
if (${childGen.isNull}) {
- throw new RuntimeException(this.$errMsgField);
+ throw new RuntimeException($errMsgField);
}
"""
ev.copy(code = code, isNull = "false", value = childGen.value)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 6f1bc80c1c..16abde064f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -224,6 +224,14 @@ class RowEncoderSuite extends SparkFunSuite {
assert(convertedBack.getSeq(2) == Seq(Seq(Seq(0L, null), null), null))
}
+ test("RowEncoder should throw RuntimeException if input row object is null") {
+ val schema = new StructType().add("int", IntegerType)
+ val encoder = RowEncoder(schema)
+ val e = intercept[RuntimeException](encoder.toRow(null))
+ assert(e.getMessage.contains("Null value appeared in non-nullable field"))
+ assert(e.getMessage.contains("top level row object"))
+ }
+
private def encodeDecodeTest(schema: StructType): Unit = {
test(s"encode/decode: ${schema.simpleString}") {
val encoder = RowEncoder(schema).resolveAndBind()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index d1c232974e..bf2b0a2c7c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -790,6 +790,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(e.getMessage.contains(
"`abstract` is a reserved keyword and cannot be used as field name"))
}
+
+ test("Dataset should support flat input object to be null") {
+ checkDataset(Seq("a", null).toDS(), "a", null)
+ }
+
+ test("Dataset should throw RuntimeException if non-flat input object is null") {
+ val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS())
+ assert(e.getMessage.contains("Null value appeared in non-nullable field"))
+ assert(e.getMessage.contains("top level non-flat input object"))
+ }
}
case class Generic[T](id: T, value: Double)