aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2015-07-15 21:47:21 -0700
committerReynold Xin <rxin@databricks.com>2015-07-15 21:47:21 -0700
commit42dea3acf90ec506a0b79720b55ae1d753cc7544 (patch)
tree617b79a51b14a397fde87e412e329676566240ca /sql
parent9c64a75bfc5e2566d1b4cd0d9b4585a818086ca6 (diff)
downloadspark-42dea3acf90ec506a0b79720b55ae1d753cc7544.tar.gz
spark-42dea3acf90ec506a0b79720b55ae1d753cc7544.tar.bz2
spark-42dea3acf90ec506a0b79720b55ae1d753cc7544.zip
[SPARK-8245][SQL] FormatNumber/Length Support for Expression
- `BinaryType` for `Length` - `FormatNumber` Author: Cheng Hao <hao.cheng@intel.com> Closes #7034 from chenghao-intel/expression and squashes the following commits: e534b87 [Cheng Hao] python api style issue 601bbf5 [Cheng Hao] add python API support 3ebe288 [Cheng Hao] update as feedback 52274f7 [Cheng Hao] add support for udf_format_number and length for binary
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala94
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala53
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala32
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala93
5 files changed, 241 insertions, 36 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index d2678ce860..e0beafe710 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -152,11 +152,12 @@ object FunctionRegistry {
expression[Base64]("base64"),
expression[Encode]("encode"),
expression[Decode]("decode"),
- expression[StringInstr]("instr"),
+ expression[FormatNumber]("format_number"),
expression[Lower]("lcase"),
expression[Lower]("lower"),
- expression[StringLength]("length"),
+ expression[Length]("length"),
expression[Levenshtein]("levenshtein"),
+ expression[StringInstr]("instr"),
expression[StringLocate]("locate"),
expression[StringLPad]("lpad"),
expression[StringTrimLeft]("ltrim"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 03b55ce5fe..c64afe7b3f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
+import java.text.DecimalFormat
import java.util.Locale
import java.util.regex.Pattern
-import org.apache.commons.lang3.StringUtils
-
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -553,17 +552,22 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
}
/**
- * A function that return the length of the given string expression.
+ * A function that return the length of the given string or binary expression.
*/
-case class StringLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = IntegerType
- override def inputTypes: Seq[DataType] = Seq(StringType)
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))
- protected override def nullSafeEval(string: Any): Any =
- string.asInstanceOf[UTF8String].numChars
+ protected override def nullSafeEval(value: Any): Any = child.dataType match {
+ case StringType => value.asInstanceOf[UTF8String].numChars
+ case BinaryType => value.asInstanceOf[Array[Byte]].length
+ }
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- defineCodeGen(ctx, ev, c => s"($c).numChars()")
+ child.dataType match {
+ case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()")
+ case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length")
+ }
}
override def prettyName: String = "length"
@@ -668,3 +672,77 @@ case class Encode(value: Expression, charset: Expression)
}
}
+/**
+ * Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
+ * and returns the result as a string. If D is 0, the result has no decimal point or
+ * fractional part.
+ */
+case class FormatNumber(x: Expression, d: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ override def left: Expression = x
+ override def right: Expression = d
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
+
+ // Associated with the pattern, for the last d value, and we will update the
+ // pattern (DecimalFormat) once the new coming d value differ with the last one.
+ @transient
+ private var lastDValue: Int = -100
+
+ // A cached DecimalFormat, for performance concern, we will change it
+ // only if the d value changed.
+ @transient
+ private val pattern: StringBuffer = new StringBuffer()
+
+ @transient
+ private val numberFormat: DecimalFormat = new DecimalFormat("")
+
+ override def eval(input: InternalRow): Any = {
+ val xObject = x.eval(input)
+ if (xObject == null) {
+ return null
+ }
+
+ val dObject = d.eval(input)
+
+ if (dObject == null || dObject.asInstanceOf[Int] < 0) {
+ return null
+ }
+ val dValue = dObject.asInstanceOf[Int]
+
+ if (dValue != lastDValue) {
+ // construct a new DecimalFormat only if a new dValue
+ pattern.delete(0, pattern.length())
+ pattern.append("#,###,###,###,###,###,##0")
+
+ // decimal place
+ if (dValue > 0) {
+ pattern.append(".")
+
+ var i = 0
+ while (i < dValue) {
+ i += 1
+ pattern.append("0")
+ }
+ }
+ val dFormat = new DecimalFormat(pattern.toString())
+ lastDValue = dValue;
+ numberFormat.applyPattern(dFormat.toPattern())
+ }
+
+ x.dataType match {
+ case ByteType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Byte]))
+ case ShortType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Short]))
+ case FloatType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Float]))
+ case IntegerType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Int]))
+ case LongType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Long]))
+ case DoubleType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Double]))
+ case _: DecimalType =>
+ UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Decimal].toJavaBigDecimal))
+ }
+ }
+
+ override def prettyName: String = "format_number"
+}
+
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
index b19f4ee37a..5d7763bedf 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType}
+import org.apache.spark.sql.types._
class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -216,15 +216,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
- test("length for string") {
- val a = 'a.string.at(0)
- checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef"))
- checkEvaluation(StringLength(a), 5, create_row("abdef"))
- checkEvaluation(StringLength(a), 0, create_row(""))
- checkEvaluation(StringLength(a), null, create_row(null))
- checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
- }
-
test("ascii for string") {
val a = 'a.string.at(0)
checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef"))
@@ -426,4 +417,46 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(
StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1)
}
+
+ test("length for string / binary") {
+ val a = 'a.string.at(0)
+ val b = 'b.binary.at(0)
+ val bytes = Array[Byte](1, 2, 3, 1, 2)
+ val string = "abdef"
+
+ // scalastyle:off
+ // non ascii characters are not allowed in the source code, so we disable the scalastyle.
+ checkEvaluation(Length(Literal("a花花c")), 4, create_row(string))
+ // scalastyle:on
+ checkEvaluation(Length(Literal(bytes)), 5, create_row(Array[Byte]()))
+
+ checkEvaluation(Length(a), 5, create_row(string))
+ checkEvaluation(Length(b), 5, create_row(bytes))
+
+ checkEvaluation(Length(a), 0, create_row(""))
+ checkEvaluation(Length(b), 0, create_row(Array[Byte]()))
+
+ checkEvaluation(Length(a), null, create_row(null))
+ checkEvaluation(Length(b), null, create_row(null))
+
+ checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string))
+ checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes))
+ }
+
+ test("number format") {
+ checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000")
+ checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000")
+ checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000")
+ checkEvaluation(FormatNumber(Literal(4), Literal(3)), "4.000")
+ checkEvaluation(FormatNumber(Literal(12831273.23481d), Literal(3)), "12,831,273.235")
+ checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal(0)), "12,831,274")
+ checkEvaluation(FormatNumber(Literal(123123324123L), Literal(3)), "123,123,324,123.000")
+ checkEvaluation(FormatNumber(Literal(123123324123L), Literal(-1)), null)
+ checkEvaluation(
+ FormatNumber(
+ Literal(Decimal(123123324123L) * Decimal(123123.21234d)), Literal(4)),
+ "15,159,339,180,002,773.2778")
+ checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null)
+ checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index c7deaca843..d6da284a4c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1685,20 +1685,44 @@ object functions {
//////////////////////////////////////////////////////////////////////////////////////////////
/**
- * Computes the length of a given string value.
+ * Computes the length of a given string / binary value.
*
* @group string_funcs
* @since 1.5.0
*/
- def strlen(e: Column): Column = StringLength(e.expr)
+ def length(e: Column): Column = Length(e.expr)
/**
- * Computes the length of a given string column.
+ * Computes the length of a given string / binary column.
*
* @group string_funcs
* @since 1.5.0
*/
- def strlen(columnName: String): Column = strlen(Column(columnName))
+ def length(columnName: String): Column = length(Column(columnName))
+
+ /**
+ * Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
+ * and returns the result as a string.
+ * If d is 0, the result has no decimal point or fractional part.
+ * If d < 0, the result will be null.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr)
+
+ /**
+ * Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
+ * and returns the result as a string.
+ * If d is 0, the result has no decimal point or fractional part.
+ * If d < 0, the result will be null.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def format_number(columnXName: String, d: Int): Column = {
+ format_number(Column(columnXName), d)
+ }
/**
* Computes the Levenshtein distance of the two given strings.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 70bd78737f..6dccdd857b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -208,17 +208,6 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(2743272264L, 2180413220L))
}
- test("string length function") {
- val df = Seq(("abc", "")).toDF("a", "b")
- checkAnswer(
- df.select(strlen($"a"), strlen("b")),
- Row(3, 0))
-
- checkAnswer(
- df.selectExpr("length(a)", "length(b)"),
- Row(3, 0))
- }
-
test("Levenshtein distance") {
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1)))
@@ -433,11 +422,91 @@ class DataFrameFunctionsSuite extends QueryTest {
val doubleData = Seq((7.2, 4.1)).toDF("a", "b")
checkAnswer(
doubleData.select(pmod('a, 'b)),
- Seq(Row(3.1000000000000005)) // same as hive
+ Seq(Row(3.1000000000000005)) // same as hive
)
checkAnswer(
doubleData.select(pmod(lit(2), lit(Int.MaxValue))),
Seq(Row(2))
)
}
+
+ test("string / binary length function") {
+ val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c")
+ checkAnswer(
+ df.select(length($"a"), length("a"), length($"b"), length("b")),
+ Row(3, 3, 4, 4))
+
+ checkAnswer(
+ df.selectExpr("length(a)", "length(b)"),
+ Row(3, 4))
+
+ intercept[AnalysisException] {
+ checkAnswer(
+ df.selectExpr("length(c)"), // int type of the argument is unacceptable
+ Row("5.0000"))
+ }
+ }
+
+ test("number format function") {
+ val tuple =
+ ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short],
+ 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381))
+ val df =
+ Seq(tuple)
+ .toDF(
+ "a", // string "aa"
+ "b", // byte 1
+ "c", // short 2
+ "d", // float 3.13223f
+ "e", // integer 4
+ "f", // long 5L
+ "g", // double 6.48173d
+ "h") // decimal 7.128381
+
+ checkAnswer(
+ df.select(
+ format_number($"f", 4),
+ format_number("f", 4)),
+ Row("5.0000", "5.0000"))
+
+ checkAnswer(
+ df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer
+ Row("1.0000"))
+
+ checkAnswer(
+ df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer
+ Row("2.0000"))
+
+ checkAnswer(
+ df.selectExpr("format_number(d, e)"), // convert the 1st argument to double
+ Row("3.1322"))
+
+ checkAnswer(
+ df.selectExpr("format_number(e, e)"), // not convert anything
+ Row("4.0000"))
+
+ checkAnswer(
+ df.selectExpr("format_number(f, e)"), // not convert anything
+ Row("5.0000"))
+
+ checkAnswer(
+ df.selectExpr("format_number(g, e)"), // not convert anything
+ Row("6.4817"))
+
+ checkAnswer(
+ df.selectExpr("format_number(h, e)"), // not convert anything
+ Row("7.1284"))
+
+ intercept[AnalysisException] {
+ checkAnswer(
+ df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable
+ Row("5.0000"))
+ }
+
+ intercept[AnalysisException] {
+ checkAnswer(
+ df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable
+ Row("5.0000"))
+ }
+ }
}