aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/NAMESPACE1
-rw-r--r--R/pkg/R/DataFrame.R28
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R13
4 files changed, 46 insertions, 0 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 5834813319..7f7a8a2e4d 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -26,6 +26,7 @@ exportMethods("arrange",
"collect",
"columns",
"count",
+ "crosstab",
"describe",
"distinct",
"dropna",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index a58433df3c..06dd6b75df 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1554,3 +1554,31 @@ setMethod("fillna",
}
dataFrame(sdf)
})
+
+#' crosstab
+#'
+#' Computes a pair-wise frequency table of the given columns. Also known as a contingency
+#' table. The number of distinct values for each column should be less than 1e4. At most 1e6
+#' non-zero pair frequencies will be returned.
+#'
+#' @param col1 name of the first column. Distinct items will make the first item of each row.
+#' @param col2 name of the second column. Distinct items will make the column names of the output.
+#' @return a local R data.frame representing the contingency table. The first column of each row
+#' will be the distinct values of `col1` and the column names will be the distinct values
+#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no
+#' occurrences will have `null` as their counts.
+#'
+#' @rdname statfunctions
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- jsonFile(sqlCtx, "/path/to/file.json")
+#' ct = crosstab(df, "title", "gender")
+#' }
+setMethod("crosstab",
+ signature(x = "DataFrame", col1 = "character", col2 = "character"),
+ function(x, col1, col2) {
+ statFunctions <- callJMethod(x@sdf, "stat")
+ sct <- callJMethod(statFunctions, "crosstab", col1, col2)
+ collect(dataFrame(sct))
+ })
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 39b5586f7c..836e0175c3 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -59,6 +59,10 @@ setGeneric("count", function(x) { standardGeneric("count") })
# @export
setGeneric("countByValue", function(x) { standardGeneric("countByValue") })
+# @rdname statfunctions
+# @export
+setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") })
+
# @rdname distinct
# @export
setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") })
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index a3039d36c9..62fe48a5d6 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -987,6 +987,19 @@ test_that("fillna() on a DataFrame", {
expect_identical(expected, actual)
})
+test_that("crosstab() on a DataFrame", {
+ rdd <- lapply(parallelize(sc, 0:3), function(x) {
+ list(paste0("a", x %% 3), paste0("b", x %% 2))
+ })
+ df <- toDF(rdd, list("a", "b"))
+ ct <- crosstab(df, "a", "b")
+ ordered <- ct[order(ct$a_b),]
+ row.names(ordered) <- NULL
+ expected <- data.frame("a_b" = c("a0", "a1", "a2"), "b0" = c(1, 0, 1), "b1" = c(1, 1, 0),
+ stringsAsFactors = FALSE, row.names = NULL)
+ expect_identical(expected, ordered)
+})
+
unlink(parquetPath)
unlink(jsonPath)
unlink(jsonPathNa)