diff options
author | Pedro Rodriguez <prodriguez@trulia.com> | 2015-08-04 22:32:21 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-08-04 22:34:02 -0700 |
commit | d34548587ab55bc2136c8f823b9e6ae96e1355a4 (patch) | |
tree | 5a1158d13c761f945742d8c6736b6afdfe198ca9 /sql/catalyst | |
parent | a02bcf20c4fc9e2e182630d197221729e996afc2 (diff) | |
download | spark-d34548587ab55bc2136c8f823b9e6ae96e1355a4.tar.gz spark-d34548587ab55bc2136c8f823b9e6ae96e1355a4.tar.bz2 spark-d34548587ab55bc2136c8f823b9e6ae96e1355a4.zip |
[SPARK-8231] [SQL] Add array_contains
This PR is based on #7580 , thanks to EntilZha
PR for work on https://issues.apache.org/jira/browse/SPARK-8231
Currently, I have an initial implementation for contains. Based on discussion on JIRA, it should behave same as Hive: https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFArrayContains.java#L102-L128
Main points are:
1. If the array is empty, null, or the value is null, return false
2. If there is a type mismatch, throw error
3. If comparison is not supported, throw error
Closes #7580
Author: Pedro Rodriguez <prodriguez@trulia.com>
Author: Pedro Rodriguez <ski.rodriguez@gmail.com>
Author: Davies Liu <davies@databricks.com>
Closes #7949 from davies/array_contains and squashes the following commits:
d3c08bc [Davies Liu] use foreach() to avoid copy
bc3d1fe [Davies Liu] fix array_contains
719e37d [Davies Liu] Merge branch 'master' of github.com:apache/spark into array_contains
e352cf9 [Pedro Rodriguez] fixed diff from master
4d5b0ff [Pedro Rodriguez] added docs and another type check
ffc0591 [Pedro Rodriguez] fixed unit test
7a22deb [Pedro Rodriguez] Changed test to use strings instead of long/ints which are different between python 2 an 3
b5ffae8 [Pedro Rodriguez] fixed pyspark test
4e7dce3 [Pedro Rodriguez] added more docs
3082399 [Pedro Rodriguez] fixed unit test
46f9789 [Pedro Rodriguez] reverted change
d3ca013 [Pedro Rodriguez] Fixed type checking to match hive behavior, then added tests to insure this
8528027 [Pedro Rodriguez] added more tests
686e029 [Pedro Rodriguez] fix scala style
d262e9d [Pedro Rodriguez] reworked type checking code and added more tests
2517a58 [Pedro Rodriguez] removed unused import
28b4f71 [Pedro Rodriguez] fixed bug with type conversions and re-added tests
12f8795 [Pedro Rodriguez] fix scala style checks
e8a20a9 [Pedro Rodriguez] added python df (broken atm)
65b562c [Pedro Rodriguez] made array_contains nullable false
33b45aa [Pedro Rodriguez] reordered test
9623c64 [Pedro Rodriguez] fixed test
4b4425b [Pedro Rodriguez] changed Arrays in tests to Seqs
72cb4b1 [Pedro Rodriguez] added checkInputTypes and docs
69c46fb [Pedro Rodriguez] added tests and codegen
9e0bfc4 [Pedro Rodriguez] initial attempt at implementation
Diffstat (limited to 'sql/catalyst')
3 files changed, 92 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 43e3e9b910..94c355f838 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -240,6 +240,7 @@ object FunctionRegistry { // collection functions expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ArrayContains]("array_contains"), // misc functions expression[Crc32]("crc32"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6ccb56578f..646afa4047 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.expressions.codegen.{ + CodegenFallback, CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ /** @@ -115,3 +116,76 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def prettyName: String = "sort_array" } + +/** + * Checks if the array (left) has the element (right) + */ +case class ArrayContains(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = right.dataType match { + case NullType => Seq() + case _ => left.dataType match { + case n @ ArrayType(element, _) => Seq(n, element) + case _ => Seq() + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (right.dataType == NullType) { + TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") + } else if (!left.dataType.isInstanceOf[ArrayType] + || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + "Arguments must be an array followed by a value of same type as the array members") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Boolean = { + val arr = left.eval(input) + if (arr == null) { + false + } else { + val value = right.eval(input) + if (value == null) { + false + } else { + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == value) return true + ) + false + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arrGen = left.gen(ctx) + val elementGen = right.gen(ctx) + val i = ctx.freshName("i") + val getValue = ctx.getValue(arrGen.primitive, right.dataType, i) + s""" + ${arrGen.code} + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = false; + if (!${arrGen.isNull}) { + ${elementGen.code} + if (!${elementGen.isNull}) { + for (int $i = 0; $i < ${arrGen.primitive}.numElements(); $i ++) { + if (${ctx.genEqual(right.dataType, elementGen.primitive, getValue)}) { + ${ev.primitive} = true; + break; + } + } + } + } + """ + } + + override def prettyName: String = "array_contains" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 2c7e85c446..95f0e38212 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -65,4 +65,19 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, ArrayType(StringType)), null) } + + test("Array contains") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + + checkEvaluation(ArrayContains(a0, Literal(1)), true) + checkEvaluation(ArrayContains(a0, Literal(0)), false) + checkEvaluation(ArrayContains(a0, Literal(null)), false) + + checkEvaluation(ArrayContains(a1, Literal("")), true) + checkEvaluation(ArrayContains(a1, Literal(null)), false) + + checkEvaluation(ArrayContains(a2, Literal(null)), false) + } } |