diff options
Diffstat (limited to 'sql/catalyst/src')
4 files changed, 83 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index aa051b1633..e7e4d1c4ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -96,6 +96,7 @@ object FunctionRegistry { expression[Rand]("rand"), expression[Randn]("randn"), expression[CreateStruct]("struct"), + expression[CreateNamedStruct]("named_struct"), expression[Sqrt]("sqrt"), // math functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 67e7dc4ec8..fa70409353 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Returns an Array containing the evaluation of all children expressions. @@ -54,6 +57,8 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) + override lazy val resolved: Boolean = childrenResolved + override lazy val dataType: StructType = { val fields = children.zipWithIndex.map { case (child, idx) => child match { @@ -74,3 +79,47 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def prettyName: String = "struct" } + +/** + * Creates a struct with the given field names and values + * + * @param children Seq(name1, val1, name2, val2, ...) + */ +case class CreateNamedStruct(children: Seq[Expression]) extends Expression { + + private lazy val (nameExprs, valExprs) = + children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip + + private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) + + override lazy val dataType: StructType = { + val fields = names.zip(valExprs).map { case (name, valExpr) => + StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + } + StructType(fields) + } + + override def foldable: Boolean = valExprs.forall(_.foldable) + + override def nullable: Boolean = false + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size % 2 != 0) { + TypeCheckResult.TypeCheckFailure("CreateNamedStruct expects an even number of arguments.") + } else { + val invalidNames = + nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable) + if (invalidNames.size != 0) { + TypeCheckResult.TypeCheckFailure( + s"Odd position only allow foldable and not-null StringType expressions, got :" + + s" ${invalidNames.mkString(",")}") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + } + + override def eval(input: InternalRow): Any = { + InternalRow(valExprs.map(_.eval(input)): _*) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index bc1537b071..8e0551b23e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -160,4 +160,15 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Explode('intField), "input to function explode should be array or map type") } + + test("check types for CreateNamedStruct") { + assertError( + CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") + assertError( + CreateNamedStruct(Seq(1, "a", "b", 2.0)), + "Odd position only allow foldable and not-null StringType expressions") + assertError( + CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), + "Odd position only allow foldable and not-null StringType expressions") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 3515d044b2..a09014e1ff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import org.scalatest.exceptions.TestFailedException + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -119,11 +121,29 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { test("CreateStruct") { val row = create_row(1, 2, 3) - val c1 = 'a.int.at(0).as("a") - val c3 = 'c.int.at(2).as("c") + val c1 = 'a.int.at(0) + val c3 = 'c.int.at(2) checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row) } + test("CreateNamedStruct") { + val row = InternalRow(1, 2, 3) + val c1 = 'a.int.at(0) + val c3 = 'c.int.at(2) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), InternalRow(1, 3), row) + } + + test("CreateNamedStruct with literal field") { + val row = InternalRow(1, 2, 3) + val c1 = 'a.int.at(0) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), InternalRow(1, "y"), row) + } + + test("CreateNamedStruct from all literal fields") { + checkEvaluation( + CreateNamedStruct(Seq("a", "x", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty) + } + test("test dsl for complex type") { def quickResolve(u: UnresolvedExtractValue): Expression = { ExtractValue(u.child, u.extraction, _ == _) |