aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala306
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala138
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala353
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala132
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java191
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java94
7 files changed, 1202 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 5c25181e1c..f62d79f8ce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -147,12 +147,24 @@ object FunctionRegistry {
expression[Base64]("base64"),
expression[Encode]("encode"),
expression[Decode]("decode"),
+ expression[StringInstr]("instr"),
expression[Lower]("lcase"),
expression[Lower]("lower"),
expression[StringLength]("length"),
expression[Levenshtein]("levenshtein"),
+ expression[StringLocate]("locate"),
+ expression[StringLPad]("lpad"),
+ expression[StringTrimLeft]("ltrim"),
+ expression[StringFormat]("printf"),
+ expression[StringRPad]("rpad"),
+ expression[StringRepeat]("repeat"),
+ expression[StringReverse]("reverse"),
+ expression[StringTrimRight]("rtrim"),
+ expression[StringSpace]("space"),
+ expression[StringSplit]("split"),
expression[Substring]("substr"),
expression[Substring]("substring"),
+ expression[StringTrim]("trim"),
expression[UnBase64]("unbase64"),
expression[Upper]("ucase"),
expression[Unhex]("unhex"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 57f436485b..f64899c1ed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import java.util.Locale
import java.util.regex.Pattern
import org.apache.commons.lang3.StringUtils
@@ -104,7 +105,7 @@ case class RLike(left: Expression, right: Expression)
override def toString: String = s"$left RLIKE $right"
}
-trait CaseConversionExpression extends ExpectsInputTypes {
+trait String2StringExpression extends ExpectsInputTypes {
self: UnaryExpression =>
def convert(v: UTF8String): UTF8String
@@ -119,7 +120,7 @@ trait CaseConversionExpression extends ExpectsInputTypes {
/**
* A function that converts the characters of a string to uppercase.
*/
-case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression {
+case class Upper(child: Expression) extends UnaryExpression with String2StringExpression {
override def convert(v: UTF8String): UTF8String = v.toUpperCase
@@ -131,7 +132,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
/**
* A function that converts the characters of a string to lowercase.
*/
-case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression {
+case class Lower(child: Expression) extends UnaryExpression with String2StringExpression {
override def convert(v: UTF8String): UTF8String = v.toLowerCase
@@ -188,6 +189,301 @@ case class EndsWith(left: Expression, right: Expression)
}
/**
+ * A function that trim the spaces from both ends for the specified string.
+ */
+case class StringTrim(child: Expression)
+ extends UnaryExpression with String2StringExpression {
+
+ def convert(v: UTF8String): UTF8String = v.trim()
+
+ override def prettyName: String = "trim"
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ defineCodeGen(ctx, ev, c => s"($c).trim()")
+ }
+}
+
+/**
+ * A function that trim the spaces from left end for given string.
+ */
+case class StringTrimLeft(child: Expression)
+ extends UnaryExpression with String2StringExpression {
+
+ def convert(v: UTF8String): UTF8String = v.trimLeft()
+
+ override def prettyName: String = "ltrim"
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ defineCodeGen(ctx, ev, c => s"($c).trimLeft()")
+ }
+}
+
+/**
+ * A function that trim the spaces from right end for given string.
+ */
+case class StringTrimRight(child: Expression)
+ extends UnaryExpression with String2StringExpression {
+
+ def convert(v: UTF8String): UTF8String = v.trimRight()
+
+ override def prettyName: String = "rtrim"
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ defineCodeGen(ctx, ev, c => s"($c).trimRight()")
+ }
+}
+
+/**
+ * A function that returns the position of the first occurrence of substr in the given string.
+ * Returns null if either of the arguments are null and
+ * returns 0 if substr could not be found in str.
+ *
+ * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1.
+ */
+case class StringInstr(str: Expression, substr: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ override def left: Expression = str
+ override def right: Expression = substr
+ override def dataType: DataType = IntegerType
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
+
+ override def nullSafeEval(string: Any, sub: Any): Any = {
+ string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 1
+ }
+
+ override def prettyName: String = "instr"
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ defineCodeGen(ctx, ev, (l, r) =>
+ s"($l).indexOf($r, 0) + 1")
+ }
+}
+
+/**
+ * A function that returns the position of the first occurrence of substr
+ * in given string after position pos.
+ */
+case class StringLocate(substr: Expression, str: Expression, start: Expression)
+ extends Expression with ExpectsInputTypes {
+
+ def this(substr: Expression, str: Expression) = {
+ this(substr, str, Literal(0))
+ }
+
+ override def children: Seq[Expression] = substr :: str :: start :: Nil
+ override def foldable: Boolean = children.forall(_.foldable)
+ override def nullable: Boolean = substr.nullable || str.nullable
+ override def dataType: DataType = IntegerType
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
+
+ override def eval(input: InternalRow): Any = {
+ val s = start.eval(input)
+ if (s == null) {
+ // if the start position is null, we need to return 0, (conform to Hive)
+ 0
+ } else {
+ val r = substr.eval(input)
+ if (r == null) {
+ null
+ } else {
+ val l = str.eval(input)
+ if (l == null) {
+ null
+ } else {
+ l.asInstanceOf[UTF8String].indexOf(
+ r.asInstanceOf[UTF8String],
+ s.asInstanceOf[Int]) + 1
+ }
+ }
+ }
+ }
+
+ override def prettyName: String = "locate"
+}
+
+/**
+ * Returns str, left-padded with pad to a length of len.
+ */
+case class StringLPad(str: Expression, len: Expression, pad: Expression)
+ extends Expression with ExpectsInputTypes {
+
+ override def children: Seq[Expression] = str :: len :: pad :: Nil
+ override def foldable: Boolean = children.forall(_.foldable)
+ override def nullable: Boolean = children.exists(_.nullable)
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType)
+
+ override def eval(input: InternalRow): Any = {
+ val s = str.eval(input)
+ if (s == null) {
+ null
+ } else {
+ val l = len.eval(input)
+ if (l == null) {
+ null
+ } else {
+ val p = pad.eval(input)
+ if (p == null) {
+ null
+ } else {
+ val len = l.asInstanceOf[Int]
+ val str = s.asInstanceOf[UTF8String]
+ val pad = p.asInstanceOf[UTF8String]
+
+ str.lpad(len, pad)
+ }
+ }
+ }
+ }
+
+ override def prettyName: String = "lpad"
+}
+
+/**
+ * Returns str, right-padded with pad to a length of len.
+ */
+case class StringRPad(str: Expression, len: Expression, pad: Expression)
+ extends Expression with ExpectsInputTypes {
+
+ override def children: Seq[Expression] = str :: len :: pad :: Nil
+ override def foldable: Boolean = children.forall(_.foldable)
+ override def nullable: Boolean = children.exists(_.nullable)
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType)
+
+ override def eval(input: InternalRow): Any = {
+ val s = str.eval(input)
+ if (s == null) {
+ null
+ } else {
+ val l = len.eval(input)
+ if (l == null) {
+ null
+ } else {
+ val p = pad.eval(input)
+ if (p == null) {
+ null
+ } else {
+ val len = l.asInstanceOf[Int]
+ val str = s.asInstanceOf[UTF8String]
+ val pad = p.asInstanceOf[UTF8String]
+
+ str.rpad(len, pad)
+ }
+ }
+ }
+ }
+
+ override def prettyName: String = "rpad"
+}
+
+/**
+ * Returns the input formatted according do printf-style format strings
+ */
+case class StringFormat(children: Expression*) extends Expression {
+
+ require(children.length >=1, "printf() should take at least 1 argument")
+
+ override def foldable: Boolean = children.forall(_.foldable)
+ override def nullable: Boolean = children(0).nullable
+ override def dataType: DataType = StringType
+ private def format: Expression = children(0)
+ private def args: Seq[Expression] = children.tail
+
+ override def eval(input: InternalRow): Any = {
+ val pattern = format.eval(input)
+ if (pattern == null) {
+ null
+ } else {
+ val sb = new StringBuffer()
+ val formatter = new java.util.Formatter(sb, Locale.US)
+
+ val arglist = args.map(_.eval(input).asInstanceOf[AnyRef])
+ formatter.format(pattern.asInstanceOf[UTF8String].toString(), arglist: _*)
+
+ UTF8String.fromString(sb.toString)
+ }
+ }
+
+ override def prettyName: String = "printf"
+}
+
+/**
+ * Returns the string which repeat the given string value n times.
+ */
+case class StringRepeat(str: Expression, times: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ override def left: Expression = str
+ override def right: Expression = times
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType)
+
+ override def nullSafeEval(string: Any, n: Any): Any = {
+ string.asInstanceOf[UTF8String].repeat(n.asInstanceOf[Integer])
+ }
+
+ override def prettyName: String = "repeat"
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)")
+ }
+}
+
+/**
+ * Returns the reversed given string.
+ */
+case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression {
+ override def convert(v: UTF8String): UTF8String = v.reverse()
+
+ override def prettyName: String = "reverse"
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ defineCodeGen(ctx, ev, c => s"($c).reverse()")
+ }
+}
+
+/**
+ * Returns a n spaces string.
+ */
+case class StringSpace(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[DataType] = Seq(IntegerType)
+
+ override def nullSafeEval(s: Any): Any = {
+ val length = s.asInstanceOf[Integer]
+
+ val spaces = new Array[Byte](if (length < 0) 0 else length)
+ java.util.Arrays.fill(spaces, ' '.asInstanceOf[Byte])
+ UTF8String.fromBytes(spaces)
+ }
+
+ override def prettyName: String = "space"
+}
+
+/**
+ * Splits str around pat (pattern is a regular expression).
+ */
+case class StringSplit(str: Expression, pattern: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ override def left: Expression = str
+ override def right: Expression = pattern
+ override def dataType: DataType = ArrayType(StringType)
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
+
+ override def nullSafeEval(string: Any, regex: Any): Any = {
+ val splits =
+ string.asInstanceOf[UTF8String].toString.split(regex.asInstanceOf[UTF8String].toString, -1)
+ splits.toSeq.map(UTF8String.fromString)
+ }
+
+ override def prettyName: String = "split"
+}
+
+/**
* A function that takes a substring of its first argument starting at a given position.
* Defined for String and Binary types.
*/
@@ -199,8 +495,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
}
override def foldable: Boolean = str.foldable && pos.foldable && len.foldable
-
- override def nullable: Boolean = str.nullable || pos.nullable || len.nullable
+ override def nullable: Boolean = str.nullable || pos.nullable || len.nullable
override def dataType: DataType = {
if (!resolved) {
@@ -373,4 +668,3 @@ case class Encode(value: Expression, charset: Expression)
}
}
-
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
index 69bef1c63e..b19f4ee37a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
@@ -288,4 +288,142 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Levenshtein(Literal("世界千世"), Literal("大a界b")), 4)
// scalastyle:on
}
+
+ test("TRIM/LTRIM/RTRIM") {
+ val s = 'a.string.at(0)
+ checkEvaluation(StringTrim(Literal(" aa ")), "aa", create_row(" abdef "))
+ checkEvaluation(StringTrim(s), "abdef", create_row(" abdef "))
+
+ checkEvaluation(StringTrimLeft(Literal(" aa ")), "aa ", create_row(" abdef "))
+ checkEvaluation(StringTrimLeft(s), "abdef ", create_row(" abdef "))
+
+ checkEvaluation(StringTrimRight(Literal(" aa ")), " aa", create_row(" abdef "))
+ checkEvaluation(StringTrimRight(s), " abdef", create_row(" abdef "))
+
+ // scalastyle:off
+ // non ascii characters are not allowed in the source code, so we disable the scalastyle.
+ checkEvaluation(StringTrimRight(s), " 花花世界", create_row(" 花花世界 "))
+ checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 "))
+ checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 "))
+ // scalastyle:on
+ }
+
+ test("FORMAT") {
+ val f = 'f.string.at(0)
+ val d1 = 'd.int.at(1)
+ val s1 = 's.int.at(2)
+
+ val row1 = create_row("aa%d%s", 12, "cc")
+ val row2 = create_row(null, 12, "cc")
+ checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
+ checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null))
+ checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
+
+ checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1)
+ checkEvaluation(StringFormat(f, d1, s1), null, row2)
+ }
+
+ test("INSTR") {
+ val s1 = 'a.string.at(0)
+ val s2 = 'b.string.at(1)
+ val s3 = 'c.string.at(2)
+ val row1 = create_row("aaads", "aa", "zz")
+
+ checkEvaluation(StringInstr(Literal("aaads"), Literal("aa")), 1, row1)
+ checkEvaluation(StringInstr(Literal("aaads"), Literal("de")), 0, row1)
+ checkEvaluation(StringInstr(Literal.create(null, StringType), Literal("de")), null, row1)
+ checkEvaluation(StringInstr(Literal("aaads"), Literal.create(null, StringType)), null, row1)
+
+ checkEvaluation(StringInstr(s1, s2), 1, row1)
+ checkEvaluation(StringInstr(s1, s3), 0, row1)
+
+ // scalastyle:off
+ // non ascii characters are not allowed in the source code, so we disable the scalastyle.
+ checkEvaluation(StringInstr(s1, s2), 3, create_row("花花世界", "世界"))
+ checkEvaluation(StringInstr(s1, s2), 1, create_row("花花世界", "花"))
+ checkEvaluation(StringInstr(s1, s2), 0, create_row("花花世界", "小"))
+ // scalastyle:on
+ }
+
+ test("LOCATE") {
+ val s1 = 'a.string.at(0)
+ val s2 = 'b.string.at(1)
+ val s3 = 'c.string.at(2)
+ val s4 = 'd.int.at(3)
+ val row1 = create_row("aaads", "aa", "zz", 1)
+
+ checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1)
+ checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1)
+ checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 0, row1)
+ checkEvaluation(new StringLocate(Literal("de"), Literal("aaads")), 0, row1)
+ checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 1), 0, row1)
+
+ checkEvaluation(new StringLocate(s2, s1), 1, row1)
+ checkEvaluation(StringLocate(s2, s1, s4), 2, row1)
+ checkEvaluation(new StringLocate(s3, s1), 0, row1)
+ checkEvaluation(StringLocate(s3, s1, Literal.create(null, IntegerType)), 0, row1)
+ }
+
+ test("LPAD/RPAD") {
+ val s1 = 'a.string.at(0)
+ val s2 = 'b.int.at(1)
+ val s3 = 'c.string.at(2)
+ val row1 = create_row("hi", 5, "??")
+ val row2 = create_row("hi", 1, "?")
+ val row3 = create_row(null, 1, "?")
+
+ checkEvaluation(StringLPad(Literal("hi"), Literal(5), Literal("??")), "???hi", row1)
+ checkEvaluation(StringLPad(Literal("hi"), Literal(1), Literal("??")), "h", row1)
+ checkEvaluation(StringLPad(s1, s2, s3), "???hi", row1)
+ checkEvaluation(StringLPad(s1, s2, s3), "h", row2)
+ checkEvaluation(StringLPad(s1, s2, s3), null, row3)
+
+ checkEvaluation(StringRPad(Literal("hi"), Literal(5), Literal("??")), "hi???", row1)
+ checkEvaluation(StringRPad(Literal("hi"), Literal(1), Literal("??")), "h", row1)
+ checkEvaluation(StringRPad(s1, s2, s3), "hi???", row1)
+ checkEvaluation(StringRPad(s1, s2, s3), "h", row2)
+ checkEvaluation(StringRPad(s1, s2, s3), null, row3)
+ }
+
+ test("REPEAT") {
+ val s1 = 'a.string.at(0)
+ val s2 = 'b.int.at(1)
+ val row1 = create_row("hi", 2)
+ val row2 = create_row(null, 1)
+
+ checkEvaluation(StringRepeat(Literal("hi"), Literal(2)), "hihi", row1)
+ checkEvaluation(StringRepeat(Literal("hi"), Literal(-1)), "", row1)
+ checkEvaluation(StringRepeat(s1, s2), "hihi", row1)
+ checkEvaluation(StringRepeat(s1, s2), null, row2)
+ }
+
+ test("REVERSE") {
+ val s = 'a.string.at(0)
+ val row1 = create_row("abccc")
+ checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1)
+ checkEvaluation(StringReverse(s), "cccba", row1)
+ }
+
+ test("SPACE") {
+ val s1 = 'b.int.at(0)
+ val row1 = create_row(2)
+ val row2 = create_row(null)
+
+ checkEvaluation(StringSpace(Literal(2)), " ", row1)
+ checkEvaluation(StringSpace(Literal(-1)), "", row1)
+ checkEvaluation(StringSpace(Literal(0)), "", row1)
+ checkEvaluation(StringSpace(s1), " ", row1)
+ checkEvaluation(StringSpace(s1), null, row2)
+ }
+
+ test("SPLIT") {
+ val s1 = 'a.string.at(0)
+ val s2 = 'b.string.at(1)
+ val row1 = create_row("aa2bb3cc", "[1-9]+")
+
+ checkEvaluation(
+ StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1)
+ checkEvaluation(
+ StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 4da9ffc495..08bf37a5c2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1627,6 +1627,179 @@ object functions {
def ascii(columnName: String): Column = ascii(Column(columnName))
/**
+ * Trim the spaces from both ends for the specified string value.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def trim(e: Column): Column = StringTrim(e.expr)
+
+ /**
+ * Trim the spaces from both ends for the specified column.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def trim(columnName: String): Column = trim(Column(columnName))
+
+ /**
+ * Trim the spaces from left end for the specified string value.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def ltrim(e: Column): Column = StringTrimLeft(e.expr)
+
+ /**
+ * Trim the spaces from left end for the specified column.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def ltrim(columnName: String): Column = ltrim(Column(columnName))
+
+ /**
+ * Trim the spaces from right end for the specified string value.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def rtrim(e: Column): Column = StringTrimRight(e.expr)
+
+ /**
+ * Trim the spaces from right end for the specified column.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def rtrim(columnName: String): Column = rtrim(Column(columnName))
+
+ /**
+ * Format strings in printf-style.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ @scala.annotation.varargs
+ def formatString(format: Column, arguments: Column*): Column = {
+ StringFormat((format +: arguments).map(_.expr): _*)
+ }
+
+ /**
+ * Format strings in printf-style.
+ * NOTE: `format` is the string value of the formatter, not column name.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ @scala.annotation.varargs
+ def formatString(format: String, arguNames: String*): Column = {
+ StringFormat(lit(format).expr +: arguNames.map(Column(_).expr): _*)
+ }
+
+ /**
+ * Locate the position of the first occurrence of substr value in the given string.
+ * Returns null if either of the arguments are null.
+ *
+ * NOTE: The position is not zero based, but 1 based index, returns 0 if substr
+ * could not be found in str.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def instr(substr: String, sub: String): Column = instr(Column(substr), Column(sub))
+
+ /**
+ * Locate the position of the first occurrence of substr column in the given string.
+ * Returns null if either of the arguments are null.
+ *
+ * NOTE: The position is not zero based, but 1 based index, returns 0 if substr
+ * could not be found in str.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def instr(substr: Column, sub: Column): Column = StringInstr(substr.expr, sub.expr)
+
+ /**
+ * Locate the position of the first occurrence of substr.
+ *
+ * NOTE: The position is not zero based, but 1 based index, returns 0 if substr
+ * could not be found in str.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def locate(substr: String, str: String): Column = {
+ locate(Column(substr), Column(str))
+ }
+
+ /**
+ * Locate the position of the first occurrence of substr.
+ *
+ * NOTE: The position is not zero based, but 1 based index, returns 0 if substr
+ * could not be found in str.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def locate(substr: Column, str: Column): Column = {
+ new StringLocate(substr.expr, str.expr)
+ }
+
+ /**
+ * Locate the position of the first occurrence of substr in a given string after position pos.
+ *
+ * NOTE: The position is not zero based, but 1 based index, returns 0 if substr
+ * could not be found in str.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def locate(substr: String, str: String, pos: String): Column = {
+ locate(Column(substr), Column(str), Column(pos))
+ }
+
+ /**
+ * Locate the position of the first occurrence of substr in a given string after position pos.
+ *
+ * NOTE: The position is not zero based, but 1 based index, returns 0 if substr
+ * could not be found in str.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def locate(substr: Column, str: Column, pos: Column): Column = {
+ StringLocate(substr.expr, str.expr, pos.expr)
+ }
+
+ /**
+ * Locate the position of the first occurrence of substr in a given string after position pos.
+ *
+ * NOTE: The position is not zero based, but 1 based index, returns 0 if substr
+ * could not be found in str.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def locate(substr: Column, str: Column, pos: Int): Column = {
+ StringLocate(substr.expr, str.expr, lit(pos).expr)
+ }
+
+ /**
+ * Locate the position of the first occurrence of substr in a given string after position pos.
+ *
+ * NOTE: The position is not zero based, but 1 based index, returns 0 if substr
+ * could not be found in str.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def locate(substr: String, str: String, pos: Int): Column = {
+ locate(Column(substr), Column(str), lit(pos))
+ }
+
+ /**
* Computes the specified value from binary to a base64 string.
*
* @group string_funcs
@@ -1659,6 +1832,46 @@ object functions {
def unbase64(columnName: String): Column = unbase64(Column(columnName))
/**
+ * Left-padded with pad to a length of len.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def lpad(str: String, len: String, pad: String): Column = {
+ lpad(Column(str), Column(len), Column(pad))
+ }
+
+ /**
+ * Left-padded with pad to a length of len.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def lpad(str: Column, len: Column, pad: Column): Column = {
+ StringLPad(str.expr, len.expr, pad.expr)
+ }
+
+ /**
+ * Left-padded with pad to a length of len.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def lpad(str: Column, len: Int, pad: Column): Column = {
+ StringLPad(str.expr, lit(len).expr, pad.expr)
+ }
+
+ /**
+ * Left-padded with pad to a length of len.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def lpad(str: String, len: Int, pad: String): Column = {
+ lpad(Column(str), len, Column(pad))
+ }
+
+ /**
* Computes the first argument into a binary from a string using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
@@ -1702,6 +1915,146 @@ object functions {
def decode(columnName: String, charset: String): Column =
decode(Column(columnName), charset)
+ /**
+ * Right-padded with pad to a length of len.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def rpad(str: String, len: String, pad: String): Column = {
+ rpad(Column(str), Column(len), Column(pad))
+ }
+
+ /**
+ * Right-padded with pad to a length of len.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def rpad(str: Column, len: Column, pad: Column): Column = {
+ StringRPad(str.expr, len.expr, pad.expr)
+ }
+
+ /**
+ * Right-padded with pad to a length of len.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def rpad(str: String, len: Int, pad: String): Column = {
+ rpad(Column(str), len, Column(pad))
+ }
+
+ /**
+ * Right-padded with pad to a length of len.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def rpad(str: Column, len: Int, pad: Column): Column = {
+ StringRPad(str.expr, lit(len).expr, pad.expr)
+ }
+
+ /**
+ * Repeat the string value of the specified column n times.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def repeat(strColumn: String, timesColumn: String): Column = {
+ repeat(Column(strColumn), Column(timesColumn))
+ }
+
+ /**
+ * Repeat the string expression value n times.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def repeat(str: Column, times: Column): Column = {
+ StringRepeat(str.expr, times.expr)
+ }
+
+ /**
+ * Repeat the string value of the specified column n times.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def repeat(strColumn: String, times: Int): Column = {
+ repeat(Column(strColumn), times)
+ }
+
+ /**
+ * Repeat the string expression value n times.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def repeat(str: Column, times: Int): Column = {
+ StringRepeat(str.expr, lit(times).expr)
+ }
+
+ /**
+ * Splits str around pattern (pattern is a regular expression).
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def split(strColumnName: String, pattern: String): Column = {
+ split(Column(strColumnName), pattern)
+ }
+
+ /**
+ * Splits str around pattern (pattern is a regular expression).
+ * NOTE: pattern is a string represent the regular expression.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def split(str: Column, pattern: String): Column = {
+ StringSplit(str.expr, lit(pattern).expr)
+ }
+
+ /**
+ * Reversed the string for the specified column.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def reverse(str: String): Column = {
+ reverse(Column(str))
+ }
+
+ /**
+ * Reversed the string for the specified value.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def reverse(str: Column): Column = {
+ StringReverse(str.expr)
+ }
+
+ /**
+ * Make a n spaces of string.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def space(n: String): Column = {
+ space(Column(n))
+ }
+
+ /**
+ * Make a n spaces of string.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def space(n: Column): Column = {
+ StringSpace(n.expr)
+ }
//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index afba28515e..173280375c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -209,21 +209,14 @@ class DataFrameFunctionsSuite extends QueryTest {
}
test("string length function") {
+ val df = Seq(("abc", "")).toDF("a", "b")
checkAnswer(
- nullStrings.select(strlen($"s"), strlen("s")),
- nullStrings.collect().toSeq.map { r =>
- val v = r.getString(1)
- val l = if (v == null) null else v.length
- Row(l, l)
- })
+ df.select(strlen($"a"), strlen("b")),
+ Row(3, 0))
checkAnswer(
- nullStrings.selectExpr("length(s)"),
- nullStrings.collect().toSeq.map { r =>
- val v = r.getString(1)
- val l = if (v == null) null else v.length
- Row(l)
- })
+ df.selectExpr("length(a)", "length(b)"),
+ Row(3, 0))
}
test("Levenshtein distance") {
@@ -273,4 +266,119 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(bytes, "大千世界"))
// scalastyle:on
}
+
+ test("string trim functions") {
+ val df = Seq((" example ", "")).toDF("a", "b")
+
+ checkAnswer(
+ df.select(ltrim($"a"), rtrim($"a"), trim($"a")),
+ Row("example ", " example", "example"))
+
+ checkAnswer(
+ df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"),
+ Row("example ", " example", "example"))
+ }
+
+ test("string formatString function") {
+ val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c")
+
+ checkAnswer(
+ df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
+ Row("aa123cc", "aa123cc"))
+
+ checkAnswer(
+ df.selectExpr("printf(a, b, c)"),
+ Row("aa123cc"))
+ }
+
+ test("string instr function") {
+ val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c")
+
+ checkAnswer(
+ df.select(instr($"a", $"b"), instr("a", "b")),
+ Row(1, 1))
+
+ checkAnswer(
+ df.selectExpr("instr(a, b)"),
+ Row(1))
+ }
+
+ test("string locate function") {
+ val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d")
+
+ checkAnswer(
+ df.select(
+ locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1),
+ locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")),
+ Row(1, 1, 2, 2, 2, 2))
+
+ checkAnswer(
+ df.selectExpr("locate(b, a)", "locate(b, a, d)"),
+ Row(1, 2))
+ }
+
+ test("string padding functions") {
+ val df = Seq(("hi", 5, "??")).toDF("a", "b", "c")
+
+ checkAnswer(
+ df.select(
+ lpad($"a", $"b", $"c"), rpad("a", "b", "c"),
+ lpad($"a", 1, $"c"), rpad("a", 1, "c")),
+ Row("???hi", "hi???", "h", "h"))
+
+ checkAnswer(
+ df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"),
+ Row("???hi", "hi???", "h", "h"))
+ }
+
+ test("string repeat function") {
+ val df = Seq(("hi", 2)).toDF("a", "b")
+
+ checkAnswer(
+ df.select(
+ repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")),
+ Row("hihi", "hihi", "hihi", "hihi"))
+
+ checkAnswer(
+ df.selectExpr("repeat(a, 2)", "repeat(a, b)"),
+ Row("hihi", "hihi"))
+ }
+
+ test("string reverse function") {
+ val df = Seq(("hi", "hhhi")).toDF("a", "b")
+
+ checkAnswer(
+ df.select(reverse($"a"), reverse("b")),
+ Row("ih", "ihhh"))
+
+ checkAnswer(
+ df.selectExpr("reverse(b)"),
+ Row("ihhh"))
+ }
+
+ test("string space function") {
+ val df = Seq((2, 3)).toDF("a", "b")
+
+ checkAnswer(
+ df.select(space($"a"), space("b")),
+ Row(" ", " "))
+
+ checkAnswer(
+ df.selectExpr("space(b)"),
+ Row(" "))
+ }
+
+ test("string split function") {
+ val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b")
+
+ checkAnswer(
+ df.select(
+ split($"a", "[1-9]+"),
+ split("a", "[1-9]+")),
+ Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc")))
+
+ checkAnswer(
+ df.selectExpr("split(a, '[1-9]+')"),
+ Row(Seq("aa", "bb", "cc")))
+ }
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 847d80ad58..60d050b0a0 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -25,6 +25,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods;
import static org.apache.spark.unsafe.PlatformDependent.*;
+
/**
* A UTF-8 String for internal Spark use.
* <p>
@@ -204,6 +205,196 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
return fromString(toString().toLowerCase());
}
+ /**
+ * Copy the bytes from the current UTF8String, and make a new UTF8String.
+ * @param start the start position of the current UTF8String in bytes.
+ * @param end the end position of the current UTF8String in bytes.
+ * @return a new UTF8String in the position of [start, end] of current UTF8String bytes.
+ */
+ private UTF8String copyUTF8String(int start, int end) {
+ int len = end - start + 1;
+ byte[] newBytes = new byte[len];
+ copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len);
+ return UTF8String.fromBytes(newBytes);
+ }
+
+ public UTF8String trim() {
+ int s = 0;
+ int e = this.numBytes - 1;
+ // skip all of the space (0x20) in the left side
+ while (s < this.numBytes && getByte(s) == 0x20) s++;
+ // skip all of the space (0x20) in the right side
+ while (e >= 0 && getByte(e) == 0x20) e--;
+
+ if (s > e) {
+ // empty string
+ return UTF8String.fromBytes(new byte[0]);
+ } else {
+ return copyUTF8String(s, e);
+ }
+ }
+
+ public UTF8String trimLeft() {
+ int s = 0;
+ // skip all of the space (0x20) in the left side
+ while (s < this.numBytes && getByte(s) == 0x20) s++;
+ if (s == this.numBytes) {
+ // empty string
+ return UTF8String.fromBytes(new byte[0]);
+ } else {
+ return copyUTF8String(s, this.numBytes - 1);
+ }
+ }
+
+ public UTF8String trimRight() {
+ int e = numBytes - 1;
+ // skip all of the space (0x20) in the right side
+ while (e >= 0 && getByte(e) == 0x20) e--;
+
+ if (e < 0) {
+ // empty string
+ return UTF8String.fromBytes(new byte[0]);
+ } else {
+ return copyUTF8String(0, e);
+ }
+ }
+
+ public UTF8String reverse() {
+ byte[] bytes = getBytes();
+ byte[] result = new byte[bytes.length];
+
+ int i = 0; // position in byte
+ while (i < numBytes) {
+ int len = numBytesForFirstByte(getByte(i));
+ System.arraycopy(bytes, i, result, result.length - i - len, len);
+
+ i += len;
+ }
+
+ return UTF8String.fromBytes(result);
+ }
+
+ public UTF8String repeat(int times) {
+ if (times <=0) {
+ return fromBytes(new byte[0]);
+ }
+
+ byte[] newBytes = new byte[numBytes * times];
+ System.arraycopy(getBytes(), 0, newBytes, 0, numBytes);
+
+ int copied = 1;
+ while (copied < times) {
+ int toCopy = Math.min(copied, times - copied);
+ System.arraycopy(newBytes, 0, newBytes, copied * numBytes, numBytes * toCopy);
+ copied += toCopy;
+ }
+
+ return UTF8String.fromBytes(newBytes);
+ }
+
+ /**
+ * Returns the position of the first occurrence of substr in
+ * current string from the specified position (0-based index).
+ *
+ * @param v the string to be searched
+ * @param start the start position of the current string for searching
+ * @return the position of the first occurrence of substr, if not found, -1 returned.
+ */
+ public int indexOf(UTF8String v, int start) {
+ if (v.numBytes() == 0) {
+ return 0;
+ }
+
+ // locate to the start position.
+ int i = 0; // position in byte
+ int c = 0; // position in character
+ while (i < numBytes && c < start) {
+ i += numBytesForFirstByte(getByte(i));
+ c += 1;
+ }
+
+ do {
+ if (i + v.numBytes > numBytes) {
+ return -1;
+ }
+ if (ByteArrayMethods.arrayEquals(base, offset + i, v.base, v.offset, v.numBytes)) {
+ return c;
+ }
+ i += numBytesForFirstByte(getByte(i));
+ c += 1;
+ } while(i < numBytes);
+
+ return -1;
+ }
+
+ /**
+ * Returns str, right-padded with pad to a length of len
+ * For example:
+ * ('hi', 5, '??') => 'hi???'
+ * ('hi', 1, '??') => 'h'
+ */
+ public UTF8String rpad(int len, UTF8String pad) {
+ int spaces = len - this.numChars(); // number of char need to pad
+ if (spaces <= 0) {
+ // no padding at all, return the substring of the current string
+ return substring(0, len);
+ } else {
+ int padChars = pad.numChars();
+ int count = spaces / padChars; // how many padding string needed
+ // the partial string of the padding
+ UTF8String remain = pad.substring(0, spaces - padChars * count);
+
+ byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes];
+ System.arraycopy(getBytes(), 0, data, 0, this.numBytes);
+ int offset = this.numBytes;
+ int idx = 0;
+ byte[] padBytes = pad.getBytes();
+ while (idx < count) {
+ System.arraycopy(padBytes, 0, data, offset, pad.numBytes);
+ ++idx;
+ offset += pad.numBytes;
+ }
+ System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes);
+
+ return UTF8String.fromBytes(data);
+ }
+ }
+
+ /**
+ * Returns str, left-padded with pad to a length of len.
+ * For example:
+ * ('hi', 5, '??') => '???hi'
+ * ('hi', 1, '??') => 'h'
+ */
+ public UTF8String lpad(int len, UTF8String pad) {
+ int spaces = len - this.numChars(); // number of char need to pad
+ if (spaces <= 0) {
+ // no padding at all, return the substring of the current string
+ return substring(0, len);
+ } else {
+ int padChars = pad.numChars();
+ int count = spaces / padChars; // how many padding string needed
+ // the partial string of the padding
+ UTF8String remain = pad.substring(0, spaces - padChars * count);
+
+ byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes];
+
+ int offset = 0;
+ int idx = 0;
+ byte[] padBytes = pad.getBytes();
+ while (idx < count) {
+ System.arraycopy(padBytes, 0, data, offset, pad.numBytes);
+ ++idx;
+ offset += pad.numBytes;
+ }
+ System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes);
+ offset += remain.numBytes;
+ System.arraycopy(getBytes(), 0, data, offset, numBytes());
+
+ return UTF8String.fromBytes(data);
+ }
+ }
+
@Override
public String toString() {
try {
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index fb463ba17f..694bdc29f3 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -121,12 +121,94 @@ public class UTF8StringSuite {
@Test
public void substring() {
- assertEquals(fromString("hello").substring(0, 0), fromString(""));
- assertEquals(fromString("hello").substring(1, 3), fromString("el"));
- assertEquals(fromString("数据砖头").substring(0, 1), fromString("数"));
- assertEquals(fromString("数据砖头").substring(1, 3), fromString("据砖"));
- assertEquals(fromString("数据砖头").substring(3, 5), fromString("头"));
- assertEquals(fromString("ߵ梷").substring(0, 2), fromString("ߵ梷"));
+ assertEquals(fromString(""), fromString("hello").substring(0, 0));
+ assertEquals(fromString("el"), fromString("hello").substring(1, 3));
+ assertEquals(fromString("数"), fromString("数据砖头").substring(0, 1));
+ assertEquals(fromString("据砖"), fromString("数据砖头").substring(1, 3));
+ assertEquals(fromString("头"), fromString("数据砖头").substring(3, 5));
+ assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0, 2));
+ }
+
+ @Test
+ public void trims() {
+ assertEquals(fromString("hello"), fromString(" hello ").trim());
+ assertEquals(fromString("hello "), fromString(" hello ").trimLeft());
+ assertEquals(fromString(" hello"), fromString(" hello ").trimRight());
+
+ assertEquals(fromString(""), fromString(" ").trim());
+ assertEquals(fromString(""), fromString(" ").trimLeft());
+ assertEquals(fromString(""), fromString(" ").trimRight());
+
+ assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim());
+ assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft());
+ assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight());
+
+ assertEquals(fromString("数据砖头"), fromString("数据砖头").trim());
+ assertEquals(fromString("数据砖头"), fromString("数据砖头").trimLeft());
+ assertEquals(fromString("数据砖头"), fromString("数据砖头").trimRight());
+ }
+
+ @Test
+ public void indexOf() {
+ assertEquals(0, fromString("").indexOf(fromString(""), 0));
+ assertEquals(-1, fromString("").indexOf(fromString("l"), 0));
+ assertEquals(0, fromString("hello").indexOf(fromString(""), 0));
+ assertEquals(2, fromString("hello").indexOf(fromString("l"), 0));
+ assertEquals(3, fromString("hello").indexOf(fromString("l"), 3));
+ assertEquals(-1, fromString("hello").indexOf(fromString("a"), 0));
+ assertEquals(2, fromString("hello").indexOf(fromString("ll"), 0));
+ assertEquals(-1, fromString("hello").indexOf(fromString("ll"), 4));
+ assertEquals(1, fromString("数据砖头").indexOf(fromString("据砖"), 0));
+ assertEquals(-1, fromString("数据砖头").indexOf(fromString("数"), 3));
+ assertEquals(0, fromString("数据砖头").indexOf(fromString("数"), 0));
+ assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0));
+ }
+
+ @Test
+ public void reverse() {
+ assertEquals(fromString("olleh"), fromString("hello").reverse());
+ assertEquals(fromString(""), fromString("").reverse());
+ assertEquals(fromString("者行孙"), fromString("孙行者").reverse());
+ assertEquals(fromString("者行孙 olleh"), fromString("hello 孙行者").reverse());
+ }
+
+ @Test
+ public void repeat() {
+ assertEquals(fromString("数d数d数d数d数d"), fromString("数d").repeat(5));
+ assertEquals(fromString("数d"), fromString("数d").repeat(1));
+ assertEquals(fromString(""), fromString("数d").repeat(-1));
+ }
+
+ @Test
+ public void pad() {
+ assertEquals(fromString("hel"), fromString("hello").lpad(3, fromString("????")));
+ assertEquals(fromString("hello"), fromString("hello").lpad(5, fromString("????")));
+ assertEquals(fromString("?hello"), fromString("hello").lpad(6, fromString("????")));
+ assertEquals(fromString("???????hello"), fromString("hello").lpad(12, fromString("????")));
+ assertEquals(fromString("?????hello"), fromString("hello").lpad(10, fromString("?????")));
+ assertEquals(fromString("???????"), fromString("").lpad(7, fromString("?????")));
+
+ assertEquals(fromString("hel"), fromString("hello").rpad(3, fromString("????")));
+ assertEquals(fromString("hello"), fromString("hello").rpad(5, fromString("????")));
+ assertEquals(fromString("hello?"), fromString("hello").rpad(6, fromString("????")));
+ assertEquals(fromString("hello???????"), fromString("hello").rpad(12, fromString("????")));
+ assertEquals(fromString("hello?????"), fromString("hello").rpad(10, fromString("?????")));
+ assertEquals(fromString("???????"), fromString("").rpad(7, fromString("?????")));
+
+
+ assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, fromString("????")));
+ assertEquals(fromString("?数据砖头"), fromString("数据砖头").lpad(5, fromString("????")));
+ assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????")));
+ assertEquals(fromString("孙行数据砖头"), fromString("数据砖头").lpad(6, fromString("孙行者")));
+ assertEquals(fromString("孙行者数据砖头"), fromString("数据砖头").lpad(7, fromString("孙行者")));
+ assertEquals(fromString("孙行者孙行者孙行数据砖头"), fromString("数据砖头").lpad(12, fromString("孙行者")));
+
+ assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, fromString("????")));
+ assertEquals(fromString("数据砖头?"), fromString("数据砖头").rpad(5, fromString("????")));
+ assertEquals(fromString("数据砖头??"), fromString("数据砖头").rpad(6, fromString("????")));
+ assertEquals(fromString("数据砖头孙行"), fromString("数据砖头").rpad(6, fromString("孙行者")));
+ assertEquals(fromString("数据砖头孙行者"), fromString("数据砖头").rpad(7, fromString("孙行者")));
+ assertEquals(fromString("数据砖头孙行者孙行者孙行"), fromString("数据砖头").rpad(12, fromString("孙行者")));
}
@Test