diff options
author | Takuya UESHIN <ueshin@happy-camper.st> | 2016-05-20 09:34:55 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-05-20 09:34:55 -0700 |
commit | d2e1aa97ef5bf7cfffc777a178f44ab8fa775266 (patch) | |
tree | e4ec3fc987637f5745e4997addf3e95a516da7a3 | |
parent | 9a9c6f5c22248c5a891e9d3b788ff12b6b4718b2 (diff) | |
download | spark-d2e1aa97ef5bf7cfffc777a178f44ab8fa775266.tar.gz spark-d2e1aa97ef5bf7cfffc777a178f44ab8fa775266.tar.bz2 spark-d2e1aa97ef5bf7cfffc777a178f44ab8fa775266.zip |
[SPARK-15308][SQL] RowEncoder should preserve nested column name.
## What changes were proposed in this pull request?
The following code generates wrong schema:
```
val schema = new StructType().add(
"struct",
new StructType()
.add("i", IntegerType, nullable = false)
.add(
"s",
new StructType().add("int", IntegerType, nullable = false),
nullable = false),
nullable = false)
val ds = sqlContext.range(10).map(l => Row(l, Row(l)))(RowEncoder(schema))
ds.printSchema()
```
This should print as follows:
```
root
|-- struct: struct (nullable = false)
| |-- i: integer (nullable = false)
| |-- s: struct (nullable = false)
| | |-- int: integer (nullable = false)
```
but the result is:
```
root
|-- struct: struct (nullable = false)
| |-- col1: integer (nullable = false)
| |-- col2: struct (nullable = false)
| | |-- col1: integer (nullable = false)
```
This PR fixes `RowEncoder` to preserve nested column name.
## How was this patch tested?
Existing tests and I added a test to check if `RowEncoder` preserves nested column name.
Author: Takuya UESHIN <ueshin@happy-camper.st>
Closes #13090 from ueshin/issues/SPARK-15308.
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala | 22 | ||||
-rw-r--r-- | sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala | 22 |
2 files changed, 34 insertions, 10 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 71b39c54fa..2f8ba33f35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -62,7 +62,7 @@ object RowEncoder { new ExpressionEncoder[Row]( schema, flat = false, - serializer.asInstanceOf[CreateStruct].children, + serializer.asInstanceOf[CreateNamedStruct].flatten, deserializer, ClassTag(cls)) } @@ -148,28 +148,30 @@ object RowEncoder { dataType = t) case StructType(fields) => - val convertedFields = fields.zipWithIndex.map { case (f, i) => + val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) => val fieldValue = serializerFor( - GetExternalRowField(inputObject, i, f.name, externalDataTypeForInput(f.dataType)), - f.dataType + GetExternalRowField( + inputObject, index, field.name, externalDataTypeForInput(field.dataType)), + field.dataType ) - if (f.nullable) { + val convertedField = if (field.nullable) { If( - Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), - Literal.create(null, f.dataType), + Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil), + Literal.create(null, field.dataType), fieldValue ) } else { fieldValue } - } + Literal(field.name) :: convertedField :: Nil + }) if (inputObject.nullable) { If(IsNull(inputObject), Literal.create(null, inputType), - CreateStruct(convertedFields)) + nonNullOutput) } else { - CreateStruct(convertedFields) + nonNullOutput } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 7bb006c173..39fcc7225b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -185,6 +185,28 @@ class RowEncoderSuite extends SparkFunSuite { assert(encoder.serializer.head.nullable == false) } + test("RowEncoder should preserve nested column name") { + val schema = new StructType().add( + "struct", + new StructType() + .add("i", IntegerType, nullable = false) + .add( + "s", + new StructType().add("int", IntegerType, nullable = false), + nullable = false), + nullable = false) + val encoder = RowEncoder(schema) + assert(encoder.serializer.length == 1) + assert(encoder.serializer.head.dataType == + new StructType() + .add("i", IntegerType, nullable = false) + .add( + "s", + new StructType().add("int", IntegerType, nullable = false), + nullable = false)) + assert(encoder.serializer.head.nullable == false) + } + test("RowEncoder should support array as the external type for ArrayType") { val schema = new StructType() .add("array", ArrayType(IntegerType)) |