aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--R/pkg/R/functions.R15
-rw-r--r--R/pkg/R/generics.R2
-rw-r--r--R/pkg/R/stats.R9
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R2
4 files changed, 22 insertions, 6 deletions
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index 7432cb8e7c..25231451df 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -259,6 +259,21 @@ setMethod("column",
function(x) {
col(x)
})
+#' corr
+#'
+#' Computes the Pearson Correlation Coefficient for two Columns.
+#'
+#' @rdname corr
+#' @name corr
+#' @family math_funcs
+#' @export
+#' @examples \dontrun{corr(df$c, df$d)}
+setMethod("corr", signature(x = "Column"),
+ function(x, col2) {
+ stopifnot(class(col2) == "Column")
+ jc <- callJStatic("org.apache.spark.sql.functions", "corr", x@jc, col2@jc)
+ column(jc)
+ })
#' cos
#'
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 4b5f786d39..acfd4841e1 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -411,7 +411,7 @@ setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") })
#' @rdname statfunctions
#' @export
-setGeneric("corr", function(x, col1, col2, method = "pearson") {standardGeneric("corr") })
+setGeneric("corr", function(x, ...) {standardGeneric("corr") })
#' @rdname summary
#' @export
diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R
index f79329b115..d17cce9c75 100644
--- a/R/pkg/R/stats.R
+++ b/R/pkg/R/stats.R
@@ -77,7 +77,7 @@ setMethod("cov",
#' Calculates the correlation of two columns of a DataFrame.
#' Currently only supports the Pearson Correlation Coefficient.
#' For Spearman Correlation, consider using RDD methods found in MLlib's Statistics.
-#'
+#'
#' @param x A SparkSQL DataFrame
#' @param col1 the name of the first column
#' @param col2 the name of the second column
@@ -95,8 +95,9 @@ setMethod("cov",
#' corr <- corr(df, "title", "gender", method = "pearson")
#' }
setMethod("corr",
- signature(x = "DataFrame", col1 = "character", col2 = "character"),
+ signature(x = "DataFrame"),
function(x, col1, col2, method = "pearson") {
+ stopifnot(class(col1) == "character" && class(col2) == "character")
statFunctions <- callJMethod(x@sdf, "stat")
callJMethod(statFunctions, "corr", col1, col2, method)
})
@@ -109,7 +110,7 @@ setMethod("corr",
#'
#' @param x A SparkSQL DataFrame.
#' @param cols A vector column names to search frequent items in.
-#' @param support (Optional) The minimum frequency for an item to be considered `frequent`.
+#' @param support (Optional) The minimum frequency for an item to be considered `frequent`.
#' Should be greater than 1e-4. Default support = 0.01.
#' @return a local R data.frame with the frequent items in each column
#'
@@ -131,7 +132,7 @@ setMethod("freqItems", signature(x = "DataFrame", cols = "character"),
#' sampleBy
#'
#' Returns a stratified sample without replacement based on the fraction given on each stratum.
-#'
+#'
#' @param x A SparkSQL DataFrame
#' @param col column that defines strata
#' @param fractions A named list giving sampling fraction for each stratum. If a stratum is
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 2d26b92ac7..a5a234a02d 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -892,7 +892,7 @@ test_that("column functions", {
c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c)
c12 <- variance(c)
c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1)
- c14 <- cume_dist() + ntile(1)
+ c14 <- cume_dist() + ntile(1) + corr(c, c1)
c15 <- dense_rank() + percent_rank() + rank() + row_number()
# Test if base::rank() is exposed