From 6996bd2e81bf6597dcda499d9a9a80927a43e30f Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Fri, 31 Jul 2015 21:18:01 -0700 Subject: [SPARK-8264][SQL]add substring_index function This PR is based on #7533 , thanks to zhichao-li Closes #7533 Author: zhichao.li Author: Davies Liu Closes #7843 from davies/str_index and squashes the following commits: 391347b [Davies Liu] add python api 3ce7802 [Davies Liu] fix substringIndex f2d29a1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into str_index 515519b [zhichao.li] add foldable and remove null checking 9546991 [zhichao.li] scala style 67c253a [zhichao.li] hide some apis and clean code b19b013 [zhichao.li] add codegen and clean code ac863e9 [zhichao.li] reduce the calling of numChars 12e108f [zhichao.li] refine unittest d92951b [zhichao.li] add lastIndexOf 52d7b03 [zhichao.li] add substring_index function --- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../catalyst/expressions/stringOperations.scala | 25 ++++++++++ .../expressions/StringExpressionsSuite.scala | 31 ++++++++++++ .../scala/org/apache/spark/sql/functions.scala | 12 ++++- .../apache/spark/sql/StringFunctionsSuite.scala | 57 ++++++++++++++++++++++ 5 files changed, 125 insertions(+), 1 deletion(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3f61a9af1f..ee44cbcba6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -199,6 +199,7 @@ object FunctionRegistry { expression[StringSplit]("split"), expression[Substring]("substr"), expression[Substring]("substring"), + expression[SubstringIndex]("substring_index"), expression[StringTrim]("trim"), expression[UnBase64]("unbase64"), expression[Upper]("ucase"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 160e72f384..5dd387a418 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -421,6 +421,31 @@ case class StringInstr(str: Expression, substr: Expression) } } +/** + * Returns the substring from string str before count occurrences of the delimiter delim. + * If count is positive, everything the left of the final delimiter (counting from left) is + * returned. If count is negative, every to the right of the final delimiter (counting from the + * right) is returned. substring_index performs a case-sensitive match when searching for delim. + */ +case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr) + override def prettyName: String = "substring_index" + + override def nullSafeEval(str: Any, delim: Any, count: Any): Any = { + str.asInstanceOf[UTF8String].subStringIndex( + delim.asInstanceOf[UTF8String], + count.asInstanceOf[Int]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)") + } +} + /** * A function that returns the position of the first occurrence of substr * in given string after position pos. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index fb72fe1714..ad87ab36fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -187,6 +188,36 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(s.substring(0), "example", row) } + test("string substring_index function") { + checkEvaluation( + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org") + checkEvaluation( + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache") + checkEvaluation( + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(1)), "www") + checkEvaluation( + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(0)), "") + checkEvaluation( + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org") + checkEvaluation( + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org") + checkEvaluation( + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-1)), "org") + checkEvaluation( + SubstringIndex(Literal(""), Literal("."), Literal(-2)), "") + checkEvaluation( + SubstringIndex(Literal.create(null, StringType), Literal("."), Literal(-2)), null) + checkEvaluation(SubstringIndex( + Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null) + // non ascii chars + // scalastyle:off + checkEvaluation( + SubstringIndex(Literal("大千世界大千世界"), Literal( "千"), Literal(2)), "大千世界大") + // scalastyle:on + checkEvaluation( + SubstringIndex(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache") + } + test("LIKE literal Regular Expression") { checkEvaluation(Literal.create(null, StringType).like("a"), null) checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 89ffa9c50d..57bb00a741 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1788,8 +1788,18 @@ object functions { def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr) /** - * Locate the position of the first occurrence of substr in a string column. + * Returns the substring from string str before count occurrences of the delimiter delim. + * If count is positive, everything the left of the final delimiter (counting from left) is + * returned. If count is negative, every to the right of the final delimiter (counting from the + * right) is returned. substring_index performs a case-sensitive match when searching for delim. * + * @group string_funcs + */ + def substring_index(str: Column, delim: String, count: Int): Column = + SubstringIndex(str.expr, lit(delim).expr, lit(count).expr) + + /** + * Locate the position of the first occurrence of substr. * NOTE: The position is not zero based, but 1 based index, returns 0 if substr * could not be found in str. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index b7f073cccb..628da95298 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -163,6 +163,63 @@ class StringFunctionsSuite extends QueryTest { Row(1)) } + test("string substring_index function") { + val df = Seq(("www.apache.org", ".", "zz")).toDF("a", "b", "c") + checkAnswer( + df.select(substring_index($"a", ".", 3)), + Row("www.apache.org")) + checkAnswer( + df.select(substring_index($"a", ".", 2)), + Row("www.apache")) + checkAnswer( + df.select(substring_index($"a", ".", 1)), + Row("www")) + checkAnswer( + df.select(substring_index($"a", ".", 0)), + Row("")) + checkAnswer( + df.select(substring_index(lit("www.apache.org"), ".", -1)), + Row("org")) + checkAnswer( + df.select(substring_index(lit("www.apache.org"), ".", -2)), + Row("apache.org")) + checkAnswer( + df.select(substring_index(lit("www.apache.org"), ".", -3)), + Row("www.apache.org")) + // str is empty string + checkAnswer( + df.select(substring_index(lit(""), ".", 1)), + Row("")) + // empty string delim + checkAnswer( + df.select(substring_index(lit("www.apache.org"), "", 1)), + Row("")) + // delim does not exist in str + checkAnswer( + df.select(substring_index(lit("www.apache.org"), "#", 1)), + Row("www.apache.org")) + // delim is 2 chars + checkAnswer( + df.select(substring_index(lit("www||apache||org"), "||", 2)), + Row("www||apache")) + checkAnswer( + df.select(substring_index(lit("www||apache||org"), "||", -2)), + Row("apache||org")) + // null + checkAnswer( + df.select(substring_index(lit(null), "||", 2)), + Row(null)) + checkAnswer( + df.select(substring_index(lit("www.apache.org"), null, 2)), + Row(null)) + // non ascii chars + // scalastyle:off + checkAnswer( + df.selectExpr("""substring_index("大千世界大千世界", "千", 2)"""), + Row("大千世界大")) + // scalastyle:on + } + test("string locate function") { val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") -- cgit v1.2.3