aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-06-22 18:32:14 +0800
committerCheng Lian <lian@databricks.com>2016-06-22 18:32:14 +0800
commit01277d4b259dcf9cad25eece1377162b7a8c946d (patch)
tree374c1308da7c6b97593f9af9b9f0605b7747265c /sql
parent39ad53f7ffddae5ba0ff0a76089ba671b14c44c8 (diff)
downloadspark-01277d4b259dcf9cad25eece1377162b7a8c946d.tar.gz
spark-01277d4b259dcf9cad25eece1377162b7a8c946d.tar.bz2
spark-01277d4b259dcf9cad25eece1377162b7a8c946d.zip
[SPARK-16097][SQL] Encoders.tuple should handle null object correctly
## What changes were proposed in this pull request? Although the top level input object can not be null, but when we use `Encoders.tuple` to combine 2 encoders, their input objects are not top level anymore and can be null. We should handle this case. ## How was this patch tested? new test in DatasetSuite Author: Wenchen Fan <wenchen@databricks.com> Closes #13807 from cloud-fan/bug.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala48
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala7
2 files changed, 42 insertions, 13 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 0023ce64aa..1fac26c438 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation}
-import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
+import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType}
import org.apache.spark.util.Utils
/**
@@ -110,16 +110,34 @@ object ExpressionEncoder {
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
- val serializer = encoders.map {
- case e if e.flat => e.serializer.head
- case other => CreateStruct(other.serializer)
- }.zipWithIndex.map { case (expr, index) =>
- expr.transformUp {
- case BoundReference(0, t, _) =>
- Invoke(
- BoundReference(0, ObjectType(cls), nullable = true),
- s"_${index + 1}",
- t)
+ val serializer = encoders.zipWithIndex.map { case (enc, index) =>
+ val originalInputObject = enc.serializer.head.collect { case b: BoundReference => b }.head
+ val newInputObject = Invoke(
+ BoundReference(0, ObjectType(cls), nullable = true),
+ s"_${index + 1}",
+ originalInputObject.dataType)
+
+ val newSerializer = enc.serializer.map(_.transformUp {
+ case b: BoundReference if b == originalInputObject => newInputObject
+ })
+
+ if (enc.flat) {
+ newSerializer.head
+ } else {
+ // For non-flat encoder, the input object is not top level anymore after being combined to
+ // a tuple encoder, thus it can be null and we should wrap the `CreateStruct` with `If` and
+ // null check to handle null case correctly.
+ // e.g. for Encoder[(Int, String)], the serializer expressions will create 2 columns, and is
+ // not able to handle the case when the input tuple is null. This is not a problem as there
+ // is a check to make sure the input object won't be null. However, if this encoder is used
+ // to create a bigger tuple encoder, the original input object becomes a filed of the new
+ // input tuple and can be null. So instead of creating a struct directly here, we should add
+ // a null/None check and return a null struct if the null/None check fails.
+ val struct = CreateStruct(newSerializer)
+ val nullCheck = Or(
+ IsNull(newInputObject),
+ Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil))
+ If(nullCheck, Literal.create(null, struct.dataType), struct)
}
}
@@ -203,8 +221,12 @@ case class ExpressionEncoder[T](
// (intermediate value is not an attribute). We assume that all serializer expressions use a same
// `BoundReference` to refer to the object, and throw exception if they don't.
assert(serializer.forall(_.references.isEmpty), "serializer cannot reference to any attributes.")
- assert(serializer.flatMap(_.collect { case b: BoundReference => b}).distinct.length <= 1,
- "all serializer expressions must use the same BoundReference.")
+ assert(serializer.flatMap { ser =>
+ val boundRefs = ser.collect { case b: BoundReference => b }
+ assert(boundRefs.nonEmpty,
+ "each serializer expression should contains at least one `BoundReference`")
+ boundRefs
+ }.distinct.length <= 1, "all serializer expressions must use the same BoundReference.")
/**
* Returns a new copy of this encoder, where the `deserializer` is resolved and bound to the
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index f02a3141a0..bd8479b2d3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -830,6 +830,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
ds.dropDuplicates("_1", "_2"),
("a", 1), ("a", 2), ("b", 1))
}
+
+ test("SPARK-16097: Encoders.tuple should handle null object correctly") {
+ val enc = Encoders.tuple(Encoders.tuple(Encoders.STRING, Encoders.STRING), Encoders.STRING)
+ val data = Seq((("a", "b"), "c"), (null, "d"))
+ val ds = spark.createDataset(data)(enc)
+ checkDataset(ds, (("a", "b"), "c"), (null, "d"))
+ }
}
case class Generic[T](id: T, value: Double)