aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala12
2 files changed, 28 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 223122300d..8e2e94669b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1777,6 +1777,23 @@ object functions {
def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, Literal(scale)) }
/**
+ * Returns the value of the column `e` rounded to 0 decimal places with HALF_EVEN round mode.
+ *
+ * @group math_funcs
+ * @since 2.0.0
+ */
+ def bround(e: Column): Column = bround(e, 0)
+
+ /**
+ * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode
+ * if `scale` >= 0 or at integral part when `scale` < 0.
+ *
+ * @group math_funcs
+ * @since 2.0.0
+ */
+ def bround(e: Column, scale: Int): Column = withExpr { BRound(e.expr, Literal(scale)) }
+
+ /**
* Shift the given value numBits left. If the given value is a long value, this function
* will return a long value else it will return an integer value.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index f5a67fd782..0de7f2321f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -207,12 +207,16 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
testOneToOneMathFunction(rint, math.rint)
}
- test("round") {
+ test("round/bround") {
val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a")
checkAnswer(
df.select(round('a), round('a, -1), round('a, -2)),
Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600))
)
+ checkAnswer(
+ df.select(bround('a), bround('a, -1), bround('a, -2)),
+ Seq(Row(5, 0, 0), Row(55, 60, 100), Row(555, 560, 600))
+ )
val pi = "3.1415"
checkAnswer(
@@ -221,6 +225,12 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
)
+ checkAnswer(
+ sql(s"SELECT bround($pi, -3), bround($pi, -2), bround($pi, -1), " +
+ s"bround($pi, 0), bround($pi, 1), bround($pi, 2), bround($pi, 3)"),
+ Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
+ BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
+ )
}
test("exp") {