aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorzhichao.li <zhichao.li@intel.com>2015-07-31 21:18:01 -0700
committerReynold Xin <rxin@databricks.com>2015-07-31 21:18:01 -0700
commit6996bd2e81bf6597dcda499d9a9a80927a43e30f (patch)
tree765e38451f122e762c1e7a8e497f77ab34671131 /sql
parent03377d2522776267a07b7d6ae9bddf79a4e0f516 (diff)
downloadspark-6996bd2e81bf6597dcda499d9a9a80927a43e30f.tar.gz
spark-6996bd2e81bf6597dcda499d9a9a80927a43e30f.tar.bz2
spark-6996bd2e81bf6597dcda499d9a9a80927a43e30f.zip
[SPARK-8264][SQL]add substring_index function
This PR is based on #7533 , thanks to zhichao-li Closes #7533 Author: zhichao.li <zhichao.li@intel.com> Author: Davies Liu <davies@databricks.com> Closes #7843 from davies/str_index and squashes the following commits: 391347b [Davies Liu] add python api 3ce7802 [Davies Liu] fix substringIndex f2d29a1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into str_index 515519b [zhichao.li] add foldable and remove null checking 9546991 [zhichao.li] scala style 67c253a [zhichao.li] hide some apis and clean code b19b013 [zhichao.li] add codegen and clean code ac863e9 [zhichao.li] reduce the calling of numChars 12e108f [zhichao.li] refine unittest d92951b [zhichao.li] add lastIndexOf 52d7b03 [zhichao.li] add substring_index function
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala25
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala57
5 files changed, 125 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 3f61a9af1f..ee44cbcba6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -199,6 +199,7 @@ object FunctionRegistry {
expression[StringSplit]("split"),
expression[Substring]("substr"),
expression[Substring]("substring"),
+ expression[SubstringIndex]("substring_index"),
expression[StringTrim]("trim"),
expression[UnBase64]("unbase64"),
expression[Upper]("ucase"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 160e72f384..5dd387a418 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -422,6 +422,31 @@ case class StringInstr(str: Expression, substr: Expression)
}
/**
+ * Returns the substring from string str before count occurrences of the delimiter delim.
+ * If count is positive, everything the left of the final delimiter (counting from left) is
+ * returned. If count is negative, every to the right of the final delimiter (counting from the
+ * right) is returned. substring_index performs a case-sensitive match when searching for delim.
+ */
+case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
+ extends TernaryExpression with ImplicitCastInputTypes {
+
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
+ override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr)
+ override def prettyName: String = "substring_index"
+
+ override def nullSafeEval(str: Any, delim: Any, count: Any): Any = {
+ str.asInstanceOf[UTF8String].subStringIndex(
+ delim.asInstanceOf[UTF8String],
+ count.asInstanceOf[Int])
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)")
+ }
+}
+
+/**
* A function that returns the position of the first occurrence of substr
* in given string after position pos.
*/
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index fb72fe1714..ad87ab36fd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -187,6 +188,36 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(s.substring(0), "example", row)
}
+ test("string substring_index function") {
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(1)), "www")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(0)), "")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-1)), "org")
+ checkEvaluation(
+ SubstringIndex(Literal(""), Literal("."), Literal(-2)), "")
+ checkEvaluation(
+ SubstringIndex(Literal.create(null, StringType), Literal("."), Literal(-2)), null)
+ checkEvaluation(SubstringIndex(
+ Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null)
+ // non ascii chars
+ // scalastyle:off
+ checkEvaluation(
+ SubstringIndex(Literal("大千世界大千世界"), Literal( "千"), Literal(2)), "大千世界大")
+ // scalastyle:on
+ checkEvaluation(
+ SubstringIndex(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache")
+ }
+
test("LIKE literal Regular Expression") {
checkEvaluation(Literal.create(null, StringType).like("a"), null)
checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 89ffa9c50d..57bb00a741 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1788,8 +1788,18 @@ object functions {
def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr)
/**
- * Locate the position of the first occurrence of substr in a string column.
+ * Returns the substring from string str before count occurrences of the delimiter delim.
+ * If count is positive, everything the left of the final delimiter (counting from left) is
+ * returned. If count is negative, every to the right of the final delimiter (counting from the
+ * right) is returned. substring_index performs a case-sensitive match when searching for delim.
*
+ * @group string_funcs
+ */
+ def substring_index(str: Column, delim: String, count: Int): Column =
+ SubstringIndex(str.expr, lit(delim).expr, lit(count).expr)
+
+ /**
+ * Locate the position of the first occurrence of substr.
* NOTE: The position is not zero based, but 1 based index, returns 0 if substr
* could not be found in str.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index b7f073cccb..628da95298 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -163,6 +163,63 @@ class StringFunctionsSuite extends QueryTest {
Row(1))
}
+ test("string substring_index function") {
+ val df = Seq(("www.apache.org", ".", "zz")).toDF("a", "b", "c")
+ checkAnswer(
+ df.select(substring_index($"a", ".", 3)),
+ Row("www.apache.org"))
+ checkAnswer(
+ df.select(substring_index($"a", ".", 2)),
+ Row("www.apache"))
+ checkAnswer(
+ df.select(substring_index($"a", ".", 1)),
+ Row("www"))
+ checkAnswer(
+ df.select(substring_index($"a", ".", 0)),
+ Row(""))
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), ".", -1)),
+ Row("org"))
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), ".", -2)),
+ Row("apache.org"))
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), ".", -3)),
+ Row("www.apache.org"))
+ // str is empty string
+ checkAnswer(
+ df.select(substring_index(lit(""), ".", 1)),
+ Row(""))
+ // empty string delim
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), "", 1)),
+ Row(""))
+ // delim does not exist in str
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), "#", 1)),
+ Row("www.apache.org"))
+ // delim is 2 chars
+ checkAnswer(
+ df.select(substring_index(lit("www||apache||org"), "||", 2)),
+ Row("www||apache"))
+ checkAnswer(
+ df.select(substring_index(lit("www||apache||org"), "||", -2)),
+ Row("apache||org"))
+ // null
+ checkAnswer(
+ df.select(substring_index(lit(null), "||", 2)),
+ Row(null))
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), null, 2)),
+ Row(null))
+ // non ascii chars
+ // scalastyle:off
+ checkAnswer(
+ df.selectExpr("""substring_index("大千世界大千世界", "千", 2)"""),
+ Row("大千世界大"))
+ // scalastyle:on
+ }
+
test("string locate function") {
val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d")