aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorFelix Cheung <felixcheung_m@hotmail.com>2017-01-11 08:29:09 -0800
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2017-01-11 08:29:09 -0800
commitd749c06677c2fd383733337f1c00f542da122b8d (patch)
tree775d6750789d1a76988bca912de0b1bca8bb7ef4 /R
parent4239a1081ad96a503fbf9277e42b97422bb8af3e (diff)
downloadspark-d749c06677c2fd383733337f1c00f542da122b8d.tar.gz
spark-d749c06677c2fd383733337f1c00f542da122b8d.tar.bz2
spark-d749c06677c2fd383733337f1c00f542da122b8d.zip
[SPARK-19130][SPARKR] Support setting literal value as column implicitly
## What changes were proposed in this pull request? ``` df$foo <- 1 ``` instead of ``` df$foo <- lit(1) ``` ## How was this patch tested? unit tests Author: Felix Cheung <felixcheung_m@hotmail.com> Closes #16510 from felixcheung/rlitcol.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/R/DataFrame.R22
-rw-r--r--R/pkg/R/utils.R4
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R18
3 files changed, 39 insertions, 5 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index c56648a8c4..3d912c9fa3 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1727,14 +1727,21 @@ setMethod("$", signature(x = "SparkDataFrame"),
getColumn(x, name)
})
-#' @param value a Column or \code{NULL}. If \code{NULL}, the specified Column is dropped.
+#' @param value a Column or an atomic vector in the length of 1 as literal value, or \code{NULL}.
+#' If \code{NULL}, the specified Column is dropped.
#' @rdname select
#' @name $<-
#' @aliases $<-,SparkDataFrame-method
#' @note $<- since 1.4.0
setMethod("$<-", signature(x = "SparkDataFrame"),
function(x, name, value) {
- stopifnot(class(value) == "Column" || is.null(value))
+ if (class(value) != "Column" && !is.null(value)) {
+ if (isAtomicLengthOne(value)) {
+ value <- lit(value)
+ } else {
+ stop("value must be a Column, literal value as atomic in length of 1, or NULL")
+ }
+ }
if (is.null(value)) {
nx <- drop(x, name)
@@ -1947,10 +1954,10 @@ setMethod("selectExpr",
#'
#' @param x a SparkDataFrame.
#' @param colName a column name.
-#' @param col a Column expression.
+#' @param col a Column expression, or an atomic vector in the length of 1 as literal value.
#' @return A SparkDataFrame with the new column added or the existing column replaced.
#' @family SparkDataFrame functions
-#' @aliases withColumn,SparkDataFrame,character,Column-method
+#' @aliases withColumn,SparkDataFrame,character-method
#' @rdname withColumn
#' @name withColumn
#' @seealso \link{rename} \link{mutate}
@@ -1963,11 +1970,16 @@ setMethod("selectExpr",
#' newDF <- withColumn(df, "newCol", df$col1 * 5)
#' # Replace an existing column
#' newDF2 <- withColumn(newDF, "newCol", newDF$col1)
+#' newDF3 <- withColumn(newDF, "newCol", 42)
#' }
#' @note withColumn since 1.4.0
setMethod("withColumn",
- signature(x = "SparkDataFrame", colName = "character", col = "Column"),
+ signature(x = "SparkDataFrame", colName = "character"),
function(x, colName, col) {
+ if (class(col) != "Column") {
+ if (!isAtomicLengthOne(col)) stop("Literal value must be atomic in length of 1")
+ col <- lit(col)
+ }
sdf <- callJMethod(x@sdf, "withColumn", colName, col@jc)
dataFrame(sdf)
})
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 1283449f35..74b3e502eb 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -863,3 +863,7 @@ basenameSansExtFromUrl <- function(url) {
# then, strip extension by the last '.'
sub("([^.]+)\\.[[:alnum:]]+$", "\\1", filename)
}
+
+isAtomicLengthOne <- function(x) {
+ is.atomic(x) && length(x) == 1
+}
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index c3f0310c75..3e8b96a513 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1001,6 +1001,17 @@ test_that("select operators", {
expect_equal(columns(df), c("name", "age", "age2"))
expect_equal(count(where(df, df$age2 == df$age * 2)), 2)
+ df$age2 <- 21
+ expect_equal(columns(df), c("name", "age", "age2"))
+ expect_equal(count(where(df, df$age2 == 21)), 3)
+
+ df$age2 <- c(22)
+ expect_equal(columns(df), c("name", "age", "age2"))
+ expect_equal(count(where(df, df$age2 == 22)), 3)
+
+ expect_error(df$age3 <- c(22, NA),
+ "value must be a Column, literal value as atomic in length of 1, or NULL")
+
# Test parameter drop
expect_equal(class(df[, 1]) == "SparkDataFrame", T)
expect_equal(class(df[, 1, drop = T]) == "Column", T)
@@ -1778,6 +1789,13 @@ test_that("withColumn() and withColumnRenamed()", {
expect_equal(length(columns(newDF)), 2)
expect_equal(first(filter(newDF, df$name != "Michael"))$age, 32)
+ newDF <- withColumn(df, "age", 18)
+ expect_equal(length(columns(newDF)), 2)
+ expect_equal(first(newDF)$age, 18)
+
+ expect_error(withColumn(df, "age", list("a")),
+ "Literal value must be atomic in length of 1")
+
newDF2 <- withColumnRenamed(df, "age", "newerAge")
expect_equal(length(columns(newDF2)), 2)
expect_equal(columns(newDF2)[1], "newerAge")