aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorfelixcheung <felixcheung_m@hotmail.com>2015-11-10 22:45:17 -0800
committerDavies Liu <davies.liu@gmail.com>2015-11-10 22:45:17 -0800
commit1a8e0468a1c07e99ad395eb0e4dc072c5cf7393a (patch)
tree7e34349fff4c37e201aae4f56fed24617b75da04 /R
parentfac53d8ec015e27d034dfe30ed8ce7d83f07efa6 (diff)
downloadspark-1a8e0468a1c07e99ad395eb0e4dc072c5cf7393a.tar.gz
spark-1a8e0468a1c07e99ad395eb0e4dc072c5cf7393a.tar.bz2
spark-1a8e0468a1c07e99ad395eb0e4dc072c5cf7393a.zip
[SPARK-11468] [SPARKR] add stddev/variance agg functions for Column
Checked names, none of them should conflict with anything in base shivaram davies rxin Author: felixcheung <felixcheung_m@hotmail.com> Closes #9489 from felixcheung/rstddev.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/NAMESPACE10
-rw-r--r--R/pkg/R/functions.R186
-rw-r--r--R/pkg/R/generics.R40
-rw-r--r--R/pkg/R/group.R8
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R83
5 files changed, 297 insertions, 30 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 52fd6c9f76..2ee7d6f94f 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -155,6 +155,7 @@ exportMethods("%in%",
"isNaN",
"isNotNull",
"isNull",
+ "kurtosis",
"lag",
"last",
"last_day",
@@ -207,12 +208,17 @@ exportMethods("%in%",
"shiftLeft",
"shiftRight",
"shiftRightUnsigned",
+ "sd",
"sign",
"signum",
"sin",
"sinh",
"size",
+ "skewness",
"soundex",
+ "stddev",
+ "stddev_pop",
+ "stddev_samp",
"sqrt",
"startsWith",
"substr",
@@ -231,6 +237,10 @@ exportMethods("%in%",
"unhex",
"unix_timestamp",
"upper",
+ "var",
+ "variance",
+ "var_pop",
+ "var_samp",
"weekofyear",
"when",
"year")
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index 0b28087029..3d0255a62f 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -520,6 +520,22 @@ setMethod("isNaN",
column(jc)
})
+#' kurtosis
+#'
+#' Aggregate function: returns the kurtosis of the values in a group.
+#'
+#' @rdname kurtosis
+#' @name kurtosis
+#' @family agg_funcs
+#' @export
+#' @examples \dontrun{kurtosis(df$c)}
+setMethod("kurtosis",
+ signature(x = "Column"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "kurtosis", x@jc)
+ column(jc)
+ })
+
#' last
#'
#' Aggregate function: returns the last value in a group.
@@ -861,6 +877,28 @@ setMethod("rtrim",
column(jc)
})
+#' sd
+#'
+#' Aggregate function: alias for \link{stddev_samp}
+#'
+#' @rdname sd
+#' @name sd
+#' @family agg_funcs
+#' @seealso \link{stddev_pop}, \link{stddev_samp}
+#' @export
+#' @examples
+#'\dontrun{
+#'stddev(df$c)
+#'select(df, stddev(df$age))
+#'agg(df, sd(df$age))
+#'}
+setMethod("sd",
+ signature(x = "Column"),
+ function(x, na.rm = FALSE) {
+ # In R, sample standard deviation is calculated with the sd() function.
+ stddev_samp(x)
+ })
+
#' second
#'
#' Extracts the seconds as an integer from a given date/timestamp/string.
@@ -958,6 +996,22 @@ setMethod("size",
column(jc)
})
+#' skewness
+#'
+#' Aggregate function: returns the skewness of the values in a group.
+#'
+#' @rdname skewness
+#' @name skewness
+#' @family agg_funcs
+#' @export
+#' @examples \dontrun{skewness(df$c)}
+setMethod("skewness",
+ signature(x = "Column"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "skewness", x@jc)
+ column(jc)
+ })
+
#' soundex
#'
#' Return the soundex code for the specified expression.
@@ -974,6 +1028,49 @@ setMethod("soundex",
column(jc)
})
+#' @rdname sd
+#' @name stddev
+setMethod("stddev",
+ signature(x = "Column"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "stddev", x@jc)
+ column(jc)
+ })
+
+#' stddev_pop
+#'
+#' Aggregate function: returns the population standard deviation of the expression in a group.
+#'
+#' @rdname stddev_pop
+#' @name stddev_pop
+#' @family agg_funcs
+#' @seealso \link{sd}, \link{stddev_samp}
+#' @export
+#' @examples \dontrun{stddev_pop(df$c)}
+setMethod("stddev_pop",
+ signature(x = "Column"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "stddev_pop", x@jc)
+ column(jc)
+ })
+
+#' stddev_samp
+#'
+#' Aggregate function: returns the unbiased sample standard deviation of the expression in a group.
+#'
+#' @rdname stddev_samp
+#' @name stddev_samp
+#' @family agg_funcs
+#' @seealso \link{stddev_pop}, \link{sd}
+#' @export
+#' @examples \dontrun{stddev_samp(df$c)}
+setMethod("stddev_samp",
+ signature(x = "Column"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "stddev_samp", x@jc)
+ column(jc)
+ })
+
#' sqrt
#'
#' Computes the square root of the specified float value.
@@ -1168,6 +1265,71 @@ setMethod("upper",
column(jc)
})
+#' var
+#'
+#' Aggregate function: alias for \link{var_samp}.
+#'
+#' @rdname var
+#' @name var
+#' @family agg_funcs
+#' @seealso \link{var_pop}, \link{var_samp}
+#' @export
+#' @examples
+#'\dontrun{
+#'variance(df$c)
+#'select(df, var_pop(df$age))
+#'agg(df, var(df$age))
+#'}
+setMethod("var",
+ signature(x = "Column"),
+ function(x, y = NULL, na.rm = FALSE, use) {
+ # In R, sample variance is calculated with the var() function.
+ var_samp(x)
+ })
+
+#' @rdname var
+#' @name variance
+setMethod("variance",
+ signature(x = "Column"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "variance", x@jc)
+ column(jc)
+ })
+
+#' var_pop
+#'
+#' Aggregate function: returns the population variance of the values in a group.
+#'
+#' @rdname var_pop
+#' @name var_pop
+#' @family agg_funcs
+#' @seealso \link{var}, \link{var_samp}
+#' @export
+#' @examples \dontrun{var_pop(df$c)}
+setMethod("var_pop",
+ signature(x = "Column"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "var_pop", x@jc)
+ column(jc)
+ })
+
+#' var_samp
+#'
+#' Aggregate function: returns the unbiased variance of the values in a group.
+#'
+#' @rdname var_samp
+#' @name var_samp
+#' @family agg_funcs
+#' @seealso \link{var_pop}, \link{var}
+#' @export
+#' @examples \dontrun{var_samp(df$c)}
+setMethod("var_samp",
+ signature(x = "Column"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "var_samp", x@jc)
+ column(jc)
+ })
+
#' weekofyear
#'
#' Extracts the week number as an integer from a given date/timestamp/string.
@@ -2020,10 +2182,10 @@ setMethod("ifelse",
#'
#' Window function: returns the cumulative distribution of values within a window partition,
#' i.e. the fraction of rows that are below the current row.
-#'
+#'
#' N = total number of rows in the partition
#' cumeDist(x) = number of values before (and including) x / N
-#'
+#'
#' This is equivalent to the CUME_DIST function in SQL.
#'
#' @rdname cumeDist
@@ -2039,13 +2201,13 @@ setMethod("cumeDist",
})
#' denseRank
-#'
+#'
#' Window function: returns the rank of rows within a window partition, without any gaps.
#' The difference between rank and denseRank is that denseRank leaves no gaps in ranking
#' sequence when there are ties. That is, if you were ranking a competition using denseRank
#' and had three people tie for second place, you would say that all three were in second
#' place and that the next person came in third.
-#'
+#'
#' This is equivalent to the DENSE_RANK function in SQL.
#'
#' @rdname denseRank
@@ -2065,7 +2227,7 @@ setMethod("denseRank",
#' Window function: returns the value that is `offset` rows before the current row, and
#' `defaultValue` if there is less than `offset` rows before the current row. For example,
#' an `offset` of one will return the previous row at any given point in the window partition.
-#'
+#'
#' This is equivalent to the LAG function in SQL.
#'
#' @rdname lag
@@ -2092,7 +2254,7 @@ setMethod("lag",
#' Window function: returns the value that is `offset` rows after the current row, and
#' `null` if there is less than `offset` rows after the current row. For example,
#' an `offset` of one will return the next row at any given point in the window partition.
-#'
+#'
#' This is equivalent to the LEAD function in SQL.
#'
#' @rdname lead
@@ -2119,7 +2281,7 @@ setMethod("lead",
#' Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window
#' partition. Fow example, if `n` is 4, the first quarter of the rows will get value 1, the second
#' quarter will get 2, the third quarter will get 3, and the last quarter will get 4.
-#'
+#'
#' This is equivalent to the NTILE function in SQL.
#'
#' @rdname ntile
@@ -2137,9 +2299,9 @@ setMethod("ntile",
#' percentRank
#'
#' Window function: returns the relative rank (i.e. percentile) of rows within a window partition.
-#'
+#'
#' This is computed by:
-#'
+#'
#' (rank of row in its partition - 1) / (number of rows in the partition - 1)
#'
#' This is equivalent to the PERCENT_RANK function in SQL.
@@ -2159,12 +2321,12 @@ setMethod("percentRank",
#' rank
#'
#' Window function: returns the rank of rows within a window partition.
-#'
+#'
#' The difference between rank and denseRank is that denseRank leaves no gaps in ranking
#' sequence when there are ties. That is, if you were ranking a competition using denseRank
#' and had three people tie for second place, you would say that all three were in second
#' place and that the next person came in third.
-#'
+#'
#' This is equivalent to the RANK function in SQL.
#'
#' @rdname rank
@@ -2189,7 +2351,7 @@ setMethod("rank",
#' rowNumber
#'
#' Window function: returns a sequential number starting at 1 within a window partition.
-#'
+#'
#' This is equivalent to the ROW_NUMBER function in SQL.
#'
#' @rdname rowNumber
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 89731affeb..92ad4ee868 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -798,6 +798,10 @@ setGeneric("instr", function(y, x) { standardGeneric("instr") })
#' @export
setGeneric("isNaN", function(x) { standardGeneric("isNaN") })
+#' @rdname kurtosis
+#' @export
+setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") })
+
#' @rdname lag
#' @export
setGeneric("lag", function(x, offset, defaultValue = NULL) { standardGeneric("lag") })
@@ -935,6 +939,10 @@ setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") })
#' @export
setGeneric("rtrim", function(x) { standardGeneric("rtrim") })
+#' @rdname sd
+#' @export
+setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") })
+
#' @rdname second
#' @export
setGeneric("second", function(x) { standardGeneric("second") })
@@ -967,10 +975,26 @@ setGeneric("signum", function(x) { standardGeneric("signum") })
#' @export
setGeneric("size", function(x) { standardGeneric("size") })
+#' @rdname skewness
+#' @export
+setGeneric("skewness", function(x) { standardGeneric("skewness") })
+
#' @rdname soundex
#' @export
setGeneric("soundex", function(x) { standardGeneric("soundex") })
+#' @rdname sd
+#' @export
+setGeneric("stddev", function(x) { standardGeneric("stddev") })
+
+#' @rdname stddev_pop
+#' @export
+setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") })
+
+#' @rdname stddev_samp
+#' @export
+setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") })
+
#' @rdname substring_index
#' @export
setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") })
@@ -1019,6 +1043,22 @@ setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timesta
#' @export
setGeneric("upper", function(x) { standardGeneric("upper") })
+#' @rdname var
+#' @export
+setGeneric("var", function(x, y = NULL, na.rm = FALSE, use) { standardGeneric("var") })
+
+#' @rdname var
+#' @export
+setGeneric("variance", function(x) { standardGeneric("variance") })
+
+#' @rdname var_pop
+#' @export
+setGeneric("var_pop", function(x) { standardGeneric("var_pop") })
+
+#' @rdname var_samp
+#' @export
+setGeneric("var_samp", function(x) { standardGeneric("var_samp") })
+
#' @rdname weekofyear
#' @export
setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") })
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 4cab1a69f6..e5f702faee 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -79,6 +79,7 @@ setMethod("count",
#' @param x a GroupedData
#' @return a DataFrame
#' @rdname agg
+#' @family agg_funcs
#' @examples
#' \dontrun{
#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)'
@@ -117,8 +118,11 @@ setMethod("summarize",
agg(x, ...)
})
-# sum/mean/avg/min/max
-methods <- c("sum", "mean", "avg", "min", "max")
+# Aggregate Functions by name
+methods <- c("avg", "max", "mean", "min", "sum")
+
+# These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", "stddev_pop",
+# "variance", "var_samp", "var_pop"
createMethod <- function(name) {
setMethod(name,
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 06f52d021c..9e453a1e7c 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -826,12 +826,13 @@ test_that("column functions", {
c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c)
c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c)
c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c)
- c9 <- signum(c) + sin(c) + sinh(c) + size(c) + soundex(c) + sqrt(c) + sum(c)
+ c9 <- signum(c) + sin(c) + sinh(c) + size(c) + stddev(c) + soundex(c) + sqrt(c) + sum(c)
c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c)
c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c)
- c12 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1)
- c13 <- cumeDist() + ntile(1)
- c14 <- denseRank() + percentRank() + rank() + rowNumber()
+ c12 <- variance(c)
+ c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1)
+ c14 <- cumeDist() + ntile(1)
+ c15 <- denseRank() + percentRank() + rank() + rowNumber()
# Test if base::rank() is exposed
expect_equal(class(rank())[[1]], "Column")
@@ -849,6 +850,12 @@ test_that("column functions", {
expect_equal(collect(df3)[[2, 1]], FALSE)
expect_equal(collect(df3)[[3, 1]], TRUE)
+ expect_equal(collect(select(df, sum(df$age)))[1, 1], 49)
+
+ expect_true(abs(collect(select(df, stddev(df$age)))[1, 1] - 7.778175) < 1e-6)
+
+ expect_equal(collect(select(df, var_pop(df$age)))[1, 1], 30.25)
+
df4 <- createDataFrame(sqlContext, list(list(a = "010101")))
expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15")
})
@@ -976,7 +983,7 @@ test_that("when(), otherwise() and ifelse() on a DataFrame", {
expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, 0, 1)))[, 1], c(1, 0))
})
-test_that("group by", {
+test_that("group by, agg functions", {
df <- jsonFile(sqlContext, jsonPath)
df1 <- agg(df, name = "max", age = "sum")
expect_equal(1, count(df1))
@@ -997,20 +1004,64 @@ test_that("group by", {
expect_is(df_summarized, "DataFrame")
expect_equal(3, count(df_summarized))
- df3 <- agg(gd, age = "sum")
- expect_is(df3, "DataFrame")
- expect_equal(3, count(df3))
-
- df3 <- agg(gd, age = sum(df$age))
+ df3 <- agg(gd, age = "stddev")
expect_is(df3, "DataFrame")
- expect_equal(3, count(df3))
- expect_equal(columns(df3), c("name", "age"))
+ df3_local <- collect(df3)
+ expect_equal(0, df3_local[df3_local$name == "Andy",][1, 2])
- df4 <- sum(gd, "age")
+ df4 <- agg(gd, sumAge = sum(df$age))
expect_is(df4, "DataFrame")
expect_equal(3, count(df4))
- expect_equal(3, count(mean(gd, "age")))
- expect_equal(3, count(max(gd, "age")))
+ expect_equal(columns(df4), c("name", "sumAge"))
+
+ df5 <- sum(gd, "age")
+ expect_is(df5, "DataFrame")
+ expect_equal(3, count(df5))
+
+ expect_equal(3, count(mean(gd)))
+ expect_equal(3, count(max(gd)))
+ expect_equal(30, collect(max(gd))[1, 2])
+ expect_equal(1, collect(count(gd))[1, 2])
+
+ mockLines2 <- c("{\"name\":\"ID1\", \"value\": \"10\"}",
+ "{\"name\":\"ID1\", \"value\": \"10\"}",
+ "{\"name\":\"ID1\", \"value\": \"22\"}",
+ "{\"name\":\"ID2\", \"value\": \"-3\"}")
+ jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp")
+ writeLines(mockLines2, jsonPath2)
+ gd2 <- groupBy(jsonFile(sqlContext, jsonPath2), "name")
+ df6 <- agg(gd2, value = "sum")
+ df6_local <- collect(df6)
+ expect_equal(42, df6_local[df6_local$name == "ID1",][1, 2])
+ expect_equal(-3, df6_local[df6_local$name == "ID2",][1, 2])
+
+ df7 <- agg(gd2, value = "stddev")
+ df7_local <- collect(df7)
+ expect_true(abs(df7_local[df7_local$name == "ID1",][1, 2] - 6.928203) < 1e-6)
+ expect_equal(0, df7_local[df7_local$name == "ID2",][1, 2])
+
+ mockLines3 <- c("{\"name\":\"Andy\", \"age\":30}",
+ "{\"name\":\"Andy\", \"age\":30}",
+ "{\"name\":\"Justin\", \"age\":19}",
+ "{\"name\":\"Justin\", \"age\":1}")
+ jsonPath3 <- tempfile(pattern="sparkr-test", fileext=".tmp")
+ writeLines(mockLines3, jsonPath3)
+ df8 <- jsonFile(sqlContext, jsonPath3)
+ gd3 <- groupBy(df8, "name")
+ gd3_local <- collect(sum(gd3))
+ expect_equal(60, gd3_local[gd3_local$name == "Andy",][1, 2])
+ expect_equal(20, gd3_local[gd3_local$name == "Justin",][1, 2])
+
+ expect_true(abs(collect(agg(df, sd(df$age)))[1, 1] - 7.778175) < 1e-6)
+ gd3_local <- collect(agg(gd3, var(df8$age)))
+ expect_equal(162, gd3_local[gd3_local$name == "Justin",][1, 2])
+
+ # make sure base:: or stats::sd, var are working
+ expect_true(abs(sd(1:2) - 0.7071068) < 1e-6)
+ expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6)
+
+ unlink(jsonPath2)
+ unlink(jsonPath3)
})
test_that("arrange() and orderBy() on a DataFrame", {
@@ -1238,7 +1289,7 @@ test_that("mutate(), transform(), rename() and names()", {
expect_equal(columns(transformedDF)[4], "newAge2")
expect_equal(first(filter(transformedDF, transformedDF$name == "Andy"))$newAge, -30)
- # test if transform on local data frames works
+ # test if base::transform on local data frames works
# ensure the proper signature is used - otherwise this will fail to run
attach(airquality)
result <- transform(Ozone, logOzone = log(Ozone))