diff options
author | Takuya UESHIN <ueshin@happy-camper.st> | 2016-05-20 09:38:34 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-05-20 09:38:34 -0700 |
commit | 2cbe96e64d5f84474b2eb59bed9ce3ab543d8aff (patch) | |
tree | b4d345f6f6693bddd2291f23b8cee47f1bbb111a /sql/catalyst | |
parent | e8adc552df80af413e1d31b020489612d13a8770 (diff) | |
download | spark-2cbe96e64d5f84474b2eb59bed9ce3ab543d8aff.tar.gz spark-2cbe96e64d5f84474b2eb59bed9ce3ab543d8aff.tar.bz2 spark-2cbe96e64d5f84474b2eb59bed9ce3ab543d8aff.zip |
[SPARK-15400][SQL] CreateNamedStruct and CreateNamedStructUnsafe should preserve metadata of value expressions if it is NamedExpression.
## What changes were proposed in this pull request?
`CreateNamedStruct` and `CreateNamedStructUnsafe` should preserve metadata of value expressions if it is `NamedExpression` like `CreateStruct` or `CreateStructUnsafe` are doing.
## How was this patch tested?
Existing tests.
Author: Takuya UESHIN <ueshin@happy-camper.st>
Closes #13193 from ueshin/issues/SPARK-15400.
Diffstat (limited to 'sql/catalyst')
2 files changed, 30 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index d986d9dca6..f60d278c54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -252,9 +252,13 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { private lazy val names = nameExprs.map(_.eval(EmptyRow)) override lazy val dataType: StructType = { - val fields = names.zip(valExprs).map { case (name, valExpr) => - StructField(name.asInstanceOf[UTF8String].toString, - valExpr.dataType, valExpr.nullable, Metadata.empty) + val fields = names.zip(valExprs).map { + case (name, valExpr: NamedExpression) => + StructField(name.asInstanceOf[UTF8String].toString, + valExpr.dataType, valExpr.nullable, valExpr.metadata) + case (name, valExpr) => + StructField(name.asInstanceOf[UTF8String].toString, + valExpr.dataType, valExpr.nullable, Metadata.empty) } StructType(fields) } @@ -365,8 +369,11 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) override lazy val dataType: StructType = { - val fields = names.zip(valExprs).map { case (name, valExpr) => - StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + val fields = names.zip(valExprs).map { + case (name, valExpr: NamedExpression) => + StructField(name, valExpr.dataType, valExpr.nullable, valExpr.metadata) + case (name, valExpr) => + StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) } StructType(fields) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 7c009a7360..ec7be4d4b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -228,4 +228,22 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkErrorMessage(structType, IntegerType, "Field name should be String Literal") checkErrorMessage(otherType, StringType, "Can't extract value from") } + + test("ensure to preserve metadata") { + val metadata = new MetadataBuilder() + .putString("key", "value") + .build() + + def checkMetadata(expr: Expression): Unit = { + assert(expr.dataType.asInstanceOf[StructType]("a").metadata === metadata) + assert(expr.dataType.asInstanceOf[StructType]("b").metadata === Metadata.empty) + } + + val a = AttributeReference("a", IntegerType, metadata = metadata)() + val b = AttributeReference("b", IntegerType)() + checkMetadata(CreateStruct(Seq(a, b))) + checkMetadata(CreateNamedStruct(Seq("a", a, "b", b))) + checkMetadata(CreateStructUnsafe(Seq(a, b))) + checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b))) + } } |