diff options
Diffstat (limited to 'R')
-rw-r--r-- | R/pkg/R/DataFrame.R | 60 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_sparkSQL.R | 18 |
2 files changed, 69 insertions, 9 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 48ac1b06f6..a741fdf709 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1431,11 +1431,11 @@ setMethod("withColumn", #' Mutate #' -#' Return a new SparkDataFrame with the specified columns added. +#' Return a new SparkDataFrame with the specified columns added or replaced. #' #' @param .data A SparkDataFrame #' @param col a named argument of the form name = col -#' @return A new SparkDataFrame with the new columns added. +#' @return A new SparkDataFrame with the new columns added or replaced. #' @family SparkDataFrame functions #' @rdname mutate #' @name mutate @@ -1450,23 +1450,65 @@ setMethod("withColumn", #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 #' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2) +#' +#' df <- createDataFrame(sqlContext, +#' list(list("Andy", 30L), list("Justin", 19L)), c("name", "age")) +#' # Replace the "age" column +#' df1 <- mutate(df, age = df$age + 1L) #' } setMethod("mutate", signature(.data = "SparkDataFrame"), function(.data, ...) { x <- .data cols <- list(...) - stopifnot(length(cols) > 0) - stopifnot(class(cols[[1]]) == "Column") + if (length(cols) <= 0) { + return(x) + } + + lapply(cols, function(col) { + stopifnot(class(col) == "Column") + }) + + # Check if there is any duplicated column name in the DataFrame + dfCols <- columns(x) + if (length(unique(dfCols)) != length(dfCols)) { + stop("Error: found duplicated column name in the DataFrame") + } + + # TODO: simplify the implementation of this method after SPARK-12225 is resolved. + + # For named arguments, use the names for arguments as the column names + # For unnamed arguments, use the argument symbols as the column names + args <- sapply(substitute(list(...))[-1], deparse) ns <- names(cols) if (!is.null(ns)) { - for (n in ns) { - if (n != "") { - cols[[n]] <- alias(cols[[n]], n) + lapply(seq_along(args), function(i) { + if (ns[[i]] != "") { + args[[i]] <<- ns[[i]] } - } + }) + } + ns <- args + + # The last column of the same name in the specific columns takes effect + deDupCols <- list() + for (i in 1:length(cols)) { + deDupCols[[ns[[i]]]] <- alias(cols[[i]], ns[[i]]) } - do.call(select, c(x, x$"*", cols)) + + # Construct the column list for projection + colList <- lapply(dfCols, function(col) { + if (!is.null(deDupCols[[col]])) { + # Replace existing column + tmpCol <- deDupCols[[col]] + deDupCols[[col]] <<- NULL + tmpCol + } else { + col(col) + } + }) + + do.call(select, c(x, colList, deDupCols)) }) #' @export diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 95d6cb8875..7058265ea3 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1581,6 +1581,24 @@ test_that("mutate(), transform(), rename() and names()", { expect_equal(columns(newDF)[3], "newAge") expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) + newDF <- mutate(df, age = df$age + 2, newAge = df$age + 3) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 33) + expect_equal(first(filter(newDF, df$name != "Michael"))$age, 32) + + newDF <- mutate(df, age = df$age + 2, newAge = df$age + 3, + age = df$age + 4, newAge = df$age + 5) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 35) + expect_equal(first(filter(newDF, df$name != "Michael"))$age, 34) + + newDF <- mutate(df, df$age + 3) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[[3]], "df$age + 3") + expect_equal(first(filter(newDF, df$name != "Michael"))[[3]], 33) + newDF2 <- rename(df, newerAge = df$age) expect_equal(length(columns(newDF2)), 2) expect_equal(columns(newDF2)[1], "newerAge") |