aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala12
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java10
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java8
3 files changed, 25 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 6608036f01..e42be85367 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -593,17 +593,19 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2
* Returns a n spaces string.
*/
case class StringSpace(child: Expression)
- extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback {
+ extends UnaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(IntegerType)
override def nullSafeEval(s: Any): Any = {
- val length = s.asInstanceOf[Integer]
+ val length = s.asInstanceOf[Int]
+ UTF8String.blankString(if (length < 0) 0 else length)
+ }
- val spaces = new Array[Byte](if (length < 0) 0 else length)
- java.util.Arrays.fill(spaces, ' '.asInstanceOf[Byte])
- UTF8String.fromBytes(spaces)
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, (length) =>
+ s"""${ev.primitive} = UTF8String.blankString(($length < 0) ? 0 : $length);""")
}
override def prettyName: String = "space"
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 3eecd657e6..819639f300 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -20,6 +20,7 @@ package org.apache.spark.unsafe.types;
import javax.annotation.Nonnull;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
+import java.util.Arrays;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
@@ -77,6 +78,15 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
}
}
+ /**
+ * Creates an UTF8String that contains `length` spaces.
+ */
+ public static UTF8String blankString(int length) {
+ byte[] spaces = new byte[length];
+ Arrays.fill(spaces, (byte) ' ');
+ return fromBytes(spaces);
+ }
+
protected UTF8String(Object base, long offset, int size) {
this.base = base;
this.offset = offset;
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 7d0c49e2fb..6a21c27461 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -286,4 +286,12 @@ public class UTF8StringSuite {
assertEquals(
UTF8String.fromString("世界千世").levenshteinDistance(UTF8String.fromString("千a世b")),4);
}
+
+ @Test
+ public void createBlankString() {
+ assertEquals(fromString(" "), blankString(1));
+ assertEquals(fromString(" "), blankString(2));
+ assertEquals(fromString(" "), blankString(3));
+ assertEquals(fromString(""), blankString(0));
+ }
}