aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
authorOscar D. Lara Yejas <odlaraye@oscars-mbp.usca.ibm.com>2016-04-27 15:47:54 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2016-04-27 15:47:54 -0700
commite4bfb4aa7382cb9c5e4eb7e2211551d5da716a61 (patch)
tree58d4303824aca6fec6f9f6311f2dcf7a3cb1bd4e /R/pkg
parent37575115b98fdc9ebadb2ebcbcd9907a3af1076c (diff)
downloadspark-e4bfb4aa7382cb9c5e4eb7e2211551d5da716a61.tar.gz
spark-e4bfb4aa7382cb9c5e4eb7e2211551d5da716a61.tar.bz2
spark-e4bfb4aa7382cb9c5e4eb7e2211551d5da716a61.zip
[SPARK-13436][SPARKR] Added parameter drop to subsetting operator [
Added parameter drop to subsetting operator [. This is useful to get a Column from a DataFrame, given its name. R supports it. In R: ``` > name <- "Sepal_Length" > class(iris[, name]) [1] "numeric" ``` Currently, in SparkR: ``` > name <- "Sepal_Length" > class(irisDF[, name]) [1] "DataFrame" ``` Previous code returns a DataFrame, which is inconsistent with R's behavior. SparkR should return a Column instead. Currently, in order for the user to return a Column given a column name as a character variable would be through `eval(parse(x))`, where x is the string `"irisDF$Sepal_Length"`. That itself is pretty hacky. `SparkR:::getColumn() `is another choice, but I don't see why this method should be externalized. Instead, following R's way to do things, the proposed implementation allows this: ``` > name <- "Sepal_Length" > class(irisDF[, name, drop=T]) [1] "Column" > class(irisDF[, name, drop=F]) [1] "DataFrame" ``` This is consistent with R: ``` > name <- "Sepal_Length" > class(iris[, name]) [1] "numeric" > class(iris[, name, drop=F]) [1] "data.frame" ``` Author: Oscar D. Lara Yejas <odlaraye@oscars-mbp.usca.ibm.com> Author: Oscar D. Lara Yejas <odlaraye@oscars-mbp.attlocal.net> Closes #11318 from olarayej/SPARK-13436.
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/R/DataFrame.R70
-rw-r--r--R/pkg/R/utils.R2
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R24
3 files changed, 54 insertions, 42 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 36aedfae86..48ac1b06f6 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1237,29 +1237,38 @@ setMethod("[[", signature(x = "SparkDataFrame", i = "numericOrcharacter"),
#' @rdname subset
#' @name [
-setMethod("[", signature(x = "SparkDataFrame", i = "missing"),
- function(x, i, j, ...) {
- if (is.numeric(j)) {
- cols <- columns(x)
- j <- cols[j]
- }
- if (length(j) > 1) {
- j <- as.list(j)
+setMethod("[", signature(x = "SparkDataFrame"),
+ function(x, i, j, ..., drop = F) {
+ # Perform filtering first if needed
+ filtered <- if (missing(i)) {
+ x
+ } else {
+ if (class(i) != "Column") {
+ stop(paste0("Expressions other than filtering predicates are not supported ",
+ "in the first parameter of extract operator [ or subset() method."))
+ }
+ filter(x, i)
}
- select(x, j)
- })
-#' @rdname subset
-#' @name [
-setMethod("[", signature(x = "SparkDataFrame", i = "Column"),
- function(x, i, j, ...) {
- # It could handle i as "character" but it seems confusing and not required
- # https://stat.ethz.ch/R-manual/R-devel/library/base/html/Extract.data.frame.html
- filtered <- filter(x, i)
- if (!missing(j)) {
- filtered[, j, ...]
- } else {
+ # If something is to be projected, then do so on the filtered SparkDataFrame
+ if (missing(j)) {
filtered
+ } else {
+ if (is.numeric(j)) {
+ cols <- columns(filtered)
+ j <- cols[j]
+ }
+ if (length(j) > 1) {
+ j <- as.list(j)
+ }
+ selected <- select(filtered, j)
+
+ # Acknowledge parameter drop. Return a Column or SparkDataFrame accordingly
+ if (ncol(selected) == 1 & drop == T) {
+ getColumn(selected, names(selected))
+ } else {
+ selected
+ }
}
})
@@ -1268,10 +1277,10 @@ setMethod("[", signature(x = "SparkDataFrame", i = "Column"),
#' Return subsets of SparkDataFrame according to given conditions
#' @param x A SparkDataFrame
#' @param subset (Optional) A logical expression to filter on rows
-#' @param select expression for the single Column or a list of columns to select from the
-#' SparkDataFrame
-#' @return A new SparkDataFrame containing only the rows that meet the condition with selected
-#' columns
+#' @param select expression for the single Column or a list of columns to select from the SparkDataFrame
+#' @param drop if TRUE, a Column will be returned if the resulting dataset has only one column.
+#' Otherwise, a SparkDataFrame will always be returned.
+#' @return A new SparkDataFrame containing only the rows that meet the condition with selected columns
#' @export
#' @family SparkDataFrame functions
#' @rdname subset
@@ -1293,12 +1302,8 @@ setMethod("[", signature(x = "SparkDataFrame", i = "Column"),
#' subset(df, select = c(1,2))
#' }
setMethod("subset", signature(x = "SparkDataFrame"),
- function(x, subset, select, ...) {
- if (missing(subset)) {
- x[, select, ...]
- } else {
- x[subset, select, ...]
- }
+ function(x, subset, select, drop = F, ...) {
+ x[subset, select, drop = drop]
})
#' Select
@@ -2520,7 +2525,7 @@ setMethod("histogram",
}
# Filter NA values in the target column and remove all other columns
- df <- na.omit(df[, colname])
+ df <- na.omit(df[, colname, drop = F])
getColumn(df, colname)
} else if (class(col) == "Column") {
@@ -2552,8 +2557,7 @@ setMethod("histogram",
col
}
- # At this point, df only has one column: the one to compute the histogram from
- stats <- collect(describe(df[, colname]))
+ stats <- collect(describe(df[, colname, drop = F]))
min <- as.numeric(stats[4, 2])
max <- as.numeric(stats[5, 2])
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index ba7a611e64..bf67e231d5 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -632,7 +632,7 @@ assignNewEnv <- function(data) {
env <- new.env()
for (i in 1:length(cols)) {
- assign(x = cols[i], value = data[, cols[i]], envir = env)
+ assign(x = cols[i], value = data[, cols[i], drop = F], envir = env)
}
env
}
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 336068035e..95d6cb8875 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -822,9 +822,10 @@ test_that("select operators", {
expect_is(df[[2]], "Column")
expect_is(df[["age"]], "Column")
- expect_is(df[, 1], "SparkDataFrame")
- expect_equal(columns(df[, 1]), c("name"))
- expect_equal(columns(df[, "age"]), c("age"))
+ expect_is(df[, 1, drop = F], "SparkDataFrame")
+ expect_equal(columns(df[, 1, drop = F]), c("name"))
+ expect_equal(columns(df[, "age", drop = F]), c("age"))
+
df2 <- df[, c("age", "name")]
expect_is(df2, "SparkDataFrame")
expect_equal(columns(df2), c("age", "name"))
@@ -835,6 +836,13 @@ test_that("select operators", {
df$age2 <- df$age * 2
expect_equal(columns(df), c("name", "age", "age2"))
expect_equal(count(where(df, df$age2 == df$age * 2)), 2)
+
+ # Test parameter drop
+ expect_equal(class(df[, 1]) == "SparkDataFrame", T)
+ expect_equal(class(df[, 1, drop = T]) == "Column", T)
+ expect_equal(class(df[, 1, drop = F]) == "SparkDataFrame", T)
+ expect_equal(class(df[df$age > 4, 2, drop = T]) == "Column", T)
+ expect_equal(class(df[df$age > 4, 2, drop = F]) == "SparkDataFrame", T)
})
test_that("select with column", {
@@ -889,13 +897,13 @@ test_that("subsetting", {
expect_equal(columns(filtered), c("name", "age"))
expect_equal(collect(filtered)$name, "Andy")
- df2 <- df[df$age == 19, 1]
+ df2 <- df[df$age == 19, 1, drop = F]
expect_is(df2, "SparkDataFrame")
expect_equal(count(df2), 1)
expect_equal(columns(df2), c("name"))
expect_equal(collect(df2)$name, "Justin")
- df3 <- df[df$age > 20, 2]
+ df3 <- df[df$age > 20, 2, drop = F]
expect_equal(count(df3), 1)
expect_equal(columns(df3), c("age"))
@@ -911,7 +919,7 @@ test_that("subsetting", {
expect_equal(count(df6), 1)
expect_equal(columns(df6), c("name", "age"))
- df7 <- subset(df, select = "name")
+ df7 <- subset(df, select = "name", drop = F)
expect_equal(count(df7), 3)
expect_equal(columns(df7), c("name"))
@@ -1888,7 +1896,7 @@ test_that("attach() on a DataFrame", {
stat2 <- summary(age)
expect_equal(collect(stat2)[5, "age"], "30")
detach("df")
- stat3 <- summary(df[, "age"])
+ stat3 <- summary(df[, "age", drop = F])
expect_equal(collect(stat3)[5, "age"], "30")
expect_error(age)
})
@@ -1928,7 +1936,7 @@ test_that("Method coltypes() to get and set R's data types of a DataFrame", {
df1 <- select(df, cast(df$age, "integer"))
coltypes(df) <- c("character", "integer")
expect_equal(dtypes(df), list(c("name", "string"), c("age", "int")))
- value <- collect(df[, 2])[[3, 1]]
+ value <- collect(df[, 2, drop = F])[[3, 1]]
expect_equal(value, collect(df1)[[3, 1]])
expect_equal(value, 22)