diff options
author | Narine Kokhlikyan <narine.kokhlikyan@gmail.com> | 2016-06-15 21:42:05 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@cs.berkeley.edu> | 2016-06-15 21:42:05 -0700 |
commit | 7c6c6926376c93acc42dd56a399d816f4838f28c (patch) | |
tree | bbf8f9dc1d7a044b890b6c95fdd3a17aa76fea89 /R/pkg/inst/tests | |
parent | b75f454f946714b93fe561055cd53b0686187d2e (diff) | |
download | spark-7c6c6926376c93acc42dd56a399d816f4838f28c.tar.gz spark-7c6c6926376c93acc42dd56a399d816f4838f28c.tar.bz2 spark-7c6c6926376c93acc42dd56a399d816f4838f28c.zip |
[SPARK-12922][SPARKR][WIP] Implement gapply() on DataFrame in SparkR
## What changes were proposed in this pull request?
gapply() applies an R function on groups grouped by one or more columns of a DataFrame, and returns a DataFrame. It is like GroupedDataSet.flatMapGroups() in the Dataset API.
Please, let me know what do you think and if you have any ideas to improve it.
Thank you!
## How was this patch tested?
Unit tests.
1. Primitive test with different column types
2. Add a boolean column
3. Compute average by a group
Author: Narine Kokhlikyan <narine.kokhlikyan@gmail.com>
Author: NarineK <narine.kokhlikyan@us.ibm.com>
Closes #12836 from NarineK/gapply2.
Diffstat (limited to 'R/pkg/inst/tests')
-rw-r--r-- | R/pkg/inst/tests/testthat/test_sparkSQL.R | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index d1ca3b726f..c11930ada6 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2146,6 +2146,71 @@ test_that("repartition by columns on DataFrame", { expect_equal(nrow(df1), 2) }) +test_that("gapply() on a DataFrame", { + df <- createDataFrame ( + list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), + c("a", "b", "c", "d")) + expected <- collect(df) + df1 <- gapply(df, list("a"), function(key, x) { x }, schema(df)) + actual <- collect(df1) + expect_identical(actual, expected) + + # Computes the sum of second column by grouping on the first and third columns + # and checks if the sum is larger than 2 + schema <- structType(structField("a", "integer"), structField("e", "boolean")) + df2 <- gapply( + df, + list(df$"a", df$"c"), + function(key, x) { + y <- data.frame(key[1], sum(x$b) > 2) + }, + schema) + actual <- collect(df2)$e + expected <- c(TRUE, TRUE) + expect_identical(actual, expected) + + # Computes the arithmetic mean of the second column by grouping + # on the first and third columns. Output the groupping value and the average. + schema <- structType(structField("a", "integer"), structField("c", "string"), + structField("avg", "double")) + df3 <- gapply( + df, + list("a", "c"), + function(key, x) { + y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) + }, + schema) + actual <- collect(df3) + actual <- actual[order(actual$a), ] + rownames(actual) <- NULL + expected <- collect(select(df, "a", "b", "c")) + expected <- data.frame(aggregate(expected$b, by = list(expected$a, expected$c), FUN = mean)) + colnames(expected) <- c("a", "c", "avg") + expected <- expected[order(expected$a), ] + rownames(expected) <- NULL + expect_identical(actual, expected) + + irisDF <- suppressWarnings(createDataFrame (iris)) + schema <- structType(structField("Sepal_Length", "double"), structField("Avg", "double")) + # Groups by `Sepal_Length` and computes the average for `Sepal_Width` + df4 <- gapply( + cols = list("Sepal_Length"), + irisDF, + function(key, x) { + y <- data.frame(key, mean(x$Sepal_Width), stringsAsFactors = FALSE) + }, + schema) + actual <- collect(df4) + actual <- actual[order(actual$Sepal_Length), ] + rownames(actual) <- NULL + agg_local_df <- data.frame(aggregate(iris$Sepal.Width, by = list(iris$Sepal.Length), FUN = mean), + stringsAsFactors = FALSE) + colnames(agg_local_df) <- c("Sepal_Length", "Avg") + expected <- agg_local_df[order(agg_local_df$Sepal_Length), ] + rownames(expected) <- NULL + expect_identical(actual, expected) +}) + test_that("Window functions on a DataFrame", { setHiveContext(sc) df <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")), |