diff options
author | Cheng Hao <hao.cheng@intel.com> | 2015-07-09 11:11:34 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-07-09 11:11:34 -0700 |
commit | 0b0b9ceaf73de472198c9804fb7ae61fa2a2e097 (patch) | |
tree | e6e86d9c5921fdd26a1393beffa3e1b7bc6f2504 | |
parent | 0cd84c86cac68600a74d84e50ad40c0c8b84822a (diff) | |
download | spark-0b0b9ceaf73de472198c9804fb7ae61fa2a2e097.tar.gz spark-0b0b9ceaf73de472198c9804fb7ae61fa2a2e097.tar.bz2 spark-0b0b9ceaf73de472198c9804fb7ae61fa2a2e097.zip |
[SPARK-8247] [SPARK-8249] [SPARK-8252] [SPARK-8254] [SPARK-8257] [SPARK-8258] [SPARK-8259] [SPARK-8261] [SPARK-8262] [SPARK-8253] [SPARK-8260] [SPARK-8267] [SQL] Add String Expressions
Author: Cheng Hao <hao.cheng@intel.com>
Closes #6762 from chenghao-intel/str_funcs and squashes the following commits:
b09a909 [Cheng Hao] update the code as feedback
7ebbf4c [Cheng Hao] Add more string expressions
7 files changed, 1202 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5c25181e1c..f62d79f8ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -147,12 +147,24 @@ object FunctionRegistry { expression[Base64]("base64"), expression[Encode]("encode"), expression[Decode]("decode"), + expression[StringInstr]("instr"), expression[Lower]("lcase"), expression[Lower]("lower"), expression[StringLength]("length"), expression[Levenshtein]("levenshtein"), + expression[StringLocate]("locate"), + expression[StringLPad]("lpad"), + expression[StringTrimLeft]("ltrim"), + expression[StringFormat]("printf"), + expression[StringRPad]("rpad"), + expression[StringRepeat]("repeat"), + expression[StringReverse]("reverse"), + expression[StringTrimRight]("rtrim"), + expression[StringSpace]("space"), + expression[StringSplit]("split"), expression[Substring]("substr"), expression[Substring]("substring"), + expression[StringTrim]("trim"), expression[UnBase64]("unbase64"), expression[Upper]("ucase"), expression[Unhex]("unhex"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 57f436485b..f64899c1ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale import java.util.regex.Pattern import org.apache.commons.lang3.StringUtils @@ -104,7 +105,7 @@ case class RLike(left: Expression, right: Expression) override def toString: String = s"$left RLIKE $right" } -trait CaseConversionExpression extends ExpectsInputTypes { +trait String2StringExpression extends ExpectsInputTypes { self: UnaryExpression => def convert(v: UTF8String): UTF8String @@ -119,7 +120,7 @@ trait CaseConversionExpression extends ExpectsInputTypes { /** * A function that converts the characters of a string to uppercase. */ -case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { +case class Upper(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toUpperCase @@ -131,7 +132,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE /** * A function that converts the characters of a string to lowercase. */ -case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { +case class Lower(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toLowerCase @@ -188,6 +189,301 @@ case class EndsWith(left: Expression, right: Expression) } /** + * A function that trim the spaces from both ends for the specified string. + */ +case class StringTrim(child: Expression) + extends UnaryExpression with String2StringExpression { + + def convert(v: UTF8String): UTF8String = v.trim() + + override def prettyName: String = "trim" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).trim()") + } +} + +/** + * A function that trim the spaces from left end for given string. + */ +case class StringTrimLeft(child: Expression) + extends UnaryExpression with String2StringExpression { + + def convert(v: UTF8String): UTF8String = v.trimLeft() + + override def prettyName: String = "ltrim" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).trimLeft()") + } +} + +/** + * A function that trim the spaces from right end for given string. + */ +case class StringTrimRight(child: Expression) + extends UnaryExpression with String2StringExpression { + + def convert(v: UTF8String): UTF8String = v.trimRight() + + override def prettyName: String = "rtrim" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).trimRight()") + } +} + +/** + * A function that returns the position of the first occurrence of substr in the given string. + * Returns null if either of the arguments are null and + * returns 0 if substr could not be found in str. + * + * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. + */ +case class StringInstr(str: Expression, substr: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = str + override def right: Expression = substr + override def dataType: DataType = IntegerType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + override def nullSafeEval(string: Any, sub: Any): Any = { + string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 1 + } + + override def prettyName: String = "instr" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (l, r) => + s"($l).indexOf($r, 0) + 1") + } +} + +/** + * A function that returns the position of the first occurrence of substr + * in given string after position pos. + */ +case class StringLocate(substr: Expression, str: Expression, start: Expression) + extends Expression with ExpectsInputTypes { + + def this(substr: Expression, str: Expression) = { + this(substr, str, Literal(0)) + } + + override def children: Seq[Expression] = substr :: str :: start :: Nil + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = substr.nullable || str.nullable + override def dataType: DataType = IntegerType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + + override def eval(input: InternalRow): Any = { + val s = start.eval(input) + if (s == null) { + // if the start position is null, we need to return 0, (conform to Hive) + 0 + } else { + val r = substr.eval(input) + if (r == null) { + null + } else { + val l = str.eval(input) + if (l == null) { + null + } else { + l.asInstanceOf[UTF8String].indexOf( + r.asInstanceOf[UTF8String], + s.asInstanceOf[Int]) + 1 + } + } + } + } + + override def prettyName: String = "locate" +} + +/** + * Returns str, left-padded with pad to a length of len. + */ +case class StringLPad(str: Expression, len: Expression, pad: Expression) + extends Expression with ExpectsInputTypes { + + override def children: Seq[Expression] = str :: len :: pad :: Nil + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = children.exists(_.nullable) + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) + + override def eval(input: InternalRow): Any = { + val s = str.eval(input) + if (s == null) { + null + } else { + val l = len.eval(input) + if (l == null) { + null + } else { + val p = pad.eval(input) + if (p == null) { + null + } else { + val len = l.asInstanceOf[Int] + val str = s.asInstanceOf[UTF8String] + val pad = p.asInstanceOf[UTF8String] + + str.lpad(len, pad) + } + } + } + } + + override def prettyName: String = "lpad" +} + +/** + * Returns str, right-padded with pad to a length of len. + */ +case class StringRPad(str: Expression, len: Expression, pad: Expression) + extends Expression with ExpectsInputTypes { + + override def children: Seq[Expression] = str :: len :: pad :: Nil + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = children.exists(_.nullable) + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) + + override def eval(input: InternalRow): Any = { + val s = str.eval(input) + if (s == null) { + null + } else { + val l = len.eval(input) + if (l == null) { + null + } else { + val p = pad.eval(input) + if (p == null) { + null + } else { + val len = l.asInstanceOf[Int] + val str = s.asInstanceOf[UTF8String] + val pad = p.asInstanceOf[UTF8String] + + str.rpad(len, pad) + } + } + } + } + + override def prettyName: String = "rpad" +} + +/** + * Returns the input formatted according do printf-style format strings + */ +case class StringFormat(children: Expression*) extends Expression { + + require(children.length >=1, "printf() should take at least 1 argument") + + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = children(0).nullable + override def dataType: DataType = StringType + private def format: Expression = children(0) + private def args: Seq[Expression] = children.tail + + override def eval(input: InternalRow): Any = { + val pattern = format.eval(input) + if (pattern == null) { + null + } else { + val sb = new StringBuffer() + val formatter = new java.util.Formatter(sb, Locale.US) + + val arglist = args.map(_.eval(input).asInstanceOf[AnyRef]) + formatter.format(pattern.asInstanceOf[UTF8String].toString(), arglist: _*) + + UTF8String.fromString(sb.toString) + } + } + + override def prettyName: String = "printf" +} + +/** + * Returns the string which repeat the given string value n times. + */ +case class StringRepeat(str: Expression, times: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = str + override def right: Expression = times + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType) + + override def nullSafeEval(string: Any, n: Any): Any = { + string.asInstanceOf[UTF8String].repeat(n.asInstanceOf[Integer]) + } + + override def prettyName: String = "repeat" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)") + } +} + +/** + * Returns the reversed given string. + */ +case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression { + override def convert(v: UTF8String): UTF8String = v.reverse() + + override def prettyName: String = "reverse" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).reverse()") + } +} + +/** + * Returns a n spaces string. + */ +case class StringSpace(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(IntegerType) + + override def nullSafeEval(s: Any): Any = { + val length = s.asInstanceOf[Integer] + + val spaces = new Array[Byte](if (length < 0) 0 else length) + java.util.Arrays.fill(spaces, ' '.asInstanceOf[Byte]) + UTF8String.fromBytes(spaces) + } + + override def prettyName: String = "space" +} + +/** + * Splits str around pat (pattern is a regular expression). + */ +case class StringSplit(str: Expression, pattern: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = str + override def right: Expression = pattern + override def dataType: DataType = ArrayType(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + override def nullSafeEval(string: Any, regex: Any): Any = { + val splits = + string.asInstanceOf[UTF8String].toString.split(regex.asInstanceOf[UTF8String].toString, -1) + splits.toSeq.map(UTF8String.fromString) + } + + override def prettyName: String = "split" +} + +/** * A function that takes a substring of its first argument starting at a given position. * Defined for String and Binary types. */ @@ -199,8 +495,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } override def foldable: Boolean = str.foldable && pos.foldable && len.foldable - - override def nullable: Boolean = str.nullable || pos.nullable || len.nullable + override def nullable: Boolean = str.nullable || pos.nullable || len.nullable override def dataType: DataType = { if (!resolved) { @@ -373,4 +668,3 @@ case class Encode(value: Expression, charset: Expression) } } - diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index 69bef1c63e..b19f4ee37a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -288,4 +288,142 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Levenshtein(Literal("世界千世"), Literal("大a界b")), 4) // scalastyle:on } + + test("TRIM/LTRIM/RTRIM") { + val s = 'a.string.at(0) + checkEvaluation(StringTrim(Literal(" aa ")), "aa", create_row(" abdef ")) + checkEvaluation(StringTrim(s), "abdef", create_row(" abdef ")) + + checkEvaluation(StringTrimLeft(Literal(" aa ")), "aa ", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(s), "abdef ", create_row(" abdef ")) + + checkEvaluation(StringTrimRight(Literal(" aa ")), " aa", create_row(" abdef ")) + checkEvaluation(StringTrimRight(s), " abdef", create_row(" abdef ")) + + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringTrimRight(s), " 花花世界", create_row(" 花花世界 ")) + checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 ")) + checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 ")) + // scalastyle:on + } + + test("FORMAT") { + val f = 'f.string.at(0) + val d1 = 'd.int.at(1) + val s1 = 's.int.at(2) + + val row1 = create_row("aa%d%s", 12, "cc") + val row2 = create_row(null, 12, "cc") + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + + checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1) + checkEvaluation(StringFormat(f, d1, s1), null, row2) + } + + test("INSTR") { + val s1 = 'a.string.at(0) + val s2 = 'b.string.at(1) + val s3 = 'c.string.at(2) + val row1 = create_row("aaads", "aa", "zz") + + checkEvaluation(StringInstr(Literal("aaads"), Literal("aa")), 1, row1) + checkEvaluation(StringInstr(Literal("aaads"), Literal("de")), 0, row1) + checkEvaluation(StringInstr(Literal.create(null, StringType), Literal("de")), null, row1) + checkEvaluation(StringInstr(Literal("aaads"), Literal.create(null, StringType)), null, row1) + + checkEvaluation(StringInstr(s1, s2), 1, row1) + checkEvaluation(StringInstr(s1, s3), 0, row1) + + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringInstr(s1, s2), 3, create_row("花花世界", "世界")) + checkEvaluation(StringInstr(s1, s2), 1, create_row("花花世界", "花")) + checkEvaluation(StringInstr(s1, s2), 0, create_row("花花世界", "小")) + // scalastyle:on + } + + test("LOCATE") { + val s1 = 'a.string.at(0) + val s2 = 'b.string.at(1) + val s3 = 'c.string.at(2) + val s4 = 'd.int.at(3) + val row1 = create_row("aaads", "aa", "zz", 1) + + checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 0, row1) + checkEvaluation(new StringLocate(Literal("de"), Literal("aaads")), 0, row1) + checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 1), 0, row1) + + checkEvaluation(new StringLocate(s2, s1), 1, row1) + checkEvaluation(StringLocate(s2, s1, s4), 2, row1) + checkEvaluation(new StringLocate(s3, s1), 0, row1) + checkEvaluation(StringLocate(s3, s1, Literal.create(null, IntegerType)), 0, row1) + } + + test("LPAD/RPAD") { + val s1 = 'a.string.at(0) + val s2 = 'b.int.at(1) + val s3 = 'c.string.at(2) + val row1 = create_row("hi", 5, "??") + val row2 = create_row("hi", 1, "?") + val row3 = create_row(null, 1, "?") + + checkEvaluation(StringLPad(Literal("hi"), Literal(5), Literal("??")), "???hi", row1) + checkEvaluation(StringLPad(Literal("hi"), Literal(1), Literal("??")), "h", row1) + checkEvaluation(StringLPad(s1, s2, s3), "???hi", row1) + checkEvaluation(StringLPad(s1, s2, s3), "h", row2) + checkEvaluation(StringLPad(s1, s2, s3), null, row3) + + checkEvaluation(StringRPad(Literal("hi"), Literal(5), Literal("??")), "hi???", row1) + checkEvaluation(StringRPad(Literal("hi"), Literal(1), Literal("??")), "h", row1) + checkEvaluation(StringRPad(s1, s2, s3), "hi???", row1) + checkEvaluation(StringRPad(s1, s2, s3), "h", row2) + checkEvaluation(StringRPad(s1, s2, s3), null, row3) + } + + test("REPEAT") { + val s1 = 'a.string.at(0) + val s2 = 'b.int.at(1) + val row1 = create_row("hi", 2) + val row2 = create_row(null, 1) + + checkEvaluation(StringRepeat(Literal("hi"), Literal(2)), "hihi", row1) + checkEvaluation(StringRepeat(Literal("hi"), Literal(-1)), "", row1) + checkEvaluation(StringRepeat(s1, s2), "hihi", row1) + checkEvaluation(StringRepeat(s1, s2), null, row2) + } + + test("REVERSE") { + val s = 'a.string.at(0) + val row1 = create_row("abccc") + checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1) + checkEvaluation(StringReverse(s), "cccba", row1) + } + + test("SPACE") { + val s1 = 'b.int.at(0) + val row1 = create_row(2) + val row2 = create_row(null) + + checkEvaluation(StringSpace(Literal(2)), " ", row1) + checkEvaluation(StringSpace(Literal(-1)), "", row1) + checkEvaluation(StringSpace(Literal(0)), "", row1) + checkEvaluation(StringSpace(s1), " ", row1) + checkEvaluation(StringSpace(s1), null, row2) + } + + test("SPLIT") { + val s1 = 'a.string.at(0) + val s2 = 'b.string.at(1) + val row1 = create_row("aa2bb3cc", "[1-9]+") + + checkEvaluation( + StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1) + checkEvaluation( + StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4da9ffc495..08bf37a5c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1627,6 +1627,179 @@ object functions { def ascii(columnName: String): Column = ascii(Column(columnName)) /** + * Trim the spaces from both ends for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def trim(e: Column): Column = StringTrim(e.expr) + + /** + * Trim the spaces from both ends for the specified column. + * + * @group string_funcs + * @since 1.5.0 + */ + def trim(columnName: String): Column = trim(Column(columnName)) + + /** + * Trim the spaces from left end for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def ltrim(e: Column): Column = StringTrimLeft(e.expr) + + /** + * Trim the spaces from left end for the specified column. + * + * @group string_funcs + * @since 1.5.0 + */ + def ltrim(columnName: String): Column = ltrim(Column(columnName)) + + /** + * Trim the spaces from right end for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def rtrim(e: Column): Column = StringTrimRight(e.expr) + + /** + * Trim the spaces from right end for the specified column. + * + * @group string_funcs + * @since 1.5.0 + */ + def rtrim(columnName: String): Column = rtrim(Column(columnName)) + + /** + * Format strings in printf-style. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def formatString(format: Column, arguments: Column*): Column = { + StringFormat((format +: arguments).map(_.expr): _*) + } + + /** + * Format strings in printf-style. + * NOTE: `format` is the string value of the formatter, not column name. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def formatString(format: String, arguNames: String*): Column = { + StringFormat(lit(format).expr +: arguNames.map(Column(_).expr): _*) + } + + /** + * Locate the position of the first occurrence of substr value in the given string. + * Returns null if either of the arguments are null. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def instr(substr: String, sub: String): Column = instr(Column(substr), Column(sub)) + + /** + * Locate the position of the first occurrence of substr column in the given string. + * Returns null if either of the arguments are null. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def instr(substr: Column, sub: Column): Column = StringInstr(substr.expr, sub.expr) + + /** + * Locate the position of the first occurrence of substr. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: String, str: String): Column = { + locate(Column(substr), Column(str)) + } + + /** + * Locate the position of the first occurrence of substr. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: Column, str: Column): Column = { + new StringLocate(substr.expr, str.expr) + } + + /** + * Locate the position of the first occurrence of substr in a given string after position pos. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: String, str: String, pos: String): Column = { + locate(Column(substr), Column(str), Column(pos)) + } + + /** + * Locate the position of the first occurrence of substr in a given string after position pos. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: Column, str: Column, pos: Column): Column = { + StringLocate(substr.expr, str.expr, pos.expr) + } + + /** + * Locate the position of the first occurrence of substr in a given string after position pos. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: Column, str: Column, pos: Int): Column = { + StringLocate(substr.expr, str.expr, lit(pos).expr) + } + + /** + * Locate the position of the first occurrence of substr in a given string after position pos. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: String, str: String, pos: Int): Column = { + locate(Column(substr), Column(str), lit(pos)) + } + + /** * Computes the specified value from binary to a base64 string. * * @group string_funcs @@ -1659,6 +1832,46 @@ object functions { def unbase64(columnName: String): Column = unbase64(Column(columnName)) /** + * Left-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: String, len: String, pad: String): Column = { + lpad(Column(str), Column(len), Column(pad)) + } + + /** + * Left-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: Column, len: Column, pad: Column): Column = { + StringLPad(str.expr, len.expr, pad.expr) + } + + /** + * Left-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: Column, len: Int, pad: Column): Column = { + StringLPad(str.expr, lit(len).expr, pad.expr) + } + + /** + * Left-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: String, len: Int, pad: String): Column = { + lpad(Column(str), len, Column(pad)) + } + + /** * Computes the first argument into a binary from a string using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). * If either argument is null, the result will also be null. @@ -1702,6 +1915,146 @@ object functions { def decode(columnName: String, charset: String): Column = decode(Column(columnName), charset) + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: String, len: String, pad: String): Column = { + rpad(Column(str), Column(len), Column(pad)) + } + + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: Column, len: Column, pad: Column): Column = { + StringRPad(str.expr, len.expr, pad.expr) + } + + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: String, len: Int, pad: String): Column = { + rpad(Column(str), len, Column(pad)) + } + + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: Column, len: Int, pad: Column): Column = { + StringRPad(str.expr, lit(len).expr, pad.expr) + } + + /** + * Repeat the string value of the specified column n times. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(strColumn: String, timesColumn: String): Column = { + repeat(Column(strColumn), Column(timesColumn)) + } + + /** + * Repeat the string expression value n times. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(str: Column, times: Column): Column = { + StringRepeat(str.expr, times.expr) + } + + /** + * Repeat the string value of the specified column n times. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(strColumn: String, times: Int): Column = { + repeat(Column(strColumn), times) + } + + /** + * Repeat the string expression value n times. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(str: Column, times: Int): Column = { + StringRepeat(str.expr, lit(times).expr) + } + + /** + * Splits str around pattern (pattern is a regular expression). + * + * @group string_funcs + * @since 1.5.0 + */ + def split(strColumnName: String, pattern: String): Column = { + split(Column(strColumnName), pattern) + } + + /** + * Splits str around pattern (pattern is a regular expression). + * NOTE: pattern is a string represent the regular expression. + * + * @group string_funcs + * @since 1.5.0 + */ + def split(str: Column, pattern: String): Column = { + StringSplit(str.expr, lit(pattern).expr) + } + + /** + * Reversed the string for the specified column. + * + * @group string_funcs + * @since 1.5.0 + */ + def reverse(str: String): Column = { + reverse(Column(str)) + } + + /** + * Reversed the string for the specified value. + * + * @group string_funcs + * @since 1.5.0 + */ + def reverse(str: Column): Column = { + StringReverse(str.expr) + } + + /** + * Make a n spaces of string. + * + * @group string_funcs + * @since 1.5.0 + */ + def space(n: String): Column = { + space(Column(n)) + } + + /** + * Make a n spaces of string. + * + * @group string_funcs + * @since 1.5.0 + */ + def space(n: Column): Column = { + StringSpace(n.expr) + } ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index afba28515e..173280375c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -209,21 +209,14 @@ class DataFrameFunctionsSuite extends QueryTest { } test("string length function") { + val df = Seq(("abc", "")).toDF("a", "b") checkAnswer( - nullStrings.select(strlen($"s"), strlen("s")), - nullStrings.collect().toSeq.map { r => - val v = r.getString(1) - val l = if (v == null) null else v.length - Row(l, l) - }) + df.select(strlen($"a"), strlen("b")), + Row(3, 0)) checkAnswer( - nullStrings.selectExpr("length(s)"), - nullStrings.collect().toSeq.map { r => - val v = r.getString(1) - val l = if (v == null) null else v.length - Row(l) - }) + df.selectExpr("length(a)", "length(b)"), + Row(3, 0)) } test("Levenshtein distance") { @@ -273,4 +266,119 @@ class DataFrameFunctionsSuite extends QueryTest { Row(bytes, "大千世界")) // scalastyle:on } + + test("string trim functions") { + val df = Seq((" example ", "")).toDF("a", "b") + + checkAnswer( + df.select(ltrim($"a"), rtrim($"a"), trim($"a")), + Row("example ", " example", "example")) + + checkAnswer( + df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), + Row("example ", " example", "example")) + } + + test("string formatString function") { + val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") + + checkAnswer( + df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), + Row("aa123cc", "aa123cc")) + + checkAnswer( + df.selectExpr("printf(a, b, c)"), + Row("aa123cc")) + } + + test("string instr function") { + val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c") + + checkAnswer( + df.select(instr($"a", $"b"), instr("a", "b")), + Row(1, 1)) + + checkAnswer( + df.selectExpr("instr(a, b)"), + Row(1)) + } + + test("string locate function") { + val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") + + checkAnswer( + df.select( + locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1), + locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")), + Row(1, 1, 2, 2, 2, 2)) + + checkAnswer( + df.selectExpr("locate(b, a)", "locate(b, a, d)"), + Row(1, 2)) + } + + test("string padding functions") { + val df = Seq(("hi", 5, "??")).toDF("a", "b", "c") + + checkAnswer( + df.select( + lpad($"a", $"b", $"c"), rpad("a", "b", "c"), + lpad($"a", 1, $"c"), rpad("a", 1, "c")), + Row("???hi", "hi???", "h", "h")) + + checkAnswer( + df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"), + Row("???hi", "hi???", "h", "h")) + } + + test("string repeat function") { + val df = Seq(("hi", 2)).toDF("a", "b") + + checkAnswer( + df.select( + repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")), + Row("hihi", "hihi", "hihi", "hihi")) + + checkAnswer( + df.selectExpr("repeat(a, 2)", "repeat(a, b)"), + Row("hihi", "hihi")) + } + + test("string reverse function") { + val df = Seq(("hi", "hhhi")).toDF("a", "b") + + checkAnswer( + df.select(reverse($"a"), reverse("b")), + Row("ih", "ihhh")) + + checkAnswer( + df.selectExpr("reverse(b)"), + Row("ihhh")) + } + + test("string space function") { + val df = Seq((2, 3)).toDF("a", "b") + + checkAnswer( + df.select(space($"a"), space("b")), + Row(" ", " ")) + + checkAnswer( + df.selectExpr("space(b)"), + Row(" ")) + } + + test("string split function") { + val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") + + checkAnswer( + df.select( + split($"a", "[1-9]+"), + split("a", "[1-9]+")), + Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc"))) + + checkAnswer( + df.selectExpr("split(a, '[1-9]+')"), + Row(Seq("aa", "bb", "cc"))) + } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 847d80ad58..60d050b0a0 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -25,6 +25,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import static org.apache.spark.unsafe.PlatformDependent.*; + /** * A UTF-8 String for internal Spark use. * <p> @@ -204,6 +205,196 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { return fromString(toString().toLowerCase()); } + /** + * Copy the bytes from the current UTF8String, and make a new UTF8String. + * @param start the start position of the current UTF8String in bytes. + * @param end the end position of the current UTF8String in bytes. + * @return a new UTF8String in the position of [start, end] of current UTF8String bytes. + */ + private UTF8String copyUTF8String(int start, int end) { + int len = end - start + 1; + byte[] newBytes = new byte[len]; + copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); + return UTF8String.fromBytes(newBytes); + } + + public UTF8String trim() { + int s = 0; + int e = this.numBytes - 1; + // skip all of the space (0x20) in the left side + while (s < this.numBytes && getByte(s) == 0x20) s++; + // skip all of the space (0x20) in the right side + while (e >= 0 && getByte(e) == 0x20) e--; + + if (s > e) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(s, e); + } + } + + public UTF8String trimLeft() { + int s = 0; + // skip all of the space (0x20) in the left side + while (s < this.numBytes && getByte(s) == 0x20) s++; + if (s == this.numBytes) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(s, this.numBytes - 1); + } + } + + public UTF8String trimRight() { + int e = numBytes - 1; + // skip all of the space (0x20) in the right side + while (e >= 0 && getByte(e) == 0x20) e--; + + if (e < 0) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(0, e); + } + } + + public UTF8String reverse() { + byte[] bytes = getBytes(); + byte[] result = new byte[bytes.length]; + + int i = 0; // position in byte + while (i < numBytes) { + int len = numBytesForFirstByte(getByte(i)); + System.arraycopy(bytes, i, result, result.length - i - len, len); + + i += len; + } + + return UTF8String.fromBytes(result); + } + + public UTF8String repeat(int times) { + if (times <=0) { + return fromBytes(new byte[0]); + } + + byte[] newBytes = new byte[numBytes * times]; + System.arraycopy(getBytes(), 0, newBytes, 0, numBytes); + + int copied = 1; + while (copied < times) { + int toCopy = Math.min(copied, times - copied); + System.arraycopy(newBytes, 0, newBytes, copied * numBytes, numBytes * toCopy); + copied += toCopy; + } + + return UTF8String.fromBytes(newBytes); + } + + /** + * Returns the position of the first occurrence of substr in + * current string from the specified position (0-based index). + * + * @param v the string to be searched + * @param start the start position of the current string for searching + * @return the position of the first occurrence of substr, if not found, -1 returned. + */ + public int indexOf(UTF8String v, int start) { + if (v.numBytes() == 0) { + return 0; + } + + // locate to the start position. + int i = 0; // position in byte + int c = 0; // position in character + while (i < numBytes && c < start) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + + do { + if (i + v.numBytes > numBytes) { + return -1; + } + if (ByteArrayMethods.arrayEquals(base, offset + i, v.base, v.offset, v.numBytes)) { + return c; + } + i += numBytesForFirstByte(getByte(i)); + c += 1; + } while(i < numBytes); + + return -1; + } + + /** + * Returns str, right-padded with pad to a length of len + * For example: + * ('hi', 5, '??') => 'hi???' + * ('hi', 1, '??') => 'h' + */ + public UTF8String rpad(int len, UTF8String pad) { + int spaces = len - this.numChars(); // number of char need to pad + if (spaces <= 0) { + // no padding at all, return the substring of the current string + return substring(0, len); + } else { + int padChars = pad.numChars(); + int count = spaces / padChars; // how many padding string needed + // the partial string of the padding + UTF8String remain = pad.substring(0, spaces - padChars * count); + + byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; + System.arraycopy(getBytes(), 0, data, 0, this.numBytes); + int offset = this.numBytes; + int idx = 0; + byte[] padBytes = pad.getBytes(); + while (idx < count) { + System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + ++idx; + offset += pad.numBytes; + } + System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + + return UTF8String.fromBytes(data); + } + } + + /** + * Returns str, left-padded with pad to a length of len. + * For example: + * ('hi', 5, '??') => '???hi' + * ('hi', 1, '??') => 'h' + */ + public UTF8String lpad(int len, UTF8String pad) { + int spaces = len - this.numChars(); // number of char need to pad + if (spaces <= 0) { + // no padding at all, return the substring of the current string + return substring(0, len); + } else { + int padChars = pad.numChars(); + int count = spaces / padChars; // how many padding string needed + // the partial string of the padding + UTF8String remain = pad.substring(0, spaces - padChars * count); + + byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; + + int offset = 0; + int idx = 0; + byte[] padBytes = pad.getBytes(); + while (idx < count) { + System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + ++idx; + offset += pad.numBytes; + } + System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + offset += remain.numBytes; + System.arraycopy(getBytes(), 0, data, offset, numBytes()); + + return UTF8String.fromBytes(data); + } + } + @Override public String toString() { try { diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index fb463ba17f..694bdc29f3 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -121,12 +121,94 @@ public class UTF8StringSuite { @Test public void substring() { - assertEquals(fromString("hello").substring(0, 0), fromString("")); - assertEquals(fromString("hello").substring(1, 3), fromString("el")); - assertEquals(fromString("数据砖头").substring(0, 1), fromString("数")); - assertEquals(fromString("数据砖头").substring(1, 3), fromString("据砖")); - assertEquals(fromString("数据砖头").substring(3, 5), fromString("头")); - assertEquals(fromString("ߵ梷").substring(0, 2), fromString("ߵ梷")); + assertEquals(fromString(""), fromString("hello").substring(0, 0)); + assertEquals(fromString("el"), fromString("hello").substring(1, 3)); + assertEquals(fromString("数"), fromString("数据砖头").substring(0, 1)); + assertEquals(fromString("据砖"), fromString("数据砖头").substring(1, 3)); + assertEquals(fromString("头"), fromString("数据砖头").substring(3, 5)); + assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0, 2)); + } + + @Test + public void trims() { + assertEquals(fromString("hello"), fromString(" hello ").trim()); + assertEquals(fromString("hello "), fromString(" hello ").trimLeft()); + assertEquals(fromString(" hello"), fromString(" hello ").trimRight()); + + assertEquals(fromString(""), fromString(" ").trim()); + assertEquals(fromString(""), fromString(" ").trimLeft()); + assertEquals(fromString(""), fromString(" ").trimRight()); + + assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); + assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft()); + assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight()); + + assertEquals(fromString("数据砖头"), fromString("数据砖头").trim()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimLeft()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimRight()); + } + + @Test + public void indexOf() { + assertEquals(0, fromString("").indexOf(fromString(""), 0)); + assertEquals(-1, fromString("").indexOf(fromString("l"), 0)); + assertEquals(0, fromString("hello").indexOf(fromString(""), 0)); + assertEquals(2, fromString("hello").indexOf(fromString("l"), 0)); + assertEquals(3, fromString("hello").indexOf(fromString("l"), 3)); + assertEquals(-1, fromString("hello").indexOf(fromString("a"), 0)); + assertEquals(2, fromString("hello").indexOf(fromString("ll"), 0)); + assertEquals(-1, fromString("hello").indexOf(fromString("ll"), 4)); + assertEquals(1, fromString("数据砖头").indexOf(fromString("据砖"), 0)); + assertEquals(-1, fromString("数据砖头").indexOf(fromString("数"), 3)); + assertEquals(0, fromString("数据砖头").indexOf(fromString("数"), 0)); + assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0)); + } + + @Test + public void reverse() { + assertEquals(fromString("olleh"), fromString("hello").reverse()); + assertEquals(fromString(""), fromString("").reverse()); + assertEquals(fromString("者行孙"), fromString("孙行者").reverse()); + assertEquals(fromString("者行孙 olleh"), fromString("hello 孙行者").reverse()); + } + + @Test + public void repeat() { + assertEquals(fromString("数d数d数d数d数d"), fromString("数d").repeat(5)); + assertEquals(fromString("数d"), fromString("数d").repeat(1)); + assertEquals(fromString(""), fromString("数d").repeat(-1)); + } + + @Test + public void pad() { + assertEquals(fromString("hel"), fromString("hello").lpad(3, fromString("????"))); + assertEquals(fromString("hello"), fromString("hello").lpad(5, fromString("????"))); + assertEquals(fromString("?hello"), fromString("hello").lpad(6, fromString("????"))); + assertEquals(fromString("???????hello"), fromString("hello").lpad(12, fromString("????"))); + assertEquals(fromString("?????hello"), fromString("hello").lpad(10, fromString("?????"))); + assertEquals(fromString("???????"), fromString("").lpad(7, fromString("?????"))); + + assertEquals(fromString("hel"), fromString("hello").rpad(3, fromString("????"))); + assertEquals(fromString("hello"), fromString("hello").rpad(5, fromString("????"))); + assertEquals(fromString("hello?"), fromString("hello").rpad(6, fromString("????"))); + assertEquals(fromString("hello???????"), fromString("hello").rpad(12, fromString("????"))); + assertEquals(fromString("hello?????"), fromString("hello").rpad(10, fromString("?????"))); + assertEquals(fromString("???????"), fromString("").rpad(7, fromString("?????"))); + + + assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, fromString("????"))); + assertEquals(fromString("?数据砖头"), fromString("数据砖头").lpad(5, fromString("????"))); + assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????"))); + assertEquals(fromString("孙行数据砖头"), fromString("数据砖头").lpad(6, fromString("孙行者"))); + assertEquals(fromString("孙行者数据砖头"), fromString("数据砖头").lpad(7, fromString("孙行者"))); + assertEquals(fromString("孙行者孙行者孙行数据砖头"), fromString("数据砖头").lpad(12, fromString("孙行者"))); + + assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, fromString("????"))); + assertEquals(fromString("数据砖头?"), fromString("数据砖头").rpad(5, fromString("????"))); + assertEquals(fromString("数据砖头??"), fromString("数据砖头").rpad(6, fromString("????"))); + assertEquals(fromString("数据砖头孙行"), fromString("数据砖头").rpad(6, fromString("孙行者"))); + assertEquals(fromString("数据砖头孙行者"), fromString("数据砖头").rpad(7, fromString("孙行者"))); + assertEquals(fromString("数据砖头孙行者孙行者孙行"), fromString("数据砖头").rpad(12, fromString("孙行者"))); } @Test |