aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-04-18 10:44:51 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-18 10:44:51 -0700
commit432d1399cb6985893932088875b2f3be981c0b5f (patch)
treebdf579b03c758e0408a8b8fcf02edda173eac4d0
parentd6fb485de8b79054db08658d904a3148a04d4180 (diff)
downloadspark-432d1399cb6985893932088875b2f3be981c0b5f.tar.gz
spark-432d1399cb6985893932088875b2f3be981c0b5f.tar.bz2
spark-432d1399cb6985893932088875b2f3be981c0b5f.zip
[SPARK-14614] [SQL] Add `bround` function
## What changes were proposed in this pull request? This PR aims to add `bound` function (aka Banker's round) by extending current `round` implementation. [Hive supports `bround` since 1.3.0.](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF) **Hive (1.3 ~ 2.0)** ``` hive> select round(2.5), bround(2.5); OK 3.0 2.0 ``` **After this PR** ```scala scala> sql("select round(2.5), bround(2.5)").head res0: org.apache.spark.sql.Row = [3,2] ``` ## How was this patch tested? Pass the Jenkins tests (with extended tests). Author: Dongjoon Hyun <dongjoon@apache.org> Closes #12376 from dongjoon-hyun/SPARK-14614.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala71
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala10
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala12
7 files changed, 113 insertions, 27 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 028463ed4f..ed19191b72 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -179,6 +179,7 @@ object FunctionRegistry {
expression[Atan]("atan"),
expression[Atan2]("atan2"),
expression[Bin]("bin"),
+ expression[BRound]("bround"),
expression[Cbrt]("cbrt"),
expression[Ceil]("ceil"),
expression[Ceil]("ceiling"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index c8a28e8477..9e190289b7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -779,7 +779,6 @@ case class Logarithm(left: Expression, right: Expression)
/**
* Round the `child`'s result to `scale` decimal place when `scale` >= 0
* or round at integral part when `scale` < 0.
- * For example, round(31.415, 2) = 31.42 and round(31.415, -1) = 30.
*
* Child of IntegralType would round to itself when `scale` >= 0.
* Child of FractionalType whose value is NaN or Infinite would always round to itself.
@@ -789,16 +788,12 @@ case class Logarithm(left: Expression, right: Expression)
*
* @param child expr to be round, all [[NumericType]] is allowed as Input
* @param scale new scale to be round to, this should be a constant int at runtime
+ * @param mode rounding mode (e.g. HALF_UP, HALF_UP)
+ * @param modeStr rounding mode string name (e.g. "ROUND_HALF_UP", "ROUND_HALF_EVEN")
*/
-@ExpressionDescription(
- usage = "_FUNC_(x, d) - Round x to d decimal places.",
- extended = "> SELECT _FUNC_(12.3456, 1);\n 12.3")
-case class Round(child: Expression, scale: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
-
- import BigDecimal.RoundingMode.HALF_UP
-
- def this(child: Expression) = this(child, Literal(0))
+abstract class RoundBase(child: Expression, scale: Expression,
+ mode: BigDecimal.RoundingMode.Value, modeStr: String)
+ extends BinaryExpression with Serializable with ImplicitCastInputTypes {
override def left: Expression = child
override def right: Expression = scale
@@ -853,28 +848,28 @@ case class Round(child: Expression, scale: Expression)
child.dataType match {
case _: DecimalType =>
val decimal = input1.asInstanceOf[Decimal]
- if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
+ if (decimal.changePrecision(decimal.precision, _scale, mode)) decimal else null
case ByteType =>
- BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte
+ BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
case ShortType =>
- BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort
+ BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, mode).toShort
case IntegerType =>
- BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt
+ BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, mode).toInt
case LongType =>
- BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong
+ BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, mode).toLong
case FloatType =>
val f = input1.asInstanceOf[Float]
if (f.isNaN || f.isInfinite) {
f
} else {
- BigDecimal(f.toDouble).setScale(_scale, HALF_UP).toFloat
+ BigDecimal(f.toDouble).setScale(_scale, mode).toFloat
}
case DoubleType =>
val d = input1.asInstanceOf[Double]
if (d.isNaN || d.isInfinite) {
d
} else {
- BigDecimal(d).setScale(_scale, HALF_UP).toDouble
+ BigDecimal(d).setScale(_scale, mode).toDouble
}
}
}
@@ -885,7 +880,8 @@ case class Round(child: Expression, scale: Expression)
val evaluationCode = child.dataType match {
case _: DecimalType =>
s"""
- if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale})) {
+ if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale},
+ java.math.BigDecimal.${modeStr})) {
${ev.value} = ${ce.value};
} else {
${ev.isNull} = true;
@@ -894,7 +890,7 @@ case class Round(child: Expression, scale: Expression)
if (_scale < 0) {
s"""
${ev.value} = new java.math.BigDecimal(${ce.value}).
- setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();"""
+ setScale(${_scale}, java.math.BigDecimal.${modeStr}).byteValue();"""
} else {
s"${ev.value} = ${ce.value};"
}
@@ -902,7 +898,7 @@ case class Round(child: Expression, scale: Expression)
if (_scale < 0) {
s"""
${ev.value} = new java.math.BigDecimal(${ce.value}).
- setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();"""
+ setScale(${_scale}, java.math.BigDecimal.${modeStr}).shortValue();"""
} else {
s"${ev.value} = ${ce.value};"
}
@@ -910,7 +906,7 @@ case class Round(child: Expression, scale: Expression)
if (_scale < 0) {
s"""
${ev.value} = new java.math.BigDecimal(${ce.value}).
- setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();"""
+ setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValue();"""
} else {
s"${ev.value} = ${ce.value};"
}
@@ -918,7 +914,7 @@ case class Round(child: Expression, scale: Expression)
if (_scale < 0) {
s"""
${ev.value} = new java.math.BigDecimal(${ce.value}).
- setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();"""
+ setScale(${_scale}, java.math.BigDecimal.${modeStr}).longValue();"""
} else {
s"${ev.value} = ${ce.value};"
}
@@ -928,7 +924,7 @@ case class Round(child: Expression, scale: Expression)
${ev.value} = ${ce.value};
} else {
${ev.value} = java.math.BigDecimal.valueOf(${ce.value}).
- setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
+ setScale(${_scale}, java.math.BigDecimal.${modeStr}).floatValue();
}"""
case DoubleType => // if child eval to NaN or Infinity, just return it.
s"""
@@ -936,7 +932,7 @@ case class Round(child: Expression, scale: Expression)
${ev.value} = ${ce.value};
} else {
${ev.value} = java.math.BigDecimal.valueOf(${ce.value}).
- setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
+ setScale(${_scale}, java.math.BigDecimal.${modeStr}).doubleValue();
}"""
}
@@ -957,3 +953,30 @@ case class Round(child: Expression, scale: Expression)
}
}
}
+
+/**
+ * Round an expression to d decimal places using HALF_UP rounding mode.
+ * round(2.5) == 3.0, round(3.5) == 4.0.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(x, d) - Round x to d decimal places using HALF_UP rounding mode.",
+ extended = "> SELECT _FUNC_(2.5, 0);\n 3.0")
+case class Round(child: Expression, scale: Expression)
+ extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_UP, "ROUND_HALF_UP")
+ with Serializable with ImplicitCastInputTypes {
+ def this(child: Expression) = this(child, Literal(0))
+}
+
+/**
+ * Round an expression to d decimal places using HALF_EVEN rounding mode,
+ * also known as Gaussian rounding or bankers' rounding.
+ * round(2.5) = 2.0, round(3.5) = 4.0.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(x, d) - Round x to d decimal places using HALF_EVEN rounding mode.",
+ extended = "> SELECT _FUNC_(2.5, 0);\n 2.0")
+case class BRound(child: Expression, scale: Expression)
+ extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_EVEN, "ROUND_HALF_EVEN")
+ with Serializable with ImplicitCastInputTypes {
+ def this(child: Expression) = this(child, Literal(0))
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index a30a3926bb..6f4ec6b701 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -201,6 +201,11 @@ final class Decimal extends Ordered[Decimal] with Serializable {
changePrecision(precision, scale, ROUND_HALF_UP)
}
+ def changePrecision(precision: Int, scale: Int, mode: Int): Boolean = mode match {
+ case java.math.BigDecimal.ROUND_HALF_UP => changePrecision(precision, scale, ROUND_HALF_UP)
+ case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN)
+ }
+
/**
* Update precision and scale while keeping our value the same, and return true if successful.
*
@@ -337,6 +342,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
object Decimal {
val ROUND_HALF_UP = BigDecimal.RoundingMode.HALF_UP
+ val ROUND_HALF_EVEN = BigDecimal.RoundingMode.HALF_EVEN
val ROUND_CEILING = BigDecimal.RoundingMode.CEILING
val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index ace6e10c6e..660dc86c3e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -192,7 +192,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
"values of function map should all be the same type")
}
- test("check types for ROUND") {
+ test("check types for ROUND/BROUND") {
assertSuccess(Round(Literal(null), Literal(null)))
assertSuccess(Round('intField, Literal(1)))
@@ -200,6 +200,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(Round('intField, 'booleanField), "requires int type")
assertError(Round('intField, 'mapField), "requires int type")
assertError(Round('booleanField, 'intField), "requires numeric type")
+
+ assertSuccess(BRound(Literal(null), Literal(null)))
+ assertSuccess(BRound('intField, Literal(1)))
+
+ assertError(BRound('intField, 'intField), "Only foldable Expression is allowed")
+ assertError(BRound('intField, 'booleanField), "requires int type")
+ assertError(BRound('intField, 'mapField), "requires int type")
+ assertError(BRound('booleanField, 'intField), "requires numeric type")
}
test("check types for Greatest/Least") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 452792d21c..1e5b657f1f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -508,7 +508,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Logarithm, DoubleType, DoubleType)
}
- test("round") {
+ test("round/bround") {
val scales = -6 to 6
val doublePi: Double = math.Pi
val shortPi: Short = 31415
@@ -529,11 +529,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++
Seq.fill(7)(31415926535897932L)
+ val intResultsB: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300,
+ 314159260) ++ Seq.fill(7)(314159265)
+
scales.zipWithIndex.foreach { case (scale, i) =>
checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow)
checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
+ checkEvaluation(BRound(doublePi, scale), doubleResults(i), EmptyRow)
+ checkEvaluation(BRound(shortPi, scale), shortResults(i), EmptyRow)
+ checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow)
+ checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow)
}
val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
@@ -543,19 +550,33 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
(0 to 7).foreach { i =>
checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow)
+ checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow)
}
(8 to 10).foreach { scale =>
checkEvaluation(Round(bdPi, scale), null, EmptyRow)
+ checkEvaluation(BRound(bdPi, scale), null, EmptyRow)
}
DataTypeTestUtils.numericTypes.foreach { dataType =>
checkEvaluation(Round(Literal.create(null, dataType), Literal(2)), null)
checkEvaluation(Round(Literal.create(null, dataType),
Literal.create(null, IntegerType)), null)
+ checkEvaluation(BRound(Literal.create(null, dataType), Literal(2)), null)
+ checkEvaluation(BRound(Literal.create(null, dataType),
+ Literal.create(null, IntegerType)), null)
}
+ checkEvaluation(Round(2.5, 0), 3.0)
+ checkEvaluation(Round(3.5, 0), 4.0)
+ checkEvaluation(Round(-2.5, 0), -3.0)
checkEvaluation(Round(-3.5, 0), -4.0)
checkEvaluation(Round(-0.35, 1), -0.4)
checkEvaluation(Round(-35, -1), -40)
+ checkEvaluation(BRound(2.5, 0), 2.0)
+ checkEvaluation(BRound(3.5, 0), 4.0)
+ checkEvaluation(BRound(-2.5, 0), -2.0)
+ checkEvaluation(BRound(-3.5, 0), -4.0)
+ checkEvaluation(BRound(-0.35, 1), -0.4)
+ checkEvaluation(BRound(-35, -1), -40)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 223122300d..8e2e94669b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1777,6 +1777,23 @@ object functions {
def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, Literal(scale)) }
/**
+ * Returns the value of the column `e` rounded to 0 decimal places with HALF_EVEN round mode.
+ *
+ * @group math_funcs
+ * @since 2.0.0
+ */
+ def bround(e: Column): Column = bround(e, 0)
+
+ /**
+ * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode
+ * if `scale` >= 0 or at integral part when `scale` < 0.
+ *
+ * @group math_funcs
+ * @since 2.0.0
+ */
+ def bround(e: Column, scale: Int): Column = withExpr { BRound(e.expr, Literal(scale)) }
+
+ /**
* Shift the given value numBits left. If the given value is a long value, this function
* will return a long value else it will return an integer value.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index f5a67fd782..0de7f2321f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -207,12 +207,16 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
testOneToOneMathFunction(rint, math.rint)
}
- test("round") {
+ test("round/bround") {
val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a")
checkAnswer(
df.select(round('a), round('a, -1), round('a, -2)),
Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600))
)
+ checkAnswer(
+ df.select(bround('a), bround('a, -1), bround('a, -2)),
+ Seq(Row(5, 0, 0), Row(55, 60, 100), Row(555, 560, 600))
+ )
val pi = "3.1415"
checkAnswer(
@@ -221,6 +225,12 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
)
+ checkAnswer(
+ sql(s"SELECT bround($pi, -3), bround($pi, -2), bround($pi, -1), " +
+ s"bround($pi, 0), bround($pi, 1), bround($pi, 2), bround($pi, 3)"),
+ Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
+ BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
+ )
}
test("exp") {