diff options
Diffstat (limited to 'sql/core/src')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 17 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala | 12 |
2 files changed, 28 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 223122300d..8e2e94669b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1777,6 +1777,23 @@ object functions { def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, Literal(scale)) } /** + * Returns the value of the column `e` rounded to 0 decimal places with HALF_EVEN round mode. + * + * @group math_funcs + * @since 2.0.0 + */ + def bround(e: Column): Column = bround(e, 0) + + /** + * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode + * if `scale` >= 0 or at integral part when `scale` < 0. + * + * @group math_funcs + * @since 2.0.0 + */ + def bround(e: Column, scale: Int): Column = withExpr { BRound(e.expr, Literal(scale)) } + + /** * Shift the given value numBits left. If the given value is a long value, this function * will return a long value else it will return an integer value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index f5a67fd782..0de7f2321f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -207,12 +207,16 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { testOneToOneMathFunction(rint, math.rint) } - test("round") { + test("round/bround") { val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a") checkAnswer( df.select(round('a), round('a, -1), round('a, -2)), Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) ) + checkAnswer( + df.select(bround('a), bround('a, -1), bround('a, -2)), + Seq(Row(5, 0, 0), Row(55, 60, 100), Row(555, 560, 600)) + ) val pi = "3.1415" checkAnswer( @@ -221,6 +225,12 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) ) + checkAnswer( + sql(s"SELECT bround($pi, -3), bround($pi, -2), bround($pi, -1), " + + s"bround($pi, 0), bround($pi, 1), bround($pi, 2), bround($pi, 3)"), + Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), + BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) + ) } test("exp") { |