aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-04-19 22:28:11 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-19 22:28:11 -0700
commit14869ae64eb27830179d4954a5dc3e0a1e1330b4 (patch)
treec294dde39b5d77c6086b3d08a726b1c8b401b95a /R/pkg
parent6f1ec1f2670cd55bc852a810ca9d5c6a2651a9f2 (diff)
downloadspark-14869ae64eb27830179d4954a5dc3e0a1e1330b4.tar.gz
spark-14869ae64eb27830179d4954a5dc3e0a1e1330b4.tar.bz2
spark-14869ae64eb27830179d4954a5dc3e0a1e1330b4.zip
[SPARK-14639] [PYTHON] [R] Add `bround` function in Python/R.
## What changes were proposed in this pull request? This issue aims to expose Scala `bround` function in Python/R API. `bround` function is implemented in SPARK-14614 by extending current `round` function. We used the following semantics from Hive. ```java public static double bround(double input, int scale) { if (Double.isNaN(input) || Double.isInfinite(input)) { return input; } return BigDecimal.valueOf(input).setScale(scale, RoundingMode.HALF_EVEN).doubleValue(); } ``` After this PR, `pyspark` and `sparkR` also support `bround` function. **PySpark** ```python >>> from pyspark.sql.functions import bround >>> sqlContext.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).collect() [Row(r=2.0)] ``` **SparkR** ```r > df = createDataFrame(sqlContext, data.frame(x = c(2.5, 3.5))) > head(collect(select(df, bround(df$x, 0)))) bround(x, 0) 1 2 2 4 ``` ## How was this patch tested? Pass the Jenkins tests (including new testcases). Author: Dongjoon Hyun <dongjoon@apache.org> Closes #12509 from dongjoon-hyun/SPARK-14639.
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/NAMESPACE1
-rw-r--r--R/pkg/R/functions.R22
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R5
4 files changed, 31 insertions, 1 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 10b9d16279..667fff7192 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -126,6 +126,7 @@ exportMethods("%in%",
"between",
"bin",
"bitwiseNOT",
+ "bround",
"cast",
"cbrt",
"ceil",
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index db877b2d63..54234b0455 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -994,7 +994,7 @@ setMethod("rint",
#' round
#'
-#' Returns the value of the column `e` rounded to 0 decimal places.
+#' Returns the value of the column `e` rounded to 0 decimal places using HALF_UP rounding mode.
#'
#' @rdname round
#' @name round
@@ -1008,6 +1008,26 @@ setMethod("round",
column(jc)
})
+#' bround
+#'
+#' Returns the value of the column `e` rounded to `scale` decimal places using HALF_EVEN rounding
+#' mode if `scale` >= 0 or at integral part when `scale` < 0.
+#' Also known as Gaussian rounding or bankers' rounding that rounds to the nearest even number.
+#' bround(2.5, 0) = 2, bround(3.5, 0) = 4.
+#'
+#' @rdname bround
+#' @name bround
+#' @family math_funcs
+#' @export
+#' @examples \dontrun{bround(df$c, 0)}
+setMethod("bround",
+ signature(x = "Column"),
+ function(x, scale = 0) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "bround", x@jc, as.integer(scale))
+ column(jc)
+ })
+
+
#' rtrim
#'
#' Trim the spaces from right end for the specified string value.
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index a71be55bca..6b67258d77 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -760,6 +760,10 @@ setGeneric("bin", function(x) { standardGeneric("bin") })
#' @export
setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") })
+#' @rdname bround
+#' @export
+setGeneric("bround", function(x, ...) { standardGeneric("bround") })
+
#' @rdname cbrt
#' @export
setGeneric("cbrt", function(x) { standardGeneric("cbrt") })
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 2f65484fcb..b923ccf6bb 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1087,6 +1087,11 @@ test_that("column functions", {
expect_equal(collect(select(df, last(df$age, TRUE)))[[1]], 19)
expect_equal(collect(select(df, last("age")))[[1]], 19)
expect_equal(collect(select(df, last("age", TRUE)))[[1]], 19)
+
+ # Test bround()
+ df <- createDataFrame(sqlContext, data.frame(x = c(2.5, 3.5)))
+ expect_equal(collect(select(df, bround(df$x, 0)))[[1]][1], 2)
+ expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4)
})
test_that("column binary mathfunctions", {