From 217db56ba11fcdf9e3a81946667d1d99ad7344ee Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 20 Jun 2016 21:09:39 -0700 Subject: [SPARK-15294][R] Add `pivot` to SparkR ## What changes were proposed in this pull request? This PR adds `pivot` function to SparkR for API parity. Since this PR is based on https://github.com/apache/spark/pull/13295 , mhnatiuk should be credited for the work he did. ## How was this patch tested? Pass the Jenkins tests (including new testcase.) Author: Dongjoon Hyun Closes #13786 from dongjoon-hyun/SPARK-15294. --- R/pkg/NAMESPACE | 1 + R/pkg/R/generics.R | 4 +++ R/pkg/R/group.R | 43 +++++++++++++++++++++++++++++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 25 ++++++++++++++++++ 4 files changed, 73 insertions(+) (limited to 'R/pkg') diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 45663f4c2c..ea42888eae 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -294,6 +294,7 @@ exportMethods("%in%", exportClasses("GroupedData") exportMethods("agg") +exportMethods("pivot") export("as.DataFrame", "cacheTable", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 3fb6370497..c307de7c07 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -160,6 +160,10 @@ setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) # @export setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")}) +# @rdname pivot +# @export +setGeneric("pivot", function(x, colname, values = list()) { standardGeneric("pivot") }) + # @rdname reduce # @export setGeneric("reduce", function(x, func) { standardGeneric("reduce") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 51e151623c..0687f14adf 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -134,6 +134,49 @@ methods <- c("avg", "max", "mean", "min", "sum") # These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", "stddev_pop", # "variance", "var_samp", "var_pop" +#' Pivot a column of the GroupedData and perform the specified aggregation. +#' +#' Pivot a column of the GroupedData and perform the specified aggregation. +#' There are two versions of pivot function: one that requires the caller to specify the list +#' of distinct values to pivot on, and one that does not. The latter is more concise but less +#' efficient, because Spark needs to first compute the list of distinct values internally. +#' +#' @param x a GroupedData object +#' @param colname A column name +#' @param values A value or a list/vector of distinct values for the output columns. +#' @return GroupedData object +#' @rdname pivot +#' @name pivot +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(data.frame( +#' earnings = c(10000, 10000, 11000, 15000, 12000, 20000, 21000, 22000), +#' course = c("R", "Python", "R", "Python", "R", "Python", "R", "Python"), +#' period = c("1H", "1H", "2H", "2H", "1H", "1H", "2H", "2H"), +#' year = c(2015, 2015, 2015, 2015, 2016, 2016, 2016, 2016) +#' )) +#' group_sum <- sum(pivot(groupBy(df, "year"), "course"), "earnings") +#' group_min <- min(pivot(groupBy(df, "year"), "course", "R"), "earnings") +#' group_max <- max(pivot(groupBy(df, "year"), "course", c("Python", "R")), "earnings") +#' group_mean <- mean(pivot(groupBy(df, "year"), "course", list("Python", "R")), "earnings") +#' } +#' @note pivot since 2.0.0 +setMethod("pivot", + signature(x = "GroupedData", colname = "character"), + function(x, colname, values = list()){ + stopifnot(length(colname) == 1) + if (length(values) == 0) { + result <- callJMethod(x@sgd, "pivot", colname) + } else { + if (length(values) > length(unique(values))) { + stop("Values are not unique") + } + result <- callJMethod(x@sgd, "pivot", colname, as.list(values)) + } + groupedData(result) + }) + createMethod <- function(name) { setMethod(name, signature(x = "GroupedData"), diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index d53c40d423..7c192fb5a0 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1398,6 +1398,31 @@ test_that("group by, agg functions", { unlink(jsonPath3) }) +test_that("pivot GroupedData column", { + df <- createDataFrame(data.frame( + earnings = c(10000, 10000, 11000, 15000, 12000, 20000, 21000, 22000), + course = c("R", "Python", "R", "Python", "R", "Python", "R", "Python"), + year = c(2013, 2013, 2014, 2014, 2015, 2015, 2016, 2016) + )) + sum1 <- collect(sum(pivot(groupBy(df, "year"), "course"), "earnings")) + sum2 <- collect(sum(pivot(groupBy(df, "year"), "course", c("Python", "R")), "earnings")) + sum3 <- collect(sum(pivot(groupBy(df, "year"), "course", list("Python", "R")), "earnings")) + sum4 <- collect(sum(pivot(groupBy(df, "year"), "course", "R"), "earnings")) + + correct_answer <- data.frame( + year = c(2013, 2014, 2015, 2016), + Python = c(10000, 15000, 20000, 22000), + R = c(10000, 11000, 12000, 21000) + ) + expect_equal(sum1, correct_answer) + expect_equal(sum2, correct_answer) + expect_equal(sum3, correct_answer) + expect_equal(sum4, correct_answer[, c("year", "R")]) + + expect_error(collect(sum(pivot(groupBy(df, "year"), "course", c("R", "R")), "earnings"))) + expect_error(collect(sum(pivot(groupBy(df, "year"), "course", list("R", "R")), "earnings"))) +}) + test_that("arrange() and orderBy() on a DataFrame", { df <- read.json(jsonPath) sorted <- arrange(df, df$age) -- cgit v1.2.3