aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
authorSun Rui <rui.sun@intel.com>2016-04-28 09:33:58 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2016-04-28 09:33:58 -0700
commit9e785079b6ed4ea691c3c14c762a7f73fb6254bf (patch)
treea0c4cdb225d176343cb3e83cb8b0c72da3a0e799 /R/pkg
parent23256be0d0846d4eb188a4d1cae6e3f261248153 (diff)
downloadspark-9e785079b6ed4ea691c3c14c762a7f73fb6254bf.tar.gz
spark-9e785079b6ed4ea691c3c14c762a7f73fb6254bf.tar.bz2
spark-9e785079b6ed4ea691c3c14c762a7f73fb6254bf.zip
[SPARK-12235][SPARKR] Enhance mutate() to support replace existing columns.
Make the behavior of mutate more consistent with that in dplyr, besides support for replacing existing columns. 1. Throw error message when there are duplicated column names in the DataFrame being mutated. 2. when there are duplicated column names in specified columns by arguments, the last column of the same name takes effect. Author: Sun Rui <rui.sun@intel.com> Closes #10220 from sun-rui/SPARK-12235.
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/R/DataFrame.R60
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R18
2 files changed, 69 insertions, 9 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 48ac1b06f6..a741fdf709 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1431,11 +1431,11 @@ setMethod("withColumn",
#' Mutate
#'
-#' Return a new SparkDataFrame with the specified columns added.
+#' Return a new SparkDataFrame with the specified columns added or replaced.
#'
#' @param .data A SparkDataFrame
#' @param col a named argument of the form name = col
-#' @return A new SparkDataFrame with the new columns added.
+#' @return A new SparkDataFrame with the new columns added or replaced.
#' @family SparkDataFrame functions
#' @rdname mutate
#' @name mutate
@@ -1450,23 +1450,65 @@ setMethod("withColumn",
#' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2)
#' names(newDF) # Will contain newCol, newCol2
#' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2)
+#'
+#' df <- createDataFrame(sqlContext,
+#' list(list("Andy", 30L), list("Justin", 19L)), c("name", "age"))
+#' # Replace the "age" column
+#' df1 <- mutate(df, age = df$age + 1L)
#' }
setMethod("mutate",
signature(.data = "SparkDataFrame"),
function(.data, ...) {
x <- .data
cols <- list(...)
- stopifnot(length(cols) > 0)
- stopifnot(class(cols[[1]]) == "Column")
+ if (length(cols) <= 0) {
+ return(x)
+ }
+
+ lapply(cols, function(col) {
+ stopifnot(class(col) == "Column")
+ })
+
+ # Check if there is any duplicated column name in the DataFrame
+ dfCols <- columns(x)
+ if (length(unique(dfCols)) != length(dfCols)) {
+ stop("Error: found duplicated column name in the DataFrame")
+ }
+
+ # TODO: simplify the implementation of this method after SPARK-12225 is resolved.
+
+ # For named arguments, use the names for arguments as the column names
+ # For unnamed arguments, use the argument symbols as the column names
+ args <- sapply(substitute(list(...))[-1], deparse)
ns <- names(cols)
if (!is.null(ns)) {
- for (n in ns) {
- if (n != "") {
- cols[[n]] <- alias(cols[[n]], n)
+ lapply(seq_along(args), function(i) {
+ if (ns[[i]] != "") {
+ args[[i]] <<- ns[[i]]
}
- }
+ })
+ }
+ ns <- args
+
+ # The last column of the same name in the specific columns takes effect
+ deDupCols <- list()
+ for (i in 1:length(cols)) {
+ deDupCols[[ns[[i]]]] <- alias(cols[[i]], ns[[i]])
}
- do.call(select, c(x, x$"*", cols))
+
+ # Construct the column list for projection
+ colList <- lapply(dfCols, function(col) {
+ if (!is.null(deDupCols[[col]])) {
+ # Replace existing column
+ tmpCol <- deDupCols[[col]]
+ deDupCols[[col]] <<- NULL
+ tmpCol
+ } else {
+ col(col)
+ }
+ })
+
+ do.call(select, c(x, colList, deDupCols))
})
#' @export
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 95d6cb8875..7058265ea3 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1581,6 +1581,24 @@ test_that("mutate(), transform(), rename() and names()", {
expect_equal(columns(newDF)[3], "newAge")
expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32)
+ newDF <- mutate(df, age = df$age + 2, newAge = df$age + 3)
+ expect_equal(length(columns(newDF)), 3)
+ expect_equal(columns(newDF)[3], "newAge")
+ expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 33)
+ expect_equal(first(filter(newDF, df$name != "Michael"))$age, 32)
+
+ newDF <- mutate(df, age = df$age + 2, newAge = df$age + 3,
+ age = df$age + 4, newAge = df$age + 5)
+ expect_equal(length(columns(newDF)), 3)
+ expect_equal(columns(newDF)[3], "newAge")
+ expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 35)
+ expect_equal(first(filter(newDF, df$name != "Michael"))$age, 34)
+
+ newDF <- mutate(df, df$age + 3)
+ expect_equal(length(columns(newDF)), 3)
+ expect_equal(columns(newDF)[[3]], "df$age + 3")
+ expect_equal(first(filter(newDF, df$name != "Michael"))[[3]], 33)
+
newDF2 <- rename(df, newerAge = df$age)
expect_equal(length(columns(newDF2)), 2)
expect_equal(columns(newDF2)[1], "newerAge")