aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSun Rui <rui.sun@intel.com>2015-10-07 09:46:37 -0700
committerDavies Liu <davies.liu@gmail.com>2015-10-07 09:46:37 -0700
commitf57c63d4c30d092a320c72f8c7181f2fa711ec30 (patch)
tree6564d8b720e2b52241efe94593f22a59a3e2cd83
parent27cdde2ff87346fb54318532a476bf85f5837da7 (diff)
downloadspark-f57c63d4c30d092a320c72f8c7181f2fa711ec30.tar.gz
spark-f57c63d4c30d092a320c72f8c7181f2fa711ec30.tar.bz2
spark-f57c63d4c30d092a320c72f8c7181f2fa711ec30.zip
[SPARK-10752] [SPARKR] Implement corr() and cov in DataFrameStatFunctions.
Author: Sun Rui <rui.sun@intel.com> Closes #8869 from sun-rui/SPARK-10752.
-rw-r--r--R/pkg/DESCRIPTION1
-rw-r--r--R/pkg/NAMESPACE2
-rw-r--r--R/pkg/R/DataFrame.R33
-rw-r--r--R/pkg/R/generics.R10
-rw-r--r--R/pkg/R/stats.R102
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R12
6 files changed, 127 insertions, 33 deletions
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index a3a16c42a6..3d6edb70ec 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -33,4 +33,5 @@ Collate:
'mllib.R'
'serialize.R'
'sparkR.R'
+ 'stats.R'
'utils.R'
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index c28c47daea..9aad35469b 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -27,6 +27,8 @@ exportMethods("arrange",
"collect",
"columns",
"count",
+ "cov",
+ "corr",
"crosstab",
"describe",
"dim",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 14aea923fc..85db3a5ed3 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1828,36 +1828,6 @@ setMethod("fillna",
dataFrame(sdf)
})
-#' crosstab
-#'
-#' Computes a pair-wise frequency table of the given columns. Also known as a contingency
-#' table. The number of distinct values for each column should be less than 1e4. At most 1e6
-#' non-zero pair frequencies will be returned.
-#'
-#' @param col1 name of the first column. Distinct items will make the first item of each row.
-#' @param col2 name of the second column. Distinct items will make the column names of the output.
-#' @return a local R data.frame representing the contingency table. The first column of each row
-#' will be the distinct values of `col1` and the column names will be the distinct values
-#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no
-#' occurrences will have zero as their counts.
-#'
-#' @rdname statfunctions
-#' @name crosstab
-#' @export
-#' @examples
-#' \dontrun{
-#' df <- jsonFile(sqlCtx, "/path/to/file.json")
-#' ct = crosstab(df, "title", "gender")
-#' }
-setMethod("crosstab",
- signature(x = "DataFrame", col1 = "character", col2 = "character"),
- function(x, col1, col2) {
- statFunctions <- callJMethod(x@sdf, "stat")
- sct <- callJMethod(statFunctions, "crosstab", col1, col2)
- collect(dataFrame(sct))
- })
-
-
#' This function downloads the contents of a DataFrame into an R's data.frame.
#' Since data.frames are held in memory, ensure that you have enough memory
#' in your system to accommodate the contents.
@@ -1879,5 +1849,4 @@ setMethod("as.data.frame",
stop(paste("Unused argument(s): ", paste(list(...), collapse=", ")))
}
collect(x)
- }
-)
+ })
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 3db41e0fe2..e9086fdbd1 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -399,6 +399,14 @@ setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") })
#' @export
setGeneric("columns", function(x) {standardGeneric("columns") })
+#' @rdname statfunctions
+#' @export
+setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") })
+
+#' @rdname statfunctions
+#' @export
+setGeneric("corr", function(x, col1, col2, method = "pearson") {standardGeneric("corr") })
+
#' @rdname describe
#' @export
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
@@ -986,4 +994,4 @@ setGeneric("rbind", signature = "...")
#' @rdname as.data.frame
#' @export
-setGeneric("as.data.frame") \ No newline at end of file
+setGeneric("as.data.frame")
diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R
new file mode 100644
index 0000000000..06382d55d0
--- /dev/null
+++ b/R/pkg/R/stats.R
@@ -0,0 +1,102 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# stats.R - Statistic functions for DataFrames.
+
+setOldClass("jobj")
+
+#' crosstab
+#'
+#' Computes a pair-wise frequency table of the given columns. Also known as a contingency
+#' table. The number of distinct values for each column should be less than 1e4. At most 1e6
+#' non-zero pair frequencies will be returned.
+#'
+#' @param col1 name of the first column. Distinct items will make the first item of each row.
+#' @param col2 name of the second column. Distinct items will make the column names of the output.
+#' @return a local R data.frame representing the contingency table. The first column of each row
+#' will be the distinct values of `col1` and the column names will be the distinct values
+#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no
+#' occurrences will have zero as their counts.
+#'
+#' @rdname statfunctions
+#' @name crosstab
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- jsonFile(sqlContext, "/path/to/file.json")
+#' ct <- crosstab(df, "title", "gender")
+#' }
+setMethod("crosstab",
+ signature(x = "DataFrame", col1 = "character", col2 = "character"),
+ function(x, col1, col2) {
+ statFunctions <- callJMethod(x@sdf, "stat")
+ sct <- callJMethod(statFunctions, "crosstab", col1, col2)
+ collect(dataFrame(sct))
+ })
+
+#' cov
+#'
+#' Calculate the sample covariance of two numerical columns of a DataFrame.
+#'
+#' @param x A SparkSQL DataFrame
+#' @param col1 the name of the first column
+#' @param col2 the name of the second column
+#' @return the covariance of the two columns.
+#'
+#' @rdname statfunctions
+#' @name cov
+#' @export
+#' @examples
+#'\dontrun{
+#' df <- jsonFile(sqlContext, "/path/to/file.json")
+#' cov <- cov(df, "title", "gender")
+#' }
+setMethod("cov",
+ signature(x = "DataFrame", col1 = "character", col2 = "character"),
+ function(x, col1, col2) {
+ statFunctions <- callJMethod(x@sdf, "stat")
+ callJMethod(statFunctions, "cov", col1, col2)
+ })
+
+#' corr
+#'
+#' Calculates the correlation of two columns of a DataFrame.
+#' Currently only supports the Pearson Correlation Coefficient.
+#' For Spearman Correlation, consider using RDD methods found in MLlib's Statistics.
+#'
+#' @param x A SparkSQL DataFrame
+#' @param col1 the name of the first column
+#' @param col2 the name of the second column
+#' @param method Optional. A character specifying the method for calculating the correlation.
+#' only "pearson" is allowed now.
+#' @return The Pearson Correlation Coefficient as a Double.
+#'
+#' @rdname statfunctions
+#' @name corr
+#' @export
+#' @examples
+#'\dontrun{
+#' df <- jsonFile(sqlContext, "/path/to/file.json")
+#' corr <- corr(df, "title", "gender")
+#' corr <- corr(df, "title", "gender", method = "pearson")
+#' }
+setMethod("corr",
+ signature(x = "DataFrame", col1 = "character", col2 = "character"),
+ function(x, col1, col2, method = "pearson") {
+ statFunctions <- callJMethod(x@sdf, "stat")
+ callJMethod(statFunctions, "corr", col1, col2, method)
+ })
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index faf42b7182..bcf52b8fa7 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -1329,6 +1329,18 @@ test_that("crosstab() on a DataFrame", {
expect_identical(expected, ordered)
})
+test_that("cov() and corr() on a DataFrame", {
+ l <- lapply(c(0:9), function(x) { list(x, x * 2.0) })
+ df <- createDataFrame(sqlContext, l, c("singles", "doubles"))
+ result <- cov(df, "singles", "doubles")
+ expect_true(abs(result - 55.0 / 3) < 1e-12)
+
+ result <- corr(df, "singles", "doubles")
+ expect_true(abs(result - 1.0) < 1e-12)
+ result <- corr(df, "singles", "doubles", "pearson")
+ expect_true(abs(result - 1.0) < 1e-12)
+})
+
test_that("SQL error message is returned from JVM", {
retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e)
expect_equal(grepl("Table Not Found: blah", retError), TRUE)