From e7f9199e709c46a6b5ad6b03c9ecf12cc19e3a41 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 26 Jan 2016 19:29:47 -0800 Subject: [SPARK-12903][SPARKR] Add covar_samp and covar_pop for SparkR Add ```covar_samp``` and ```covar_pop``` for SparkR. Should we also provide ```cov``` alias for ```covar_samp```? There is ```cov``` implementation at stats.R which masks ```stats::cov``` already, but may bring to breaking API change. cc sun-rui felixcheung shivaram Author: Yanbo Liang Closes #10829 from yanboliang/spark-12903. --- R/pkg/NAMESPACE | 2 ++ R/pkg/R/functions.R | 58 +++++++++++++++++++++++++++++++ R/pkg/R/generics.R | 10 +++++- R/pkg/R/stats.R | 3 +- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 ++ 5 files changed, 73 insertions(+), 2 deletions(-) (limited to 'R/pkg') diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 2cc1544bef..f194a46303 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -35,6 +35,8 @@ exportMethods("arrange", "count", "cov", "corr", + "covar_samp", + "covar_pop", "crosstab", "describe", "dim", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9bb7876b38..8f8651c295 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -275,6 +275,64 @@ setMethod("corr", signature(x = "Column"), column(jc) }) +#' cov +#' +#' Compute the sample covariance between two expressions. +#' +#' @rdname cov +#' @name cov +#' @family math_funcs +#' @export +#' @examples +#' \dontrun{ +#' cov(df$c, df$d) +#' cov("c", "d") +#' covar_samp(df$c, df$d) +#' covar_samp("c", "d") +#' } +setMethod("cov", signature(x = "characterOrColumn"), + function(x, col2) { + stopifnot(is(class(col2), "characterOrColumn")) + covar_samp(x, col2) + }) + +#' @rdname cov +#' @name covar_samp +setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), + function(col1, col2) { + stopifnot(class(col1) == class(col2)) + if (class(col1) == "Column") { + col1 <- col1@jc + col2 <- col2@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "covar_samp", col1, col2) + column(jc) + }) + +#' covar_pop +#' +#' Compute the population covariance between two expressions. +#' +#' @rdname covar_pop +#' @name covar_pop +#' @family math_funcs +#' @export +#' @examples +#' \dontrun{ +#' covar_pop(df$c, df$d) +#' covar_pop("c", "d") +#' } +setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), + function(col1, col2) { + stopifnot(class(col1) == class(col2)) + if (class(col1) == "Column") { + col1 <- col1@jc + col2 <- col2@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "covar_pop", col1, col2) + column(jc) + }) + #' cos #' #' Computes the cosine of the given value. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 04784d5156..2dba71abec 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -418,12 +418,20 @@ setGeneric("columns", function(x) {standardGeneric("columns") }) #' @rdname statfunctions #' @export -setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") }) +setGeneric("cov", function(x, ...) {standardGeneric("cov") }) #' @rdname statfunctions #' @export setGeneric("corr", function(x, ...) {standardGeneric("corr") }) +#' @rdname statfunctions +#' @export +setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") }) + +#' @rdname statfunctions +#' @export +setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") }) + #' @rdname summary #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index d17cce9c75..2e8076843f 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -66,8 +66,9 @@ setMethod("crosstab", #' cov <- cov(df, "title", "gender") #' } setMethod("cov", - signature(x = "DataFrame", col1 = "character", col2 = "character"), + signature(x = "DataFrame"), function(x, col1, col2) { + stopifnot(class(col1) == "character" && class(col2) == "character") statFunctions <- callJMethod(x@sdf, "stat") callJMethod(statFunctions, "cov", col1, col2) }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index b52a11fb1a..7b5713720d 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -996,6 +996,8 @@ test_that("column functions", { c14 <- cume_dist() + ntile(1) + corr(c, c1) c15 <- dense_rank() + percent_rank() + rank() + row_number() c16 <- is.nan(c) + isnan(c) + isNaN(c) + c17 <- cov(c, c1) + cov("c", "c1") + covar_samp(c, c1) + covar_samp("c", "c1") + c18 <- covar_pop(c, c1) + covar_pop("c", "c1") # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) -- cgit v1.2.3