aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-06-20 21:09:39 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2016-06-20 21:09:39 -0700
commit217db56ba11fcdf9e3a81946667d1d99ad7344ee (patch)
treec126e953440cf3ac4e902ea259d4a9ec1b60bd6e /R
parenta46553cbacf0e4012df89fe55385dec5beaa680a (diff)
downloadspark-217db56ba11fcdf9e3a81946667d1d99ad7344ee.tar.gz
spark-217db56ba11fcdf9e3a81946667d1d99ad7344ee.tar.bz2
spark-217db56ba11fcdf9e3a81946667d1d99ad7344ee.zip
[SPARK-15294][R] Add `pivot` to SparkR
## What changes were proposed in this pull request? This PR adds `pivot` function to SparkR for API parity. Since this PR is based on https://github.com/apache/spark/pull/13295 , mhnatiuk should be credited for the work he did. ## How was this patch tested? Pass the Jenkins tests (including new testcase.) Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13786 from dongjoon-hyun/SPARK-15294.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/NAMESPACE1
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/R/group.R43
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R25
4 files changed, 73 insertions, 0 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 45663f4c2c..ea42888eae 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -294,6 +294,7 @@ exportMethods("%in%",
exportClasses("GroupedData")
exportMethods("agg")
+exportMethods("pivot")
export("as.DataFrame",
"cacheTable",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 3fb6370497..c307de7c07 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -160,6 +160,10 @@ setGeneric("persist", function(x, newLevel) { standardGeneric("persist") })
# @export
setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")})
+# @rdname pivot
+# @export
+setGeneric("pivot", function(x, colname, values = list()) { standardGeneric("pivot") })
+
# @rdname reduce
# @export
setGeneric("reduce", function(x, func) { standardGeneric("reduce") })
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 51e151623c..0687f14adf 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -134,6 +134,49 @@ methods <- c("avg", "max", "mean", "min", "sum")
# These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", "stddev_pop",
# "variance", "var_samp", "var_pop"
+#' Pivot a column of the GroupedData and perform the specified aggregation.
+#'
+#' Pivot a column of the GroupedData and perform the specified aggregation.
+#' There are two versions of pivot function: one that requires the caller to specify the list
+#' of distinct values to pivot on, and one that does not. The latter is more concise but less
+#' efficient, because Spark needs to first compute the list of distinct values internally.
+#'
+#' @param x a GroupedData object
+#' @param colname A column name
+#' @param values A value or a list/vector of distinct values for the output columns.
+#' @return GroupedData object
+#' @rdname pivot
+#' @name pivot
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- createDataFrame(data.frame(
+#' earnings = c(10000, 10000, 11000, 15000, 12000, 20000, 21000, 22000),
+#' course = c("R", "Python", "R", "Python", "R", "Python", "R", "Python"),
+#' period = c("1H", "1H", "2H", "2H", "1H", "1H", "2H", "2H"),
+#' year = c(2015, 2015, 2015, 2015, 2016, 2016, 2016, 2016)
+#' ))
+#' group_sum <- sum(pivot(groupBy(df, "year"), "course"), "earnings")
+#' group_min <- min(pivot(groupBy(df, "year"), "course", "R"), "earnings")
+#' group_max <- max(pivot(groupBy(df, "year"), "course", c("Python", "R")), "earnings")
+#' group_mean <- mean(pivot(groupBy(df, "year"), "course", list("Python", "R")), "earnings")
+#' }
+#' @note pivot since 2.0.0
+setMethod("pivot",
+ signature(x = "GroupedData", colname = "character"),
+ function(x, colname, values = list()){
+ stopifnot(length(colname) == 1)
+ if (length(values) == 0) {
+ result <- callJMethod(x@sgd, "pivot", colname)
+ } else {
+ if (length(values) > length(unique(values))) {
+ stop("Values are not unique")
+ }
+ result <- callJMethod(x@sgd, "pivot", colname, as.list(values))
+ }
+ groupedData(result)
+ })
+
createMethod <- function(name) {
setMethod(name,
signature(x = "GroupedData"),
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index d53c40d423..7c192fb5a0 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1398,6 +1398,31 @@ test_that("group by, agg functions", {
unlink(jsonPath3)
})
+test_that("pivot GroupedData column", {
+ df <- createDataFrame(data.frame(
+ earnings = c(10000, 10000, 11000, 15000, 12000, 20000, 21000, 22000),
+ course = c("R", "Python", "R", "Python", "R", "Python", "R", "Python"),
+ year = c(2013, 2013, 2014, 2014, 2015, 2015, 2016, 2016)
+ ))
+ sum1 <- collect(sum(pivot(groupBy(df, "year"), "course"), "earnings"))
+ sum2 <- collect(sum(pivot(groupBy(df, "year"), "course", c("Python", "R")), "earnings"))
+ sum3 <- collect(sum(pivot(groupBy(df, "year"), "course", list("Python", "R")), "earnings"))
+ sum4 <- collect(sum(pivot(groupBy(df, "year"), "course", "R"), "earnings"))
+
+ correct_answer <- data.frame(
+ year = c(2013, 2014, 2015, 2016),
+ Python = c(10000, 15000, 20000, 22000),
+ R = c(10000, 11000, 12000, 21000)
+ )
+ expect_equal(sum1, correct_answer)
+ expect_equal(sum2, correct_answer)
+ expect_equal(sum3, correct_answer)
+ expect_equal(sum4, correct_answer[, c("year", "R")])
+
+ expect_error(collect(sum(pivot(groupBy(df, "year"), "course", c("R", "R")), "earnings")))
+ expect_error(collect(sum(pivot(groupBy(df, "year"), "course", list("R", "R")), "earnings")))
+})
+
test_that("arrange() and orderBy() on a DataFrame", {
df <- read.json(jsonPath)
sorted <- arrange(df, df$age)