aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorTarek Auel <tarek.auel@googlemail.com>2015-08-04 08:59:42 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-04 08:59:42 -0700
commitb1f88a38d53aebe7cabb762cdd2f1cc64726b0b4 (patch)
treeb57775883940817e2986a963c4128c97ba1d8b43 /sql
parentd702d53732b44e8242448ce5302738bd130717d8 (diff)
downloadspark-b1f88a38d53aebe7cabb762cdd2f1cc64726b0b4.tar.gz
spark-b1f88a38d53aebe7cabb762cdd2f1cc64726b0b4.tar.bz2
spark-b1f88a38d53aebe7cabb762cdd2f1cc64726b0b4.zip
[SPARK-8244] [SQL] string function: find in set
This PR is based on #7186 (just fix the conflict), thanks to tarekauel . find_in_set(string str, string strList): int Returns the first occurance of str in strList where strList is a comma-delimited string. Returns null if either argument is null. Returns 0 if the first argument contains any commas. For example, find_in_set('ab', 'abc,b,ab,c,def') returns 3. Only add this to SQL, not DataFrame. Closes #7186 Author: Tarek Auel <tarek.auel@googlemail.com> Author: Davies Liu <davies@databricks.com> Closes #7900 from davies/find_in_set and squashes the following commits: 4334209 [Davies Liu] Merge branch 'master' of github.com:apache/spark into find_in_set 8f00572 [Davies Liu] Merge branch 'master' of github.com:apache/spark into find_in_set 243ede4 [Tarek Auel] [SPARK-8244][SQL] hive compatibility 1aaf64e [Tarek Auel] [SPARK-8244][SQL] unit test fix e4093a4 [Tarek Auel] [SPARK-8244][SQL] final modifier for COMMA_UTF8 0d05df5 [Tarek Auel] Merge branch 'master' into SPARK-8244 208d710 [Tarek Auel] [SPARK-8244] address comments & bug fix 71b2e69 [Tarek Auel] [SPARK-8244] find_in_set 66c7fda [Tarek Auel] Merge branch 'master' into SPARK-8244 61b8ca2 [Tarek Auel] [SPARK-8224] removed loop and split; use unsafe String comparison 4f75a65 [Tarek Auel] Merge branch 'master' into SPARK-8244 e3b20c8 [Tarek Auel] [SPARK-8244] added type check 1c2bbb7 [Tarek Auel] [SPARK-8244] findInSet
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala25
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala8
4 files changed, 42 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index bc08466461..6140d1b129 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -177,6 +177,7 @@ object FunctionRegistry {
expression[ConcatWs]("concat_ws"),
expression[Encode]("encode"),
expression[Decode]("decode"),
+ expression[FindInSet]("find_in_set"),
expression[FormatNumber]("format_number"),
expression[InitCap]("initcap"),
expression[Lower]("lcase"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 56225290cd..0cc785d9f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -18,8 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import java.text.DecimalFormat
-import java.util.Arrays
-import java.util.Locale
+import java.util.{Arrays, Locale}
import java.util.regex.{MatchResult, Pattern}
import org.apache.commons.lang3.StringEscapeUtils
@@ -351,6 +350,28 @@ case class EndsWith(left: Expression, right: Expression)
}
/**
+ * A function that returns the index (1-based) of the given string (left) in the comma-
+ * delimited list (right). Returns 0, if the string wasn't found or if the given
+ * string (left) contains a comma.
+ */
+case class FindInSet(left: Expression, right: Expression) extends BinaryExpression
+ with ImplicitCastInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
+
+ override protected def nullSafeEval(word: Any, set: Any): Any =
+ set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String])
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, (word, set) =>
+ s"${ev.primitive} = $set.findInSet($word);"
+ )
+ }
+
+ override def dataType: DataType = IntegerType
+}
+
+/**
* A function that trim the spaces from both ends for the specified string.
*/
case class StringTrim(child: Expression)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 906be701be..23f36ca43d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -675,4 +675,14 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null)
checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null)
}
+
+ test("find in set") {
+ checkEvaluation(
+ FindInSet(Literal.create(null, StringType), Literal.create(null, StringType)), null)
+ checkEvaluation(FindInSet(Literal("ab"), Literal.create(null, StringType)), null)
+ checkEvaluation(FindInSet(Literal.create(null, StringType), Literal("abc,b,ab,c,def")), null)
+ checkEvaluation(FindInSet(Literal("ab"), Literal("abc,b,ab,c,def")), 3)
+ checkEvaluation(FindInSet(Literal("abf"), Literal("abc,b,ab,c,def")), 0)
+ checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 431dcf7382..6137527757 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -208,6 +208,14 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(2743272264L, 2180413220L))
}
+ test("string function find_in_set") {
+ val df = Seq(("abc,b,ab,c,def", "abc,b,ab,c,def")).toDF("a", "b")
+
+ checkAnswer(
+ df.selectExpr("find_in_set('ab', a)", "find_in_set('x', b)"),
+ Row(3, 0))
+ }
+
test("conditional function: least") {
checkAnswer(
testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1),