aboutsummaryrefslogtreecommitdiff
path: root/unsafe
diff options
context:
space:
mode:
authorzhichao.li <zhichao.li@intel.com>2015-07-31 21:18:01 -0700
committerReynold Xin <rxin@databricks.com>2015-07-31 21:18:01 -0700
commit6996bd2e81bf6597dcda499d9a9a80927a43e30f (patch)
tree765e38451f122e762c1e7a8e497f77ab34671131 /unsafe
parent03377d2522776267a07b7d6ae9bddf79a4e0f516 (diff)
downloadspark-6996bd2e81bf6597dcda499d9a9a80927a43e30f.tar.gz
spark-6996bd2e81bf6597dcda499d9a9a80927a43e30f.tar.bz2
spark-6996bd2e81bf6597dcda499d9a9a80927a43e30f.zip
[SPARK-8264][SQL]add substring_index function
This PR is based on #7533 , thanks to zhichao-li Closes #7533 Author: zhichao.li <zhichao.li@intel.com> Author: Davies Liu <davies@databricks.com> Closes #7843 from davies/str_index and squashes the following commits: 391347b [Davies Liu] add python api 3ce7802 [Davies Liu] fix substringIndex f2d29a1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into str_index 515519b [zhichao.li] add foldable and remove null checking 9546991 [zhichao.li] scala style 67c253a [zhichao.li] hide some apis and clean code b19b013 [zhichao.li] add codegen and clean code ac863e9 [zhichao.li] reduce the calling of numChars 12e108f [zhichao.li] refine unittest d92951b [zhichao.li] add lastIndexOf 52d7b03 [zhichao.li] add substring_index function
Diffstat (limited to 'unsafe')
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java80
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java38
2 files changed, 117 insertions, 1 deletions
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 9d4998fd48..2561c1c2a1 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -198,7 +198,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
*/
public UTF8String substring(final int start, final int until) {
if (until <= start || start >= numBytes) {
- return fromBytes(new byte[0]);
+ return UTF8String.EMPTY_UTF8;
}
int i = 0;
@@ -407,6 +407,84 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
}
/**
+ * Find the `str` from left to right.
+ */
+ private int find(UTF8String str, int start) {
+ assert (str.numBytes > 0);
+ while (start <= numBytes - str.numBytes) {
+ if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) {
+ return start;
+ }
+ start += 1;
+ }
+ return -1;
+ }
+
+ /**
+ * Find the `str` from right to left.
+ */
+ private int rfind(UTF8String str, int start) {
+ assert (str.numBytes > 0);
+ while (start >= 0) {
+ if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) {
+ return start;
+ }
+ start -= 1;
+ }
+ return -1;
+ }
+
+ /**
+ * Returns the substring from string str before count occurrences of the delimiter delim.
+ * If count is positive, everything the left of the final delimiter (counting from left) is
+ * returned. If count is negative, every to the right of the final delimiter (counting from the
+ * right) is returned. subStringIndex performs a case-sensitive match when searching for delim.
+ */
+ public UTF8String subStringIndex(UTF8String delim, int count) {
+ if (delim.numBytes == 0 || count == 0) {
+ return EMPTY_UTF8;
+ }
+ if (count > 0) {
+ int idx = -1;
+ while (count > 0) {
+ idx = find(delim, idx + 1);
+ if (idx >= 0) {
+ count --;
+ } else {
+ // can not find enough delim
+ return this;
+ }
+ }
+ if (idx == 0) {
+ return EMPTY_UTF8;
+ }
+ byte[] bytes = new byte[idx];
+ copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx);
+ return fromBytes(bytes);
+
+ } else {
+ int idx = numBytes - delim.numBytes + 1;
+ count = -count;
+ while (count > 0) {
+ idx = rfind(delim, idx - 1);
+ if (idx >= 0) {
+ count --;
+ } else {
+ // can not find enough delim
+ return this;
+ }
+ }
+ if (idx + delim.numBytes == numBytes) {
+ return EMPTY_UTF8;
+ }
+ int size = numBytes - delim.numBytes - idx;
+ byte[] bytes = new byte[size];
+ copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size);
+ return fromBytes(bytes);
+ }
+ }
+
+ /**
* Returns str, right-padded with pad to a length of len
* For example:
* ('hi', 5, '??') =&gt; 'hi???'
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index c565210872..43eed70632 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -241,6 +241,44 @@ public class UTF8StringSuite {
}
@Test
+ public void substring_index() {
+ assertEquals(fromString("www.apache.org"),
+ fromString("www.apache.org").subStringIndex(fromString("."), 3));
+ assertEquals(fromString("www.apache"),
+ fromString("www.apache.org").subStringIndex(fromString("."), 2));
+ assertEquals(fromString("www"),
+ fromString("www.apache.org").subStringIndex(fromString("."), 1));
+ assertEquals(fromString(""),
+ fromString("www.apache.org").subStringIndex(fromString("."), 0));
+ assertEquals(fromString("org"),
+ fromString("www.apache.org").subStringIndex(fromString("."), -1));
+ assertEquals(fromString("apache.org"),
+ fromString("www.apache.org").subStringIndex(fromString("."), -2));
+ assertEquals(fromString("www.apache.org"),
+ fromString("www.apache.org").subStringIndex(fromString("."), -3));
+ // str is empty string
+ assertEquals(fromString(""),
+ fromString("").subStringIndex(fromString("."), 1));
+ // empty string delim
+ assertEquals(fromString(""),
+ fromString("www.apache.org").subStringIndex(fromString(""), 1));
+ // delim does not exist in str
+ assertEquals(fromString("www.apache.org"),
+ fromString("www.apache.org").subStringIndex(fromString("#"), 2));
+ // delim is 2 chars
+ assertEquals(fromString("www||apache"),
+ fromString("www||apache||org").subStringIndex(fromString("||"), 2));
+ assertEquals(fromString("apache||org"),
+ fromString("www||apache||org").subStringIndex(fromString("||"), -2));
+ // non ascii chars
+ assertEquals(fromString("大千世界大"),
+ fromString("大千世界大千世界").subStringIndex(fromString("千"), 2));
+ // overlapped delim
+ assertEquals(fromString("||"), fromString("||||||").subStringIndex(fromString("|||"), 3));
+ assertEquals(fromString("|||"), fromString("||||||").subStringIndex(fromString("|||"), -4));
+ }
+
+ @Test
public void reverse() {
assertEquals(fromString("olleh"), fromString("hello").reverse());
assertEquals(EMPTY_UTF8, EMPTY_UTF8.reverse());