aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-05-01 12:49:02 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-01 12:49:02 -0700
commit37537760d19eab878a5e1a48641cc49e6cb4b989 (patch)
treea37c553f7c27835399dd46a26fafd8dcc1613437 /sql/core/src
parent16860327286bc08b4e2283d51b4c8fe024ba5006 (diff)
downloadspark-37537760d19eab878a5e1a48641cc49e6cb4b989.tar.gz
spark-37537760d19eab878a5e1a48641cc49e6cb4b989.tar.bz2
spark-37537760d19eab878a5e1a48641cc49e6cb4b989.zip
[SPARK-7274] [SQL] Create Column expression for array/struct creation.
Author: Reynold Xin <rxin@databricks.com> Closes #5802 from rxin/SPARK-7274 and squashes the following commits: 19aecaa [Reynold Xin] Fixed unicode tests. bfc1538 [Reynold Xin] Export all Python functions. 2517b8c [Reynold Xin] Code review. 23da335 [Reynold Xin] Fixed Python bug. 132002e [Reynold Xin] Fixed tests. 56fce26 [Reynold Xin] Added Python support. b0d591a [Reynold Xin] Fixed debug error. 86926a6 [Reynold Xin] Added test suite. 7dbb9ab [Reynold Xin] Ok one more. 470e2f5 [Reynold Xin] One more MLlib ... e2d14f0 [Reynold Xin] [SPARK-7274][SQL] Create Column expression for array/struct creation.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala41
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala84
2 files changed, 124 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 242e64d3ff..7e283393d0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag}
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -284,6 +284,23 @@ object functions {
def abs(e: Column): Column = Abs(e.expr)
/**
+ * Creates a new array column. The input columns must all have the same data type.
+ *
+ * @group normal_funcs
+ */
+ @scala.annotation.varargs
+ def array(cols: Column*): Column = CreateArray(cols.map(_.expr))
+
+ /**
+ * Creates a new array column. The input columns must all have the same data type.
+ *
+ * @group normal_funcs
+ */
+ def array(colName: String, colNames: String*): Column = {
+ array((colName +: colNames).map(col) : _*)
+ }
+
+ /**
* Returns the first column that is not null.
* {{{
* df.select(coalesce(df("a"), df("b")))
@@ -391,6 +408,28 @@ object functions {
def sqrt(e: Column): Column = Sqrt(e.expr)
/**
+ * Creates a new struct column. The input column must be a column in a [[DataFrame]], or
+ * a derived column expression that is named (i.e. aliased).
+ *
+ * @group normal_funcs
+ */
+ @scala.annotation.varargs
+ def struct(cols: Column*): Column = {
+ require(cols.forall(_.expr.isInstanceOf[NamedExpression]),
+ s"struct input columns must all be named or aliased ($cols)")
+ CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression]))
+ }
+
+ /**
+ * Creates a new struct column that composes multiple input columns.
+ *
+ * @group normal_funcs
+ */
+ def struct(colName: String, colNames: String*): Column = {
+ struct((colName +: colNames).map(col) : _*)
+ }
+
+ /**
* Converts a string expression to upper case.
*
* @group normal_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
new file mode 100644
index 0000000000..ca03713ef4
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
+import org.apache.spark.sql.types._
+
+/**
+ * Test suite for functions in [[org.apache.spark.sql.functions]].
+ */
+class DataFrameFunctionsSuite extends QueryTest {
+
+ test("array with column name") {
+ val df = Seq((0, 1)).toDF("a", "b")
+ val row = df.select(array("a", "b")).first()
+
+ val expectedType = ArrayType(IntegerType, containsNull = false)
+ assert(row.schema(0).dataType === expectedType)
+ assert(row.getAs[Seq[Int]](0) === Seq(0, 1))
+ }
+
+ test("array with column expression") {
+ val df = Seq((0, 1)).toDF("a", "b")
+ val row = df.select(array(col("a"), col("b") + col("b"))).first()
+
+ val expectedType = ArrayType(IntegerType, containsNull = false)
+ assert(row.schema(0).dataType === expectedType)
+ assert(row.getAs[Seq[Int]](0) === Seq(0, 2))
+ }
+
+ // Turn this on once we add a rule to the analyzer to throw a friendly exception
+ ignore("array: throw exception if putting columns of different types into an array") {
+ val df = Seq((0, "str")).toDF("a", "b")
+ intercept[AnalysisException] {
+ df.select(array("a", "b"))
+ }
+ }
+
+ test("struct with column name") {
+ val df = Seq((1, "str")).toDF("a", "b")
+ val row = df.select(struct("a", "b")).first()
+
+ val expectedType = StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", StringType)
+ ))
+ assert(row.schema(0).dataType === expectedType)
+ assert(row.getAs[Row](0) === Row(1, "str"))
+ }
+
+ test("struct with column expression") {
+ val df = Seq((1, "str")).toDF("a", "b")
+ val row = df.select(struct((col("a") * 2).as("c"), col("b"))).first()
+
+ val expectedType = StructType(Seq(
+ StructField("c", IntegerType, nullable = false),
+ StructField("b", StringType)
+ ))
+ assert(row.schema(0).dataType === expectedType)
+ assert(row.getAs[Row](0) === Row(2, "str"))
+ }
+
+ test("struct: must use named column expression") {
+ intercept[IllegalArgumentException] {
+ struct(col("a") * 2)
+ }
+ }
+}