diff options
author | Yijie Shen <henry.yijieshen@gmail.com> | 2015-07-02 10:12:25 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-07-02 10:12:25 -0700 |
commit | 52302a803967114b29a8bf6b74459477364c5b88 (patch) | |
tree | 0a694ba254c0c113eb69277a445898f1421a386d | |
parent | afa021e03f0a1a326be2ed742332845b77f94c55 (diff) | |
download | spark-52302a803967114b29a8bf6b74459477364c5b88.tar.gz spark-52302a803967114b29a8bf6b74459477364c5b88.tar.bz2 spark-52302a803967114b29a8bf6b74459477364c5b88.zip |
[SPARK-8407] [SQL] complex type constructors: struct and named_struct
This is a follow up of [SPARK-8283](https://issues.apache.org/jira/browse/SPARK-8283) ([PR-6828](https://github.com/apache/spark/pull/6828)), to support both `struct` and `named_struct` in Spark SQL.
After [#6725](https://github.com/apache/spark/pull/6828), the semantic of [`CreateStruct`](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala#L56) methods have changed a little and do not limited to cols of `NamedExpressions`, it will name non-NamedExpression fields following the hive convention, col1, col2 ...
This PR would both loosen [`struct`](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/functions.scala#L723) to take children of `Expression` type and add `named_struct` support.
Author: Yijie Shen <henry.yijieshen@gmail.com>
Closes #6874 from yijieshen/SPARK-8283 and squashes the following commits:
4cd3375ac [Yijie Shen] change struct documentation
d599d0b [Yijie Shen] rebase code
9a7039e [Yijie Shen] fix reviews and regenerate golden answers
b487354 [Yijie Shen] replace assert using checkAnswer
f07e114 [Yijie Shen] tiny fix
9613be9 [Yijie Shen] review fix
7fef712 [Yijie Shen] Fix checkInputTypes' implementation using foldable and nullable
60812a7 [Yijie Shen] Fix type check
828d694 [Yijie Shen] remove unnecessary resolved assertion inside dataType method
fd3cd8e [Yijie Shen] remove type check from eval
7a71255 [Yijie Shen] tiny fix
ccbbd86 [Yijie Shen] Fix reviews
47da332 [Yijie Shen] remove nameStruct API from DataFrame
917e680 [Yijie Shen] Fix reviews
4bd75ad [Yijie Shen] loosen struct method in functions.scala to take Expression children
0acb7be [Yijie Shen] Add CreateNamedStruct in both DataFrame function API and FunctionRegistery
-rw-r--r-- | python/pyspark/sql/functions.py | 1 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala | 1 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala | 49 | ||||
-rw-r--r-- | sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala | 11 | ||||
-rw-r--r-- | sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala | 24 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 11 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 40 | ||||
-rw-r--r-- | sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad (renamed from sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55) | 0 | ||||
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala | 2 |
9 files changed, 126 insertions, 13 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bccde6083c..12263e6a75 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -467,7 +467,6 @@ def struct(*cols): """Creates a new struct column. :param cols: list of column names (string) or list of :class:`Column` expressions - that are named or aliased. >>> df.select(struct('age', 'name').alias("struct")).collect() [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index aa051b1633..e7e4d1c4ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -96,6 +96,7 @@ object FunctionRegistry { expression[Rand]("rand"), expression[Randn]("randn"), expression[CreateStruct]("struct"), + expression[CreateNamedStruct]("named_struct"), expression[Sqrt]("sqrt"), // math functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 67e7dc4ec8..fa70409353 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Returns an Array containing the evaluation of all children expressions. @@ -54,6 +57,8 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) + override lazy val resolved: Boolean = childrenResolved + override lazy val dataType: StructType = { val fields = children.zipWithIndex.map { case (child, idx) => child match { @@ -74,3 +79,47 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def prettyName: String = "struct" } + +/** + * Creates a struct with the given field names and values + * + * @param children Seq(name1, val1, name2, val2, ...) + */ +case class CreateNamedStruct(children: Seq[Expression]) extends Expression { + + private lazy val (nameExprs, valExprs) = + children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip + + private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) + + override lazy val dataType: StructType = { + val fields = names.zip(valExprs).map { case (name, valExpr) => + StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + } + StructType(fields) + } + + override def foldable: Boolean = valExprs.forall(_.foldable) + + override def nullable: Boolean = false + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size % 2 != 0) { + TypeCheckResult.TypeCheckFailure("CreateNamedStruct expects an even number of arguments.") + } else { + val invalidNames = + nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable) + if (invalidNames.size != 0) { + TypeCheckResult.TypeCheckFailure( + s"Odd position only allow foldable and not-null StringType expressions, got :" + + s" ${invalidNames.mkString(",")}") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + } + + override def eval(input: InternalRow): Any = { + InternalRow(valExprs.map(_.eval(input)): _*) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index bc1537b071..8e0551b23e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -160,4 +160,15 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Explode('intField), "input to function explode should be array or map type") } + + test("check types for CreateNamedStruct") { + assertError( + CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") + assertError( + CreateNamedStruct(Seq(1, "a", "b", 2.0)), + "Odd position only allow foldable and not-null StringType expressions") + assertError( + CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), + "Odd position only allow foldable and not-null StringType expressions") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 3515d044b2..a09014e1ff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import org.scalatest.exceptions.TestFailedException + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -119,11 +121,29 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { test("CreateStruct") { val row = create_row(1, 2, 3) - val c1 = 'a.int.at(0).as("a") - val c3 = 'c.int.at(2).as("c") + val c1 = 'a.int.at(0) + val c3 = 'c.int.at(2) checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row) } + test("CreateNamedStruct") { + val row = InternalRow(1, 2, 3) + val c1 = 'a.int.at(0) + val c3 = 'c.int.at(2) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), InternalRow(1, 3), row) + } + + test("CreateNamedStruct with literal field") { + val row = InternalRow(1, 2, 3) + val c1 = 'a.int.at(0) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), InternalRow(1, "y"), row) + } + + test("CreateNamedStruct from all literal fields") { + checkEvaluation( + CreateNamedStruct(Seq("a", "x", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty) + } + test("test dsl for complex type") { def quickResolve(u: UnresolvedExtractValue): Expression = { ExtractValue(u.child, u.extraction, _ == _) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a5b6828685..4ee1fb8374 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -739,17 +739,18 @@ object functions { def sqrt(colName: String): Column = sqrt(Column(colName)) /** - * Creates a new struct column. The input column must be a column in a [[DataFrame]], or - * a derived column expression that is named (i.e. aliased). + * Creates a new struct column. + * If the input column is a column in a [[DataFrame]], or a derived column expression + * that is named (i.e. aliased), its name would be remained as the StructField's name, + * otherwise, the newly generated StructField's name would be auto generated as col${index + 1}, + * i.e. col1, col2, col3, ... * * @group normal_funcs * @since 1.4.0 */ @scala.annotation.varargs def struct(cols: Column*): Column = { - require(cols.forall(_.expr.isInstanceOf[NamedExpression]), - s"struct input columns must all be named or aliased ($cols)") - CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression])) + CreateStruct(cols.map(_.expr)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7ae89bcb1b..0d43aca877 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -79,10 +79,42 @@ class DataFrameFunctionsSuite extends QueryTest { assert(row.getAs[Row](0) === Row(2, "str")) } - test("struct: must use named column expression") { - intercept[IllegalArgumentException] { - struct(col("a") * 2) - } + test("struct with column expression to be automatically named") { + val df = Seq((1, "str")).toDF("a", "b") + val result = df.select(struct((col("a") * 2), col("b"))) + + val expectedType = StructType(Seq( + StructField("col1", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Row(Row(2, "str"))) + } + + test("struct with literal columns") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val result = df.select(struct((col("a") * 2), lit(5.0))) + + val expectedType = StructType(Seq( + StructField("col1", IntegerType, nullable = false), + StructField("col2", DoubleType, nullable = false) + )) + + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Seq(Row(Row(2, 5.0)), Row(Row(4, 5.0)))) + } + + test("struct with all literal columns") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val result = df.select(struct(lit("v"), lit(5.0))) + + val expectedType = StructType(Seq( + StructField("col1", StringType, nullable = false), + StructField("col2", DoubleType, nullable = false) + )) + + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Seq(Row(Row("v", 5.0)), Row(Row("v", 5.0)))) } test("constant functions") { diff --git a/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad index 7bc77e7f2a..7bc77e7f2a 100644 --- a/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 +++ b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4cdba03b27..991da2f829 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -132,7 +132,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { lower("AA"), "10", repeat(lower("AA"), 3), "11", lower(repeat("AA", 3)), "12", - printf("Bb%d", 12), "13", + printf("bb%d", 12), "13", repeat(printf("s%d", 14), 2), "14") FROM src LIMIT 1""") createQueryTest("NaN to Decimal", |