aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzhichao.li <zhichao.li@intel.com>2015-07-31 21:18:01 -0700
committerReynold Xin <rxin@databricks.com>2015-07-31 21:18:01 -0700
commit6996bd2e81bf6597dcda499d9a9a80927a43e30f (patch)
tree765e38451f122e762c1e7a8e497f77ab34671131
parent03377d2522776267a07b7d6ae9bddf79a4e0f516 (diff)
downloadspark-6996bd2e81bf6597dcda499d9a9a80927a43e30f.tar.gz
spark-6996bd2e81bf6597dcda499d9a9a80927a43e30f.tar.bz2
spark-6996bd2e81bf6597dcda499d9a9a80927a43e30f.zip
[SPARK-8264][SQL]add substring_index function
This PR is based on #7533 , thanks to zhichao-li Closes #7533 Author: zhichao.li <zhichao.li@intel.com> Author: Davies Liu <davies@databricks.com> Closes #7843 from davies/str_index and squashes the following commits: 391347b [Davies Liu] add python api 3ce7802 [Davies Liu] fix substringIndex f2d29a1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into str_index 515519b [zhichao.li] add foldable and remove null checking 9546991 [zhichao.li] scala style 67c253a [zhichao.li] hide some apis and clean code b19b013 [zhichao.li] add codegen and clean code ac863e9 [zhichao.li] reduce the calling of numChars 12e108f [zhichao.li] refine unittest d92951b [zhichao.li] add lastIndexOf 52d7b03 [zhichao.li] add substring_index function
-rw-r--r--python/pyspark/sql/functions.py19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala25
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala57
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java80
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java38
8 files changed, 261 insertions, 2 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index bb9926ce8c..89a2a5ceaa 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -921,6 +921,25 @@ def trunc(date, format):
@since(1.5)
+@ignore_unicode_prefix
+def substring_index(str, delim, count):
+ """
+ Returns the substring from string str before count occurrences of the delimiter delim.
+ If count is positive, everything the left of the final delimiter (counting from left) is
+ returned. If count is negative, every to the right of the final delimiter (counting from the
+ right) is returned. substring_index performs a case-sensitive match when searching for delim.
+
+ >>> df = sqlContext.createDataFrame([('a.b.c.d',)], ['s'])
+ >>> df.select(substring_index(df.s, '.', 2).alias('s')).collect()
+ [Row(s=u'a.b')]
+ >>> df.select(substring_index(df.s, '.', -3).alias('s')).collect()
+ [Row(s=u'b.c.d')]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.substring_index(_to_java_column(str), delim, count))
+
+
+@since(1.5)
def size(col):
"""
Collection function: returns the length of the array or map stored in the column.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 3f61a9af1f..ee44cbcba6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -199,6 +199,7 @@ object FunctionRegistry {
expression[StringSplit]("split"),
expression[Substring]("substr"),
expression[Substring]("substring"),
+ expression[SubstringIndex]("substring_index"),
expression[StringTrim]("trim"),
expression[UnBase64]("unbase64"),
expression[Upper]("ucase"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 160e72f384..5dd387a418 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -422,6 +422,31 @@ case class StringInstr(str: Expression, substr: Expression)
}
/**
+ * Returns the substring from string str before count occurrences of the delimiter delim.
+ * If count is positive, everything the left of the final delimiter (counting from left) is
+ * returned. If count is negative, every to the right of the final delimiter (counting from the
+ * right) is returned. substring_index performs a case-sensitive match when searching for delim.
+ */
+case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
+ extends TernaryExpression with ImplicitCastInputTypes {
+
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
+ override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr)
+ override def prettyName: String = "substring_index"
+
+ override def nullSafeEval(str: Any, delim: Any, count: Any): Any = {
+ str.asInstanceOf[UTF8String].subStringIndex(
+ delim.asInstanceOf[UTF8String],
+ count.asInstanceOf[Int])
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)")
+ }
+}
+
+/**
* A function that returns the position of the first occurrence of substr
* in given string after position pos.
*/
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index fb72fe1714..ad87ab36fd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -187,6 +188,36 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(s.substring(0), "example", row)
}
+ test("string substring_index function") {
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(1)), "www")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(0)), "")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-1)), "org")
+ checkEvaluation(
+ SubstringIndex(Literal(""), Literal("."), Literal(-2)), "")
+ checkEvaluation(
+ SubstringIndex(Literal.create(null, StringType), Literal("."), Literal(-2)), null)
+ checkEvaluation(SubstringIndex(
+ Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null)
+ // non ascii chars
+ // scalastyle:off
+ checkEvaluation(
+ SubstringIndex(Literal("大千世界大千世界"), Literal( "千"), Literal(2)), "大千世界大")
+ // scalastyle:on
+ checkEvaluation(
+ SubstringIndex(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache")
+ }
+
test("LIKE literal Regular Expression") {
checkEvaluation(Literal.create(null, StringType).like("a"), null)
checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 89ffa9c50d..57bb00a741 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1788,8 +1788,18 @@ object functions {
def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr)
/**
- * Locate the position of the first occurrence of substr in a string column.
+ * Returns the substring from string str before count occurrences of the delimiter delim.
+ * If count is positive, everything the left of the final delimiter (counting from left) is
+ * returned. If count is negative, every to the right of the final delimiter (counting from the
+ * right) is returned. substring_index performs a case-sensitive match when searching for delim.
*
+ * @group string_funcs
+ */
+ def substring_index(str: Column, delim: String, count: Int): Column =
+ SubstringIndex(str.expr, lit(delim).expr, lit(count).expr)
+
+ /**
+ * Locate the position of the first occurrence of substr.
* NOTE: The position is not zero based, but 1 based index, returns 0 if substr
* could not be found in str.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index b7f073cccb..628da95298 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -163,6 +163,63 @@ class StringFunctionsSuite extends QueryTest {
Row(1))
}
+ test("string substring_index function") {
+ val df = Seq(("www.apache.org", ".", "zz")).toDF("a", "b", "c")
+ checkAnswer(
+ df.select(substring_index($"a", ".", 3)),
+ Row("www.apache.org"))
+ checkAnswer(
+ df.select(substring_index($"a", ".", 2)),
+ Row("www.apache"))
+ checkAnswer(
+ df.select(substring_index($"a", ".", 1)),
+ Row("www"))
+ checkAnswer(
+ df.select(substring_index($"a", ".", 0)),
+ Row(""))
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), ".", -1)),
+ Row("org"))
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), ".", -2)),
+ Row("apache.org"))
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), ".", -3)),
+ Row("www.apache.org"))
+ // str is empty string
+ checkAnswer(
+ df.select(substring_index(lit(""), ".", 1)),
+ Row(""))
+ // empty string delim
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), "", 1)),
+ Row(""))
+ // delim does not exist in str
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), "#", 1)),
+ Row("www.apache.org"))
+ // delim is 2 chars
+ checkAnswer(
+ df.select(substring_index(lit("www||apache||org"), "||", 2)),
+ Row("www||apache"))
+ checkAnswer(
+ df.select(substring_index(lit("www||apache||org"), "||", -2)),
+ Row("apache||org"))
+ // null
+ checkAnswer(
+ df.select(substring_index(lit(null), "||", 2)),
+ Row(null))
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), null, 2)),
+ Row(null))
+ // non ascii chars
+ // scalastyle:off
+ checkAnswer(
+ df.selectExpr("""substring_index("大千世界大千世界", "千", 2)"""),
+ Row("大千世界大"))
+ // scalastyle:on
+ }
+
test("string locate function") {
val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d")
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 9d4998fd48..2561c1c2a1 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -198,7 +198,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
*/
public UTF8String substring(final int start, final int until) {
if (until <= start || start >= numBytes) {
- return fromBytes(new byte[0]);
+ return UTF8String.EMPTY_UTF8;
}
int i = 0;
@@ -407,6 +407,84 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
}
/**
+ * Find the `str` from left to right.
+ */
+ private int find(UTF8String str, int start) {
+ assert (str.numBytes > 0);
+ while (start <= numBytes - str.numBytes) {
+ if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) {
+ return start;
+ }
+ start += 1;
+ }
+ return -1;
+ }
+
+ /**
+ * Find the `str` from right to left.
+ */
+ private int rfind(UTF8String str, int start) {
+ assert (str.numBytes > 0);
+ while (start >= 0) {
+ if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) {
+ return start;
+ }
+ start -= 1;
+ }
+ return -1;
+ }
+
+ /**
+ * Returns the substring from string str before count occurrences of the delimiter delim.
+ * If count is positive, everything the left of the final delimiter (counting from left) is
+ * returned. If count is negative, every to the right of the final delimiter (counting from the
+ * right) is returned. subStringIndex performs a case-sensitive match when searching for delim.
+ */
+ public UTF8String subStringIndex(UTF8String delim, int count) {
+ if (delim.numBytes == 0 || count == 0) {
+ return EMPTY_UTF8;
+ }
+ if (count > 0) {
+ int idx = -1;
+ while (count > 0) {
+ idx = find(delim, idx + 1);
+ if (idx >= 0) {
+ count --;
+ } else {
+ // can not find enough delim
+ return this;
+ }
+ }
+ if (idx == 0) {
+ return EMPTY_UTF8;
+ }
+ byte[] bytes = new byte[idx];
+ copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx);
+ return fromBytes(bytes);
+
+ } else {
+ int idx = numBytes - delim.numBytes + 1;
+ count = -count;
+ while (count > 0) {
+ idx = rfind(delim, idx - 1);
+ if (idx >= 0) {
+ count --;
+ } else {
+ // can not find enough delim
+ return this;
+ }
+ }
+ if (idx + delim.numBytes == numBytes) {
+ return EMPTY_UTF8;
+ }
+ int size = numBytes - delim.numBytes - idx;
+ byte[] bytes = new byte[size];
+ copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size);
+ return fromBytes(bytes);
+ }
+ }
+
+ /**
* Returns str, right-padded with pad to a length of len
* For example:
* ('hi', 5, '??') =&gt; 'hi???'
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index c565210872..43eed70632 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -241,6 +241,44 @@ public class UTF8StringSuite {
}
@Test
+ public void substring_index() {
+ assertEquals(fromString("www.apache.org"),
+ fromString("www.apache.org").subStringIndex(fromString("."), 3));
+ assertEquals(fromString("www.apache"),
+ fromString("www.apache.org").subStringIndex(fromString("."), 2));
+ assertEquals(fromString("www"),
+ fromString("www.apache.org").subStringIndex(fromString("."), 1));
+ assertEquals(fromString(""),
+ fromString("www.apache.org").subStringIndex(fromString("."), 0));
+ assertEquals(fromString("org"),
+ fromString("www.apache.org").subStringIndex(fromString("."), -1));
+ assertEquals(fromString("apache.org"),
+ fromString("www.apache.org").subStringIndex(fromString("."), -2));
+ assertEquals(fromString("www.apache.org"),
+ fromString("www.apache.org").subStringIndex(fromString("."), -3));
+ // str is empty string
+ assertEquals(fromString(""),
+ fromString("").subStringIndex(fromString("."), 1));
+ // empty string delim
+ assertEquals(fromString(""),
+ fromString("www.apache.org").subStringIndex(fromString(""), 1));
+ // delim does not exist in str
+ assertEquals(fromString("www.apache.org"),
+ fromString("www.apache.org").subStringIndex(fromString("#"), 2));
+ // delim is 2 chars
+ assertEquals(fromString("www||apache"),
+ fromString("www||apache||org").subStringIndex(fromString("||"), 2));
+ assertEquals(fromString("apache||org"),
+ fromString("www||apache||org").subStringIndex(fromString("||"), -2));
+ // non ascii chars
+ assertEquals(fromString("大千世界大"),
+ fromString("大千世界大千世界").subStringIndex(fromString("千"), 2));
+ // overlapped delim
+ assertEquals(fromString("||"), fromString("||||||").subStringIndex(fromString("|||"), 3));
+ assertEquals(fromString("|||"), fromString("||||||").subStringIndex(fromString("|||"), -4));
+ }
+
+ @Test
public void reverse() {
assertEquals(fromString("olleh"), fromString("hello").reverse());
assertEquals(EMPTY_UTF8, EMPTY_UTF8.reverse());