aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authoradrian555 <wzhuang@us.ibm.com>2015-11-05 14:47:38 -0800
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-11-05 14:47:38 -0800
commitb9455d1f1810e1e3f472014f665ad3ad3122bcc0 (patch)
treed82b4ce2cbbe3833afcf77519a3470c5a5d98919 /R
parent8a5314efd19fb8f8a194a373fd994b954cc1fd47 (diff)
downloadspark-b9455d1f1810e1e3f472014f665ad3ad3122bcc0.tar.gz
spark-b9455d1f1810e1e3f472014f665ad3ad3122bcc0.tar.bz2
spark-b9455d1f1810e1e3f472014f665ad3ad3122bcc0.zip
[SPARK-11260][SPARKR] with() function support
Author: adrian555 <wzhuang@us.ibm.com> Author: Adrian Zhuang <adrian555@users.noreply.github.com> Closes #9443 from adrian555/with.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/NAMESPACE1
-rw-r--r--R/pkg/R/DataFrame.R30
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/R/utils.R13
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R9
5 files changed, 51 insertions, 6 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index cd9537a265..56b8ed0bf2 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -83,6 +83,7 @@ exportMethods("arrange",
"unique",
"unpersist",
"where",
+ "with",
"withColumn",
"withColumnRenamed",
"write.df")
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index df5bc81371..44ce9414da 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -2126,11 +2126,29 @@ setMethod("as.data.frame",
setMethod("attach",
signature(what = "DataFrame"),
function(what, pos = 2, name = deparse(substitute(what)), warn.conflicts = TRUE) {
- cols <- columns(what)
- stopifnot(length(cols) > 0)
- newEnv <- new.env()
- for (i in 1:length(cols)) {
- assign(x = cols[i], value = what[, cols[i]], envir = newEnv)
- }
+ newEnv <- assignNewEnv(what)
attach(newEnv, pos = pos, name = name, warn.conflicts = warn.conflicts)
})
+
+#' Evaluate a R expression in an environment constructed from a DataFrame
+#' with() allows access to columns of a DataFrame by simply referring to
+#' their name. It appends every column of a DataFrame into a new
+#' environment. Then, the given expression is evaluated in this new
+#' environment.
+#'
+#' @rdname with
+#' @title Evaluate a R expression in an environment constructed from a DataFrame
+#' @param data (DataFrame) DataFrame to use for constructing an environment.
+#' @param expr (expression) Expression to evaluate.
+#' @param ... arguments to be passed to future methods.
+#' @examples
+#' \dontrun{
+#' with(irisDf, nrow(Sepal_Width))
+#' }
+#' @seealso \link{attach}
+setMethod("with",
+ signature(data = "DataFrame"),
+ function(data, expr, ...) {
+ newEnv <- assignNewEnv(data)
+ eval(substitute(expr), envir = newEnv, enclos = newEnv)
+ })
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 0b35340e48..083d37fee2 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1043,3 +1043,7 @@ setGeneric("as.data.frame")
#' @rdname attach
#' @export
setGeneric("attach")
+
+#' @rdname with
+#' @export
+setGeneric("with")
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 0b9e2957fe..db3b2c4bbd 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -623,3 +623,16 @@ convertNamedListToEnv <- function(namedList) {
}
env
}
+
+# Assign a new environment for attach() and with() methods
+assignNewEnv <- function(data) {
+ stopifnot(class(data) == "DataFrame")
+ cols <- columns(data)
+ stopifnot(length(cols) > 0)
+
+ env <- new.env()
+ for (i in 1:length(cols)) {
+ assign(x = cols[i], value = data[, cols[i]], envir = env)
+ }
+ env
+} \ No newline at end of file
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index b4a4d03b26..816315b1e4 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -1494,6 +1494,15 @@ test_that("attach() on a DataFrame", {
expect_error(age)
})
+test_that("with() on a DataFrame", {
+ df <- createDataFrame(sqlContext, iris)
+ expect_error(Sepal_Length)
+ sum1 <- with(df, list(summary(Sepal_Length), summary(Sepal_Width)))
+ expect_equal(collect(sum1[[1]])[1, "Sepal_Length"], "150")
+ sum2 <- with(df, distinct(Sepal_Length))
+ expect_equal(nrow(sum2), 35)
+})
+
unlink(parquetPath)
unlink(jsonPath)
unlink(jsonPathNa)