aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala17
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala18
2 files changed, 30 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index d986d9dca6..f60d278c54 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -252,9 +252,13 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
private lazy val names = nameExprs.map(_.eval(EmptyRow))
override lazy val dataType: StructType = {
- val fields = names.zip(valExprs).map { case (name, valExpr) =>
- StructField(name.asInstanceOf[UTF8String].toString,
- valExpr.dataType, valExpr.nullable, Metadata.empty)
+ val fields = names.zip(valExprs).map {
+ case (name, valExpr: NamedExpression) =>
+ StructField(name.asInstanceOf[UTF8String].toString,
+ valExpr.dataType, valExpr.nullable, valExpr.metadata)
+ case (name, valExpr) =>
+ StructField(name.asInstanceOf[UTF8String].toString,
+ valExpr.dataType, valExpr.nullable, Metadata.empty)
}
StructType(fields)
}
@@ -365,8 +369,11 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression
private lazy val names = nameExprs.map(_.eval(EmptyRow).toString)
override lazy val dataType: StructType = {
- val fields = names.zip(valExprs).map { case (name, valExpr) =>
- StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty)
+ val fields = names.zip(valExprs).map {
+ case (name, valExpr: NamedExpression) =>
+ StructField(name, valExpr.dataType, valExpr.nullable, valExpr.metadata)
+ case (name, valExpr) =>
+ StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty)
}
StructType(fields)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 7c009a7360..ec7be4d4b8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -228,4 +228,22 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkErrorMessage(structType, IntegerType, "Field name should be String Literal")
checkErrorMessage(otherType, StringType, "Can't extract value from")
}
+
+ test("ensure to preserve metadata") {
+ val metadata = new MetadataBuilder()
+ .putString("key", "value")
+ .build()
+
+ def checkMetadata(expr: Expression): Unit = {
+ assert(expr.dataType.asInstanceOf[StructType]("a").metadata === metadata)
+ assert(expr.dataType.asInstanceOf[StructType]("b").metadata === Metadata.empty)
+ }
+
+ val a = AttributeReference("a", IntegerType, metadata = metadata)()
+ val b = AttributeReference("b", IntegerType)()
+ checkMetadata(CreateStruct(Seq(a, b)))
+ checkMetadata(CreateNamedStruct(Seq("a", a, "b", b)))
+ checkMetadata(CreateStructUnsafe(Seq(a, b)))
+ checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b)))
+ }
}