diff options
author | Pedro Rodriguez <prodriguez@trulia.com> | 2015-08-04 22:32:21 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-08-04 22:35:45 -0700 |
commit | 28bb977302ff3077c82bb8ee7518eb36bddaf2b3 (patch) | |
tree | 4e7c2a7811ddd6d03c1f69dacee38c161a801293 | |
parent | bca196754ddf2ccd057d775bd5c3f7d3e5657e6f (diff) | |
download | spark-28bb977302ff3077c82bb8ee7518eb36bddaf2b3.tar.gz spark-28bb977302ff3077c82bb8ee7518eb36bddaf2b3.tar.bz2 spark-28bb977302ff3077c82bb8ee7518eb36bddaf2b3.zip |
[SPARK-8231] [SQL] Add array_contains
This PR is based on #7580 , thanks to EntilZha
PR for work on https://issues.apache.org/jira/browse/SPARK-8231
Currently, I have an initial implementation for contains. Based on discussion on JIRA, it should behave same as Hive: https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFArrayContains.java#L102-L128
Main points are:
1. If the array is empty, null, or the value is null, return false
2. If there is a type mismatch, throw error
3. If comparison is not supported, throw error
Closes #7580
Author: Pedro Rodriguez <prodriguez@trulia.com>
Author: Pedro Rodriguez <ski.rodriguez@gmail.com>
Author: Davies Liu <davies@databricks.com>
Closes #7949 from davies/array_contains and squashes the following commits:
d3c08bc [Davies Liu] use foreach() to avoid copy
bc3d1fe [Davies Liu] fix array_contains
719e37d [Davies Liu] Merge branch 'master' of github.com:apache/spark into array_contains
e352cf9 [Pedro Rodriguez] fixed diff from master
4d5b0ff [Pedro Rodriguez] added docs and another type check
ffc0591 [Pedro Rodriguez] fixed unit test
7a22deb [Pedro Rodriguez] Changed test to use strings instead of long/ints which are different between python 2 an 3
b5ffae8 [Pedro Rodriguez] fixed pyspark test
4e7dce3 [Pedro Rodriguez] added more docs
3082399 [Pedro Rodriguez] fixed unit test
46f9789 [Pedro Rodriguez] reverted change
d3ca013 [Pedro Rodriguez] Fixed type checking to match hive behavior, then added tests to insure this
8528027 [Pedro Rodriguez] added more tests
686e029 [Pedro Rodriguez] fix scala style
d262e9d [Pedro Rodriguez] reworked type checking code and added more tests
2517a58 [Pedro Rodriguez] removed unused import
28b4f71 [Pedro Rodriguez] fixed bug with type conversions and re-added tests
12f8795 [Pedro Rodriguez] fix scala style checks
e8a20a9 [Pedro Rodriguez] added python df (broken atm)
65b562c [Pedro Rodriguez] made array_contains nullable false
33b45aa [Pedro Rodriguez] reordered test
9623c64 [Pedro Rodriguez] fixed test
4b4425b [Pedro Rodriguez] changed Arrays in tests to Seqs
72cb4b1 [Pedro Rodriguez] added checkInputTypes and docs
69c46fb [Pedro Rodriguez] added tests and codegen
9e0bfc4 [Pedro Rodriguez] initial attempt at implementation
6 files changed, 163 insertions, 5 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e65b14dc0e..9f0d71d796 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1311,6 +1311,23 @@ def array(*cols): return Column(jc) +@since(1.5) +def array_contains(col, value): + """ + Collection function: returns True if the array contains the given value. The collection + elements and value must be of the same type. + + :param col: name of column containing array + :param value: value to check for in array + + >>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df.select(array_contains(df.data, "a")).collect() + [Row(array_contains(data,a)=True), Row(array_contains(data,a)=False)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 43e3e9b910..94c355f838 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -240,6 +240,7 @@ object FunctionRegistry { // collection functions expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ArrayContains]("array_contains"), // misc functions expression[Crc32]("crc32"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6ccb56578f..646afa4047 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.expressions.codegen.{ + CodegenFallback, CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ /** @@ -115,3 +116,76 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def prettyName: String = "sort_array" } + +/** + * Checks if the array (left) has the element (right) + */ +case class ArrayContains(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = right.dataType match { + case NullType => Seq() + case _ => left.dataType match { + case n @ ArrayType(element, _) => Seq(n, element) + case _ => Seq() + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (right.dataType == NullType) { + TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") + } else if (!left.dataType.isInstanceOf[ArrayType] + || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + "Arguments must be an array followed by a value of same type as the array members") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Boolean = { + val arr = left.eval(input) + if (arr == null) { + false + } else { + val value = right.eval(input) + if (value == null) { + false + } else { + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == value) return true + ) + false + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arrGen = left.gen(ctx) + val elementGen = right.gen(ctx) + val i = ctx.freshName("i") + val getValue = ctx.getValue(arrGen.primitive, right.dataType, i) + s""" + ${arrGen.code} + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = false; + if (!${arrGen.isNull}) { + ${elementGen.code} + if (!${elementGen.isNull}) { + for (int $i = 0; $i < ${arrGen.primitive}.numElements(); $i ++) { + if (${ctx.genEqual(right.dataType, elementGen.primitive, getValue)}) { + ${ev.primitive} = true; + break; + } + } + } + } + """ + } + + override def prettyName: String = "array_contains" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 2c7e85c446..95f0e38212 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -65,4 +65,19 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, ArrayType(StringType)), null) } + + test("Array contains") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + + checkEvaluation(ArrayContains(a0, Literal(1)), true) + checkEvaluation(ArrayContains(a0, Literal(0)), false) + checkEvaluation(ArrayContains(a0, Literal(null)), false) + + checkEvaluation(ArrayContains(a1, Literal("")), true) + checkEvaluation(ArrayContains(a1, Literal(null)), false) + + checkEvaluation(ArrayContains(a2, Literal(null)), false) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bff7017254..5a10c3891a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2121,6 +2121,14 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** + * Returns true if the array contain the value + * @group collection_funcs + * @since 1.5.0 + */ + def array_contains(column: Column, value: Any): Column = + ArrayContains(column.expr, Literal(value)) + + /** * Creates a new row for each element in the given array or map column. * * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6137527757..03116a374f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -323,9 +323,9 @@ class DataFrameFunctionsSuite extends QueryTest { test("array size function") { val df = Seq( - (Array[Int](1, 2), "x"), - (Array[Int](), "y"), - (Array[Int](1, 2, 3), "z") + (Seq[Int](1, 2), "x"), + (Seq[Int](), "y"), + (Seq[Int](1, 2, 3), "z") ).toDF("a", "b") checkAnswer( df.select(size($"a")), @@ -352,4 +352,47 @@ class DataFrameFunctionsSuite extends QueryTest { Seq(Row(2), Row(0), Row(3)) ) } + + test("array contains function") { + val df = Seq( + (Seq[Int](1, 2), "x"), + (Seq[Int](), "x") + ).toDF("a", "b") + + // Simple test cases + checkAnswer( + df.select(array_contains(df("a"), 1)), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(a, 1)"), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df.select(array_contains(array(lit(2), lit(null)), 1)), + Seq(Row(false), Row(false)) + ) + + // In hive, this errors because null has no type information + intercept[AnalysisException] { + df.select(array_contains(df("a"), null)) + } + intercept[AnalysisException] { + df.selectExpr("array_contains(a, null)") + } + intercept[AnalysisException] { + df.selectExpr("array_contains(null, 1)") + } + + // In hive, if either argument has a matching type has a null value, return false, even if + // the first argument array contains a null and the second argument is null + checkAnswer( + df.selectExpr("array_contains(array(array(1), null)[1], 1)"), + Seq(Row(false), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(array(0, null), array(1, null)[1])"), + Seq(Row(false), Row(false)) + ) + } } |