diff options
author | Tarek Auel <tarek.auel@googlemail.com> | 2015-08-04 08:59:42 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-08-04 08:59:42 -0700 |
commit | b1f88a38d53aebe7cabb762cdd2f1cc64726b0b4 (patch) | |
tree | b57775883940817e2986a963c4128c97ba1d8b43 /sql | |
parent | d702d53732b44e8242448ce5302738bd130717d8 (diff) | |
download | spark-b1f88a38d53aebe7cabb762cdd2f1cc64726b0b4.tar.gz spark-b1f88a38d53aebe7cabb762cdd2f1cc64726b0b4.tar.bz2 spark-b1f88a38d53aebe7cabb762cdd2f1cc64726b0b4.zip |
[SPARK-8244] [SQL] string function: find in set
This PR is based on #7186 (just fix the conflict), thanks to tarekauel .
find_in_set(string str, string strList): int
Returns the first occurance of str in strList where strList is a comma-delimited string. Returns null if either argument is null. Returns 0 if the first argument contains any commas. For example, find_in_set('ab', 'abc,b,ab,c,def') returns 3.
Only add this to SQL, not DataFrame.
Closes #7186
Author: Tarek Auel <tarek.auel@googlemail.com>
Author: Davies Liu <davies@databricks.com>
Closes #7900 from davies/find_in_set and squashes the following commits:
4334209 [Davies Liu] Merge branch 'master' of github.com:apache/spark into find_in_set
8f00572 [Davies Liu] Merge branch 'master' of github.com:apache/spark into find_in_set
243ede4 [Tarek Auel] [SPARK-8244][SQL] hive compatibility
1aaf64e [Tarek Auel] [SPARK-8244][SQL] unit test fix
e4093a4 [Tarek Auel] [SPARK-8244][SQL] final modifier for COMMA_UTF8
0d05df5 [Tarek Auel] Merge branch 'master' into SPARK-8244
208d710 [Tarek Auel] [SPARK-8244] address comments & bug fix
71b2e69 [Tarek Auel] [SPARK-8244] find_in_set
66c7fda [Tarek Auel] Merge branch 'master' into SPARK-8244
61b8ca2 [Tarek Auel] [SPARK-8224] removed loop and split; use unsafe String comparison
4f75a65 [Tarek Auel] Merge branch 'master' into SPARK-8244
e3b20c8 [Tarek Auel] [SPARK-8244] added type check
1c2bbb7 [Tarek Auel] [SPARK-8244] findInSet
Diffstat (limited to 'sql')
4 files changed, 42 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index bc08466461..6140d1b129 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -177,6 +177,7 @@ object FunctionRegistry { expression[ConcatWs]("concat_ws"), expression[Encode]("encode"), expression[Decode]("decode"), + expression[FindInSet]("find_in_set"), expression[FormatNumber]("format_number"), expression[InitCap]("initcap"), expression[Lower]("lcase"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 56225290cd..0cc785d9f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.text.DecimalFormat -import java.util.Arrays -import java.util.Locale +import java.util.{Arrays, Locale} import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils @@ -351,6 +350,28 @@ case class EndsWith(left: Expression, right: Expression) } /** + * A function that returns the index (1-based) of the given string (left) in the comma- + * delimited list (right). Returns 0, if the string wasn't found or if the given + * string (left) contains a comma. + */ +case class FindInSet(left: Expression, right: Expression) extends BinaryExpression + with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + + override protected def nullSafeEval(word: Any, set: Any): Any = + set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String]) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (word, set) => + s"${ev.primitive} = $set.findInSet($word);" + ) + } + + override def dataType: DataType = IntegerType +} + +/** * A function that trim the spaces from both ends for the specified string. */ case class StringTrim(child: Expression) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 906be701be..23f36ca43d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -675,4 +675,14 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null) checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null) } + + test("find in set") { + checkEvaluation( + FindInSet(Literal.create(null, StringType), Literal.create(null, StringType)), null) + checkEvaluation(FindInSet(Literal("ab"), Literal.create(null, StringType)), null) + checkEvaluation(FindInSet(Literal.create(null, StringType), Literal("abc,b,ab,c,def")), null) + checkEvaluation(FindInSet(Literal("ab"), Literal("abc,b,ab,c,def")), 3) + checkEvaluation(FindInSet(Literal("abf"), Literal("abc,b,ab,c,def")), 0) + checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 431dcf7382..6137527757 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -208,6 +208,14 @@ class DataFrameFunctionsSuite extends QueryTest { Row(2743272264L, 2180413220L)) } + test("string function find_in_set") { + val df = Seq(("abc,b,ab,c,def", "abc,b,ab,c,def")).toDF("a", "b") + + checkAnswer( + df.selectExpr("find_in_set('ab', a)", "find_in_set('x', b)"), + Row(3, 0)) + } + test("conditional function: least") { checkAnswer( testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1), |