From d749c06677c2fd383733337f1c00f542da122b8d Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 11 Jan 2017 08:29:09 -0800 Subject: [SPARK-19130][SPARKR] Support setting literal value as column implicitly ## What changes were proposed in this pull request? ``` df$foo <- 1 ``` instead of ``` df$foo <- lit(1) ``` ## How was this patch tested? unit tests Author: Felix Cheung Closes #16510 from felixcheung/rlitcol. --- R/pkg/R/DataFrame.R | 22 +++++++++++++++++----- R/pkg/R/utils.R | 4 ++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 18 ++++++++++++++++++ 3 files changed, 39 insertions(+), 5 deletions(-) (limited to 'R') diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index c56648a8c4..3d912c9fa3 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1727,14 +1727,21 @@ setMethod("$", signature(x = "SparkDataFrame"), getColumn(x, name) }) -#' @param value a Column or \code{NULL}. If \code{NULL}, the specified Column is dropped. +#' @param value a Column or an atomic vector in the length of 1 as literal value, or \code{NULL}. +#' If \code{NULL}, the specified Column is dropped. #' @rdname select #' @name $<- #' @aliases $<-,SparkDataFrame-method #' @note $<- since 1.4.0 setMethod("$<-", signature(x = "SparkDataFrame"), function(x, name, value) { - stopifnot(class(value) == "Column" || is.null(value)) + if (class(value) != "Column" && !is.null(value)) { + if (isAtomicLengthOne(value)) { + value <- lit(value) + } else { + stop("value must be a Column, literal value as atomic in length of 1, or NULL") + } + } if (is.null(value)) { nx <- drop(x, name) @@ -1947,10 +1954,10 @@ setMethod("selectExpr", #' #' @param x a SparkDataFrame. #' @param colName a column name. -#' @param col a Column expression. +#' @param col a Column expression, or an atomic vector in the length of 1 as literal value. #' @return A SparkDataFrame with the new column added or the existing column replaced. #' @family SparkDataFrame functions -#' @aliases withColumn,SparkDataFrame,character,Column-method +#' @aliases withColumn,SparkDataFrame,character-method #' @rdname withColumn #' @name withColumn #' @seealso \link{rename} \link{mutate} @@ -1963,11 +1970,16 @@ setMethod("selectExpr", #' newDF <- withColumn(df, "newCol", df$col1 * 5) #' # Replace an existing column #' newDF2 <- withColumn(newDF, "newCol", newDF$col1) +#' newDF3 <- withColumn(newDF, "newCol", 42) #' } #' @note withColumn since 1.4.0 setMethod("withColumn", - signature(x = "SparkDataFrame", colName = "character", col = "Column"), + signature(x = "SparkDataFrame", colName = "character"), function(x, colName, col) { + if (class(col) != "Column") { + if (!isAtomicLengthOne(col)) stop("Literal value must be atomic in length of 1") + col <- lit(col) + } sdf <- callJMethod(x@sdf, "withColumn", colName, col@jc) dataFrame(sdf) }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 1283449f35..74b3e502eb 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -863,3 +863,7 @@ basenameSansExtFromUrl <- function(url) { # then, strip extension by the last '.' sub("([^.]+)\\.[[:alnum:]]+$", "\\1", filename) } + +isAtomicLengthOne <- function(x) { + is.atomic(x) && length(x) == 1 +} diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index c3f0310c75..3e8b96a513 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1001,6 +1001,17 @@ test_that("select operators", { expect_equal(columns(df), c("name", "age", "age2")) expect_equal(count(where(df, df$age2 == df$age * 2)), 2) + df$age2 <- 21 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == 21)), 3) + + df$age2 <- c(22) + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == 22)), 3) + + expect_error(df$age3 <- c(22, NA), + "value must be a Column, literal value as atomic in length of 1, or NULL") + # Test parameter drop expect_equal(class(df[, 1]) == "SparkDataFrame", T) expect_equal(class(df[, 1, drop = T]) == "Column", T) @@ -1778,6 +1789,13 @@ test_that("withColumn() and withColumnRenamed()", { expect_equal(length(columns(newDF)), 2) expect_equal(first(filter(newDF, df$name != "Michael"))$age, 32) + newDF <- withColumn(df, "age", 18) + expect_equal(length(columns(newDF)), 2) + expect_equal(first(newDF)$age, 18) + + expect_error(withColumn(df, "age", list("a")), + "Literal value must be atomic in length of 1") + newDF2 <- withColumnRenamed(df, "age", "newerAge") expect_equal(length(columns(newDF2)), 2) expect_equal(columns(newDF2)[1], "newerAge") -- cgit v1.2.3