diff options
author | Reynold Xin <rxin@databricks.com> | 2015-07-19 16:48:47 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-07-19 16:48:47 -0700 |
commit | 163e3f1df94f6b7d3dadb46a87dbb3a2bade3f95 (patch) | |
tree | da58016b49a92c5b0318ace791844194b3838324 /unsafe/src | |
parent | 7a81245345f2d6124423161786bb0d9f1c278ab8 (diff) | |
download | spark-163e3f1df94f6b7d3dadb46a87dbb3a2bade3f95.tar.gz spark-163e3f1df94f6b7d3dadb46a87dbb3a2bade3f95.tar.bz2 spark-163e3f1df94f6b7d3dadb46a87dbb3a2bade3f95.zip |
[SPARK-8241][SQL] string function: concat_ws.
I also changed the semantics of concat w.r.t. null back to the same behavior as Hive.
That is to say, concat now returns null if any input is null.
Author: Reynold Xin <rxin@databricks.com>
Closes #7504 from rxin/concat_ws and squashes the following commits:
83fd950 [Reynold Xin] Fixed type casting.
3ae85f7 [Reynold Xin] Write null better.
cdc7be6 [Reynold Xin] Added code generation for pure string mode.
a61c4e4 [Reynold Xin] Updated comments.
2d51406 [Reynold Xin] [SPARK-8241][SQL] string function: concat_ws.
Diffstat (limited to 'unsafe/src')
-rw-r--r-- | unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java | 58 | ||||
-rw-r--r-- | unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java | 62 |
2 files changed, 102 insertions, 18 deletions
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 9723b6e083..3eecd657e6 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -397,19 +397,16 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { } /** - * Concatenates input strings together into a single string. A null input is skipped. - * For example, concat("a", null, "c") would yield "ac". + * Concatenates input strings together into a single string. Returns null if any input is null. */ public static UTF8String concat(UTF8String... inputs) { - if (inputs == null) { - return fromBytes(new byte[0]); - } - // Compute the total length of the result. int totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { totalLength += inputs[i].numBytes; + } else { + return null; } } @@ -417,6 +414,45 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { final byte[] result = new byte[totalLength]; int offset = 0; for (int i = 0; i < inputs.length; i++) { + int len = inputs[i].numBytes; + PlatformDependent.copyMemory( + inputs[i].base, inputs[i].offset, + result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + len); + offset += len; + } + return fromBytes(result); + } + + /** + * Concatenates input strings together into a single string using the separator. + * A null input is skipped. For example, concat(",", "a", null, "c") would yield "a,c". + */ + public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { + if (separator == null) { + return null; + } + + int numInputBytes = 0; // total number of bytes from the inputs + int numInputs = 0; // number of non-null inputs + for (int i = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + numInputBytes += inputs[i].numBytes; + numInputs++; + } + } + + if (numInputs == 0) { + // Return an empty string if there is no input, or all the inputs are null. + return fromBytes(new byte[0]); + } + + // Allocate a new byte array, and copy the inputs one by one into it. + // The size of the new array is the size of all inputs, plus the separators. + final byte[] result = new byte[numInputBytes + (numInputs - 1) * separator.numBytes]; + int offset = 0; + + for (int i = 0, j = 0; i < inputs.length; i++) { if (inputs[i] != null) { int len = inputs[i].numBytes; PlatformDependent.copyMemory( @@ -424,6 +460,16 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, len); offset += len; + + j++; + // Add separator if this is not the last input. + if (j < numInputs) { + PlatformDependent.copyMemory( + separator.base, separator.offset, + result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + separator.numBytes); + offset += separator.numBytes; + } } } return fromBytes(result); diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 0db7522b50..7d0c49e2fb 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -88,16 +88,50 @@ public class UTF8StringSuite { @Test public void concatTest() { - assertEquals(concat(), fromString("")); - assertEquals(concat(null), fromString("")); - assertEquals(concat(fromString("")), fromString("")); - assertEquals(concat(fromString("ab")), fromString("ab")); - assertEquals(concat(fromString("a"), fromString("b")), fromString("ab")); - assertEquals(concat(fromString("a"), fromString("b"), fromString("c")), fromString("abc")); - assertEquals(concat(fromString("a"), null, fromString("c")), fromString("ac")); - assertEquals(concat(fromString("a"), null, null), fromString("a")); - assertEquals(concat(null, null, null), fromString("")); - assertEquals(concat(fromString("数据"), fromString("砖头")), fromString("数据砖头")); + assertEquals(fromString(""), concat()); + assertEquals(null, concat((UTF8String) null)); + assertEquals(fromString(""), concat(fromString(""))); + assertEquals(fromString("ab"), concat(fromString("ab"))); + assertEquals(fromString("ab"), concat(fromString("a"), fromString("b"))); + assertEquals(fromString("abc"), concat(fromString("a"), fromString("b"), fromString("c"))); + assertEquals(null, concat(fromString("a"), null, fromString("c"))); + assertEquals(null, concat(fromString("a"), null, null)); + assertEquals(null, concat(null, null, null)); + assertEquals(fromString("数据砖头"), concat(fromString("数据"), fromString("砖头"))); + } + + @Test + public void concatWsTest() { + // Returns null if the separator is null + assertEquals(null, concatWs(null, (UTF8String)null)); + assertEquals(null, concatWs(null, fromString("a"))); + + // If separator is null, concatWs should skip all null inputs and never return null. + UTF8String sep = fromString("哈哈"); + assertEquals( + fromString(""), + concatWs(sep, fromString(""))); + assertEquals( + fromString("ab"), + concatWs(sep, fromString("ab"))); + assertEquals( + fromString("a哈哈b"), + concatWs(sep, fromString("a"), fromString("b"))); + assertEquals( + fromString("a哈哈b哈哈c"), + concatWs(sep, fromString("a"), fromString("b"), fromString("c"))); + assertEquals( + fromString("a哈哈c"), + concatWs(sep, fromString("a"), null, fromString("c"))); + assertEquals( + fromString("a"), + concatWs(sep, fromString("a"), null, null)); + assertEquals( + fromString(""), + concatWs(sep, null, null, null)); + assertEquals( + fromString("数据哈哈砖头"), + concatWs(sep, fromString("数据"), fromString("砖头"))); } @Test @@ -215,14 +249,18 @@ public class UTF8StringSuite { assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????"))); assertEquals(fromString("孙行数据砖头"), fromString("数据砖头").lpad(6, fromString("孙行者"))); assertEquals(fromString("孙行者数据砖头"), fromString("数据砖头").lpad(7, fromString("孙行者"))); - assertEquals(fromString("孙行者孙行者孙行数据砖头"), fromString("数据砖头").lpad(12, fromString("孙行者"))); + assertEquals( + fromString("孙行者孙行者孙行数据砖头"), + fromString("数据砖头").lpad(12, fromString("孙行者"))); assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, fromString("????"))); assertEquals(fromString("数据砖头?"), fromString("数据砖头").rpad(5, fromString("????"))); assertEquals(fromString("数据砖头??"), fromString("数据砖头").rpad(6, fromString("????"))); assertEquals(fromString("数据砖头孙行"), fromString("数据砖头").rpad(6, fromString("孙行者"))); assertEquals(fromString("数据砖头孙行者"), fromString("数据砖头").rpad(7, fromString("孙行者"))); - assertEquals(fromString("数据砖头孙行者孙行者孙行"), fromString("数据砖头").rpad(12, fromString("孙行者"))); + assertEquals( + fromString("数据砖头孙行者孙行者孙行"), + fromString("数据砖头").rpad(12, fromString("孙行者"))); } @Test |