aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPedro Rodriguez <ski.rodriguez@gmail.com>2015-07-21 00:53:20 -0700
committerReynold Xin <rxin@databricks.com>2015-07-21 00:53:20 -0700
commit560c658a7462844c698b5bda09a4cfb4094fd65b (patch)
tree1efc951d4071fe9cc9c220d28971ee2fbf3bd775
parent8c8f0ef59e12b6f13d5a0bf2d7bf1248b5c1e369 (diff)
downloadspark-560c658a7462844c698b5bda09a4cfb4094fd65b.tar.gz
spark-560c658a7462844c698b5bda09a4cfb4094fd65b.tar.bz2
spark-560c658a7462844c698b5bda09a4cfb4094fd65b.zip
[SPARK-8230][SQL] Add array/map size method
Pull Request for: https://issues.apache.org/jira/browse/SPARK-8230 Primary issue resolved is to implement array/map size for Spark SQL. Code is ready for review by a committer. Chen Hao is on the JIRA ticket, but I don't know his username on github, rxin is also on JIRA ticket. Things to review: 1. Where to put added functions namespace wise, they seem to be part of a few operations on collections which includes `sort_array` and `array_contains`. Hence the name given `collectionOperations.scala` and `_collection_functions` in python. 2. In Python code, should it be in a `1.5.0` function array or in a collections array? 3. Are there any missing methods on the `Size` case class? Looks like many of these functions have generated Java code, is that also needed in this case? 4. Something else? Author: Pedro Rodriguez <ski.rodriguez@gmail.com> Author: Pedro Rodriguez <prodriguez@trulia.com> Closes #7462 from EntilZha/SPARK-8230 and squashes the following commits: 9a442ae [Pedro Rodriguez] fixed functions and sorted __all__ 9aea3bb [Pedro Rodriguez] removed imports from python docs 15d4bf1 [Pedro Rodriguez] Added null test case and changed to nullSafeCodeGen d88247c [Pedro Rodriguez] removed python code bd5f0e4 [Pedro Rodriguez] removed duplicate function from rebase/merge 59931b4 [Pedro Rodriguez] fixed compile bug instroduced when merging c187175 [Pedro Rodriguez] updated code to add size to __all__ directly and removed redundent pretty print 130839f [Pedro Rodriguez] fixed failing test aa9bade [Pedro Rodriguez] fix style e093473 [Pedro Rodriguez] updated python code with docs, switched classes/traits implemented, added (failing) expression tests 0449377 [Pedro Rodriguez] refactored code to use better abstract classes/traits and implementations 9a1a2ff [Pedro Rodriguez] added unit tests for map size 2bfbcb6 [Pedro Rodriguez] added unit test for size 20df2b4 [Pedro Rodriguez] Finished working version of size function and added it to python b503e75 [Pedro Rodriguez] First attempt at implementing size for maps and arrays 99a6a5c [Pedro Rodriguez] fixed failing test cac75ac [Pedro Rodriguez] fix style 933d843 [Pedro Rodriguez] updated python code with docs, switched classes/traits implemented, added (failing) expression tests 42bb7d4 [Pedro Rodriguez] refactored code to use better abstract classes/traits and implementations f9c3b8a [Pedro Rodriguez] added unit tests for map size 2515d9f [Pedro Rodriguez] added documentation 0e60541 [Pedro Rodriguez] added unit test for size acf9853 [Pedro Rodriguez] Finished working version of size function and added it to python 84a5d38 [Pedro Rodriguez] First attempt at implementing size for maps and arrays
-rw-r--r--python/pyspark/sql/functions.py15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala37
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala46
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala31
6 files changed, 152 insertions, 1 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 3c134faa0a..719e623a1a 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -50,6 +50,7 @@ __all__ = [
'regexp_replace',
'sha1',
'sha2',
+ 'size',
'sparkPartitionId',
'struct',
'udf',
@@ -825,6 +826,20 @@ def weekofyear(col):
return Column(sc._jvm.functions.weekofyear(col))
+@since(1.5)
+def size(col):
+ """
+ Collection function: returns the length of the array or map stored in the column.
+ :param col: name of column or expression
+
+ >>> df = sqlContext.createDataFrame([([1, 2, 3],),([1],),([],)], ['data'])
+ >>> df.select(size(df.data)).collect()
+ [Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.size(_to_java_column(col)))
+
+
class UserDefinedFunction(object):
"""
User defined function in Python
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index aec392379c..13523720da 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -195,8 +195,10 @@ object FunctionRegistry {
expression[Quarter]("quarter"),
expression[Second]("second"),
expression[WeekOfYear]("weekofyear"),
- expression[Year]("year")
+ expression[Year]("year"),
+ // collection functions
+ expression[Size]("size")
)
val builtin: FunctionRegistry = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
new file mode 100644
index 0000000000..2d92dcf23a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
+import org.apache.spark.sql.types._
+
+/**
+ * Given an array or map, returns its size.
+ */
+case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+ override def dataType: DataType = IntegerType
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType))
+
+ override def nullSafeEval(value: Any): Int = child.dataType match {
+ case ArrayType(_, _) => value.asInstanceOf[Seq[Any]].size
+ case MapType(_, _, _) => value.asInstanceOf[Map[Any, Any]].size
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).size();")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
new file mode 100644
index 0000000000..28c41b5716
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types._
+
+
+class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ test("Array and Map Size") {
+ val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
+ val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType))
+
+ checkEvaluation(Size(a0), 3)
+ checkEvaluation(Size(a1), 0)
+ checkEvaluation(Size(a2), 2)
+
+ val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType))
+ val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
+ val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType))
+
+ checkEvaluation(Size(m0), 2)
+ checkEvaluation(Size(m1), 0)
+ checkEvaluation(Size(m2), 1)
+
+ checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null)
+ checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 6d60dae624..60b089180c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -42,6 +42,7 @@ import org.apache.spark.util.Utils
* @groupname misc_funcs Misc functions
* @groupname window_funcs Window functions
* @groupname string_funcs String functions
+ * @groupname collection_funcs Collection functions
* @groupname Ungrouped Support functions for DataFrames.
* @since 1.3.0
*/
@@ -2054,6 +2055,25 @@ object functions {
def weekofyear(columnName: String): Column = weekofyear(Column(columnName))
//////////////////////////////////////////////////////////////////////////////////////////////
+ // Collection functions
+ //////////////////////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Returns length of array or map
+ * @group collection_funcs
+ * @since 1.5.0
+ */
+ def size(columnName: String): Column = size(Column(columnName))
+
+ /**
+ * Returns length of array or map
+ * @group collection_funcs
+ * @since 1.5.0
+ */
+ def size(column: Column): Column = Size(column.expr)
+
+
+ //////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////
// scalastyle:off
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 8d2ff2f969..1baec5d376 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -267,4 +267,35 @@ class DataFrameFunctionsSuite extends QueryTest {
)
}
+ test("array size function") {
+ val df = Seq(
+ (Array[Int](1, 2), "x"),
+ (Array[Int](), "y"),
+ (Array[Int](1, 2, 3), "z")
+ ).toDF("a", "b")
+ checkAnswer(
+ df.select(size("a")),
+ Seq(Row(2), Row(0), Row(3))
+ )
+ checkAnswer(
+ df.selectExpr("size(a)"),
+ Seq(Row(2), Row(0), Row(3))
+ )
+ }
+
+ test("map size function") {
+ val df = Seq(
+ (Map[Int, Int](1 -> 1, 2 -> 2), "x"),
+ (Map[Int, Int](), "y"),
+ (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z")
+ ).toDF("a", "b")
+ checkAnswer(
+ df.select(size("a")),
+ Seq(Row(2), Row(0), Row(3))
+ )
+ checkAnswer(
+ df.selectExpr("size(a)"),
+ Seq(Row(2), Row(0), Row(3))
+ )
+ }
}