aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-05-06 01:08:04 +0800
committerCheng Lian <lian@databricks.com>2016-05-06 01:08:04 +0800
commit55cc1c991a9e39efb14177a948b09b7909e53e25 (patch)
tree4ba9dafd35df7e8374688169b72e37a5a51cb196 /sql
parent77361a433adce109c2b752b11dda25b56eca0352 (diff)
downloadspark-55cc1c991a9e39efb14177a948b09b7909e53e25.tar.gz
spark-55cc1c991a9e39efb14177a948b09b7909e53e25.tar.bz2
spark-55cc1c991a9e39efb14177a948b09b7909e53e25.zip
[SPARK-14139][SQL] RowEncoder should preserve schema nullability
## What changes were proposed in this pull request? The problem is: In `RowEncoder`, we use `Invoke` to get the field of an external row, which lose the nullability information. This PR creates a `GetExternalRowField` expression, so that we can preserve the nullability info. TODO: simplify the null handling logic in `RowEncoder`, to remove so many if branches, in follow-up PR. ## How was this patch tested? new tests in `RowEncoderSuite` Note that, This PR takes over https://github.com/apache/spark/pull/11980, with a little simplification, so all credits should go to koertkuipers Author: Wenchen Fan <wenchen@databricks.com> Author: Koert Kuipers <koert@tresata.com> Closes #12364 from cloud-fan/nullable.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala36
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala42
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala18
4 files changed, 88 insertions, 16 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 44e135cbf8..cfde3bfbec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -35,9 +35,8 @@ import org.apache.spark.unsafe.types.UTF8String
object RowEncoder {
def apply(schema: StructType): ExpressionEncoder[Row] = {
val cls = classOf[Row]
- val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
- // We use an If expression to wrap extractorsFor result of StructType
- val serializer = serializerFor(inputObject, schema).asInstanceOf[If].falseValue
+ val inputObject = BoundReference(0, ObjectType(cls), nullable = false)
+ val serializer = serializerFor(inputObject, schema)
val deserializer = deserializerFor(schema)
new ExpressionEncoder[Row](
schema,
@@ -130,21 +129,28 @@ object RowEncoder {
case StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
- val method = if (f.dataType.isInstanceOf[StructType]) {
- "getStruct"
+ val fieldValue = serializerFor(
+ GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType)),
+ f.dataType
+ )
+ if (f.nullable) {
+ If(
+ Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
+ Literal.create(null, f.dataType),
+ fieldValue
+ )
} else {
- "get"
+ fieldValue
}
- If(
- Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
- Literal.create(null, f.dataType),
- serializerFor(
- Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil),
- f.dataType))
}
- If(IsNull(inputObject),
- Literal.create(null, inputType),
- CreateStruct(convertedFields))
+
+ if (inputObject.nullable) {
+ If(IsNull(inputObject),
+ Literal.create(null, inputType),
+ CreateStruct(convertedFields))
+ } else {
+ CreateStruct(convertedFields)
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 523eed825f..dbaff1625e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -688,3 +688,45 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
ev.copy(code = code, isNull = "false", value = childGen.value)
}
}
+
+/**
+ * Returns the value of field at index `index` from the external row `child`.
+ * This class can be viewed as [[GetStructField]] for [[Row]]s instead of [[InternalRow]]s.
+ *
+ * Note that the input row and the field we try to get are both guaranteed to be not null, if they
+ * are null, a runtime exception will be thrown.
+ */
+case class GetExternalRowField(
+ child: Expression,
+ index: Int,
+ dataType: DataType) extends UnaryExpression with NonSQLExpression {
+
+ override def nullable: Boolean = false
+
+ override def eval(input: InternalRow): Any =
+ throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val row = child.genCode(ctx)
+
+ val getField = dataType match {
+ case ObjectType(x) if x == classOf[Row] => s"""${row.value}.getStruct($index)"""
+ case _ => s"""(${ctx.boxedType(dataType)}) ${row.value}.get($index)"""
+ }
+
+ val code = s"""
+ ${row.code}
+
+ if (${row.isNull}) {
+ throw new RuntimeException("The input external row cannot be null.");
+ }
+
+ if (${row.value}.isNullAt($index)) {
+ throw new RuntimeException("The ${index}th field of input row cannot be null.");
+ }
+
+ final ${ctx.javaType(dataType)} ${ev.value} = $getField;
+ """
+ ev.copy(code = code, isNull = "false")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index a8fa372b1e..98be3b053d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -160,6 +160,14 @@ class RowEncoderSuite extends SparkFunSuite {
.compareTo(convertedBack.getDecimal(3)) == 0)
}
+ test("RowEncoder should preserve schema nullability") {
+ val schema = new StructType().add("int", IntegerType, nullable = false)
+ val encoder = RowEncoder(schema)
+ assert(encoder.serializer.length == 1)
+ assert(encoder.serializer.head.dataType == IntegerType)
+ assert(encoder.serializer.head.nullable == false)
+ }
+
private def encodeDecodeTest(schema: StructType): Unit = {
test(s"encode/decode: ${schema.simpleString}") {
val encoder = RowEncoder(schema)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 68a12b0622..3cb4e52c6d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import scala.language.postfixOps
-import org.apache.spark.sql.catalyst.encoders.OuterScopes
+import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
@@ -658,6 +658,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val dataset = Seq(1, 2, 3).toDS()
checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4)
}
+
+ test("runtime null check for RowEncoder") {
+ val schema = new StructType().add("i", IntegerType, nullable = false)
+ val df = sqlContext.range(10).map(l => {
+ if (l % 5 == 0) {
+ Row(null)
+ } else {
+ Row(l)
+ }
+ })(RowEncoder(schema))
+
+ val message = intercept[Exception] {
+ df.collect()
+ }.getMessage
+ assert(message.contains("The 0th field of input row cannot be null"))
+ }
}
case class OtherTuple(_1: String, _2: Int)