diff options
author | Yijie Shen <henry.yijieshen@gmail.com> | 2015-07-02 10:12:25 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-07-02 10:12:25 -0700 |
commit | 52302a803967114b29a8bf6b74459477364c5b88 (patch) | |
tree | 0a694ba254c0c113eb69277a445898f1421a386d /sql/core | |
parent | afa021e03f0a1a326be2ed742332845b77f94c55 (diff) | |
download | spark-52302a803967114b29a8bf6b74459477364c5b88.tar.gz spark-52302a803967114b29a8bf6b74459477364c5b88.tar.bz2 spark-52302a803967114b29a8bf6b74459477364c5b88.zip |
[SPARK-8407] [SQL] complex type constructors: struct and named_struct
This is a follow up of [SPARK-8283](https://issues.apache.org/jira/browse/SPARK-8283) ([PR-6828](https://github.com/apache/spark/pull/6828)), to support both `struct` and `named_struct` in Spark SQL.
After [#6725](https://github.com/apache/spark/pull/6828), the semantic of [`CreateStruct`](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala#L56) methods have changed a little and do not limited to cols of `NamedExpressions`, it will name non-NamedExpression fields following the hive convention, col1, col2 ...
This PR would both loosen [`struct`](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/functions.scala#L723) to take children of `Expression` type and add `named_struct` support.
Author: Yijie Shen <henry.yijieshen@gmail.com>
Closes #6874 from yijieshen/SPARK-8283 and squashes the following commits:
4cd3375ac [Yijie Shen] change struct documentation
d599d0b [Yijie Shen] rebase code
9a7039e [Yijie Shen] fix reviews and regenerate golden answers
b487354 [Yijie Shen] replace assert using checkAnswer
f07e114 [Yijie Shen] tiny fix
9613be9 [Yijie Shen] review fix
7fef712 [Yijie Shen] Fix checkInputTypes' implementation using foldable and nullable
60812a7 [Yijie Shen] Fix type check
828d694 [Yijie Shen] remove unnecessary resolved assertion inside dataType method
fd3cd8e [Yijie Shen] remove type check from eval
7a71255 [Yijie Shen] tiny fix
ccbbd86 [Yijie Shen] Fix reviews
47da332 [Yijie Shen] remove nameStruct API from DataFrame
917e680 [Yijie Shen] Fix reviews
4bd75ad [Yijie Shen] loosen struct method in functions.scala to take Expression children
0acb7be [Yijie Shen] Add CreateNamedStruct in both DataFrame function API and FunctionRegistery
Diffstat (limited to 'sql/core')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 11 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 40 |
2 files changed, 42 insertions, 9 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a5b6828685..4ee1fb8374 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -739,17 +739,18 @@ object functions { def sqrt(colName: String): Column = sqrt(Column(colName)) /** - * Creates a new struct column. The input column must be a column in a [[DataFrame]], or - * a derived column expression that is named (i.e. aliased). + * Creates a new struct column. + * If the input column is a column in a [[DataFrame]], or a derived column expression + * that is named (i.e. aliased), its name would be remained as the StructField's name, + * otherwise, the newly generated StructField's name would be auto generated as col${index + 1}, + * i.e. col1, col2, col3, ... * * @group normal_funcs * @since 1.4.0 */ @scala.annotation.varargs def struct(cols: Column*): Column = { - require(cols.forall(_.expr.isInstanceOf[NamedExpression]), - s"struct input columns must all be named or aliased ($cols)") - CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression])) + CreateStruct(cols.map(_.expr)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7ae89bcb1b..0d43aca877 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -79,10 +79,42 @@ class DataFrameFunctionsSuite extends QueryTest { assert(row.getAs[Row](0) === Row(2, "str")) } - test("struct: must use named column expression") { - intercept[IllegalArgumentException] { - struct(col("a") * 2) - } + test("struct with column expression to be automatically named") { + val df = Seq((1, "str")).toDF("a", "b") + val result = df.select(struct((col("a") * 2), col("b"))) + + val expectedType = StructType(Seq( + StructField("col1", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Row(Row(2, "str"))) + } + + test("struct with literal columns") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val result = df.select(struct((col("a") * 2), lit(5.0))) + + val expectedType = StructType(Seq( + StructField("col1", IntegerType, nullable = false), + StructField("col2", DoubleType, nullable = false) + )) + + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Seq(Row(Row(2, 5.0)), Row(Row(4, 5.0)))) + } + + test("struct with all literal columns") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val result = df.select(struct(lit("v"), lit(5.0))) + + val expectedType = StructType(Seq( + StructField("col1", StringType, nullable = false), + StructField("col2", DoubleType, nullable = false) + )) + + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Seq(Row(Row("v", 5.0)), Row(Row("v", 5.0)))) } test("constant functions") { |