aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorYijie Shen <henry.yijieshen@gmail.com>2015-07-14 23:30:41 -0700
committerReynold Xin <rxin@databricks.com>2015-07-14 23:30:41 -0700
commitf0e129740dc2442a21dfa7fbd97360df87291095 (patch)
tree4090cc733796e42175668a0b639964b39fa5a69a /sql/core
parent3f6296fed4ee10f53e728eb1e02f13338839b94d (diff)
downloadspark-f0e129740dc2442a21dfa7fbd97360df87291095.tar.gz
spark-f0e129740dc2442a21dfa7fbd97360df87291095.tar.bz2
spark-f0e129740dc2442a21dfa7fbd97360df87291095.zip
[SPARK-8279][SQL]Add math function round
JIRA: https://issues.apache.org/jira/browse/SPARK-8279 Author: Yijie Shen <henry.yijieshen@gmail.com> Closes #6938 from yijieshen/udf_round_3 and squashes the following commits: 07a124c [Yijie Shen] remove useless def children 392b65b [Yijie Shen] add negative scale test in DecimalSuite 61760ee [Yijie Shen] address reviews 302a78a [Yijie Shen] Add dataframe function test 31dfe7c [Yijie Shen] refactor round to make it readable 8c7a949 [Yijie Shen] rebase & inputTypes update 9555e35 [Yijie Shen] tiny style fix d10be4a [Yijie Shen] use TypeCollection to specify wanted input and implicit cast c3b9839 [Yijie Shen] rely on implict cast to handle string input b0bff79 [Yijie Shen] make round's inner method's name more meaningful 9bd6930 [Yijie Shen] revert accidental change e6f44c4 [Yijie Shen] refactor eval and genCode 1b87540 [Yijie Shen] modify checkInputDataTypes using foldable 5486b2d [Yijie Shen] DataFrame API modification 2077888 [Yijie Shen] codegen versioned eval 6cd9a64 [Yijie Shen] refactor Round's constructor 9be894e [Yijie Shen] add round functions in o.a.s.sql.functions 7c83e13 [Yijie Shen] more tests on round 56db4bb [Yijie Shen] Add decimal support to Round 7e163ae [Yijie Shen] style fix 653d047 [Yijie Shen] Add math function round
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala32
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala15
2 files changed, 47 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 0d4e160ed8..5119ee31d8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1390,6 +1390,38 @@ object functions {
def rint(columnName: String): Column = rint(Column(columnName))
/**
+ * Returns the value of the column `e` rounded to 0 decimal places.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def round(e: Column): Column = round(e.expr, 0)
+
+ /**
+ * Returns the value of the given column rounded to 0 decimal places.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def round(columnName: String): Column = round(Column(columnName), 0)
+
+ /**
+ * Returns the value of `e` rounded to `scale` decimal places.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale))
+
+ /**
+ * Returns the value of the given column rounded to `scale` decimal places.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def round(columnName: String, scale: Int): Column = round(Column(columnName), scale)
+
+ /**
* Shift the the given value numBits left. If the given value is a long value, this function
* will return a long value else it will return an integer value.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index b30b9f1225..087126bb2e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -198,6 +198,21 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneMathFunction(rint, math.rint)
}
+ test("round") {
+ val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a")
+ checkAnswer(
+ df.select(round('a), round('a, -1), round('a, -2)),
+ Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600))
+ )
+
+ val pi = 3.1415
+ checkAnswer(
+ ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
+ s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"),
+ Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142))
+ )
+ }
+
test("exp") {
testOneToOneMathFunction(exp, math.exp)
}