diff options
author | Sun Rui <rui.sun@intel.com> | 2016-04-28 09:33:58 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@cs.berkeley.edu> | 2016-04-28 09:33:58 -0700 |
commit | 9e785079b6ed4ea691c3c14c762a7f73fb6254bf (patch) | |
tree | a0c4cdb225d176343cb3e83cb8b0c72da3a0e799 /R/pkg | |
parent | 23256be0d0846d4eb188a4d1cae6e3f261248153 (diff) | |
download | spark-9e785079b6ed4ea691c3c14c762a7f73fb6254bf.tar.gz spark-9e785079b6ed4ea691c3c14c762a7f73fb6254bf.tar.bz2 spark-9e785079b6ed4ea691c3c14c762a7f73fb6254bf.zip |
[SPARK-12235][SPARKR] Enhance mutate() to support replace existing columns.
Make the behavior of mutate more consistent with that in dplyr, besides support for replacing existing columns.
1. Throw error message when there are duplicated column names in the DataFrame being mutated.
2. when there are duplicated column names in specified columns by arguments, the last column of the same name takes effect.
Author: Sun Rui <rui.sun@intel.com>
Closes #10220 from sun-rui/SPARK-12235.
Diffstat (limited to 'R/pkg')
-rw-r--r-- | R/pkg/R/DataFrame.R | 60 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_sparkSQL.R | 18 |
2 files changed, 69 insertions, 9 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 48ac1b06f6..a741fdf709 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1431,11 +1431,11 @@ setMethod("withColumn", #' Mutate #' -#' Return a new SparkDataFrame with the specified columns added. +#' Return a new SparkDataFrame with the specified columns added or replaced. #' #' @param .data A SparkDataFrame #' @param col a named argument of the form name = col -#' @return A new SparkDataFrame with the new columns added. +#' @return A new SparkDataFrame with the new columns added or replaced. #' @family SparkDataFrame functions #' @rdname mutate #' @name mutate @@ -1450,23 +1450,65 @@ setMethod("withColumn", #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 #' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2) +#' +#' df <- createDataFrame(sqlContext, +#' list(list("Andy", 30L), list("Justin", 19L)), c("name", "age")) +#' # Replace the "age" column +#' df1 <- mutate(df, age = df$age + 1L) #' } setMethod("mutate", signature(.data = "SparkDataFrame"), function(.data, ...) { x <- .data cols <- list(...) - stopifnot(length(cols) > 0) - stopifnot(class(cols[[1]]) == "Column") + if (length(cols) <= 0) { + return(x) + } + + lapply(cols, function(col) { + stopifnot(class(col) == "Column") + }) + + # Check if there is any duplicated column name in the DataFrame + dfCols <- columns(x) + if (length(unique(dfCols)) != length(dfCols)) { + stop("Error: found duplicated column name in the DataFrame") + } + + # TODO: simplify the implementation of this method after SPARK-12225 is resolved. + + # For named arguments, use the names for arguments as the column names + # For unnamed arguments, use the argument symbols as the column names + args <- sapply(substitute(list(...))[-1], deparse) ns <- names(cols) if (!is.null(ns)) { - for (n in ns) { - if (n != "") { - cols[[n]] <- alias(cols[[n]], n) + lapply(seq_along(args), function(i) { + if (ns[[i]] != "") { + args[[i]] <<- ns[[i]] } - } + }) + } + ns <- args + + # The last column of the same name in the specific columns takes effect + deDupCols <- list() + for (i in 1:length(cols)) { + deDupCols[[ns[[i]]]] <- alias(cols[[i]], ns[[i]]) } - do.call(select, c(x, x$"*", cols)) + + # Construct the column list for projection + colList <- lapply(dfCols, function(col) { + if (!is.null(deDupCols[[col]])) { + # Replace existing column + tmpCol <- deDupCols[[col]] + deDupCols[[col]] <<- NULL + tmpCol + } else { + col(col) + } + }) + + do.call(select, c(x, colList, deDupCols)) }) #' @export diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 95d6cb8875..7058265ea3 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1581,6 +1581,24 @@ test_that("mutate(), transform(), rename() and names()", { expect_equal(columns(newDF)[3], "newAge") expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) + newDF <- mutate(df, age = df$age + 2, newAge = df$age + 3) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 33) + expect_equal(first(filter(newDF, df$name != "Michael"))$age, 32) + + newDF <- mutate(df, age = df$age + 2, newAge = df$age + 3, + age = df$age + 4, newAge = df$age + 5) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 35) + expect_equal(first(filter(newDF, df$name != "Michael"))$age, 34) + + newDF <- mutate(df, df$age + 3) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[[3]], "df$age + 3") + expect_equal(first(filter(newDF, df$name != "Michael"))[[3]], 33) + newDF2 <- rename(df, newerAge = df$age) expect_equal(length(columns(newDF2)), 2) expect_equal(columns(newDF2)[1], "newerAge") |