diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-05-06 01:08:04 +0800 |
---|---|---|
committer | Cheng Lian <lian@databricks.com> | 2016-05-06 01:08:04 +0800 |
commit | 55cc1c991a9e39efb14177a948b09b7909e53e25 (patch) | |
tree | 4ba9dafd35df7e8374688169b72e37a5a51cb196 /sql/catalyst | |
parent | 77361a433adce109c2b752b11dda25b56eca0352 (diff) | |
download | spark-55cc1c991a9e39efb14177a948b09b7909e53e25.tar.gz spark-55cc1c991a9e39efb14177a948b09b7909e53e25.tar.bz2 spark-55cc1c991a9e39efb14177a948b09b7909e53e25.zip |
[SPARK-14139][SQL] RowEncoder should preserve schema nullability
## What changes were proposed in this pull request?
The problem is: In `RowEncoder`, we use `Invoke` to get the field of an external row, which lose the nullability information. This PR creates a `GetExternalRowField` expression, so that we can preserve the nullability info.
TODO: simplify the null handling logic in `RowEncoder`, to remove so many if branches, in follow-up PR.
## How was this patch tested?
new tests in `RowEncoderSuite`
Note that, This PR takes over https://github.com/apache/spark/pull/11980, with a little simplification, so all credits should go to koertkuipers
Author: Wenchen Fan <wenchen@databricks.com>
Author: Koert Kuipers <koert@tresata.com>
Closes #12364 from cloud-fan/nullable.
Diffstat (limited to 'sql/catalyst')
3 files changed, 71 insertions, 15 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 44e135cbf8..cfde3bfbec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -35,9 +35,8 @@ import org.apache.spark.unsafe.types.UTF8String object RowEncoder { def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - // We use an If expression to wrap extractorsFor result of StructType - val serializer = serializerFor(inputObject, schema).asInstanceOf[If].falseValue + val inputObject = BoundReference(0, ObjectType(cls), nullable = false) + val serializer = serializerFor(inputObject, schema) val deserializer = deserializerFor(schema) new ExpressionEncoder[Row]( schema, @@ -130,21 +129,28 @@ object RowEncoder { case StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => - val method = if (f.dataType.isInstanceOf[StructType]) { - "getStruct" + val fieldValue = serializerFor( + GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType)), + f.dataType + ) + if (f.nullable) { + If( + Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), + Literal.create(null, f.dataType), + fieldValue + ) } else { - "get" + fieldValue } - If( - Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), - Literal.create(null, f.dataType), - serializerFor( - Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil), - f.dataType)) } - If(IsNull(inputObject), - Literal.create(null, inputType), - CreateStruct(convertedFields)) + + if (inputObject.nullable) { + If(IsNull(inputObject), + Literal.create(null, inputType), + CreateStruct(convertedFields)) + } else { + CreateStruct(convertedFields) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 523eed825f..dbaff1625e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -688,3 +688,45 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) ev.copy(code = code, isNull = "false", value = childGen.value) } } + +/** + * Returns the value of field at index `index` from the external row `child`. + * This class can be viewed as [[GetStructField]] for [[Row]]s instead of [[InternalRow]]s. + * + * Note that the input row and the field we try to get are both guaranteed to be not null, if they + * are null, a runtime exception will be thrown. + */ +case class GetExternalRowField( + child: Expression, + index: Int, + dataType: DataType) extends UnaryExpression with NonSQLExpression { + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val row = child.genCode(ctx) + + val getField = dataType match { + case ObjectType(x) if x == classOf[Row] => s"""${row.value}.getStruct($index)""" + case _ => s"""(${ctx.boxedType(dataType)}) ${row.value}.get($index)""" + } + + val code = s""" + ${row.code} + + if (${row.isNull}) { + throw new RuntimeException("The input external row cannot be null."); + } + + if (${row.value}.isNullAt($index)) { + throw new RuntimeException("The ${index}th field of input row cannot be null."); + } + + final ${ctx.javaType(dataType)} ${ev.value} = $getField; + """ + ev.copy(code = code, isNull = "false") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index a8fa372b1e..98be3b053d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -160,6 +160,14 @@ class RowEncoderSuite extends SparkFunSuite { .compareTo(convertedBack.getDecimal(3)) == 0) } + test("RowEncoder should preserve schema nullability") { + val schema = new StructType().add("int", IntegerType, nullable = false) + val encoder = RowEncoder(schema) + assert(encoder.serializer.length == 1) + assert(encoder.serializer.head.dataType == IntegerType) + assert(encoder.serializer.head.nullable == false) + } + private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { val encoder = RowEncoder(schema) |