diff options
author | Pedro Rodriguez <ski.rodriguez@gmail.com> | 2015-07-21 00:53:20 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-07-21 00:53:20 -0700 |
commit | 560c658a7462844c698b5bda09a4cfb4094fd65b (patch) | |
tree | 1efc951d4071fe9cc9c220d28971ee2fbf3bd775 | |
parent | 8c8f0ef59e12b6f13d5a0bf2d7bf1248b5c1e369 (diff) | |
download | spark-560c658a7462844c698b5bda09a4cfb4094fd65b.tar.gz spark-560c658a7462844c698b5bda09a4cfb4094fd65b.tar.bz2 spark-560c658a7462844c698b5bda09a4cfb4094fd65b.zip |
[SPARK-8230][SQL] Add array/map size method
Pull Request for: https://issues.apache.org/jira/browse/SPARK-8230
Primary issue resolved is to implement array/map size for Spark SQL. Code is ready for review by a committer. Chen Hao is on the JIRA ticket, but I don't know his username on github, rxin is also on JIRA ticket.
Things to review:
1. Where to put added functions namespace wise, they seem to be part of a few operations on collections which includes `sort_array` and `array_contains`. Hence the name given `collectionOperations.scala` and `_collection_functions` in python.
2. In Python code, should it be in a `1.5.0` function array or in a collections array?
3. Are there any missing methods on the `Size` case class? Looks like many of these functions have generated Java code, is that also needed in this case?
4. Something else?
Author: Pedro Rodriguez <ski.rodriguez@gmail.com>
Author: Pedro Rodriguez <prodriguez@trulia.com>
Closes #7462 from EntilZha/SPARK-8230 and squashes the following commits:
9a442ae [Pedro Rodriguez] fixed functions and sorted __all__
9aea3bb [Pedro Rodriguez] removed imports from python docs
15d4bf1 [Pedro Rodriguez] Added null test case and changed to nullSafeCodeGen
d88247c [Pedro Rodriguez] removed python code
bd5f0e4 [Pedro Rodriguez] removed duplicate function from rebase/merge
59931b4 [Pedro Rodriguez] fixed compile bug instroduced when merging
c187175 [Pedro Rodriguez] updated code to add size to __all__ directly and removed redundent pretty print
130839f [Pedro Rodriguez] fixed failing test
aa9bade [Pedro Rodriguez] fix style
e093473 [Pedro Rodriguez] updated python code with docs, switched classes/traits implemented, added (failing) expression tests
0449377 [Pedro Rodriguez] refactored code to use better abstract classes/traits and implementations
9a1a2ff [Pedro Rodriguez] added unit tests for map size
2bfbcb6 [Pedro Rodriguez] added unit test for size
20df2b4 [Pedro Rodriguez] Finished working version of size function and added it to python
b503e75 [Pedro Rodriguez] First attempt at implementing size for maps and arrays
99a6a5c [Pedro Rodriguez] fixed failing test
cac75ac [Pedro Rodriguez] fix style
933d843 [Pedro Rodriguez] updated python code with docs, switched classes/traits implemented, added (failing) expression tests
42bb7d4 [Pedro Rodriguez] refactored code to use better abstract classes/traits and implementations
f9c3b8a [Pedro Rodriguez] added unit tests for map size
2515d9f [Pedro Rodriguez] added documentation
0e60541 [Pedro Rodriguez] added unit test for size
acf9853 [Pedro Rodriguez] Finished working version of size function and added it to python
84a5d38 [Pedro Rodriguez] First attempt at implementing size for maps and arrays
6 files changed, 152 insertions, 1 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3c134faa0a..719e623a1a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -50,6 +50,7 @@ __all__ = [ 'regexp_replace', 'sha1', 'sha2', + 'size', 'sparkPartitionId', 'struct', 'udf', @@ -825,6 +826,20 @@ def weekofyear(col): return Column(sc._jvm.functions.weekofyear(col)) +@since(1.5) +def size(col): + """ + Collection function: returns the length of the array or map stored in the column. + :param col: name of column or expression + + >>> df = sqlContext.createDataFrame([([1, 2, 3],),([1],),([],)], ['data']) + >>> df.select(size(df.data)).collect() + [Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.size(_to_java_column(col))) + + class UserDefinedFunction(object): """ User defined function in Python diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index aec392379c..13523720da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -195,8 +195,10 @@ object FunctionRegistry { expression[Quarter]("quarter"), expression[Second]("second"), expression[WeekOfYear]("weekofyear"), - expression[Year]("year") + expression[Year]("year"), + // collection functions + expression[Size]("size") ) val builtin: FunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala new file mode 100644 index 0000000000..2d92dcf23a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.types._ + +/** + * Given an array or map, returns its size. + */ +case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def dataType: DataType = IntegerType + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) + + override def nullSafeEval(value: Any): Int = child.dataType match { + case ArrayType(_, _) => value.asInstanceOf[Seq[Any]].size + case MapType(_, _, _) => value.asInstanceOf[Map[Any, Any]].size + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).size();") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala new file mode 100644 index 0000000000..28c41b5716 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + + +class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("Array and Map Size") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) + val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) + + checkEvaluation(Size(a0), 3) + checkEvaluation(Size(a1), 0) + checkEvaluation(Size(a2), 2) + + val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType)) + val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) + val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType)) + + checkEvaluation(Size(m0), 2) + checkEvaluation(Size(m1), 0) + checkEvaluation(Size(m2), 1) + + checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null) + checkEvaluation(Literal.create(null, ArrayType(StringType)), null) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6d60dae624..60b089180c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -42,6 +42,7 @@ import org.apache.spark.util.Utils * @groupname misc_funcs Misc functions * @groupname window_funcs Window functions * @groupname string_funcs String functions + * @groupname collection_funcs Collection functions * @groupname Ungrouped Support functions for DataFrames. * @since 1.3.0 */ @@ -2054,6 +2055,25 @@ object functions { def weekofyear(columnName: String): Column = weekofyear(Column(columnName)) ////////////////////////////////////////////////////////////////////////////////////////////// + // Collection functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Returns length of array or map + * @group collection_funcs + * @since 1.5.0 + */ + def size(columnName: String): Column = size(Column(columnName)) + + /** + * Returns length of array or map + * @group collection_funcs + * @since 1.5.0 + */ + def size(column: Column): Column = Size(column.expr) + + + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// // scalastyle:off diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 8d2ff2f969..1baec5d376 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -267,4 +267,35 @@ class DataFrameFunctionsSuite extends QueryTest { ) } + test("array size function") { + val df = Seq( + (Array[Int](1, 2), "x"), + (Array[Int](), "y"), + (Array[Int](1, 2, 3), "z") + ).toDF("a", "b") + checkAnswer( + df.select(size("a")), + Seq(Row(2), Row(0), Row(3)) + ) + checkAnswer( + df.selectExpr("size(a)"), + Seq(Row(2), Row(0), Row(3)) + ) + } + + test("map size function") { + val df = Seq( + (Map[Int, Int](1 -> 1, 2 -> 2), "x"), + (Map[Int, Int](), "y"), + (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z") + ).toDF("a", "b") + checkAnswer( + df.select(size("a")), + Seq(Row(2), Row(0), Row(3)) + ) + checkAnswer( + df.selectExpr("size(a)"), + Seq(Row(2), Row(0), Row(3)) + ) + } } |