From 432d1399cb6985893932088875b2f3be981c0b5f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 18 Apr 2016 10:44:51 -0700 Subject: [SPARK-14614] [SQL] Add `bround` function ## What changes were proposed in this pull request? This PR aims to add `bound` function (aka Banker's round) by extending current `round` implementation. [Hive supports `bround` since 1.3.0.](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF) **Hive (1.3 ~ 2.0)** ``` hive> select round(2.5), bround(2.5); OK 3.0 2.0 ``` **After this PR** ```scala scala> sql("select round(2.5), bround(2.5)").head res0: org.apache.spark.sql.Row = [3,2] ``` ## How was this patch tested? Pass the Jenkins tests (with extended tests). Author: Dongjoon Hyun Closes #12376 from dongjoon-hyun/SPARK-14614. --- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/mathExpressions.scala | 71 ++++++++++++++-------- .../scala/org/apache/spark/sql/types/Decimal.scala | 6 ++ .../analysis/ExpressionTypeCheckingSuite.scala | 10 ++- .../catalyst/expressions/MathFunctionsSuite.scala | 23 ++++++- .../scala/org/apache/spark/sql/functions.scala | 17 ++++++ .../apache/spark/sql/MathExpressionsSuite.scala | 12 +++- 7 files changed, 113 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 028463ed4f..ed19191b72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -179,6 +179,7 @@ object FunctionRegistry { expression[Atan]("atan"), expression[Atan2]("atan2"), expression[Bin]("bin"), + expression[BRound]("bround"), expression[Cbrt]("cbrt"), expression[Ceil]("ceil"), expression[Ceil]("ceiling"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index c8a28e8477..9e190289b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -779,7 +779,6 @@ case class Logarithm(left: Expression, right: Expression) /** * Round the `child`'s result to `scale` decimal place when `scale` >= 0 * or round at integral part when `scale` < 0. - * For example, round(31.415, 2) = 31.42 and round(31.415, -1) = 30. * * Child of IntegralType would round to itself when `scale` >= 0. * Child of FractionalType whose value is NaN or Infinite would always round to itself. @@ -789,16 +788,12 @@ case class Logarithm(left: Expression, right: Expression) * * @param child expr to be round, all [[NumericType]] is allowed as Input * @param scale new scale to be round to, this should be a constant int at runtime + * @param mode rounding mode (e.g. HALF_UP, HALF_UP) + * @param modeStr rounding mode string name (e.g. "ROUND_HALF_UP", "ROUND_HALF_EVEN") */ -@ExpressionDescription( - usage = "_FUNC_(x, d) - Round x to d decimal places.", - extended = "> SELECT _FUNC_(12.3456, 1);\n 12.3") -case class Round(child: Expression, scale: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - import BigDecimal.RoundingMode.HALF_UP - - def this(child: Expression) = this(child, Literal(0)) +abstract class RoundBase(child: Expression, scale: Expression, + mode: BigDecimal.RoundingMode.Value, modeStr: String) + extends BinaryExpression with Serializable with ImplicitCastInputTypes { override def left: Expression = child override def right: Expression = scale @@ -853,28 +848,28 @@ case class Round(child: Expression, scale: Expression) child.dataType match { case _: DecimalType => val decimal = input1.asInstanceOf[Decimal] - if (decimal.changePrecision(decimal.precision, _scale)) decimal else null + if (decimal.changePrecision(decimal.precision, _scale, mode)) decimal else null case ByteType => - BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte + BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => - BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort + BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, mode).toShort case IntegerType => - BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt + BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, mode).toInt case LongType => - BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong + BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, mode).toLong case FloatType => val f = input1.asInstanceOf[Float] if (f.isNaN || f.isInfinite) { f } else { - BigDecimal(f.toDouble).setScale(_scale, HALF_UP).toFloat + BigDecimal(f.toDouble).setScale(_scale, mode).toFloat } case DoubleType => val d = input1.asInstanceOf[Double] if (d.isNaN || d.isInfinite) { d } else { - BigDecimal(d).setScale(_scale, HALF_UP).toDouble + BigDecimal(d).setScale(_scale, mode).toDouble } } } @@ -885,7 +880,8 @@ case class Round(child: Expression, scale: Expression) val evaluationCode = child.dataType match { case _: DecimalType => s""" - if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale})) { + if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale}, + java.math.BigDecimal.${modeStr})) { ${ev.value} = ${ce.value}; } else { ${ev.isNull} = true; @@ -894,7 +890,7 @@ case class Round(child: Expression, scale: Expression) if (_scale < 0) { s""" ${ev.value} = new java.math.BigDecimal(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();""" + setScale(${_scale}, java.math.BigDecimal.${modeStr}).byteValue();""" } else { s"${ev.value} = ${ce.value};" } @@ -902,7 +898,7 @@ case class Round(child: Expression, scale: Expression) if (_scale < 0) { s""" ${ev.value} = new java.math.BigDecimal(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();""" + setScale(${_scale}, java.math.BigDecimal.${modeStr}).shortValue();""" } else { s"${ev.value} = ${ce.value};" } @@ -910,7 +906,7 @@ case class Round(child: Expression, scale: Expression) if (_scale < 0) { s""" ${ev.value} = new java.math.BigDecimal(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();""" + setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValue();""" } else { s"${ev.value} = ${ce.value};" } @@ -918,7 +914,7 @@ case class Round(child: Expression, scale: Expression) if (_scale < 0) { s""" ${ev.value} = new java.math.BigDecimal(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();""" + setScale(${_scale}, java.math.BigDecimal.${modeStr}).longValue();""" } else { s"${ev.value} = ${ce.value};" } @@ -928,7 +924,7 @@ case class Round(child: Expression, scale: Expression) ${ev.value} = ${ce.value}; } else { ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); + setScale(${_scale}, java.math.BigDecimal.${modeStr}).floatValue(); }""" case DoubleType => // if child eval to NaN or Infinity, just return it. s""" @@ -936,7 +932,7 @@ case class Round(child: Expression, scale: Expression) ${ev.value} = ${ce.value}; } else { ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); + setScale(${_scale}, java.math.BigDecimal.${modeStr}).doubleValue(); }""" } @@ -957,3 +953,30 @@ case class Round(child: Expression, scale: Expression) } } } + +/** + * Round an expression to d decimal places using HALF_UP rounding mode. + * round(2.5) == 3.0, round(3.5) == 4.0. + */ +@ExpressionDescription( + usage = "_FUNC_(x, d) - Round x to d decimal places using HALF_UP rounding mode.", + extended = "> SELECT _FUNC_(2.5, 0);\n 3.0") +case class Round(child: Expression, scale: Expression) + extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_UP, "ROUND_HALF_UP") + with Serializable with ImplicitCastInputTypes { + def this(child: Expression) = this(child, Literal(0)) +} + +/** + * Round an expression to d decimal places using HALF_EVEN rounding mode, + * also known as Gaussian rounding or bankers' rounding. + * round(2.5) = 2.0, round(3.5) = 4.0. + */ +@ExpressionDescription( + usage = "_FUNC_(x, d) - Round x to d decimal places using HALF_EVEN rounding mode.", + extended = "> SELECT _FUNC_(2.5, 0);\n 2.0") +case class BRound(child: Expression, scale: Expression) + extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_EVEN, "ROUND_HALF_EVEN") + with Serializable with ImplicitCastInputTypes { + def this(child: Expression) = this(child, Literal(0)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index a30a3926bb..6f4ec6b701 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -201,6 +201,11 @@ final class Decimal extends Ordered[Decimal] with Serializable { changePrecision(precision, scale, ROUND_HALF_UP) } + def changePrecision(precision: Int, scale: Int, mode: Int): Boolean = mode match { + case java.math.BigDecimal.ROUND_HALF_UP => changePrecision(precision, scale, ROUND_HALF_UP) + case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN) + } + /** * Update precision and scale while keeping our value the same, and return true if successful. * @@ -337,6 +342,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { object Decimal { val ROUND_HALF_UP = BigDecimal.RoundingMode.HALF_UP + val ROUND_HALF_EVEN = BigDecimal.RoundingMode.HALF_EVEN val ROUND_CEILING = BigDecimal.RoundingMode.CEILING val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index ace6e10c6e..660dc86c3e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -192,7 +192,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { "values of function map should all be the same type") } - test("check types for ROUND") { + test("check types for ROUND/BROUND") { assertSuccess(Round(Literal(null), Literal(null))) assertSuccess(Round('intField, Literal(1))) @@ -200,6 +200,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Round('intField, 'booleanField), "requires int type") assertError(Round('intField, 'mapField), "requires int type") assertError(Round('booleanField, 'intField), "requires numeric type") + + assertSuccess(BRound(Literal(null), Literal(null))) + assertSuccess(BRound('intField, Literal(1))) + + assertError(BRound('intField, 'intField), "Only foldable Expression is allowed") + assertError(BRound('intField, 'booleanField), "requires int type") + assertError(BRound('intField, 'mapField), "requires int type") + assertError(BRound('booleanField, 'intField), "requires numeric type") } test("check types for Greatest/Least") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 452792d21c..1e5b657f1f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -508,7 +508,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Logarithm, DoubleType, DoubleType) } - test("round") { + test("round/bround") { val scales = -6 to 6 val doublePi: Double = math.Pi val shortPi: Short = 31415 @@ -529,11 +529,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++ Seq.fill(7)(31415926535897932L) + val intResultsB: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300, + 314159260) ++ Seq.fill(7)(314159265) + scales.zipWithIndex.foreach { case (scale, i) => checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) + checkEvaluation(BRound(doublePi, scale), doubleResults(i), EmptyRow) + checkEvaluation(BRound(shortPi, scale), shortResults(i), EmptyRow) + checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow) + checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow) } val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), @@ -543,19 +550,33 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null (0 to 7).foreach { i => checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) + checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow) } (8 to 10).foreach { scale => checkEvaluation(Round(bdPi, scale), null, EmptyRow) + checkEvaluation(BRound(bdPi, scale), null, EmptyRow) } DataTypeTestUtils.numericTypes.foreach { dataType => checkEvaluation(Round(Literal.create(null, dataType), Literal(2)), null) checkEvaluation(Round(Literal.create(null, dataType), Literal.create(null, IntegerType)), null) + checkEvaluation(BRound(Literal.create(null, dataType), Literal(2)), null) + checkEvaluation(BRound(Literal.create(null, dataType), + Literal.create(null, IntegerType)), null) } + checkEvaluation(Round(2.5, 0), 3.0) + checkEvaluation(Round(3.5, 0), 4.0) + checkEvaluation(Round(-2.5, 0), -3.0) checkEvaluation(Round(-3.5, 0), -4.0) checkEvaluation(Round(-0.35, 1), -0.4) checkEvaluation(Round(-35, -1), -40) + checkEvaluation(BRound(2.5, 0), 2.0) + checkEvaluation(BRound(3.5, 0), 4.0) + checkEvaluation(BRound(-2.5, 0), -2.0) + checkEvaluation(BRound(-3.5, 0), -4.0) + checkEvaluation(BRound(-0.35, 1), -0.4) + checkEvaluation(BRound(-35, -1), -40) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 223122300d..8e2e94669b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1776,6 +1776,23 @@ object functions { */ def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, Literal(scale)) } + /** + * Returns the value of the column `e` rounded to 0 decimal places with HALF_EVEN round mode. + * + * @group math_funcs + * @since 2.0.0 + */ + def bround(e: Column): Column = bround(e, 0) + + /** + * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode + * if `scale` >= 0 or at integral part when `scale` < 0. + * + * @group math_funcs + * @since 2.0.0 + */ + def bround(e: Column, scale: Int): Column = withExpr { BRound(e.expr, Literal(scale)) } + /** * Shift the given value numBits left. If the given value is a long value, this function * will return a long value else it will return an integer value. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index f5a67fd782..0de7f2321f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -207,12 +207,16 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { testOneToOneMathFunction(rint, math.rint) } - test("round") { + test("round/bround") { val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a") checkAnswer( df.select(round('a), round('a, -1), round('a, -2)), Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) ) + checkAnswer( + df.select(bround('a), bround('a, -1), bround('a, -2)), + Seq(Row(5, 0, 0), Row(55, 60, 100), Row(555, 560, 600)) + ) val pi = "3.1415" checkAnswer( @@ -221,6 +225,12 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) ) + checkAnswer( + sql(s"SELECT bround($pi, -3), bround($pi, -2), bround($pi, -1), " + + s"bround($pi, 0), bround($pi, 1), bround($pi, 2), bround($pi, 3)"), + Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), + BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) + ) } test("exp") { -- cgit v1.2.3