aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorTakuya UESHIN <ueshin@happy-camper.st>2016-05-20 09:34:55 -0700
committerReynold Xin <rxin@databricks.com>2016-05-20 09:34:55 -0700
commitd2e1aa97ef5bf7cfffc777a178f44ab8fa775266 (patch)
treee4ec3fc987637f5745e4997addf3e95a516da7a3 /sql
parent9a9c6f5c22248c5a891e9d3b788ff12b6b4718b2 (diff)
downloadspark-d2e1aa97ef5bf7cfffc777a178f44ab8fa775266.tar.gz
spark-d2e1aa97ef5bf7cfffc777a178f44ab8fa775266.tar.bz2
spark-d2e1aa97ef5bf7cfffc777a178f44ab8fa775266.zip
[SPARK-15308][SQL] RowEncoder should preserve nested column name.
## What changes were proposed in this pull request? The following code generates wrong schema: ``` val schema = new StructType().add( "struct", new StructType() .add("i", IntegerType, nullable = false) .add( "s", new StructType().add("int", IntegerType, nullable = false), nullable = false), nullable = false) val ds = sqlContext.range(10).map(l => Row(l, Row(l)))(RowEncoder(schema)) ds.printSchema() ``` This should print as follows: ``` root |-- struct: struct (nullable = false) | |-- i: integer (nullable = false) | |-- s: struct (nullable = false) | | |-- int: integer (nullable = false) ``` but the result is: ``` root |-- struct: struct (nullable = false) | |-- col1: integer (nullable = false) | |-- col2: struct (nullable = false) | | |-- col1: integer (nullable = false) ``` This PR fixes `RowEncoder` to preserve nested column name. ## How was this patch tested? Existing tests and I added a test to check if `RowEncoder` preserves nested column name. Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #13090 from ueshin/issues/SPARK-15308.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala22
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala22
2 files changed, 34 insertions, 10 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 71b39c54fa..2f8ba33f35 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -62,7 +62,7 @@ object RowEncoder {
new ExpressionEncoder[Row](
schema,
flat = false,
- serializer.asInstanceOf[CreateStruct].children,
+ serializer.asInstanceOf[CreateNamedStruct].flatten,
deserializer,
ClassTag(cls))
}
@@ -148,28 +148,30 @@ object RowEncoder {
dataType = t)
case StructType(fields) =>
- val convertedFields = fields.zipWithIndex.map { case (f, i) =>
+ val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) =>
val fieldValue = serializerFor(
- GetExternalRowField(inputObject, i, f.name, externalDataTypeForInput(f.dataType)),
- f.dataType
+ GetExternalRowField(
+ inputObject, index, field.name, externalDataTypeForInput(field.dataType)),
+ field.dataType
)
- if (f.nullable) {
+ val convertedField = if (field.nullable) {
If(
- Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
- Literal.create(null, f.dataType),
+ Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil),
+ Literal.create(null, field.dataType),
fieldValue
)
} else {
fieldValue
}
- }
+ Literal(field.name) :: convertedField :: Nil
+ })
if (inputObject.nullable) {
If(IsNull(inputObject),
Literal.create(null, inputType),
- CreateStruct(convertedFields))
+ nonNullOutput)
} else {
- CreateStruct(convertedFields)
+ nonNullOutput
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 7bb006c173..39fcc7225b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -185,6 +185,28 @@ class RowEncoderSuite extends SparkFunSuite {
assert(encoder.serializer.head.nullable == false)
}
+ test("RowEncoder should preserve nested column name") {
+ val schema = new StructType().add(
+ "struct",
+ new StructType()
+ .add("i", IntegerType, nullable = false)
+ .add(
+ "s",
+ new StructType().add("int", IntegerType, nullable = false),
+ nullable = false),
+ nullable = false)
+ val encoder = RowEncoder(schema)
+ assert(encoder.serializer.length == 1)
+ assert(encoder.serializer.head.dataType ==
+ new StructType()
+ .add("i", IntegerType, nullable = false)
+ .add(
+ "s",
+ new StructType().add("int", IntegerType, nullable = false),
+ nullable = false))
+ assert(encoder.serializer.head.nullable == false)
+ }
+
test("RowEncoder should support array as the external type for ArrayType") {
val schema = new StructType()
.add("array", ArrayType(IntegerType))