aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-07-18 14:07:56 -0700
committerReynold Xin <rxin@databricks.com>2015-07-18 14:07:56 -0700
commit6e1e2eba696e89ba57bf5450b9c72c4386e43dc8 (patch)
treeff09024d2e45656f412ba39f331508f1cf436eab
parent3d2134fc0d90379b89da08de7614aef1ac674b1b (diff)
downloadspark-6e1e2eba696e89ba57bf5450b9c72c4386e43dc8.tar.gz
spark-6e1e2eba696e89ba57bf5450b9c72c4386e43dc8.tar.bz2
spark-6e1e2eba696e89ba57bf5450b9c72c4386e43dc8.zip
[SPARK-8240][SQL] string function: concat
Author: Reynold Xin <rxin@databricks.com> Closes #7486 from rxin/concat and squashes the following commits: 5217d6e [Reynold Xin] Removed Hive's concat test. f5cb7a3 [Reynold Xin] Concat is never nullable. ae4e61f [Reynold Xin] Removed extra import. fddcbbd [Reynold Xin] Fixed NPE. 22e831c [Reynold Xin] Added missing file. 57a2352 [Reynold Xin] [SPARK-8240][SQL] string function: concat
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala37
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala)24
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala242
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala284
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala4
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java40
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java14
9 files changed, 421 insertions, 247 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index ce552a1d65..d1cda6bc27 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -152,6 +152,7 @@ object FunctionRegistry {
// string functions
expression[Ascii]("ascii"),
expression[Base64]("base64"),
+ expression[Concat]("concat"),
expression[Encode]("encode"),
expression[Decode]("decode"),
expression[FormatNumber]("format_number"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index c64afe7b3f..b36354eff0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -27,6 +27,43 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// This file defines expressions for string operations.
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+
+/**
+ * An expression that concatenates multiple input strings into a single string.
+ * Input expressions that are evaluated to nulls are skipped.
+ *
+ * For example, `concat("a", null, "b")` is evaluated to `"ab"`.
+ *
+ * Note that this is different from Hive since Hive outputs null if any input is null.
+ * We never output null.
+ */
+case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType)
+ override def dataType: DataType = StringType
+
+ override def nullable: Boolean = false
+ override def foldable: Boolean = children.forall(_.foldable)
+
+ override def eval(input: InternalRow): Any = {
+ val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
+ UTF8String.concat(inputs : _*)
+ }
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val evals = children.map(_.gen(ctx))
+ val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(", ")
+ evals.map(_.code).mkString("\n") + s"""
+ boolean ${ev.isNull} = false;
+ UTF8String ${ev.primitive} = UTF8String.concat($inputs);
+ """
+ }
+}
+
trait StringRegexExpression extends ImplicitCastInputTypes {
self: BinaryExpression =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 5d7763bedf..0ed567a90d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -22,7 +22,29 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
-class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ test("concat") {
+ def testConcat(inputs: String*): Unit = {
+ val expected = inputs.filter(_ != null).mkString
+ checkEvaluation(Concat(inputs.map(Literal.create(_, StringType))), expected, EmptyRow)
+ }
+
+ testConcat()
+ testConcat(null)
+ testConcat("")
+ testConcat("ab")
+ testConcat("a", "b")
+ testConcat("a", "b", "C")
+ testConcat("a", null, "C")
+ testConcat("a", null, null)
+ testConcat(null, null, null)
+
+ // scalastyle:off
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+ testConcat("数据", null, "砖头")
+ // scalastyle:on
+ }
test("StringComparison") {
val row = create_row("abc", null)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index b56fd9a71b..c180407389 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1711,6 +1711,28 @@ object functions {
//////////////////////////////////////////////////////////////////////////////////////////////
/**
+ * Concatenates input strings together into a single string.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ @scala.annotation.varargs
+ def concat(exprs: Column*): Column = Concat(exprs.map(_.expr))
+
+ /**
+ * Concatenates input strings together into a single string.
+ *
+ * This is the variant of concat that takes in the column names.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ @scala.annotation.varargs
+ def concat(columnName: String, columnNames: String*): Column = {
+ concat((columnName +: columnNames).map(Column.apply): _*)
+ }
+
+ /**
* Computes the length of a given string / binary value.
*
* @group string_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 6dccdd857b..29f1197a85 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -208,169 +208,6 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(2743272264L, 2180413220L))
}
- test("Levenshtein distance") {
- val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
- checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1)))
- checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1)))
- }
-
- test("string ascii function") {
- val df = Seq(("abc", "")).toDF("a", "b")
- checkAnswer(
- df.select(ascii($"a"), ascii("b")),
- Row(97, 0))
-
- checkAnswer(
- df.selectExpr("ascii(a)", "ascii(b)"),
- Row(97, 0))
- }
-
- test("string base64/unbase64 function") {
- val bytes = Array[Byte](1, 2, 3, 4)
- val df = Seq((bytes, "AQIDBA==")).toDF("a", "b")
- checkAnswer(
- df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")),
- Row("AQIDBA==", "AQIDBA==", bytes, bytes))
-
- checkAnswer(
- df.selectExpr("base64(a)", "unbase64(b)"),
- Row("AQIDBA==", bytes))
- }
-
- test("string encode/decode function") {
- val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116)
- // scalastyle:off
- // non ascii characters are not allowed in the code, so we disable the scalastyle here.
- val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c")
- checkAnswer(
- df.select(
- encode($"a", "utf-8"),
- encode("a", "utf-8"),
- decode($"c", "utf-8"),
- decode("c", "utf-8")),
- Row(bytes, bytes, "大千世界", "大千世界"))
-
- checkAnswer(
- df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"),
- Row(bytes, "大千世界"))
- // scalastyle:on
- }
-
- test("string trim functions") {
- val df = Seq((" example ", "")).toDF("a", "b")
-
- checkAnswer(
- df.select(ltrim($"a"), rtrim($"a"), trim($"a")),
- Row("example ", " example", "example"))
-
- checkAnswer(
- df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"),
- Row("example ", " example", "example"))
- }
-
- test("string formatString function") {
- val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c")
-
- checkAnswer(
- df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
- Row("aa123cc", "aa123cc"))
-
- checkAnswer(
- df.selectExpr("printf(a, b, c)"),
- Row("aa123cc"))
- }
-
- test("string instr function") {
- val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c")
-
- checkAnswer(
- df.select(instr($"a", $"b"), instr("a", "b")),
- Row(1, 1))
-
- checkAnswer(
- df.selectExpr("instr(a, b)"),
- Row(1))
- }
-
- test("string locate function") {
- val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d")
-
- checkAnswer(
- df.select(
- locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1),
- locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")),
- Row(1, 1, 2, 2, 2, 2))
-
- checkAnswer(
- df.selectExpr("locate(b, a)", "locate(b, a, d)"),
- Row(1, 2))
- }
-
- test("string padding functions") {
- val df = Seq(("hi", 5, "??")).toDF("a", "b", "c")
-
- checkAnswer(
- df.select(
- lpad($"a", $"b", $"c"), rpad("a", "b", "c"),
- lpad($"a", 1, $"c"), rpad("a", 1, "c")),
- Row("???hi", "hi???", "h", "h"))
-
- checkAnswer(
- df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"),
- Row("???hi", "hi???", "h", "h"))
- }
-
- test("string repeat function") {
- val df = Seq(("hi", 2)).toDF("a", "b")
-
- checkAnswer(
- df.select(
- repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")),
- Row("hihi", "hihi", "hihi", "hihi"))
-
- checkAnswer(
- df.selectExpr("repeat(a, 2)", "repeat(a, b)"),
- Row("hihi", "hihi"))
- }
-
- test("string reverse function") {
- val df = Seq(("hi", "hhhi")).toDF("a", "b")
-
- checkAnswer(
- df.select(reverse($"a"), reverse("b")),
- Row("ih", "ihhh"))
-
- checkAnswer(
- df.selectExpr("reverse(b)"),
- Row("ihhh"))
- }
-
- test("string space function") {
- val df = Seq((2, 3)).toDF("a", "b")
-
- checkAnswer(
- df.select(space($"a"), space("b")),
- Row(" ", " "))
-
- checkAnswer(
- df.selectExpr("space(b)"),
- Row(" "))
- }
-
- test("string split function") {
- val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b")
-
- checkAnswer(
- df.select(
- split($"a", "[1-9]+"),
- split("a", "[1-9]+")),
- Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc")))
-
- checkAnswer(
- df.selectExpr("split(a, '[1-9]+')"),
- Row(Seq("aa", "bb", "cc")))
- }
-
test("conditional function: least") {
checkAnswer(
testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1),
@@ -430,83 +267,4 @@ class DataFrameFunctionsSuite extends QueryTest {
)
}
- test("string / binary length function") {
- val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c")
- checkAnswer(
- df.select(length($"a"), length("a"), length($"b"), length("b")),
- Row(3, 3, 4, 4))
-
- checkAnswer(
- df.selectExpr("length(a)", "length(b)"),
- Row(3, 4))
-
- intercept[AnalysisException] {
- checkAnswer(
- df.selectExpr("length(c)"), // int type of the argument is unacceptable
- Row("5.0000"))
- }
- }
-
- test("number format function") {
- val tuple =
- ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short],
- 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381))
- val df =
- Seq(tuple)
- .toDF(
- "a", // string "aa"
- "b", // byte 1
- "c", // short 2
- "d", // float 3.13223f
- "e", // integer 4
- "f", // long 5L
- "g", // double 6.48173d
- "h") // decimal 7.128381
-
- checkAnswer(
- df.select(
- format_number($"f", 4),
- format_number("f", 4)),
- Row("5.0000", "5.0000"))
-
- checkAnswer(
- df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer
- Row("1.0000"))
-
- checkAnswer(
- df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer
- Row("2.0000"))
-
- checkAnswer(
- df.selectExpr("format_number(d, e)"), // convert the 1st argument to double
- Row("3.1322"))
-
- checkAnswer(
- df.selectExpr("format_number(e, e)"), // not convert anything
- Row("4.0000"))
-
- checkAnswer(
- df.selectExpr("format_number(f, e)"), // not convert anything
- Row("5.0000"))
-
- checkAnswer(
- df.selectExpr("format_number(g, e)"), // not convert anything
- Row("6.4817"))
-
- checkAnswer(
- df.selectExpr("format_number(h, e)"), // not convert anything
- Row("7.1284"))
-
- intercept[AnalysisException] {
- checkAnswer(
- df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable
- Row("5.0000"))
- }
-
- intercept[AnalysisException] {
- checkAnswer(
- df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable
- Row("5.0000"))
- }
- }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
new file mode 100644
index 0000000000..4eff33ed45
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -0,0 +1,284 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.Decimal
+
+
+class StringFunctionsSuite extends QueryTest {
+
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+
+ test("string concat") {
+ val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")
+
+ checkAnswer(
+ df.select(concat($"a", $"b", $"c")),
+ Row("ab"))
+
+ checkAnswer(
+ df.selectExpr("concat(a, b, c)"),
+ Row("ab"))
+ }
+
+
+ test("string Levenshtein distance") {
+ val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
+ checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1)))
+ checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1)))
+ }
+
+ test("string ascii function") {
+ val df = Seq(("abc", "")).toDF("a", "b")
+ checkAnswer(
+ df.select(ascii($"a"), ascii("b")),
+ Row(97, 0))
+
+ checkAnswer(
+ df.selectExpr("ascii(a)", "ascii(b)"),
+ Row(97, 0))
+ }
+
+ test("string base64/unbase64 function") {
+ val bytes = Array[Byte](1, 2, 3, 4)
+ val df = Seq((bytes, "AQIDBA==")).toDF("a", "b")
+ checkAnswer(
+ df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")),
+ Row("AQIDBA==", "AQIDBA==", bytes, bytes))
+
+ checkAnswer(
+ df.selectExpr("base64(a)", "unbase64(b)"),
+ Row("AQIDBA==", bytes))
+ }
+
+ test("string encode/decode function") {
+ val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116)
+ // scalastyle:off
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+ val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c")
+ checkAnswer(
+ df.select(
+ encode($"a", "utf-8"),
+ encode("a", "utf-8"),
+ decode($"c", "utf-8"),
+ decode("c", "utf-8")),
+ Row(bytes, bytes, "大千世界", "大千世界"))
+
+ checkAnswer(
+ df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"),
+ Row(bytes, "大千世界"))
+ // scalastyle:on
+ }
+
+ test("string trim functions") {
+ val df = Seq((" example ", "")).toDF("a", "b")
+
+ checkAnswer(
+ df.select(ltrim($"a"), rtrim($"a"), trim($"a")),
+ Row("example ", " example", "example"))
+
+ checkAnswer(
+ df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"),
+ Row("example ", " example", "example"))
+ }
+
+ test("string formatString function") {
+ val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c")
+
+ checkAnswer(
+ df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
+ Row("aa123cc", "aa123cc"))
+
+ checkAnswer(
+ df.selectExpr("printf(a, b, c)"),
+ Row("aa123cc"))
+ }
+
+ test("string instr function") {
+ val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c")
+
+ checkAnswer(
+ df.select(instr($"a", $"b"), instr("a", "b")),
+ Row(1, 1))
+
+ checkAnswer(
+ df.selectExpr("instr(a, b)"),
+ Row(1))
+ }
+
+ test("string locate function") {
+ val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d")
+
+ checkAnswer(
+ df.select(
+ locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1),
+ locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")),
+ Row(1, 1, 2, 2, 2, 2))
+
+ checkAnswer(
+ df.selectExpr("locate(b, a)", "locate(b, a, d)"),
+ Row(1, 2))
+ }
+
+ test("string padding functions") {
+ val df = Seq(("hi", 5, "??")).toDF("a", "b", "c")
+
+ checkAnswer(
+ df.select(
+ lpad($"a", $"b", $"c"), rpad("a", "b", "c"),
+ lpad($"a", 1, $"c"), rpad("a", 1, "c")),
+ Row("???hi", "hi???", "h", "h"))
+
+ checkAnswer(
+ df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"),
+ Row("???hi", "hi???", "h", "h"))
+ }
+
+ test("string repeat function") {
+ val df = Seq(("hi", 2)).toDF("a", "b")
+
+ checkAnswer(
+ df.select(
+ repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")),
+ Row("hihi", "hihi", "hihi", "hihi"))
+
+ checkAnswer(
+ df.selectExpr("repeat(a, 2)", "repeat(a, b)"),
+ Row("hihi", "hihi"))
+ }
+
+ test("string reverse function") {
+ val df = Seq(("hi", "hhhi")).toDF("a", "b")
+
+ checkAnswer(
+ df.select(reverse($"a"), reverse("b")),
+ Row("ih", "ihhh"))
+
+ checkAnswer(
+ df.selectExpr("reverse(b)"),
+ Row("ihhh"))
+ }
+
+ test("string space function") {
+ val df = Seq((2, 3)).toDF("a", "b")
+
+ checkAnswer(
+ df.select(space($"a"), space("b")),
+ Row(" ", " "))
+
+ checkAnswer(
+ df.selectExpr("space(b)"),
+ Row(" "))
+ }
+
+ test("string split function") {
+ val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b")
+
+ checkAnswer(
+ df.select(
+ split($"a", "[1-9]+"),
+ split("a", "[1-9]+")),
+ Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc")))
+
+ checkAnswer(
+ df.selectExpr("split(a, '[1-9]+')"),
+ Row(Seq("aa", "bb", "cc")))
+ }
+
+ test("string / binary length function") {
+ val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c")
+ checkAnswer(
+ df.select(length($"a"), length("a"), length($"b"), length("b")),
+ Row(3, 3, 4, 4))
+
+ checkAnswer(
+ df.selectExpr("length(a)", "length(b)"),
+ Row(3, 4))
+
+ intercept[AnalysisException] {
+ checkAnswer(
+ df.selectExpr("length(c)"), // int type of the argument is unacceptable
+ Row("5.0000"))
+ }
+ }
+
+ test("number format function") {
+ val tuple =
+ ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short],
+ 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381))
+ val df =
+ Seq(tuple)
+ .toDF(
+ "a", // string "aa"
+ "b", // byte 1
+ "c", // short 2
+ "d", // float 3.13223f
+ "e", // integer 4
+ "f", // long 5L
+ "g", // double 6.48173d
+ "h") // decimal 7.128381
+
+ checkAnswer(
+ df.select(
+ format_number($"f", 4),
+ format_number("f", 4)),
+ Row("5.0000", "5.0000"))
+
+ checkAnswer(
+ df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer
+ Row("1.0000"))
+
+ checkAnswer(
+ df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer
+ Row("2.0000"))
+
+ checkAnswer(
+ df.selectExpr("format_number(d, e)"), // convert the 1st argument to double
+ Row("3.1322"))
+
+ checkAnswer(
+ df.selectExpr("format_number(e, e)"), // not convert anything
+ Row("4.0000"))
+
+ checkAnswer(
+ df.selectExpr("format_number(f, e)"), // not convert anything
+ Row("5.0000"))
+
+ checkAnswer(
+ df.selectExpr("format_number(g, e)"), // not convert anything
+ Row("6.4817"))
+
+ checkAnswer(
+ df.selectExpr("format_number(h, e)"), // not convert anything
+ Row("7.1284"))
+
+ intercept[AnalysisException] {
+ checkAnswer(
+ df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable
+ Row("5.0000"))
+ }
+
+ intercept[AnalysisException] {
+ checkAnswer(
+ df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable
+ Row("5.0000"))
+ }
+ }
+}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 6b8f2f6217..299cc599ff 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -256,6 +256,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"timestamp_2",
"timestamp_udf",
+ // Hive outputs NULL if any concat input has null. We never output null for concat.
+ "udf_concat",
+
// Unlike Hive, we do support log base in (0, 1.0], therefore disable this
"udf7"
)
@@ -846,7 +849,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_case",
"udf_ceil",
"udf_ceiling",
- "udf_concat",
"udf_concat_insert1",
"udf_concat_insert2",
"udf_concat_ws",
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index e7f9fbb2bc..9723b6e083 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -21,6 +21,7 @@ import javax.annotation.Nonnull;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
+import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import static org.apache.spark.unsafe.PlatformDependent.*;
@@ -322,7 +323,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
}
i += numBytesForFirstByte(getByte(i));
c += 1;
- } while(i < numBytes);
+ } while (i < numBytes);
return -1;
}
@@ -395,6 +396,39 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
}
}
+ /**
+ * Concatenates input strings together into a single string. A null input is skipped.
+ * For example, concat("a", null, "c") would yield "ac".
+ */
+ public static UTF8String concat(UTF8String... inputs) {
+ if (inputs == null) {
+ return fromBytes(new byte[0]);
+ }
+
+ // Compute the total length of the result.
+ int totalLength = 0;
+ for (int i = 0; i < inputs.length; i++) {
+ if (inputs[i] != null) {
+ totalLength += inputs[i].numBytes;
+ }
+ }
+
+ // Allocate a new byte array, and copy the inputs one by one into it.
+ final byte[] result = new byte[totalLength];
+ int offset = 0;
+ for (int i = 0; i < inputs.length; i++) {
+ if (inputs[i] != null) {
+ int len = inputs[i].numBytes;
+ PlatformDependent.copyMemory(
+ inputs[i].base, inputs[i].offset,
+ result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
+ len);
+ offset += len;
+ }
+ }
+ return fromBytes(result);
+ }
+
@Override
public String toString() {
try {
@@ -413,7 +447,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
}
@Override
- public int compareTo(final UTF8String other) {
+ public int compareTo(@Nonnull final UTF8String other) {
int len = Math.min(numBytes, other.numBytes);
// TODO: compare 8 bytes as unsigned long
for (int i = 0; i < len; i ++) {
@@ -434,7 +468,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
public boolean equals(final Object other) {
if (other instanceof UTF8String) {
UTF8String o = (UTF8String) other;
- if (numBytes != o.numBytes){
+ if (numBytes != o.numBytes) {
return false;
}
return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes);
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 694bdc29f3..0db7522b50 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -87,6 +87,20 @@ public class UTF8StringSuite {
}
@Test
+ public void concatTest() {
+ assertEquals(concat(), fromString(""));
+ assertEquals(concat(null), fromString(""));
+ assertEquals(concat(fromString("")), fromString(""));
+ assertEquals(concat(fromString("ab")), fromString("ab"));
+ assertEquals(concat(fromString("a"), fromString("b")), fromString("ab"));
+ assertEquals(concat(fromString("a"), fromString("b"), fromString("c")), fromString("abc"));
+ assertEquals(concat(fromString("a"), null, fromString("c")), fromString("ac"));
+ assertEquals(concat(fromString("a"), null, null), fromString("a"));
+ assertEquals(concat(null, null, null), fromString(""));
+ assertEquals(concat(fromString("数据"), fromString("砖头")), fromString("数据砖头"));
+ }
+
+ @Test
public void contains() {
assertTrue(fromString("").contains(fromString("")));
assertTrue(fromString("hello").contains(fromString("ello")));