diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-06-22 18:32:14 +0800 |
---|---|---|
committer | Cheng Lian <lian@databricks.com> | 2016-06-22 18:32:14 +0800 |
commit | 01277d4b259dcf9cad25eece1377162b7a8c946d (patch) | |
tree | 374c1308da7c6b97593f9af9b9f0605b7747265c /sql/catalyst | |
parent | 39ad53f7ffddae5ba0ff0a76089ba671b14c44c8 (diff) | |
download | spark-01277d4b259dcf9cad25eece1377162b7a8c946d.tar.gz spark-01277d4b259dcf9cad25eece1377162b7a8c946d.tar.bz2 spark-01277d4b259dcf9cad25eece1377162b7a8c946d.zip |
[SPARK-16097][SQL] Encoders.tuple should handle null object correctly
## What changes were proposed in this pull request?
Although the top level input object can not be null, but when we use `Encoders.tuple` to combine 2 encoders, their input objects are not top level anymore and can be null. We should handle this case.
## How was this patch tested?
new test in DatasetSuite
Author: Wenchen Fan <wenchen@databricks.com>
Closes #13807 from cloud-fan/bug.
Diffstat (limited to 'sql/catalyst')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala | 48 |
1 files changed, 35 insertions, 13 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 0023ce64aa..1fac26c438 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} -import org.apache.spark.sql.types.{ObjectType, StructField, StructType} +import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType} import org.apache.spark.util.Utils /** @@ -110,16 +110,34 @@ object ExpressionEncoder { val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - val serializer = encoders.map { - case e if e.flat => e.serializer.head - case other => CreateStruct(other.serializer) - }.zipWithIndex.map { case (expr, index) => - expr.transformUp { - case BoundReference(0, t, _) => - Invoke( - BoundReference(0, ObjectType(cls), nullable = true), - s"_${index + 1}", - t) + val serializer = encoders.zipWithIndex.map { case (enc, index) => + val originalInputObject = enc.serializer.head.collect { case b: BoundReference => b }.head + val newInputObject = Invoke( + BoundReference(0, ObjectType(cls), nullable = true), + s"_${index + 1}", + originalInputObject.dataType) + + val newSerializer = enc.serializer.map(_.transformUp { + case b: BoundReference if b == originalInputObject => newInputObject + }) + + if (enc.flat) { + newSerializer.head + } else { + // For non-flat encoder, the input object is not top level anymore after being combined to + // a tuple encoder, thus it can be null and we should wrap the `CreateStruct` with `If` and + // null check to handle null case correctly. + // e.g. for Encoder[(Int, String)], the serializer expressions will create 2 columns, and is + // not able to handle the case when the input tuple is null. This is not a problem as there + // is a check to make sure the input object won't be null. However, if this encoder is used + // to create a bigger tuple encoder, the original input object becomes a filed of the new + // input tuple and can be null. So instead of creating a struct directly here, we should add + // a null/None check and return a null struct if the null/None check fails. + val struct = CreateStruct(newSerializer) + val nullCheck = Or( + IsNull(newInputObject), + Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil)) + If(nullCheck, Literal.create(null, struct.dataType), struct) } } @@ -203,8 +221,12 @@ case class ExpressionEncoder[T]( // (intermediate value is not an attribute). We assume that all serializer expressions use a same // `BoundReference` to refer to the object, and throw exception if they don't. assert(serializer.forall(_.references.isEmpty), "serializer cannot reference to any attributes.") - assert(serializer.flatMap(_.collect { case b: BoundReference => b}).distinct.length <= 1, - "all serializer expressions must use the same BoundReference.") + assert(serializer.flatMap { ser => + val boundRefs = ser.collect { case b: BoundReference => b } + assert(boundRefs.nonEmpty, + "each serializer expression should contains at least one `BoundReference`") + boundRefs + }.distinct.length <= 1, "all serializer expressions must use the same BoundReference.") /** * Returns a new copy of this encoder, where the `deserializer` is resolved and bound to the |