diff options
author | Yijie Shen <henry.yijieshen@gmail.com> | 2015-07-14 23:30:41 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-07-14 23:30:41 -0700 |
commit | f0e129740dc2442a21dfa7fbd97360df87291095 (patch) | |
tree | 4090cc733796e42175668a0b639964b39fa5a69a /sql/core | |
parent | 3f6296fed4ee10f53e728eb1e02f13338839b94d (diff) | |
download | spark-f0e129740dc2442a21dfa7fbd97360df87291095.tar.gz spark-f0e129740dc2442a21dfa7fbd97360df87291095.tar.bz2 spark-f0e129740dc2442a21dfa7fbd97360df87291095.zip |
[SPARK-8279][SQL]Add math function round
JIRA: https://issues.apache.org/jira/browse/SPARK-8279
Author: Yijie Shen <henry.yijieshen@gmail.com>
Closes #6938 from yijieshen/udf_round_3 and squashes the following commits:
07a124c [Yijie Shen] remove useless def children
392b65b [Yijie Shen] add negative scale test in DecimalSuite
61760ee [Yijie Shen] address reviews
302a78a [Yijie Shen] Add dataframe function test
31dfe7c [Yijie Shen] refactor round to make it readable
8c7a949 [Yijie Shen] rebase & inputTypes update
9555e35 [Yijie Shen] tiny style fix
d10be4a [Yijie Shen] use TypeCollection to specify wanted input and implicit cast
c3b9839 [Yijie Shen] rely on implict cast to handle string input
b0bff79 [Yijie Shen] make round's inner method's name more meaningful
9bd6930 [Yijie Shen] revert accidental change
e6f44c4 [Yijie Shen] refactor eval and genCode
1b87540 [Yijie Shen] modify checkInputDataTypes using foldable
5486b2d [Yijie Shen] DataFrame API modification
2077888 [Yijie Shen] codegen versioned eval
6cd9a64 [Yijie Shen] refactor Round's constructor
9be894e [Yijie Shen] add round functions in o.a.s.sql.functions
7c83e13 [Yijie Shen] more tests on round
56db4bb [Yijie Shen] Add decimal support to Round
7e163ae [Yijie Shen] style fix
653d047 [Yijie Shen] Add math function round
Diffstat (limited to 'sql/core')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 32 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala | 15 |
2 files changed, 47 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d4e160ed8..5119ee31d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1390,6 +1390,38 @@ object functions { def rint(columnName: String): Column = rint(Column(columnName)) /** + * Returns the value of the column `e` rounded to 0 decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column): Column = round(e.expr, 0) + + /** + * Returns the value of the given column rounded to 0 decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(columnName: String): Column = round(Column(columnName), 0) + + /** + * Returns the value of `e` rounded to `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) + + /** + * Returns the value of the given column rounded to `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(columnName: String, scale: Int): Column = round(Column(columnName), scale) + + /** * Shift the the given value numBits left. If the given value is a long value, this function * will return a long value else it will return an integer value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index b30b9f1225..087126bb2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -198,6 +198,21 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction(rint, math.rint) } + test("round") { + val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a") + checkAnswer( + df.select(round('a), round('a, -1), round('a, -2)), + Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) + ) + + val pi = 3.1415 + checkAnswer( + ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), + Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142)) + ) + } + test("exp") { testOneToOneMathFunction(exp, math.exp) } |