From f57c63d4c30d092a320c72f8c7181f2fa711ec30 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 7 Oct 2015 09:46:37 -0700 Subject: [SPARK-10752] [SPARKR] Implement corr() and cov in DataFrameStatFunctions. Author: Sun Rui Closes #8869 from sun-rui/SPARK-10752. --- R/pkg/DESCRIPTION | 1 + R/pkg/NAMESPACE | 2 + R/pkg/R/DataFrame.R | 33 +------------ R/pkg/R/generics.R | 10 +++- R/pkg/R/stats.R | 102 +++++++++++++++++++++++++++++++++++++++ R/pkg/inst/tests/test_sparkSQL.R | 12 +++++ 6 files changed, 127 insertions(+), 33 deletions(-) create mode 100644 R/pkg/R/stats.R diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index a3a16c42a6..3d6edb70ec 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -33,4 +33,5 @@ Collate: 'mllib.R' 'serialize.R' 'sparkR.R' + 'stats.R' 'utils.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index c28c47daea..9aad35469b 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -27,6 +27,8 @@ exportMethods("arrange", "collect", "columns", "count", + "cov", + "corr", "crosstab", "describe", "dim", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 14aea923fc..85db3a5ed3 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1828,36 +1828,6 @@ setMethod("fillna", dataFrame(sdf) }) -#' crosstab -#' -#' Computes a pair-wise frequency table of the given columns. Also known as a contingency -#' table. The number of distinct values for each column should be less than 1e4. At most 1e6 -#' non-zero pair frequencies will be returned. -#' -#' @param col1 name of the first column. Distinct items will make the first item of each row. -#' @param col2 name of the second column. Distinct items will make the column names of the output. -#' @return a local R data.frame representing the contingency table. The first column of each row -#' will be the distinct values of `col1` and the column names will be the distinct values -#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no -#' occurrences will have zero as their counts. -#' -#' @rdname statfunctions -#' @name crosstab -#' @export -#' @examples -#' \dontrun{ -#' df <- jsonFile(sqlCtx, "/path/to/file.json") -#' ct = crosstab(df, "title", "gender") -#' } -setMethod("crosstab", - signature(x = "DataFrame", col1 = "character", col2 = "character"), - function(x, col1, col2) { - statFunctions <- callJMethod(x@sdf, "stat") - sct <- callJMethod(statFunctions, "crosstab", col1, col2) - collect(dataFrame(sct)) - }) - - #' This function downloads the contents of a DataFrame into an R's data.frame. #' Since data.frames are held in memory, ensure that you have enough memory #' in your system to accommodate the contents. @@ -1879,5 +1849,4 @@ setMethod("as.data.frame", stop(paste("Unused argument(s): ", paste(list(...), collapse=", "))) } collect(x) - } -) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 3db41e0fe2..e9086fdbd1 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -399,6 +399,14 @@ setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) #' @export setGeneric("columns", function(x) {standardGeneric("columns") }) +#' @rdname statfunctions +#' @export +setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") }) + +#' @rdname statfunctions +#' @export +setGeneric("corr", function(x, col1, col2, method = "pearson") {standardGeneric("corr") }) + #' @rdname describe #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) @@ -986,4 +994,4 @@ setGeneric("rbind", signature = "...") #' @rdname as.data.frame #' @export -setGeneric("as.data.frame") \ No newline at end of file +setGeneric("as.data.frame") diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R new file mode 100644 index 0000000000..06382d55d0 --- /dev/null +++ b/R/pkg/R/stats.R @@ -0,0 +1,102 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# stats.R - Statistic functions for DataFrames. + +setOldClass("jobj") + +#' crosstab +#' +#' Computes a pair-wise frequency table of the given columns. Also known as a contingency +#' table. The number of distinct values for each column should be less than 1e4. At most 1e6 +#' non-zero pair frequencies will be returned. +#' +#' @param col1 name of the first column. Distinct items will make the first item of each row. +#' @param col2 name of the second column. Distinct items will make the column names of the output. +#' @return a local R data.frame representing the contingency table. The first column of each row +#' will be the distinct values of `col1` and the column names will be the distinct values +#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no +#' occurrences will have zero as their counts. +#' +#' @rdname statfunctions +#' @name crosstab +#' @export +#' @examples +#' \dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' ct <- crosstab(df, "title", "gender") +#' } +setMethod("crosstab", + signature(x = "DataFrame", col1 = "character", col2 = "character"), + function(x, col1, col2) { + statFunctions <- callJMethod(x@sdf, "stat") + sct <- callJMethod(statFunctions, "crosstab", col1, col2) + collect(dataFrame(sct)) + }) + +#' cov +#' +#' Calculate the sample covariance of two numerical columns of a DataFrame. +#' +#' @param x A SparkSQL DataFrame +#' @param col1 the name of the first column +#' @param col2 the name of the second column +#' @return the covariance of the two columns. +#' +#' @rdname statfunctions +#' @name cov +#' @export +#' @examples +#'\dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' cov <- cov(df, "title", "gender") +#' } +setMethod("cov", + signature(x = "DataFrame", col1 = "character", col2 = "character"), + function(x, col1, col2) { + statFunctions <- callJMethod(x@sdf, "stat") + callJMethod(statFunctions, "cov", col1, col2) + }) + +#' corr +#' +#' Calculates the correlation of two columns of a DataFrame. +#' Currently only supports the Pearson Correlation Coefficient. +#' For Spearman Correlation, consider using RDD methods found in MLlib's Statistics. +#' +#' @param x A SparkSQL DataFrame +#' @param col1 the name of the first column +#' @param col2 the name of the second column +#' @param method Optional. A character specifying the method for calculating the correlation. +#' only "pearson" is allowed now. +#' @return The Pearson Correlation Coefficient as a Double. +#' +#' @rdname statfunctions +#' @name corr +#' @export +#' @examples +#'\dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' corr <- corr(df, "title", "gender") +#' corr <- corr(df, "title", "gender", method = "pearson") +#' } +setMethod("corr", + signature(x = "DataFrame", col1 = "character", col2 = "character"), + function(x, col1, col2, method = "pearson") { + statFunctions <- callJMethod(x@sdf, "stat") + callJMethod(statFunctions, "corr", col1, col2, method) + }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index faf42b7182..bcf52b8fa7 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1329,6 +1329,18 @@ test_that("crosstab() on a DataFrame", { expect_identical(expected, ordered) }) +test_that("cov() and corr() on a DataFrame", { + l <- lapply(c(0:9), function(x) { list(x, x * 2.0) }) + df <- createDataFrame(sqlContext, l, c("singles", "doubles")) + result <- cov(df, "singles", "doubles") + expect_true(abs(result - 55.0 / 3) < 1e-12) + + result <- corr(df, "singles", "doubles") + expect_true(abs(result - 1.0) < 1e-12) + result <- corr(df, "singles", "doubles", "pearson") + expect_true(abs(result - 1.0) < 1e-12) +}) + test_that("SQL error message is returned from JVM", { retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) expect_equal(grepl("Table Not Found: blah", retError), TRUE) -- cgit v1.2.3