aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-04-18 10:44:51 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-18 10:44:51 -0700
commit432d1399cb6985893932088875b2f3be981c0b5f (patch)
treebdf579b03c758e0408a8b8fcf02edda173eac4d0 /sql/core/src
parentd6fb485de8b79054db08658d904a3148a04d4180 (diff)
downloadspark-432d1399cb6985893932088875b2f3be981c0b5f.tar.gz
spark-432d1399cb6985893932088875b2f3be981c0b5f.tar.bz2
spark-432d1399cb6985893932088875b2f3be981c0b5f.zip
[SPARK-14614] [SQL] Add `bround` function
## What changes were proposed in this pull request? This PR aims to add `bound` function (aka Banker's round) by extending current `round` implementation. [Hive supports `bround` since 1.3.0.](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF) **Hive (1.3 ~ 2.0)** ``` hive> select round(2.5), bround(2.5); OK 3.0 2.0 ``` **After this PR** ```scala scala> sql("select round(2.5), bround(2.5)").head res0: org.apache.spark.sql.Row = [3,2] ``` ## How was this patch tested? Pass the Jenkins tests (with extended tests). Author: Dongjoon Hyun <dongjoon@apache.org> Closes #12376 from dongjoon-hyun/SPARK-14614.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala12
2 files changed, 28 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 223122300d..8e2e94669b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1777,6 +1777,23 @@ object functions {
def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, Literal(scale)) }
/**
+ * Returns the value of the column `e` rounded to 0 decimal places with HALF_EVEN round mode.
+ *
+ * @group math_funcs
+ * @since 2.0.0
+ */
+ def bround(e: Column): Column = bround(e, 0)
+
+ /**
+ * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode
+ * if `scale` >= 0 or at integral part when `scale` < 0.
+ *
+ * @group math_funcs
+ * @since 2.0.0
+ */
+ def bround(e: Column, scale: Int): Column = withExpr { BRound(e.expr, Literal(scale)) }
+
+ /**
* Shift the given value numBits left. If the given value is a long value, this function
* will return a long value else it will return an integer value.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index f5a67fd782..0de7f2321f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -207,12 +207,16 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
testOneToOneMathFunction(rint, math.rint)
}
- test("round") {
+ test("round/bround") {
val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a")
checkAnswer(
df.select(round('a), round('a, -1), round('a, -2)),
Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600))
)
+ checkAnswer(
+ df.select(bround('a), bround('a, -1), bround('a, -2)),
+ Seq(Row(5, 0, 0), Row(55, 60, 100), Row(555, 560, 600))
+ )
val pi = "3.1415"
checkAnswer(
@@ -221,6 +225,12 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
)
+ checkAnswer(
+ sql(s"SELECT bround($pi, -3), bround($pi, -2), bround($pi, -1), " +
+ s"bround($pi, 0), bround($pi, 1), bround($pi, 2), bround($pi, 3)"),
+ Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
+ BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
+ )
}
test("exp") {