aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-05-08 18:29:57 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-05-08 18:29:57 -0700
commit0a901dd3a1eb3fd459d45b771ce4ad2cfef2a944 (patch)
treec7b2479550c1ebadca8f8b5f4caf6f63db0f57e1 /R/pkg
parentb6c797b08cbd08d7aab59ad0106af0f5f41ef186 (diff)
downloadspark-0a901dd3a1eb3fd459d45b771ce4ad2cfef2a944.tar.gz
spark-0a901dd3a1eb3fd459d45b771ce4ad2cfef2a944.tar.bz2
spark-0a901dd3a1eb3fd459d45b771ce4ad2cfef2a944.zip
[SPARK-7231] [SPARKR] Changes to make SparkR DataFrame dplyr friendly.
Changes include 1. Rename sortDF to arrange 2. Add new aliases `group_by` and `sample_frac`, `summarize` 3. Add more user friendly column addition (mutate), rename 4. Support mean as an alias for avg in Scala and also support n_distinct, n as in dplyr Using these changes we can pretty much run the examples as described in http://cran.rstudio.com/web/packages/dplyr/vignettes/introduction.html with the same syntax The only thing missing in SparkR is auto resolving column names when used in an expression i.e. making something like `select(flights, delay)` works in dply but we right now need `select(flights, flights$delay)` or `select(flights, "delay")`. But this is a complicated change and I'll file a new issue for it cc sun-rui rxin Author: Shivaram Venkataraman <shivaram@cs.berkeley.edu> Closes #6005 from shivaram/sparkr-df-api and squashes the following commits: 5e0716a [Shivaram Venkataraman] Fix some roxygen bugs 1254953 [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into sparkr-df-api 0521149 [Shivaram Venkataraman] Changes to make SparkR DataFrame dplyr friendly. Changes include 1. Rename sortDF to arrange 2. Add new aliases `group_by` and `sample_frac`, `summarize` 3. Add more user friendly column addition (mutate), rename 4. Support mean as an alias for avg in Scala and also support n_distinct, n as in dplyr
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/NAMESPACE11
-rw-r--r--R/pkg/R/DataFrame.R127
-rw-r--r--R/pkg/R/column.R32
-rw-r--r--R/pkg/R/generics.R41
-rw-r--r--R/pkg/R/group.R10
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R36
6 files changed, 228 insertions, 29 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 7611f479a6..819e9a24e5 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -9,7 +9,8 @@ export("print.jobj")
exportClasses("DataFrame")
-exportMethods("cache",
+exportMethods("arrange",
+ "cache",
"collect",
"columns",
"count",
@@ -20,6 +21,7 @@ exportMethods("cache",
"explain",
"filter",
"first",
+ "group_by",
"groupBy",
"head",
"insertInto",
@@ -28,12 +30,15 @@ exportMethods("cache",
"join",
"limit",
"orderBy",
+ "mutate",
"names",
"persist",
"printSchema",
"registerTempTable",
+ "rename",
"repartition",
"sampleDF",
+ "sample_frac",
"saveAsParquetFile",
"saveAsTable",
"saveDF",
@@ -42,7 +47,7 @@ exportMethods("cache",
"selectExpr",
"show",
"showDF",
- "sortDF",
+ "summarize",
"take",
"unionAll",
"unpersist",
@@ -72,6 +77,8 @@ exportMethods("abs",
"max",
"mean",
"min",
+ "n",
+ "n_distinct",
"rlike",
"sqrt",
"startsWith",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 354642e7bc..8a9d2dd45c 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -480,6 +480,7 @@ setMethod("distinct",
#' @param withReplacement Sampling with replacement or not
#' @param fraction The (rough) sample target fraction
#' @rdname sampleDF
+#' @aliases sample_frac
#' @export
#' @examples
#'\dontrun{
@@ -501,6 +502,15 @@ setMethod("sampleDF",
dataFrame(sdf)
})
+#' @rdname sampleDF
+#' @aliases sampleDF
+setMethod("sample_frac",
+ signature(x = "DataFrame", withReplacement = "logical",
+ fraction = "numeric"),
+ function(x, withReplacement, fraction) {
+ sampleDF(x, withReplacement, fraction)
+ })
+
#' Count
#'
#' Returns the number of rows in a DataFrame
@@ -682,7 +692,8 @@ setMethod("toRDD",
#' @param x a DataFrame
#' @return a GroupedData
#' @seealso GroupedData
-#' @rdname DataFrame
+#' @aliases group_by
+#' @rdname groupBy
#' @export
#' @examples
#' \dontrun{
@@ -705,12 +716,21 @@ setMethod("groupBy",
groupedData(sgd)
})
-#' Agg
+#' @rdname groupBy
+#' @aliases group_by
+setMethod("group_by",
+ signature(x = "DataFrame"),
+ function(x, ...) {
+ groupBy(x, ...)
+ })
+
+#' Summarize data across columns
#'
#' Compute aggregates by specifying a list of columns
#'
#' @param x a DataFrame
#' @rdname DataFrame
+#' @aliases summarize
#' @export
setMethod("agg",
signature(x = "DataFrame"),
@@ -718,6 +738,14 @@ setMethod("agg",
agg(groupBy(x), ...)
})
+#' @rdname DataFrame
+#' @aliases agg
+setMethod("summarize",
+ signature(x = "DataFrame"),
+ function(x, ...) {
+ agg(x, ...)
+ })
+
############################## RDD Map Functions ##################################
# All of the following functions mirror the existing RDD map functions, #
@@ -886,7 +914,7 @@ setMethod("select",
signature(x = "DataFrame", col = "list"),
function(x, col) {
cols <- lapply(col, function(c) {
- if (class(c)== "Column") {
+ if (class(c) == "Column") {
c@jc
} else {
col(c)@jc
@@ -946,6 +974,42 @@ setMethod("withColumn",
select(x, x$"*", alias(col, colName))
})
+#' Mutate
+#'
+#' Return a new DataFrame with the specified columns added.
+#'
+#' @param x A DataFrame
+#' @param col a named argument of the form name = col
+#' @return A new DataFrame with the new columns added.
+#' @rdname withColumn
+#' @aliases withColumn
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2)
+#' names(newDF) # Will contain newCol, newCol2
+#' }
+setMethod("mutate",
+ signature(x = "DataFrame"),
+ function(x, ...) {
+ cols <- list(...)
+ stopifnot(length(cols) > 0)
+ stopifnot(class(cols[[1]]) == "Column")
+ ns <- names(cols)
+ if (!is.null(ns)) {
+ for (n in ns) {
+ if (n != "") {
+ cols[[n]] <- alias(cols[[n]], n)
+ }
+ }
+ }
+ do.call(select, c(x, x$"*", cols))
+ })
+
#' WithColumnRenamed
#'
#' Rename an existing column in a DataFrame.
@@ -977,9 +1041,47 @@ setMethod("withColumnRenamed",
select(x, cols)
})
+#' Rename
+#'
+#' Rename an existing column in a DataFrame.
+#'
+#' @param x A DataFrame
+#' @param newCol A named pair of the form new_column_name = existing_column
+#' @return A DataFrame with the column name changed.
+#' @rdname withColumnRenamed
+#' @aliases withColumnRenamed
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' newDF <- rename(df, col1 = df$newCol1)
+#' }
+setMethod("rename",
+ signature(x = "DataFrame"),
+ function(x, ...) {
+ renameCols <- list(...)
+ stopifnot(length(renameCols) > 0)
+ stopifnot(class(renameCols[[1]]) == "Column")
+ newNames <- names(renameCols)
+ oldNames <- lapply(renameCols, function(col) {
+ callJMethod(col@jc, "toString")
+ })
+ cols <- lapply(columns(x), function(c) {
+ if (c %in% oldNames) {
+ alias(col(c), newNames[[match(c, oldNames)]])
+ } else {
+ col(c)
+ }
+ })
+ select(x, cols)
+ })
+
setClassUnion("characterOrColumn", c("character", "Column"))
-#' SortDF
+#' Arrange
#'
#' Sort a DataFrame by the specified column(s).
#'
@@ -987,7 +1089,7 @@ setClassUnion("characterOrColumn", c("character", "Column"))
#' @param col Either a Column object or character vector indicating the field to sort on
#' @param ... Additional sorting fields
#' @return A DataFrame where all elements are sorted.
-#' @rdname sortDF
+#' @rdname arrange
#' @export
#' @examples
#'\dontrun{
@@ -995,11 +1097,11 @@ setClassUnion("characterOrColumn", c("character", "Column"))
#' sqlCtx <- sparkRSQL.init(sc)
#' path <- "path/to/file.json"
#' df <- jsonFile(sqlCtx, path)
-#' sortDF(df, df$col1)
-#' sortDF(df, "col1")
-#' sortDF(df, asc(df$col1), desc(abs(df$col2)))
+#' arrange(df, df$col1)
+#' arrange(df, "col1")
+#' arrange(df, asc(df$col1), desc(abs(df$col2)))
#' }
-setMethod("sortDF",
+setMethod("arrange",
signature(x = "DataFrame", col = "characterOrColumn"),
function(x, col, ...) {
if (class(col) == "character") {
@@ -1013,12 +1115,12 @@ setMethod("sortDF",
dataFrame(sdf)
})
-#' @rdname sortDF
+#' @rdname arrange
#' @aliases orderBy,DataFrame,function-method
setMethod("orderBy",
signature(x = "DataFrame", col = "characterOrColumn"),
function(x, col) {
- sortDF(x, col)
+ arrange(x, col)
})
#' Filter
@@ -1026,7 +1128,7 @@ setMethod("orderBy",
#' Filter the rows of a DataFrame according to a given condition.
#'
#' @param x A DataFrame to be sorted.
-#' @param condition The condition to sort on. This may either be a Column expression
+#' @param condition The condition to filter on. This may either be a Column expression
#' or a string containing a SQL statement
#' @return A DataFrame containing only the rows that meet the condition.
#' @rdname filter
@@ -1106,6 +1208,7 @@ setMethod("join",
#'
#' Return a new DataFrame containing the union of rows in this DataFrame
#' and another DataFrame. This is equivalent to `UNION ALL` in SQL.
+#' Note that this does not remove duplicate rows across the two DataFrames.
#'
#' @param x A Spark DataFrame
#' @param y A Spark DataFrame
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index 95fb9ff088..9a68445ab4 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -131,6 +131,8 @@ createMethods()
#' alias
#'
#' Set a new name for a column
+
+#' @rdname column
setMethod("alias",
signature(object = "Column"),
function(object, data) {
@@ -141,8 +143,12 @@ setMethod("alias",
}
})
+#' substr
+#'
#' An expression that returns a substring.
#'
+#' @rdname column
+#'
#' @param start starting position
#' @param stop ending position
setMethod("substr", signature(x = "Column"),
@@ -152,6 +158,9 @@ setMethod("substr", signature(x = "Column"),
})
#' Casts the column to a different data type.
+#'
+#' @rdname column
+#'
#' @examples
#' \dontrun{
#' cast(df$age, "string")
@@ -173,8 +182,8 @@ setMethod("cast",
#' Approx Count Distinct
#'
-#' Returns the approximate number of distinct items in a group.
-#'
+#' @rdname column
+#' @return the approximate number of distinct items in a group.
setMethod("approxCountDistinct",
signature(x = "Column"),
function(x, rsd = 0.95) {
@@ -184,8 +193,8 @@ setMethod("approxCountDistinct",
#' Count Distinct
#'
-#' returns the number of distinct items in a group.
-#'
+#' @rdname column
+#' @return the number of distinct items in a group.
setMethod("countDistinct",
signature(x = "Column"),
function(x, ...) {
@@ -197,3 +206,18 @@ setMethod("countDistinct",
column(jc)
})
+#' @rdname column
+#' @aliases countDistinct
+setMethod("n_distinct",
+ signature(x = "Column"),
+ function(x, ...) {
+ countDistinct(x, ...)
+ })
+
+#' @rdname column
+#' @aliases count
+setMethod("n",
+ signature(x = "Column"),
+ function(x) {
+ count(x)
+ })
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 380e8ebe8c..557128a419 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -380,6 +380,14 @@ setGeneric("value", function(bcast) { standardGeneric("value") })
#################### DataFrame Methods ########################
+#' @rdname agg
+#' @export
+setGeneric("agg", function (x, ...) { standardGeneric("agg") })
+
+#' @rdname arrange
+#' @export
+setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") })
+
#' @rdname schema
#' @export
setGeneric("columns", function(x) {standardGeneric("columns") })
@@ -404,6 +412,10 @@ setGeneric("except", function(x, y) { standardGeneric("except") })
#' @export
setGeneric("filter", function(x, condition) { standardGeneric("filter") })
+#' @rdname groupBy
+#' @export
+setGeneric("group_by", function(x, ...) { standardGeneric("group_by") })
+
#' @rdname DataFrame
#' @export
setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") })
@@ -424,7 +436,11 @@ setGeneric("isLocal", function(x) { standardGeneric("isLocal") })
#' @export
setGeneric("limit", function(x, num) {standardGeneric("limit") })
-#' @rdname sortDF
+#' @rdname withColumn
+#' @export
+setGeneric("mutate", function(x, ...) {standardGeneric("mutate") })
+
+#' @rdname arrange
#' @export
setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") })
@@ -432,12 +448,23 @@ setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") })
#' @export
setGeneric("printSchema", function(x) { standardGeneric("printSchema") })
+#' @rdname withColumnRenamed
+#' @export
+setGeneric("rename", function(x, ...) { standardGeneric("rename") })
+
#' @rdname registerTempTable
#' @export
setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") })
#' @rdname sampleDF
#' @export
+setGeneric("sample_frac",
+ function(x, withReplacement, fraction, seed) {
+ standardGeneric("sample_frac")
+ })
+
+#' @rdname sampleDF
+#' @export
setGeneric("sampleDF",
function(x, withReplacement, fraction, seed) {
standardGeneric("sampleDF")
@@ -473,9 +500,9 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr")
#' @export
setGeneric("showDF", function(x,...) { standardGeneric("showDF") })
-#' @rdname sortDF
+#' @rdname agg
#' @export
-setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") })
+setGeneric("summarize", function(x,...) { standardGeneric("summarize") })
# @rdname tojson
# @export
@@ -566,6 +593,14 @@ setGeneric("lower", function(x) { standardGeneric("lower") })
#' @rdname column
#' @export
+setGeneric("n", function(x) { standardGeneric("n") })
+
+#' @rdname column
+#' @export
+setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") })
+
+#' @rdname column
+#' @export
setGeneric("rlike", function(x, ...) { standardGeneric("rlike") })
#' @rdname column
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 02237b3672..5a7a8a2cab 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -56,6 +56,7 @@ setMethod("show", "GroupedData",
#'
#' @param x a GroupedData
#' @return a DataFrame
+#' @rdname agg
#' @export
#' @examples
#' \dontrun{
@@ -83,8 +84,6 @@ setMethod("count",
#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)'
#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum
#' }
-setGeneric("agg", function (x, ...) { standardGeneric("agg") })
-
setMethod("agg",
signature(x = "GroupedData"),
function(x, ...) {
@@ -112,6 +111,13 @@ setMethod("agg",
dataFrame(sdf)
})
+#' @rdname agg
+#' @aliases agg
+setMethod("summarize",
+ signature(x = "GroupedData"),
+ function(x, ...) {
+ agg(x, ...)
+ })
# sum/mean/avg/min/max
methods <- c("sum", "mean", "avg", "min", "max")
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 7a42e289fc..dbb535e245 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -428,6 +428,10 @@ test_that("sampleDF on a DataFrame", {
expect_true(inherits(sampled, "DataFrame"))
sampled2 <- sampleDF(df, FALSE, 0.1)
expect_true(count(sampled2) < 3)
+
+ # Also test sample_frac
+ sampled3 <- sample_frac(df, FALSE, 0.1)
+ expect_true(count(sampled3) < 3)
})
test_that("select operators", {
@@ -533,6 +537,7 @@ test_that("column functions", {
c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c)
c3 <- lower(c) + upper(c) + first(c) + last(c)
c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string")
+ c5 <- n(c) + n_distinct(c)
})
test_that("string operators", {
@@ -557,6 +562,13 @@ test_that("group by", {
expect_true(inherits(df2, "DataFrame"))
expect_true(3 == count(df2))
+ # Also test group_by, summarize, mean
+ gd1 <- group_by(df, "name")
+ expect_true(inherits(gd1, "GroupedData"))
+ df_summarized <- summarize(gd, mean_age = mean(df$age))
+ expect_true(inherits(df_summarized, "DataFrame"))
+ expect_true(3 == count(df_summarized))
+
df3 <- agg(gd, age = "sum")
expect_true(inherits(df3, "DataFrame"))
expect_true(3 == count(df3))
@@ -573,12 +585,12 @@ test_that("group by", {
expect_true(3 == count(max(gd, "age")))
})
-test_that("sortDF() and orderBy() on a DataFrame", {
+test_that("arrange() and orderBy() on a DataFrame", {
df <- jsonFile(sqlCtx, jsonPath)
- sorted <- sortDF(df, df$age)
+ sorted <- arrange(df, df$age)
expect_true(collect(sorted)[1,2] == "Michael")
- sorted2 <- sortDF(df, "name")
+ sorted2 <- arrange(df, "name")
expect_true(collect(sorted2)[2,"age"] == 19)
sorted3 <- orderBy(df, asc(df$age))
@@ -659,17 +671,17 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", {
writeLines(lines, jsonPath2)
df2 <- loadDF(sqlCtx, jsonPath2, "json")
- unioned <- sortDF(unionAll(df, df2), df$age)
+ unioned <- arrange(unionAll(df, df2), df$age)
expect_true(inherits(unioned, "DataFrame"))
expect_true(count(unioned) == 6)
expect_true(first(unioned)$name == "Michael")
- excepted <- sortDF(except(df, df2), desc(df$age))
+ excepted <- arrange(except(df, df2), desc(df$age))
expect_true(inherits(unioned, "DataFrame"))
expect_true(count(excepted) == 2)
expect_true(first(excepted)$name == "Justin")
- intersected <- sortDF(intersect(df, df2), df$age)
+ intersected <- arrange(intersect(df, df2), df$age)
expect_true(inherits(unioned, "DataFrame"))
expect_true(count(intersected) == 1)
expect_true(first(intersected)$name == "Andy")
@@ -687,6 +699,18 @@ test_that("withColumn() and withColumnRenamed()", {
expect_true(columns(newDF2)[1] == "newerAge")
})
+test_that("mutate() and rename()", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ newDF <- mutate(df, newAge = df$age + 2)
+ expect_true(length(columns(newDF)) == 3)
+ expect_true(columns(newDF)[3] == "newAge")
+ expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32)
+
+ newDF2 <- rename(df, newerAge = df$age)
+ expect_true(length(columns(newDF2)) == 2)
+ expect_true(columns(newDF2)[1] == "newerAge")
+})
+
test_that("saveDF() on DataFrame and works with parquetFile", {
df <- jsonFile(sqlCtx, jsonPath)
saveDF(df, parquetPath, "parquet", mode="overwrite")