aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-26 19:29:47 -0800
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2016-01-26 19:29:47 -0800
commite7f9199e709c46a6b5ad6b03c9ecf12cc19e3a41 (patch)
treec1721ec403a49b7aa6f7180616ac3d334a9c6cd6
parentb72611f20a03c790b6fd341b6ffdb3b5437609ee (diff)
downloadspark-e7f9199e709c46a6b5ad6b03c9ecf12cc19e3a41.tar.gz
spark-e7f9199e709c46a6b5ad6b03c9ecf12cc19e3a41.tar.bz2
spark-e7f9199e709c46a6b5ad6b03c9ecf12cc19e3a41.zip
[SPARK-12903][SPARKR] Add covar_samp and covar_pop for SparkR
Add ```covar_samp``` and ```covar_pop``` for SparkR. Should we also provide ```cov``` alias for ```covar_samp```? There is ```cov``` implementation at stats.R which masks ```stats::cov``` already, but may bring to breaking API change. cc sun-rui felixcheung shivaram Author: Yanbo Liang <ybliang8@gmail.com> Closes #10829 from yanboliang/spark-12903.
-rw-r--r--R/pkg/NAMESPACE2
-rw-r--r--R/pkg/R/functions.R58
-rw-r--r--R/pkg/R/generics.R10
-rw-r--r--R/pkg/R/stats.R3
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R2
5 files changed, 73 insertions, 2 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 2cc1544bef..f194a46303 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -35,6 +35,8 @@ exportMethods("arrange",
"count",
"cov",
"corr",
+ "covar_samp",
+ "covar_pop",
"crosstab",
"describe",
"dim",
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index 9bb7876b38..8f8651c295 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -275,6 +275,64 @@ setMethod("corr", signature(x = "Column"),
column(jc)
})
+#' cov
+#'
+#' Compute the sample covariance between two expressions.
+#'
+#' @rdname cov
+#' @name cov
+#' @family math_funcs
+#' @export
+#' @examples
+#' \dontrun{
+#' cov(df$c, df$d)
+#' cov("c", "d")
+#' covar_samp(df$c, df$d)
+#' covar_samp("c", "d")
+#' }
+setMethod("cov", signature(x = "characterOrColumn"),
+ function(x, col2) {
+ stopifnot(is(class(col2), "characterOrColumn"))
+ covar_samp(x, col2)
+ })
+
+#' @rdname cov
+#' @name covar_samp
+setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"),
+ function(col1, col2) {
+ stopifnot(class(col1) == class(col2))
+ if (class(col1) == "Column") {
+ col1 <- col1@jc
+ col2 <- col2@jc
+ }
+ jc <- callJStatic("org.apache.spark.sql.functions", "covar_samp", col1, col2)
+ column(jc)
+ })
+
+#' covar_pop
+#'
+#' Compute the population covariance between two expressions.
+#'
+#' @rdname covar_pop
+#' @name covar_pop
+#' @family math_funcs
+#' @export
+#' @examples
+#' \dontrun{
+#' covar_pop(df$c, df$d)
+#' covar_pop("c", "d")
+#' }
+setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"),
+ function(col1, col2) {
+ stopifnot(class(col1) == class(col2))
+ if (class(col1) == "Column") {
+ col1 <- col1@jc
+ col2 <- col2@jc
+ }
+ jc <- callJStatic("org.apache.spark.sql.functions", "covar_pop", col1, col2)
+ column(jc)
+ })
+
#' cos
#'
#' Computes the cosine of the given value.
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 04784d5156..2dba71abec 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -418,12 +418,20 @@ setGeneric("columns", function(x) {standardGeneric("columns") })
#' @rdname statfunctions
#' @export
-setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") })
+setGeneric("cov", function(x, ...) {standardGeneric("cov") })
#' @rdname statfunctions
#' @export
setGeneric("corr", function(x, ...) {standardGeneric("corr") })
+#' @rdname statfunctions
+#' @export
+setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") })
+
+#' @rdname statfunctions
+#' @export
+setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") })
+
#' @rdname summary
#' @export
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R
index d17cce9c75..2e8076843f 100644
--- a/R/pkg/R/stats.R
+++ b/R/pkg/R/stats.R
@@ -66,8 +66,9 @@ setMethod("crosstab",
#' cov <- cov(df, "title", "gender")
#' }
setMethod("cov",
- signature(x = "DataFrame", col1 = "character", col2 = "character"),
+ signature(x = "DataFrame"),
function(x, col1, col2) {
+ stopifnot(class(col1) == "character" && class(col2) == "character")
statFunctions <- callJMethod(x@sdf, "stat")
callJMethod(statFunctions, "cov", col1, col2)
})
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index b52a11fb1a..7b5713720d 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -996,6 +996,8 @@ test_that("column functions", {
c14 <- cume_dist() + ntile(1) + corr(c, c1)
c15 <- dense_rank() + percent_rank() + rank() + row_number()
c16 <- is.nan(c) + isnan(c) + isNaN(c)
+ c17 <- cov(c, c1) + cov("c", "c1") + covar_samp(c, c1) + covar_samp("c", "c1")
+ c18 <- covar_pop(c, c1) + covar_pop("c", "c1")
# Test if base::is.nan() is exposed
expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE))