From 1a8e0468a1c07e99ad395eb0e4dc072c5cf7393a Mon Sep 17 00:00:00 2001 From: felixcheung Date: Tue, 10 Nov 2015 22:45:17 -0800 Subject: [SPARK-11468] [SPARKR] add stddev/variance agg functions for Column Checked names, none of them should conflict with anything in base shivaram davies rxin Author: felixcheung Closes #9489 from felixcheung/rstddev. --- R/pkg/NAMESPACE | 10 +++ R/pkg/R/functions.R | 186 ++++++++++++++++++++++++++++++++++++--- R/pkg/R/generics.R | 40 +++++++++ R/pkg/R/group.R | 8 +- R/pkg/inst/tests/test_sparkSQL.R | 83 +++++++++++++---- 5 files changed, 297 insertions(+), 30 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 52fd6c9f76..2ee7d6f94f 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -155,6 +155,7 @@ exportMethods("%in%", "isNaN", "isNotNull", "isNull", + "kurtosis", "lag", "last", "last_day", @@ -207,12 +208,17 @@ exportMethods("%in%", "shiftLeft", "shiftRight", "shiftRightUnsigned", + "sd", "sign", "signum", "sin", "sinh", "size", + "skewness", "soundex", + "stddev", + "stddev_pop", + "stddev_samp", "sqrt", "startsWith", "substr", @@ -231,6 +237,10 @@ exportMethods("%in%", "unhex", "unix_timestamp", "upper", + "var", + "variance", + "var_pop", + "var_samp", "weekofyear", "when", "year") diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 0b28087029..3d0255a62f 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -520,6 +520,22 @@ setMethod("isNaN", column(jc) }) +#' kurtosis +#' +#' Aggregate function: returns the kurtosis of the values in a group. +#' +#' @rdname kurtosis +#' @name kurtosis +#' @family agg_funcs +#' @export +#' @examples \dontrun{kurtosis(df$c)} +setMethod("kurtosis", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "kurtosis", x@jc) + column(jc) + }) + #' last #' #' Aggregate function: returns the last value in a group. @@ -861,6 +877,28 @@ setMethod("rtrim", column(jc) }) +#' sd +#' +#' Aggregate function: alias for \link{stddev_samp} +#' +#' @rdname sd +#' @name sd +#' @family agg_funcs +#' @seealso \link{stddev_pop}, \link{stddev_samp} +#' @export +#' @examples +#'\dontrun{ +#'stddev(df$c) +#'select(df, stddev(df$age)) +#'agg(df, sd(df$age)) +#'} +setMethod("sd", + signature(x = "Column"), + function(x, na.rm = FALSE) { + # In R, sample standard deviation is calculated with the sd() function. + stddev_samp(x) + }) + #' second #' #' Extracts the seconds as an integer from a given date/timestamp/string. @@ -958,6 +996,22 @@ setMethod("size", column(jc) }) +#' skewness +#' +#' Aggregate function: returns the skewness of the values in a group. +#' +#' @rdname skewness +#' @name skewness +#' @family agg_funcs +#' @export +#' @examples \dontrun{skewness(df$c)} +setMethod("skewness", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "skewness", x@jc) + column(jc) + }) + #' soundex #' #' Return the soundex code for the specified expression. @@ -974,6 +1028,49 @@ setMethod("soundex", column(jc) }) +#' @rdname sd +#' @name stddev +setMethod("stddev", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "stddev", x@jc) + column(jc) + }) + +#' stddev_pop +#' +#' Aggregate function: returns the population standard deviation of the expression in a group. +#' +#' @rdname stddev_pop +#' @name stddev_pop +#' @family agg_funcs +#' @seealso \link{sd}, \link{stddev_samp} +#' @export +#' @examples \dontrun{stddev_pop(df$c)} +setMethod("stddev_pop", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "stddev_pop", x@jc) + column(jc) + }) + +#' stddev_samp +#' +#' Aggregate function: returns the unbiased sample standard deviation of the expression in a group. +#' +#' @rdname stddev_samp +#' @name stddev_samp +#' @family agg_funcs +#' @seealso \link{stddev_pop}, \link{sd} +#' @export +#' @examples \dontrun{stddev_samp(df$c)} +setMethod("stddev_samp", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "stddev_samp", x@jc) + column(jc) + }) + #' sqrt #' #' Computes the square root of the specified float value. @@ -1168,6 +1265,71 @@ setMethod("upper", column(jc) }) +#' var +#' +#' Aggregate function: alias for \link{var_samp}. +#' +#' @rdname var +#' @name var +#' @family agg_funcs +#' @seealso \link{var_pop}, \link{var_samp} +#' @export +#' @examples +#'\dontrun{ +#'variance(df$c) +#'select(df, var_pop(df$age)) +#'agg(df, var(df$age)) +#'} +setMethod("var", + signature(x = "Column"), + function(x, y = NULL, na.rm = FALSE, use) { + # In R, sample variance is calculated with the var() function. + var_samp(x) + }) + +#' @rdname var +#' @name variance +setMethod("variance", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "variance", x@jc) + column(jc) + }) + +#' var_pop +#' +#' Aggregate function: returns the population variance of the values in a group. +#' +#' @rdname var_pop +#' @name var_pop +#' @family agg_funcs +#' @seealso \link{var}, \link{var_samp} +#' @export +#' @examples \dontrun{var_pop(df$c)} +setMethod("var_pop", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "var_pop", x@jc) + column(jc) + }) + +#' var_samp +#' +#' Aggregate function: returns the unbiased variance of the values in a group. +#' +#' @rdname var_samp +#' @name var_samp +#' @family agg_funcs +#' @seealso \link{var_pop}, \link{var} +#' @export +#' @examples \dontrun{var_samp(df$c)} +setMethod("var_samp", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "var_samp", x@jc) + column(jc) + }) + #' weekofyear #' #' Extracts the week number as an integer from a given date/timestamp/string. @@ -2020,10 +2182,10 @@ setMethod("ifelse", #' #' Window function: returns the cumulative distribution of values within a window partition, #' i.e. the fraction of rows that are below the current row. -#' +#' #' N = total number of rows in the partition #' cumeDist(x) = number of values before (and including) x / N -#' +#' #' This is equivalent to the CUME_DIST function in SQL. #' #' @rdname cumeDist @@ -2039,13 +2201,13 @@ setMethod("cumeDist", }) #' denseRank -#' +#' #' Window function: returns the rank of rows within a window partition, without any gaps. #' The difference between rank and denseRank is that denseRank leaves no gaps in ranking #' sequence when there are ties. That is, if you were ranking a competition using denseRank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. -#' +#' #' This is equivalent to the DENSE_RANK function in SQL. #' #' @rdname denseRank @@ -2065,7 +2227,7 @@ setMethod("denseRank", #' Window function: returns the value that is `offset` rows before the current row, and #' `defaultValue` if there is less than `offset` rows before the current row. For example, #' an `offset` of one will return the previous row at any given point in the window partition. -#' +#' #' This is equivalent to the LAG function in SQL. #' #' @rdname lag @@ -2092,7 +2254,7 @@ setMethod("lag", #' Window function: returns the value that is `offset` rows after the current row, and #' `null` if there is less than `offset` rows after the current row. For example, #' an `offset` of one will return the next row at any given point in the window partition. -#' +#' #' This is equivalent to the LEAD function in SQL. #' #' @rdname lead @@ -2119,7 +2281,7 @@ setMethod("lead", #' Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window #' partition. Fow example, if `n` is 4, the first quarter of the rows will get value 1, the second #' quarter will get 2, the third quarter will get 3, and the last quarter will get 4. -#' +#' #' This is equivalent to the NTILE function in SQL. #' #' @rdname ntile @@ -2137,9 +2299,9 @@ setMethod("ntile", #' percentRank #' #' Window function: returns the relative rank (i.e. percentile) of rows within a window partition. -#' +#' #' This is computed by: -#' +#' #' (rank of row in its partition - 1) / (number of rows in the partition - 1) #' #' This is equivalent to the PERCENT_RANK function in SQL. @@ -2159,12 +2321,12 @@ setMethod("percentRank", #' rank #' #' Window function: returns the rank of rows within a window partition. -#' +#' #' The difference between rank and denseRank is that denseRank leaves no gaps in ranking #' sequence when there are ties. That is, if you were ranking a competition using denseRank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. -#' +#' #' This is equivalent to the RANK function in SQL. #' #' @rdname rank @@ -2189,7 +2351,7 @@ setMethod("rank", #' rowNumber #' #' Window function: returns a sequential number starting at 1 within a window partition. -#' +#' #' This is equivalent to the ROW_NUMBER function in SQL. #' #' @rdname rowNumber diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 89731affeb..92ad4ee868 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -798,6 +798,10 @@ setGeneric("instr", function(y, x) { standardGeneric("instr") }) #' @export setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) +#' @rdname kurtosis +#' @export +setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) + #' @rdname lag #' @export setGeneric("lag", function(x, offset, defaultValue = NULL) { standardGeneric("lag") }) @@ -935,6 +939,10 @@ setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) #' @export setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) +#' @rdname sd +#' @export +setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) + #' @rdname second #' @export setGeneric("second", function(x) { standardGeneric("second") }) @@ -967,10 +975,26 @@ setGeneric("signum", function(x) { standardGeneric("signum") }) #' @export setGeneric("size", function(x) { standardGeneric("size") }) +#' @rdname skewness +#' @export +setGeneric("skewness", function(x) { standardGeneric("skewness") }) + #' @rdname soundex #' @export setGeneric("soundex", function(x) { standardGeneric("soundex") }) +#' @rdname sd +#' @export +setGeneric("stddev", function(x) { standardGeneric("stddev") }) + +#' @rdname stddev_pop +#' @export +setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) + +#' @rdname stddev_samp +#' @export +setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) + #' @rdname substring_index #' @export setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) @@ -1019,6 +1043,22 @@ setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timesta #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) +#' @rdname var +#' @export +setGeneric("var", function(x, y = NULL, na.rm = FALSE, use) { standardGeneric("var") }) + +#' @rdname var +#' @export +setGeneric("variance", function(x) { standardGeneric("variance") }) + +#' @rdname var_pop +#' @export +setGeneric("var_pop", function(x) { standardGeneric("var_pop") }) + +#' @rdname var_samp +#' @export +setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) + #' @rdname weekofyear #' @export setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 4cab1a69f6..e5f702faee 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -79,6 +79,7 @@ setMethod("count", #' @param x a GroupedData #' @return a DataFrame #' @rdname agg +#' @family agg_funcs #' @examples #' \dontrun{ #' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' @@ -117,8 +118,11 @@ setMethod("summarize", agg(x, ...) }) -# sum/mean/avg/min/max -methods <- c("sum", "mean", "avg", "min", "max") +# Aggregate Functions by name +methods <- c("avg", "max", "mean", "min", "sum") + +# These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", "stddev_pop", +# "variance", "var_samp", "var_pop" createMethod <- function(name) { setMethod(name, diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 06f52d021c..9e453a1e7c 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -826,12 +826,13 @@ test_that("column functions", { c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c) c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) - c9 <- signum(c) + sin(c) + sinh(c) + size(c) + soundex(c) + sqrt(c) + sum(c) + c9 <- signum(c) + sin(c) + sinh(c) + size(c) + stddev(c) + soundex(c) + sqrt(c) + sum(c) c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) - c12 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) - c13 <- cumeDist() + ntile(1) - c14 <- denseRank() + percentRank() + rank() + rowNumber() + c12 <- variance(c) + c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) + c14 <- cumeDist() + ntile(1) + c15 <- denseRank() + percentRank() + rank() + rowNumber() # Test if base::rank() is exposed expect_equal(class(rank())[[1]], "Column") @@ -849,6 +850,12 @@ test_that("column functions", { expect_equal(collect(df3)[[2, 1]], FALSE) expect_equal(collect(df3)[[3, 1]], TRUE) + expect_equal(collect(select(df, sum(df$age)))[1, 1], 49) + + expect_true(abs(collect(select(df, stddev(df$age)))[1, 1] - 7.778175) < 1e-6) + + expect_equal(collect(select(df, var_pop(df$age)))[1, 1], 30.25) + df4 <- createDataFrame(sqlContext, list(list(a = "010101"))) expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15") }) @@ -976,7 +983,7 @@ test_that("when(), otherwise() and ifelse() on a DataFrame", { expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, 0, 1)))[, 1], c(1, 0)) }) -test_that("group by", { +test_that("group by, agg functions", { df <- jsonFile(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") expect_equal(1, count(df1)) @@ -997,20 +1004,64 @@ test_that("group by", { expect_is(df_summarized, "DataFrame") expect_equal(3, count(df_summarized)) - df3 <- agg(gd, age = "sum") - expect_is(df3, "DataFrame") - expect_equal(3, count(df3)) - - df3 <- agg(gd, age = sum(df$age)) + df3 <- agg(gd, age = "stddev") expect_is(df3, "DataFrame") - expect_equal(3, count(df3)) - expect_equal(columns(df3), c("name", "age")) + df3_local <- collect(df3) + expect_equal(0, df3_local[df3_local$name == "Andy",][1, 2]) - df4 <- sum(gd, "age") + df4 <- agg(gd, sumAge = sum(df$age)) expect_is(df4, "DataFrame") expect_equal(3, count(df4)) - expect_equal(3, count(mean(gd, "age"))) - expect_equal(3, count(max(gd, "age"))) + expect_equal(columns(df4), c("name", "sumAge")) + + df5 <- sum(gd, "age") + expect_is(df5, "DataFrame") + expect_equal(3, count(df5)) + + expect_equal(3, count(mean(gd))) + expect_equal(3, count(max(gd))) + expect_equal(30, collect(max(gd))[1, 2]) + expect_equal(1, collect(count(gd))[1, 2]) + + mockLines2 <- c("{\"name\":\"ID1\", \"value\": \"10\"}", + "{\"name\":\"ID1\", \"value\": \"10\"}", + "{\"name\":\"ID1\", \"value\": \"22\"}", + "{\"name\":\"ID2\", \"value\": \"-3\"}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(mockLines2, jsonPath2) + gd2 <- groupBy(jsonFile(sqlContext, jsonPath2), "name") + df6 <- agg(gd2, value = "sum") + df6_local <- collect(df6) + expect_equal(42, df6_local[df6_local$name == "ID1",][1, 2]) + expect_equal(-3, df6_local[df6_local$name == "ID2",][1, 2]) + + df7 <- agg(gd2, value = "stddev") + df7_local <- collect(df7) + expect_true(abs(df7_local[df7_local$name == "ID1",][1, 2] - 6.928203) < 1e-6) + expect_equal(0, df7_local[df7_local$name == "ID2",][1, 2]) + + mockLines3 <- c("{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}", + "{\"name\":\"Justin\", \"age\":1}") + jsonPath3 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(mockLines3, jsonPath3) + df8 <- jsonFile(sqlContext, jsonPath3) + gd3 <- groupBy(df8, "name") + gd3_local <- collect(sum(gd3)) + expect_equal(60, gd3_local[gd3_local$name == "Andy",][1, 2]) + expect_equal(20, gd3_local[gd3_local$name == "Justin",][1, 2]) + + expect_true(abs(collect(agg(df, sd(df$age)))[1, 1] - 7.778175) < 1e-6) + gd3_local <- collect(agg(gd3, var(df8$age))) + expect_equal(162, gd3_local[gd3_local$name == "Justin",][1, 2]) + + # make sure base:: or stats::sd, var are working + expect_true(abs(sd(1:2) - 0.7071068) < 1e-6) + expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6) + + unlink(jsonPath2) + unlink(jsonPath3) }) test_that("arrange() and orderBy() on a DataFrame", { @@ -1238,7 +1289,7 @@ test_that("mutate(), transform(), rename() and names()", { expect_equal(columns(transformedDF)[4], "newAge2") expect_equal(first(filter(transformedDF, transformedDF$name == "Andy"))$newAge, -30) - # test if transform on local data frames works + # test if base::transform on local data frames works # ensure the proper signature is used - otherwise this will fail to run attach(airquality) result <- transform(Ozone, logOzone = log(Ozone)) -- cgit v1.2.3