From dc3220ce11c7513b1452c82ee82cb86e908bcc2d Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Mon, 26 Oct 2015 20:58:18 -0700 Subject: [SPARK-11209][SPARKR] Add window functions into SparkR [step 1]. Author: Sun Rui Closes #9193 from sun-rui/SPARK-11209. --- R/pkg/NAMESPACE | 4 ++ R/pkg/R/functions.R | 98 ++++++++++++++++++++++++++++++++++++++++ R/pkg/R/generics.R | 16 +++++++ R/pkg/inst/tests/test_sparkSQL.R | 2 + 4 files changed, 120 insertions(+) (limited to 'R') diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 52f7a0106a..b73bed3128 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -119,6 +119,7 @@ exportMethods("%in%", "count", "countDistinct", "crc32", + "cumeDist", "date_add", "date_format", "date_sub", @@ -150,8 +151,10 @@ exportMethods("%in%", "isNaN", "isNotNull", "isNull", + "lag", "last", "last_day", + "lead", "least", "length", "levenshtein", @@ -177,6 +180,7 @@ exportMethods("%in%", "nanvl", "negate", "next_day", + "ntile", "otherwise", "pmod", "quarter", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index a72fb7bb42..366290fe66 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2013,3 +2013,101 @@ setMethod("ifelse", "otherwise", no) column(jc) }) + +###################### Window functions###################### + +#' cumeDist +#' +#' Window function: returns the cumulative distribution of values within a window partition, +#' i.e. the fraction of rows that are below the current row. +#' +#' N = total number of rows in the partition +#' cumeDist(x) = number of values before (and including) x / N +#' +#' This is equivalent to the CUME_DIST function in SQL. +#' +#' @rdname cumeDist +#' @name cumeDist +#' @family window_funcs +#' @export +#' @examples \dontrun{cumeDist()} +setMethod("cumeDist", + signature(x = "missing"), + function() { + jc <- callJStatic("org.apache.spark.sql.functions", "cumeDist") + column(jc) + }) + +#' lag +#' +#' Window function: returns the value that is `offset` rows before the current row, and +#' `defaultValue` if there is less than `offset` rows before the current row. For example, +#' an `offset` of one will return the previous row at any given point in the window partition. +#' +#' This is equivalent to the LAG function in SQL. +#' +#' @rdname lag +#' @name lag +#' @family window_funcs +#' @export +#' @examples \dontrun{lag(df$c)} +setMethod("lag", + signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), + function(x, offset, defaultValue = NULL) { + col <- if (class(x) == "Column") { + x@jc + } else { + x + } + + jc <- callJStatic("org.apache.spark.sql.functions", + "lag", col, as.integer(offset), defaultValue) + column(jc) + }) + +#' lead +#' +#' Window function: returns the value that is `offset` rows after the current row, and +#' `null` if there is less than `offset` rows after the current row. For example, +#' an `offset` of one will return the next row at any given point in the window partition. +#' +#' This is equivalent to the LEAD function in SQL. +#' +#' @rdname lead +#' @name lead +#' @family window_funcs +#' @export +#' @examples \dontrun{lead(df$c)} +setMethod("lead", + signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), + function(x, offset, defaultValue = NULL) { + col <- if (class(x) == "Column") { + x@jc + } else { + x + } + + jc <- callJStatic("org.apache.spark.sql.functions", + "lead", col, as.integer(offset), defaultValue) + column(jc) + }) + +#' ntile +#' +#' Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window +#' partition. Fow example, if `n` is 4, the first quarter of the rows will get value 1, the second +#' quarter will get 2, the third quarter will get 3, and the last quarter will get 4. +#' +#' This is equivalent to the NTILE function in SQL. +#' +#' @rdname ntile +#' @name ntile +#' @family window_funcs +#' @export +#' @examples \dontrun{ntile(1)} +setMethod("ntile", + signature(x = "numeric"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "ntile", as.integer(x)) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 4a419f785e..c11c3c8d3e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -714,6 +714,10 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) +#' @rdname cumeDist +#' @export +setGeneric("cumeDist", function(x) { standardGeneric("cumeDist") }) + #' @rdname datediff #' @export setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) @@ -790,6 +794,10 @@ setGeneric("instr", function(y, x) { standardGeneric("instr") }) #' @export setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) +#' @rdname lag +#' @export +setGeneric("lag", function(x, offset, defaultValue = NULL) { standardGeneric("lag") }) + #' @rdname last #' @export setGeneric("last", function(x) { standardGeneric("last") }) @@ -798,6 +806,10 @@ setGeneric("last", function(x) { standardGeneric("last") }) #' @export setGeneric("last_day", function(x) { standardGeneric("last_day") }) +#' @rdname lead +#' @export +setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") }) + #' @rdname least #' @export setGeneric("least", function(x, ...) { standardGeneric("least") }) @@ -858,6 +870,10 @@ setGeneric("negate", function(x) { standardGeneric("negate") }) #' @export setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) +#' @rdname ntile +#' @export +setGeneric("ntile", function(x) { standardGeneric("ntile") }) + #' @rdname countDistinct #' @export setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 540854d114..e1d4499925 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -829,6 +829,8 @@ test_that("column functions", { c9 <- signum(c) + sin(c) + sinh(c) + size(c) + soundex(c) + sqrt(c) + sum(c) c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) + c12 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) + c13 <- cumeDist() + ntile(1) df <- jsonFile(sqlContext, jsonPath) df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) -- cgit v1.2.3