aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorPedro Rodriguez <prodriguez@trulia.com>2015-08-04 22:32:21 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-04 22:34:02 -0700
commitd34548587ab55bc2136c8f823b9e6ae96e1355a4 (patch)
tree5a1158d13c761f945742d8c6736b6afdfe198ca9 /sql
parenta02bcf20c4fc9e2e182630d197221729e996afc2 (diff)
downloadspark-d34548587ab55bc2136c8f823b9e6ae96e1355a4.tar.gz
spark-d34548587ab55bc2136c8f823b9e6ae96e1355a4.tar.bz2
spark-d34548587ab55bc2136c8f823b9e6ae96e1355a4.zip
[SPARK-8231] [SQL] Add array_contains
This PR is based on #7580 , thanks to EntilZha PR for work on https://issues.apache.org/jira/browse/SPARK-8231 Currently, I have an initial implementation for contains. Based on discussion on JIRA, it should behave same as Hive: https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFArrayContains.java#L102-L128 Main points are: 1. If the array is empty, null, or the value is null, return false 2. If there is a type mismatch, throw error 3. If comparison is not supported, throw error Closes #7580 Author: Pedro Rodriguez <prodriguez@trulia.com> Author: Pedro Rodriguez <ski.rodriguez@gmail.com> Author: Davies Liu <davies@databricks.com> Closes #7949 from davies/array_contains and squashes the following commits: d3c08bc [Davies Liu] use foreach() to avoid copy bc3d1fe [Davies Liu] fix array_contains 719e37d [Davies Liu] Merge branch 'master' of github.com:apache/spark into array_contains e352cf9 [Pedro Rodriguez] fixed diff from master 4d5b0ff [Pedro Rodriguez] added docs and another type check ffc0591 [Pedro Rodriguez] fixed unit test 7a22deb [Pedro Rodriguez] Changed test to use strings instead of long/ints which are different between python 2 an 3 b5ffae8 [Pedro Rodriguez] fixed pyspark test 4e7dce3 [Pedro Rodriguez] added more docs 3082399 [Pedro Rodriguez] fixed unit test 46f9789 [Pedro Rodriguez] reverted change d3ca013 [Pedro Rodriguez] Fixed type checking to match hive behavior, then added tests to insure this 8528027 [Pedro Rodriguez] added more tests 686e029 [Pedro Rodriguez] fix scala style d262e9d [Pedro Rodriguez] reworked type checking code and added more tests 2517a58 [Pedro Rodriguez] removed unused import 28b4f71 [Pedro Rodriguez] fixed bug with type conversions and re-added tests 12f8795 [Pedro Rodriguez] fix scala style checks e8a20a9 [Pedro Rodriguez] added python df (broken atm) 65b562c [Pedro Rodriguez] made array_contains nullable false 33b45aa [Pedro Rodriguez] reordered test 9623c64 [Pedro Rodriguez] fixed test 4b4425b [Pedro Rodriguez] changed Arrays in tests to Seqs 72cb4b1 [Pedro Rodriguez] added checkInputTypes and docs 69c46fb [Pedro Rodriguez] added tests and codegen 9e0bfc4 [Pedro Rodriguez] initial attempt at implementation
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala78
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala49
5 files changed, 146 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 43e3e9b910..94c355f838 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -240,6 +240,7 @@ object FunctionRegistry {
// collection functions
expression[Size]("size"),
expression[SortArray]("sort_array"),
+ expression[ArrayContains]("array_contains"),
// misc functions
expression[Crc32]("crc32"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 6ccb56578f..646afa4047 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Comparator
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, CodeGenContext, GeneratedExpressionCode}
-import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.catalyst.expressions.codegen.{
+ CodegenFallback, CodeGenContext, GeneratedExpressionCode}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
/**
@@ -115,3 +116,76 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
override def prettyName: String = "sort_array"
}
+
+/**
+ * Checks if the array (left) has the element (right)
+ */
+case class ArrayContains(left: Expression, right: Expression)
+ extends BinaryExpression with ImplicitCastInputTypes {
+
+ override def dataType: DataType = BooleanType
+
+ override def inputTypes: Seq[AbstractDataType] = right.dataType match {
+ case NullType => Seq()
+ case _ => left.dataType match {
+ case n @ ArrayType(element, _) => Seq(n, element)
+ case _ => Seq()
+ }
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (right.dataType == NullType) {
+ TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments")
+ } else if (!left.dataType.isInstanceOf[ArrayType]
+ || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) {
+ TypeCheckResult.TypeCheckFailure(
+ "Arguments must be an array followed by a value of same type as the array members")
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
+ override def nullable: Boolean = false
+
+ override def eval(input: InternalRow): Boolean = {
+ val arr = left.eval(input)
+ if (arr == null) {
+ false
+ } else {
+ val value = right.eval(input)
+ if (value == null) {
+ false
+ } else {
+ arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
+ if (v == value) return true
+ )
+ false
+ }
+ }
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val arrGen = left.gen(ctx)
+ val elementGen = right.gen(ctx)
+ val i = ctx.freshName("i")
+ val getValue = ctx.getValue(arrGen.primitive, right.dataType, i)
+ s"""
+ ${arrGen.code}
+ boolean ${ev.isNull} = false;
+ boolean ${ev.primitive} = false;
+ if (!${arrGen.isNull}) {
+ ${elementGen.code}
+ if (!${elementGen.isNull}) {
+ for (int $i = 0; $i < ${arrGen.primitive}.numElements(); $i ++) {
+ if (${ctx.genEqual(right.dataType, elementGen.primitive, getValue)}) {
+ ${ev.primitive} = true;
+ break;
+ }
+ }
+ }
+ }
+ """
+ }
+
+ override def prettyName: String = "array_contains"
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
index 2c7e85c446..95f0e38212 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
@@ -65,4 +65,19 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
}
+
+ test("Array contains") {
+ val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
+ val a2 = Literal.create(Seq(null), ArrayType(LongType))
+
+ checkEvaluation(ArrayContains(a0, Literal(1)), true)
+ checkEvaluation(ArrayContains(a0, Literal(0)), false)
+ checkEvaluation(ArrayContains(a0, Literal(null)), false)
+
+ checkEvaluation(ArrayContains(a1, Literal("")), true)
+ checkEvaluation(ArrayContains(a1, Literal(null)), false)
+
+ checkEvaluation(ArrayContains(a2, Literal(null)), false)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index bff7017254..5a10c3891a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2121,6 +2121,14 @@ object functions {
//////////////////////////////////////////////////////////////////////////////////////////////
/**
+ * Returns true if the array contain the value
+ * @group collection_funcs
+ * @since 1.5.0
+ */
+ def array_contains(column: Column, value: Any): Column =
+ ArrayContains(column.expr, Literal(value))
+
+ /**
* Creates a new row for each element in the given array or map column.
*
* @group collection_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 6137527757..03116a374f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -323,9 +323,9 @@ class DataFrameFunctionsSuite extends QueryTest {
test("array size function") {
val df = Seq(
- (Array[Int](1, 2), "x"),
- (Array[Int](), "y"),
- (Array[Int](1, 2, 3), "z")
+ (Seq[Int](1, 2), "x"),
+ (Seq[Int](), "y"),
+ (Seq[Int](1, 2, 3), "z")
).toDF("a", "b")
checkAnswer(
df.select(size($"a")),
@@ -352,4 +352,47 @@ class DataFrameFunctionsSuite extends QueryTest {
Seq(Row(2), Row(0), Row(3))
)
}
+
+ test("array contains function") {
+ val df = Seq(
+ (Seq[Int](1, 2), "x"),
+ (Seq[Int](), "x")
+ ).toDF("a", "b")
+
+ // Simple test cases
+ checkAnswer(
+ df.select(array_contains(df("a"), 1)),
+ Seq(Row(true), Row(false))
+ )
+ checkAnswer(
+ df.selectExpr("array_contains(a, 1)"),
+ Seq(Row(true), Row(false))
+ )
+ checkAnswer(
+ df.select(array_contains(array(lit(2), lit(null)), 1)),
+ Seq(Row(false), Row(false))
+ )
+
+ // In hive, this errors because null has no type information
+ intercept[AnalysisException] {
+ df.select(array_contains(df("a"), null))
+ }
+ intercept[AnalysisException] {
+ df.selectExpr("array_contains(a, null)")
+ }
+ intercept[AnalysisException] {
+ df.selectExpr("array_contains(null, 1)")
+ }
+
+ // In hive, if either argument has a matching type has a null value, return false, even if
+ // the first argument array contains a null and the second argument is null
+ checkAnswer(
+ df.selectExpr("array_contains(array(array(1), null)[1], 1)"),
+ Seq(Row(false), Row(false))
+ )
+ checkAnswer(
+ df.selectExpr("array_contains(array(0, null), array(1, null)[1])"),
+ Seq(Row(false), Row(false))
+ )
+ }
}