aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--R/pkg/R/DataFrame.R60
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R18
2 files changed, 69 insertions, 9 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 48ac1b06f6..a741fdf709 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1431,11 +1431,11 @@ setMethod("withColumn",
#' Mutate
#'
-#' Return a new SparkDataFrame with the specified columns added.
+#' Return a new SparkDataFrame with the specified columns added or replaced.
#'
#' @param .data A SparkDataFrame
#' @param col a named argument of the form name = col
-#' @return A new SparkDataFrame with the new columns added.
+#' @return A new SparkDataFrame with the new columns added or replaced.
#' @family SparkDataFrame functions
#' @rdname mutate
#' @name mutate
@@ -1450,23 +1450,65 @@ setMethod("withColumn",
#' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2)
#' names(newDF) # Will contain newCol, newCol2
#' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2)
+#'
+#' df <- createDataFrame(sqlContext,
+#' list(list("Andy", 30L), list("Justin", 19L)), c("name", "age"))
+#' # Replace the "age" column
+#' df1 <- mutate(df, age = df$age + 1L)
#' }
setMethod("mutate",
signature(.data = "SparkDataFrame"),
function(.data, ...) {
x <- .data
cols <- list(...)
- stopifnot(length(cols) > 0)
- stopifnot(class(cols[[1]]) == "Column")
+ if (length(cols) <= 0) {
+ return(x)
+ }
+
+ lapply(cols, function(col) {
+ stopifnot(class(col) == "Column")
+ })
+
+ # Check if there is any duplicated column name in the DataFrame
+ dfCols <- columns(x)
+ if (length(unique(dfCols)) != length(dfCols)) {
+ stop("Error: found duplicated column name in the DataFrame")
+ }
+
+ # TODO: simplify the implementation of this method after SPARK-12225 is resolved.
+
+ # For named arguments, use the names for arguments as the column names
+ # For unnamed arguments, use the argument symbols as the column names
+ args <- sapply(substitute(list(...))[-1], deparse)
ns <- names(cols)
if (!is.null(ns)) {
- for (n in ns) {
- if (n != "") {
- cols[[n]] <- alias(cols[[n]], n)
+ lapply(seq_along(args), function(i) {
+ if (ns[[i]] != "") {
+ args[[i]] <<- ns[[i]]
}
- }
+ })
+ }
+ ns <- args
+
+ # The last column of the same name in the specific columns takes effect
+ deDupCols <- list()
+ for (i in 1:length(cols)) {
+ deDupCols[[ns[[i]]]] <- alias(cols[[i]], ns[[i]])
}
- do.call(select, c(x, x$"*", cols))
+
+ # Construct the column list for projection
+ colList <- lapply(dfCols, function(col) {
+ if (!is.null(deDupCols[[col]])) {
+ # Replace existing column
+ tmpCol <- deDupCols[[col]]
+ deDupCols[[col]] <<- NULL
+ tmpCol
+ } else {
+ col(col)
+ }
+ })
+
+ do.call(select, c(x, colList, deDupCols))
})
#' @export
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 95d6cb8875..7058265ea3 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1581,6 +1581,24 @@ test_that("mutate(), transform(), rename() and names()", {
expect_equal(columns(newDF)[3], "newAge")
expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32)
+ newDF <- mutate(df, age = df$age + 2, newAge = df$age + 3)
+ expect_equal(length(columns(newDF)), 3)
+ expect_equal(columns(newDF)[3], "newAge")
+ expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 33)
+ expect_equal(first(filter(newDF, df$name != "Michael"))$age, 32)
+
+ newDF <- mutate(df, age = df$age + 2, newAge = df$age + 3,
+ age = df$age + 4, newAge = df$age + 5)
+ expect_equal(length(columns(newDF)), 3)
+ expect_equal(columns(newDF)[3], "newAge")
+ expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 35)
+ expect_equal(first(filter(newDF, df$name != "Michael"))$age, 34)
+
+ newDF <- mutate(df, df$age + 3)
+ expect_equal(length(columns(newDF)), 3)
+ expect_equal(columns(newDF)[[3]], "df$age + 3")
+ expect_equal(first(filter(newDF, df$name != "Michael"))[[3]], 33)
+
newDF2 <- rename(df, newerAge = df$age)
expect_equal(length(columns(newDF2)), 2)
expect_equal(columns(newDF2)[1], "newerAge")