diff options
6 files changed, 163 insertions, 5 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e65b14dc0e..9f0d71d796 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1311,6 +1311,23 @@ def array(*cols): return Column(jc) +@since(1.5) +def array_contains(col, value): + """ + Collection function: returns True if the array contains the given value. The collection + elements and value must be of the same type. + + :param col: name of column containing array + :param value: value to check for in array + + >>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df.select(array_contains(df.data, "a")).collect() + [Row(array_contains(data,a)=True), Row(array_contains(data,a)=False)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 43e3e9b910..94c355f838 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -240,6 +240,7 @@ object FunctionRegistry { // collection functions expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ArrayContains]("array_contains"), // misc functions expression[Crc32]("crc32"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6ccb56578f..646afa4047 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.expressions.codegen.{ + CodegenFallback, CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ /** @@ -115,3 +116,76 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def prettyName: String = "sort_array" } + +/** + * Checks if the array (left) has the element (right) + */ +case class ArrayContains(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = right.dataType match { + case NullType => Seq() + case _ => left.dataType match { + case n @ ArrayType(element, _) => Seq(n, element) + case _ => Seq() + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (right.dataType == NullType) { + TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") + } else if (!left.dataType.isInstanceOf[ArrayType] + || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + "Arguments must be an array followed by a value of same type as the array members") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Boolean = { + val arr = left.eval(input) + if (arr == null) { + false + } else { + val value = right.eval(input) + if (value == null) { + false + } else { + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == value) return true + ) + false + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arrGen = left.gen(ctx) + val elementGen = right.gen(ctx) + val i = ctx.freshName("i") + val getValue = ctx.getValue(arrGen.primitive, right.dataType, i) + s""" + ${arrGen.code} + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = false; + if (!${arrGen.isNull}) { + ${elementGen.code} + if (!${elementGen.isNull}) { + for (int $i = 0; $i < ${arrGen.primitive}.numElements(); $i ++) { + if (${ctx.genEqual(right.dataType, elementGen.primitive, getValue)}) { + ${ev.primitive} = true; + break; + } + } + } + } + """ + } + + override def prettyName: String = "array_contains" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 2c7e85c446..95f0e38212 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -65,4 +65,19 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, ArrayType(StringType)), null) } + + test("Array contains") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + + checkEvaluation(ArrayContains(a0, Literal(1)), true) + checkEvaluation(ArrayContains(a0, Literal(0)), false) + checkEvaluation(ArrayContains(a0, Literal(null)), false) + + checkEvaluation(ArrayContains(a1, Literal("")), true) + checkEvaluation(ArrayContains(a1, Literal(null)), false) + + checkEvaluation(ArrayContains(a2, Literal(null)), false) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bff7017254..5a10c3891a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2121,6 +2121,14 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** + * Returns true if the array contain the value + * @group collection_funcs + * @since 1.5.0 + */ + def array_contains(column: Column, value: Any): Column = + ArrayContains(column.expr, Literal(value)) + + /** * Creates a new row for each element in the given array or map column. * * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6137527757..03116a374f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -323,9 +323,9 @@ class DataFrameFunctionsSuite extends QueryTest { test("array size function") { val df = Seq( - (Array[Int](1, 2), "x"), - (Array[Int](), "y"), - (Array[Int](1, 2, 3), "z") + (Seq[Int](1, 2), "x"), + (Seq[Int](), "y"), + (Seq[Int](1, 2, 3), "z") ).toDF("a", "b") checkAnswer( df.select(size($"a")), @@ -352,4 +352,47 @@ class DataFrameFunctionsSuite extends QueryTest { Seq(Row(2), Row(0), Row(3)) ) } + + test("array contains function") { + val df = Seq( + (Seq[Int](1, 2), "x"), + (Seq[Int](), "x") + ).toDF("a", "b") + + // Simple test cases + checkAnswer( + df.select(array_contains(df("a"), 1)), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(a, 1)"), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df.select(array_contains(array(lit(2), lit(null)), 1)), + Seq(Row(false), Row(false)) + ) + + // In hive, this errors because null has no type information + intercept[AnalysisException] { + df.select(array_contains(df("a"), null)) + } + intercept[AnalysisException] { + df.selectExpr("array_contains(a, null)") + } + intercept[AnalysisException] { + df.selectExpr("array_contains(null, 1)") + } + + // In hive, if either argument has a matching type has a null value, return false, even if + // the first argument array contains a null and the second argument is null + checkAnswer( + df.selectExpr("array_contains(array(array(1), null)[1], 1)"), + Seq(Row(false), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(array(0, null), array(1, null)[1])"), + Seq(Row(false), Row(false)) + ) + } } |