diff options
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala | 1 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala | 37 | ||||
-rw-r--r-- | sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala) | 24 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 22 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 242 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala | 284 | ||||
-rw-r--r-- | sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala | 4 | ||||
-rw-r--r-- | unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java | 40 | ||||
-rw-r--r-- | unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java | 14 |
9 files changed, 421 insertions, 247 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ce552a1d65..d1cda6bc27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -152,6 +152,7 @@ object FunctionRegistry { // string functions expression[Ascii]("ascii"), expression[Base64]("base64"), + expression[Concat]("concat"), expression[Encode]("encode"), expression[Decode]("decode"), expression[FormatNumber]("format_number"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index c64afe7b3f..b36354eff0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -27,6 +27,43 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines expressions for string operations. +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +/** + * An expression that concatenates multiple input strings into a single string. + * Input expressions that are evaluated to nulls are skipped. + * + * For example, `concat("a", null, "b")` is evaluated to `"ab"`. + * + * Note that this is different from Hive since Hive outputs null if any input is null. + * We never output null. + */ +case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + override def dataType: DataType = StringType + + override def nullable: Boolean = false + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = { + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val evals = children.map(_.gen(ctx)) + val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(", ") + evals.map(_.code).mkString("\n") + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.primitive} = UTF8String.concat($inputs); + """ + } +} + trait StringRegexExpression extends ImplicitCastInputTypes { self: BinaryExpression => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 5d7763bedf..0ed567a90d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -22,7 +22,29 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ -class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { +class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("concat") { + def testConcat(inputs: String*): Unit = { + val expected = inputs.filter(_ != null).mkString + checkEvaluation(Concat(inputs.map(Literal.create(_, StringType))), expected, EmptyRow) + } + + testConcat() + testConcat(null) + testConcat("") + testConcat("ab") + testConcat("a", "b") + testConcat("a", "b", "C") + testConcat("a", null, "C") + testConcat("a", null, null) + testConcat(null, null, null) + + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + testConcat("数据", null, "砖头") + // scalastyle:on + } test("StringComparison") { val row = create_row("abc", null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b56fd9a71b..c180407389 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1711,6 +1711,28 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** + * Concatenates input strings together into a single string. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat(exprs: Column*): Column = Concat(exprs.map(_.expr)) + + /** + * Concatenates input strings together into a single string. + * + * This is the variant of concat that takes in the column names. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat(columnName: String, columnNames: String*): Column = { + concat((columnName +: columnNames).map(Column.apply): _*) + } + + /** * Computes the length of a given string / binary value. * * @group string_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6dccdd857b..29f1197a85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -208,169 +208,6 @@ class DataFrameFunctionsSuite extends QueryTest { Row(2743272264L, 2180413220L)) } - test("Levenshtein distance") { - val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") - checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1))) - checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1))) - } - - test("string ascii function") { - val df = Seq(("abc", "")).toDF("a", "b") - checkAnswer( - df.select(ascii($"a"), ascii("b")), - Row(97, 0)) - - checkAnswer( - df.selectExpr("ascii(a)", "ascii(b)"), - Row(97, 0)) - } - - test("string base64/unbase64 function") { - val bytes = Array[Byte](1, 2, 3, 4) - val df = Seq((bytes, "AQIDBA==")).toDF("a", "b") - checkAnswer( - df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")), - Row("AQIDBA==", "AQIDBA==", bytes, bytes)) - - checkAnswer( - df.selectExpr("base64(a)", "unbase64(b)"), - Row("AQIDBA==", bytes)) - } - - test("string encode/decode function") { - val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116) - // scalastyle:off - // non ascii characters are not allowed in the code, so we disable the scalastyle here. - val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c") - checkAnswer( - df.select( - encode($"a", "utf-8"), - encode("a", "utf-8"), - decode($"c", "utf-8"), - decode("c", "utf-8")), - Row(bytes, bytes, "大千世界", "大千世界")) - - checkAnswer( - df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"), - Row(bytes, "大千世界")) - // scalastyle:on - } - - test("string trim functions") { - val df = Seq((" example ", "")).toDF("a", "b") - - checkAnswer( - df.select(ltrim($"a"), rtrim($"a"), trim($"a")), - Row("example ", " example", "example")) - - checkAnswer( - df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), - Row("example ", " example", "example")) - } - - test("string formatString function") { - val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") - - checkAnswer( - df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), - Row("aa123cc", "aa123cc")) - - checkAnswer( - df.selectExpr("printf(a, b, c)"), - Row("aa123cc")) - } - - test("string instr function") { - val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c") - - checkAnswer( - df.select(instr($"a", $"b"), instr("a", "b")), - Row(1, 1)) - - checkAnswer( - df.selectExpr("instr(a, b)"), - Row(1)) - } - - test("string locate function") { - val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") - - checkAnswer( - df.select( - locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1), - locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")), - Row(1, 1, 2, 2, 2, 2)) - - checkAnswer( - df.selectExpr("locate(b, a)", "locate(b, a, d)"), - Row(1, 2)) - } - - test("string padding functions") { - val df = Seq(("hi", 5, "??")).toDF("a", "b", "c") - - checkAnswer( - df.select( - lpad($"a", $"b", $"c"), rpad("a", "b", "c"), - lpad($"a", 1, $"c"), rpad("a", 1, "c")), - Row("???hi", "hi???", "h", "h")) - - checkAnswer( - df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"), - Row("???hi", "hi???", "h", "h")) - } - - test("string repeat function") { - val df = Seq(("hi", 2)).toDF("a", "b") - - checkAnswer( - df.select( - repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")), - Row("hihi", "hihi", "hihi", "hihi")) - - checkAnswer( - df.selectExpr("repeat(a, 2)", "repeat(a, b)"), - Row("hihi", "hihi")) - } - - test("string reverse function") { - val df = Seq(("hi", "hhhi")).toDF("a", "b") - - checkAnswer( - df.select(reverse($"a"), reverse("b")), - Row("ih", "ihhh")) - - checkAnswer( - df.selectExpr("reverse(b)"), - Row("ihhh")) - } - - test("string space function") { - val df = Seq((2, 3)).toDF("a", "b") - - checkAnswer( - df.select(space($"a"), space("b")), - Row(" ", " ")) - - checkAnswer( - df.selectExpr("space(b)"), - Row(" ")) - } - - test("string split function") { - val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") - - checkAnswer( - df.select( - split($"a", "[1-9]+"), - split("a", "[1-9]+")), - Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc"))) - - checkAnswer( - df.selectExpr("split(a, '[1-9]+')"), - Row(Seq("aa", "bb", "cc"))) - } - test("conditional function: least") { checkAnswer( testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1), @@ -430,83 +267,4 @@ class DataFrameFunctionsSuite extends QueryTest { ) } - test("string / binary length function") { - val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") - checkAnswer( - df.select(length($"a"), length("a"), length($"b"), length("b")), - Row(3, 3, 4, 4)) - - checkAnswer( - df.selectExpr("length(a)", "length(b)"), - Row(3, 4)) - - intercept[AnalysisException] { - checkAnswer( - df.selectExpr("length(c)"), // int type of the argument is unacceptable - Row("5.0000")) - } - } - - test("number format function") { - val tuple = - ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], - 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381)) - val df = - Seq(tuple) - .toDF( - "a", // string "aa" - "b", // byte 1 - "c", // short 2 - "d", // float 3.13223f - "e", // integer 4 - "f", // long 5L - "g", // double 6.48173d - "h") // decimal 7.128381 - - checkAnswer( - df.select( - format_number($"f", 4), - format_number("f", 4)), - Row("5.0000", "5.0000")) - - checkAnswer( - df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer - Row("1.0000")) - - checkAnswer( - df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer - Row("2.0000")) - - checkAnswer( - df.selectExpr("format_number(d, e)"), // convert the 1st argument to double - Row("3.1322")) - - checkAnswer( - df.selectExpr("format_number(e, e)"), // not convert anything - Row("4.0000")) - - checkAnswer( - df.selectExpr("format_number(f, e)"), // not convert anything - Row("5.0000")) - - checkAnswer( - df.selectExpr("format_number(g, e)"), // not convert anything - Row("6.4817")) - - checkAnswer( - df.selectExpr("format_number(h, e)"), // not convert anything - Row("7.1284")) - - intercept[AnalysisException] { - checkAnswer( - df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable - Row("5.0000")) - } - - intercept[AnalysisException] { - checkAnswer( - df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable - Row("5.0000")) - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala new file mode 100644 index 0000000000..4eff33ed45 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.Decimal + + +class StringFunctionsSuite extends QueryTest { + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + test("string concat") { + val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") + + checkAnswer( + df.select(concat($"a", $"b", $"c")), + Row("ab")) + + checkAnswer( + df.selectExpr("concat(a, b, c)"), + Row("ab")) + } + + + test("string Levenshtein distance") { + val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") + checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1))) + checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1))) + } + + test("string ascii function") { + val df = Seq(("abc", "")).toDF("a", "b") + checkAnswer( + df.select(ascii($"a"), ascii("b")), + Row(97, 0)) + + checkAnswer( + df.selectExpr("ascii(a)", "ascii(b)"), + Row(97, 0)) + } + + test("string base64/unbase64 function") { + val bytes = Array[Byte](1, 2, 3, 4) + val df = Seq((bytes, "AQIDBA==")).toDF("a", "b") + checkAnswer( + df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")), + Row("AQIDBA==", "AQIDBA==", bytes, bytes)) + + checkAnswer( + df.selectExpr("base64(a)", "unbase64(b)"), + Row("AQIDBA==", bytes)) + } + + test("string encode/decode function") { + val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116) + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c") + checkAnswer( + df.select( + encode($"a", "utf-8"), + encode("a", "utf-8"), + decode($"c", "utf-8"), + decode("c", "utf-8")), + Row(bytes, bytes, "大千世界", "大千世界")) + + checkAnswer( + df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"), + Row(bytes, "大千世界")) + // scalastyle:on + } + + test("string trim functions") { + val df = Seq((" example ", "")).toDF("a", "b") + + checkAnswer( + df.select(ltrim($"a"), rtrim($"a"), trim($"a")), + Row("example ", " example", "example")) + + checkAnswer( + df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), + Row("example ", " example", "example")) + } + + test("string formatString function") { + val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") + + checkAnswer( + df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), + Row("aa123cc", "aa123cc")) + + checkAnswer( + df.selectExpr("printf(a, b, c)"), + Row("aa123cc")) + } + + test("string instr function") { + val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c") + + checkAnswer( + df.select(instr($"a", $"b"), instr("a", "b")), + Row(1, 1)) + + checkAnswer( + df.selectExpr("instr(a, b)"), + Row(1)) + } + + test("string locate function") { + val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") + + checkAnswer( + df.select( + locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1), + locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")), + Row(1, 1, 2, 2, 2, 2)) + + checkAnswer( + df.selectExpr("locate(b, a)", "locate(b, a, d)"), + Row(1, 2)) + } + + test("string padding functions") { + val df = Seq(("hi", 5, "??")).toDF("a", "b", "c") + + checkAnswer( + df.select( + lpad($"a", $"b", $"c"), rpad("a", "b", "c"), + lpad($"a", 1, $"c"), rpad("a", 1, "c")), + Row("???hi", "hi???", "h", "h")) + + checkAnswer( + df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"), + Row("???hi", "hi???", "h", "h")) + } + + test("string repeat function") { + val df = Seq(("hi", 2)).toDF("a", "b") + + checkAnswer( + df.select( + repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")), + Row("hihi", "hihi", "hihi", "hihi")) + + checkAnswer( + df.selectExpr("repeat(a, 2)", "repeat(a, b)"), + Row("hihi", "hihi")) + } + + test("string reverse function") { + val df = Seq(("hi", "hhhi")).toDF("a", "b") + + checkAnswer( + df.select(reverse($"a"), reverse("b")), + Row("ih", "ihhh")) + + checkAnswer( + df.selectExpr("reverse(b)"), + Row("ihhh")) + } + + test("string space function") { + val df = Seq((2, 3)).toDF("a", "b") + + checkAnswer( + df.select(space($"a"), space("b")), + Row(" ", " ")) + + checkAnswer( + df.selectExpr("space(b)"), + Row(" ")) + } + + test("string split function") { + val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") + + checkAnswer( + df.select( + split($"a", "[1-9]+"), + split("a", "[1-9]+")), + Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc"))) + + checkAnswer( + df.selectExpr("split(a, '[1-9]+')"), + Row(Seq("aa", "bb", "cc"))) + } + + test("string / binary length function") { + val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") + checkAnswer( + df.select(length($"a"), length("a"), length($"b"), length("b")), + Row(3, 3, 4, 4)) + + checkAnswer( + df.selectExpr("length(a)", "length(b)"), + Row(3, 4)) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("length(c)"), // int type of the argument is unacceptable + Row("5.0000")) + } + } + + test("number format function") { + val tuple = + ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], + 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381)) + val df = + Seq(tuple) + .toDF( + "a", // string "aa" + "b", // byte 1 + "c", // short 2 + "d", // float 3.13223f + "e", // integer 4 + "f", // long 5L + "g", // double 6.48173d + "h") // decimal 7.128381 + + checkAnswer( + df.select( + format_number($"f", 4), + format_number("f", 4)), + Row("5.0000", "5.0000")) + + checkAnswer( + df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer + Row("1.0000")) + + checkAnswer( + df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer + Row("2.0000")) + + checkAnswer( + df.selectExpr("format_number(d, e)"), // convert the 1st argument to double + Row("3.1322")) + + checkAnswer( + df.selectExpr("format_number(e, e)"), // not convert anything + Row("4.0000")) + + checkAnswer( + df.selectExpr("format_number(f, e)"), // not convert anything + Row("5.0000")) + + checkAnswer( + df.selectExpr("format_number(g, e)"), // not convert anything + Row("6.4817")) + + checkAnswer( + df.selectExpr("format_number(h, e)"), // not convert anything + Row("7.1284")) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable + Row("5.0000")) + } + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable + Row("5.0000")) + } + } +} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 6b8f2f6217..299cc599ff 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -256,6 +256,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "timestamp_2", "timestamp_udf", + // Hive outputs NULL if any concat input has null. We never output null for concat. + "udf_concat", + // Unlike Hive, we do support log base in (0, 1.0], therefore disable this "udf7" ) @@ -846,7 +849,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_case", "udf_ceil", "udf_ceiling", - "udf_concat", "udf_concat_insert1", "udf_concat_insert2", "udf_concat_ws", diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e7f9fbb2bc..9723b6e083 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -21,6 +21,7 @@ import javax.annotation.Nonnull; import java.io.Serializable; import java.io.UnsupportedEncodingException; +import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import static org.apache.spark.unsafe.PlatformDependent.*; @@ -322,7 +323,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { } i += numBytesForFirstByte(getByte(i)); c += 1; - } while(i < numBytes); + } while (i < numBytes); return -1; } @@ -395,6 +396,39 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { } } + /** + * Concatenates input strings together into a single string. A null input is skipped. + * For example, concat("a", null, "c") would yield "ac". + */ + public static UTF8String concat(UTF8String... inputs) { + if (inputs == null) { + return fromBytes(new byte[0]); + } + + // Compute the total length of the result. + int totalLength = 0; + for (int i = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + totalLength += inputs[i].numBytes; + } + } + + // Allocate a new byte array, and copy the inputs one by one into it. + final byte[] result = new byte[totalLength]; + int offset = 0; + for (int i = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + int len = inputs[i].numBytes; + PlatformDependent.copyMemory( + inputs[i].base, inputs[i].offset, + result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + len); + offset += len; + } + } + return fromBytes(result); + } + @Override public String toString() { try { @@ -413,7 +447,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { } @Override - public int compareTo(final UTF8String other) { + public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); // TODO: compare 8 bytes as unsigned long for (int i = 0; i < len; i ++) { @@ -434,7 +468,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { public boolean equals(final Object other) { if (other instanceof UTF8String) { UTF8String o = (UTF8String) other; - if (numBytes != o.numBytes){ + if (numBytes != o.numBytes) { return false; } return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes); diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 694bdc29f3..0db7522b50 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -87,6 +87,20 @@ public class UTF8StringSuite { } @Test + public void concatTest() { + assertEquals(concat(), fromString("")); + assertEquals(concat(null), fromString("")); + assertEquals(concat(fromString("")), fromString("")); + assertEquals(concat(fromString("ab")), fromString("ab")); + assertEquals(concat(fromString("a"), fromString("b")), fromString("ab")); + assertEquals(concat(fromString("a"), fromString("b"), fromString("c")), fromString("abc")); + assertEquals(concat(fromString("a"), null, fromString("c")), fromString("ac")); + assertEquals(concat(fromString("a"), null, null), fromString("a")); + assertEquals(concat(null, null, null), fromString("")); + assertEquals(concat(fromString("数据"), fromString("砖头")), fromString("数据砖头")); + } + + @Test public void contains() { assertTrue(fromString("").contains(fromString(""))); assertTrue(fromString("hello").contains(fromString("ello"))); |