aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2015-07-27 23:02:23 -0700
committerDavies Liu <davies.liu@gmail.com>2015-07-27 23:02:23 -0700
commit9c5612f4e197dec82a5eac9542896d6216a866b7 (patch)
tree197cef432df57b209d0b9a2eb247ea839a175f15
parent60f08c7c8775c0462b74bc65b41397be6eb24b6d (diff)
downloadspark-9c5612f4e197dec82a5eac9542896d6216a866b7.tar.gz
spark-9c5612f4e197dec82a5eac9542896d6216a866b7.tar.bz2
spark-9c5612f4e197dec82a5eac9542896d6216a866b7.zip
[MINOR] [SQL] Support mutable expression unit test with codegen projection
This is actually contains 3 minor issues: 1) Enable the unit test(codegen) for mutable expressions (FormatNumber, Regexp_Replace/Regexp_Extract) 2) Use the `PlatformDependent.copyMemory` instead of the `System.arrayCopy` Author: Cheng Hao <hao.cheng@intel.com> Closes #7566 from chenghao-intel/codegen_ut and squashes the following commits: 24f43ea [Cheng Hao] enable codegen for mutable expression & UTF8String performance
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala34
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java32
3 files changed, 41 insertions, 26 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 38b0fb37de..edfffbc01c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -777,7 +777,6 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
override def dataType: DataType = IntegerType
-
protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any =
leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 0f9c986f64..8e0ea76d15 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -57,19 +57,27 @@ class StringFunctionsSuite extends QueryTest {
}
test("string regex_replace / regex_extract") {
- val df = Seq(("100-200", "")).toDF("a", "b")
+ val df = Seq(
+ ("100-200", "(\\d+)-(\\d+)", "300"),
+ ("100-200", "(\\d+)-(\\d+)", "400"),
+ ("100-200", "(\\d+)", "400")).toDF("a", "b", "c")
checkAnswer(
df.select(
regexp_replace($"a", "(\\d+)", "num"),
regexp_extract($"a", "(\\d+)-(\\d+)", 1)),
- Row("num-num", "100"))
-
- checkAnswer(
- df.selectExpr(
- "regexp_replace(a, '(\\d+)', 'num')",
- "regexp_extract(a, '(\\d+)-(\\d+)', 2)"),
- Row("num-num", "200"))
+ Row("num-num", "100") :: Row("num-num", "100") :: Row("num-num", "100") :: Nil)
+
+ // for testing the mutable state of the expression in code gen.
+ // This is a hack way to enable the codegen, thus the codegen is enable by default,
+ // it will still use the interpretProjection if projection followed by a LocalRelation,
+ // hence we add a filter operator.
+ // See the optimizer rule `ConvertToLocalRelation`
+ checkAnswer(
+ df.filter("isnotnull(a)").selectExpr(
+ "regexp_replace(a, b, c)",
+ "regexp_extract(a, b, 1)"),
+ Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil)
}
test("string ascii function") {
@@ -290,5 +298,15 @@ class StringFunctionsSuite extends QueryTest {
df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable
Row("5.0000"))
}
+
+ // for testing the mutable state of the expression in code gen.
+ // This is a hack way to enable the codegen, thus the codegen is enable by default,
+ // it will still use the interpretProjection if projection follows by a LocalRelation,
+ // hence we add a filter operator.
+ // See the optimizer rule `ConvertToLocalRelation`
+ val df2 = Seq((5L, 4), (4L, 3), (3L, 2)).toDF("a", "b")
+ checkAnswer(
+ df2.filter("b>0").selectExpr("format_number(a, b)"),
+ Row("5.0000") :: Row("4.000") :: Row("3.00") :: Nil)
}
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 85381cf0ef..3e1cc67dbf 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -300,13 +300,13 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
}
public UTF8String reverse() {
- byte[] bytes = getBytes();
- byte[] result = new byte[bytes.length];
+ byte[] result = new byte[this.numBytes];
int i = 0; // position in byte
while (i < numBytes) {
int len = numBytesForFirstByte(getByte(i));
- System.arraycopy(bytes, i, result, result.length - i - len, len);
+ copyMemory(this.base, this.offset + i, result,
+ BYTE_ARRAY_OFFSET + result.length - i - len, len);
i += len;
}
@@ -316,11 +316,11 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
public UTF8String repeat(int times) {
if (times <=0) {
- return fromBytes(new byte[0]);
+ return EMPTY_UTF8;
}
byte[] newBytes = new byte[numBytes * times];
- System.arraycopy(getBytes(), 0, newBytes, 0, numBytes);
+ copyMemory(this.base, this.offset, newBytes, BYTE_ARRAY_OFFSET, numBytes);
int copied = 1;
while (copied < times) {
@@ -385,16 +385,15 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
UTF8String remain = pad.substring(0, spaces - padChars * count);
byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes];
- System.arraycopy(getBytes(), 0, data, 0, this.numBytes);
+ copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET, this.numBytes);
int offset = this.numBytes;
int idx = 0;
- byte[] padBytes = pad.getBytes();
while (idx < count) {
- System.arraycopy(padBytes, 0, data, offset, pad.numBytes);
+ copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes);
++idx;
offset += pad.numBytes;
}
- System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes);
+ copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes);
return UTF8String.fromBytes(data);
}
@@ -421,15 +420,14 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
int offset = 0;
int idx = 0;
- byte[] padBytes = pad.getBytes();
while (idx < count) {
- System.arraycopy(padBytes, 0, data, offset, pad.numBytes);
+ copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes);
++idx;
offset += pad.numBytes;
}
- System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes);
+ copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes);
offset += remain.numBytes;
- System.arraycopy(getBytes(), 0, data, offset, numBytes());
+ copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET + offset, numBytes());
return UTF8String.fromBytes(data);
}
@@ -454,9 +452,9 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
int offset = 0;
for (int i = 0; i < inputs.length; i++) {
int len = inputs[i].numBytes;
- PlatformDependent.copyMemory(
+ copyMemory(
inputs[i].base, inputs[i].offset,
- result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
+ result, BYTE_ARRAY_OFFSET + offset,
len);
offset += len;
}
@@ -494,7 +492,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
for (int i = 0, j = 0; i < inputs.length; i++) {
if (inputs[i] != null) {
int len = inputs[i].numBytes;
- PlatformDependent.copyMemory(
+ copyMemory(
inputs[i].base, inputs[i].offset,
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
len);
@@ -503,7 +501,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
j++;
// Add separator if this is not the last input.
if (j < numInputs) {
- PlatformDependent.copyMemory(
+ copyMemory(
separator.base, separator.offset,
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
separator.numBytes);