aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
authorSun Rui <rui.sun@intel.com>2015-10-26 20:58:18 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-10-26 20:58:18 -0700
commitdc3220ce11c7513b1452c82ee82cb86e908bcc2d (patch)
treebfdf594afdf3731919cccd131363973c50df98cd /R/pkg
parent82464fb2e02ca4e4d425017815090497b79dc93f (diff)
downloadspark-dc3220ce11c7513b1452c82ee82cb86e908bcc2d.tar.gz
spark-dc3220ce11c7513b1452c82ee82cb86e908bcc2d.tar.bz2
spark-dc3220ce11c7513b1452c82ee82cb86e908bcc2d.zip
[SPARK-11209][SPARKR] Add window functions into SparkR [step 1].
Author: Sun Rui <rui.sun@intel.com> Closes #9193 from sun-rui/SPARK-11209.
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/NAMESPACE4
-rw-r--r--R/pkg/R/functions.R98
-rw-r--r--R/pkg/R/generics.R16
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R2
4 files changed, 120 insertions, 0 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 52f7a0106a..b73bed3128 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -119,6 +119,7 @@ exportMethods("%in%",
"count",
"countDistinct",
"crc32",
+ "cumeDist",
"date_add",
"date_format",
"date_sub",
@@ -150,8 +151,10 @@ exportMethods("%in%",
"isNaN",
"isNotNull",
"isNull",
+ "lag",
"last",
"last_day",
+ "lead",
"least",
"length",
"levenshtein",
@@ -177,6 +180,7 @@ exportMethods("%in%",
"nanvl",
"negate",
"next_day",
+ "ntile",
"otherwise",
"pmod",
"quarter",
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index a72fb7bb42..366290fe66 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -2013,3 +2013,101 @@ setMethod("ifelse",
"otherwise", no)
column(jc)
})
+
+###################### Window functions######################
+
+#' cumeDist
+#'
+#' Window function: returns the cumulative distribution of values within a window partition,
+#' i.e. the fraction of rows that are below the current row.
+#'
+#' N = total number of rows in the partition
+#' cumeDist(x) = number of values before (and including) x / N
+#'
+#' This is equivalent to the CUME_DIST function in SQL.
+#'
+#' @rdname cumeDist
+#' @name cumeDist
+#' @family window_funcs
+#' @export
+#' @examples \dontrun{cumeDist()}
+setMethod("cumeDist",
+ signature(x = "missing"),
+ function() {
+ jc <- callJStatic("org.apache.spark.sql.functions", "cumeDist")
+ column(jc)
+ })
+
+#' lag
+#'
+#' Window function: returns the value that is `offset` rows before the current row, and
+#' `defaultValue` if there is less than `offset` rows before the current row. For example,
+#' an `offset` of one will return the previous row at any given point in the window partition.
+#'
+#' This is equivalent to the LAG function in SQL.
+#'
+#' @rdname lag
+#' @name lag
+#' @family window_funcs
+#' @export
+#' @examples \dontrun{lag(df$c)}
+setMethod("lag",
+ signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"),
+ function(x, offset, defaultValue = NULL) {
+ col <- if (class(x) == "Column") {
+ x@jc
+ } else {
+ x
+ }
+
+ jc <- callJStatic("org.apache.spark.sql.functions",
+ "lag", col, as.integer(offset), defaultValue)
+ column(jc)
+ })
+
+#' lead
+#'
+#' Window function: returns the value that is `offset` rows after the current row, and
+#' `null` if there is less than `offset` rows after the current row. For example,
+#' an `offset` of one will return the next row at any given point in the window partition.
+#'
+#' This is equivalent to the LEAD function in SQL.
+#'
+#' @rdname lead
+#' @name lead
+#' @family window_funcs
+#' @export
+#' @examples \dontrun{lead(df$c)}
+setMethod("lead",
+ signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"),
+ function(x, offset, defaultValue = NULL) {
+ col <- if (class(x) == "Column") {
+ x@jc
+ } else {
+ x
+ }
+
+ jc <- callJStatic("org.apache.spark.sql.functions",
+ "lead", col, as.integer(offset), defaultValue)
+ column(jc)
+ })
+
+#' ntile
+#'
+#' Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window
+#' partition. Fow example, if `n` is 4, the first quarter of the rows will get value 1, the second
+#' quarter will get 2, the third quarter will get 3, and the last quarter will get 4.
+#'
+#' This is equivalent to the NTILE function in SQL.
+#'
+#' @rdname ntile
+#' @name ntile
+#' @family window_funcs
+#' @export
+#' @examples \dontrun{ntile(1)}
+setMethod("ntile",
+ signature(x = "numeric"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "ntile", as.integer(x))
+ column(jc)
+ })
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 4a419f785e..c11c3c8d3e 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -714,6 +714,10 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct")
#' @export
setGeneric("crc32", function(x) { standardGeneric("crc32") })
+#' @rdname cumeDist
+#' @export
+setGeneric("cumeDist", function(x) { standardGeneric("cumeDist") })
+
#' @rdname datediff
#' @export
setGeneric("datediff", function(y, x) { standardGeneric("datediff") })
@@ -790,6 +794,10 @@ setGeneric("instr", function(y, x) { standardGeneric("instr") })
#' @export
setGeneric("isNaN", function(x) { standardGeneric("isNaN") })
+#' @rdname lag
+#' @export
+setGeneric("lag", function(x, offset, defaultValue = NULL) { standardGeneric("lag") })
+
#' @rdname last
#' @export
setGeneric("last", function(x) { standardGeneric("last") })
@@ -798,6 +806,10 @@ setGeneric("last", function(x) { standardGeneric("last") })
#' @export
setGeneric("last_day", function(x) { standardGeneric("last_day") })
+#' @rdname lead
+#' @export
+setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") })
+
#' @rdname least
#' @export
setGeneric("least", function(x, ...) { standardGeneric("least") })
@@ -858,6 +870,10 @@ setGeneric("negate", function(x) { standardGeneric("negate") })
#' @export
setGeneric("next_day", function(y, x) { standardGeneric("next_day") })
+#' @rdname ntile
+#' @export
+setGeneric("ntile", function(x) { standardGeneric("ntile") })
+
#' @rdname countDistinct
#' @export
setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") })
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 540854d114..e1d4499925 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -829,6 +829,8 @@ test_that("column functions", {
c9 <- signum(c) + sin(c) + sinh(c) + size(c) + soundex(c) + sqrt(c) + sum(c)
c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c)
c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c)
+ c12 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1)
+ c13 <- cumeDist() + ntile(1)
df <- jsonFile(sqlContext, jsonPath)
df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20)))