aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--R/pkg/NAMESPACE2
-rw-r--r--R/pkg/R/column.R14
-rw-r--r--R/pkg/R/functions.R14
-rw-r--r--R/pkg/R/generics.R8
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R7
5 files changed, 45 insertions, 0 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 607aef2611..8fa12d5ade 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -152,6 +152,7 @@ exportMethods("abs",
"n_distinct",
"nanvl",
"negate",
+ "otherwise",
"pmod",
"quarter",
"reverse",
@@ -182,6 +183,7 @@ exportMethods("abs",
"unhex",
"upper",
"weekofyear",
+ "when",
"year")
exportClasses("GroupedData")
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index 328f595d08..5a07ebd308 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -203,3 +203,17 @@ setMethod("%in%",
jc <- callJMethod(x@jc, "in", table)
return(column(jc))
})
+
+#' otherwise
+#'
+#' If values in the specified column are null, returns the value.
+#' Can be used in conjunction with `when` to specify a default value for expressions.
+#'
+#' @rdname column
+setMethod("otherwise",
+ signature(x = "Column", value = "ANY"),
+ function(x, value) {
+ value <- ifelse(class(value) == "Column", value@jc, value)
+ jc <- callJMethod(x@jc, "otherwise", value)
+ column(jc)
+ })
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index e606b20570..366c230e1e 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -165,3 +165,17 @@ setMethod("n", signature(x = "Column"),
function(x) {
count(x)
})
+
+#' when
+#'
+#' Evaluates a list of conditions and returns one of multiple possible result expressions.
+#' For unmatched expressions null is returned.
+#'
+#' @rdname column
+setMethod("when", signature(condition = "Column", value = "ANY"),
+ function(condition, value) {
+ condition <- condition@jc
+ value <- ifelse(class(value) == "Column", value@jc, value)
+ jc <- callJStatic("org.apache.spark.sql.functions", "when", condition, value)
+ column(jc)
+ })
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 5c1cc98fd9..338b32e648 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -651,6 +651,14 @@ setGeneric("rlike", function(x, ...) { standardGeneric("rlike") })
#' @export
setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") })
+#' @rdname column
+#' @export
+setGeneric("when", function(condition, value) { standardGeneric("when") })
+
+#' @rdname column
+#' @export
+setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") })
+
###################### Expression Function Methods ##########################
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 83caba8b5b..841de657df 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -727,6 +727,13 @@ test_that("greatest() and least() on a DataFrame", {
expect_equal(collect(select(df, least(df$a, df$b)))[, 1], c(1, 3))
})
+test_that("when() and otherwise() on a DataFrame", {
+ l <- list(list(a = 1, b = 2), list(a = 3, b = 4))
+ df <- createDataFrame(sqlContext, l)
+ expect_equal(collect(select(df, when(df$a > 1 & df$b > 2, 1)))[, 1], c(NA, 1))
+ expect_equal(collect(select(df, otherwise(when(df$a > 1, 1), 0)))[, 1], c(0, 1))
+})
+
test_that("group by", {
df <- jsonFile(sqlContext, jsonPath)
df1 <- agg(df, name = "max", age = "sum")