From 4d535d1f1c19faa43f96433aee8760e37b1690ea Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 10 Mar 2016 17:31:19 -0800 Subject: [SPARK-13389][SPARKR] SparkR support first/last with ignore NAs ## What changes were proposed in this pull request? SparkR support first/last with ignore NAs cc sun-rui felixcheung shivaram ## How was the this patch tested? unit tests Author: Yanbo Liang Closes #11267 from yanboliang/spark-13389. --- R/pkg/R/functions.R | 40 ++++++++++++++++++++++++------- R/pkg/R/generics.R | 4 ++-- R/pkg/inst/tests/testthat/test_sparkSQL.R | 11 +++++++++ 3 files changed, 45 insertions(+), 10 deletions(-) (limited to 'R') diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index e5521f3cff..d9c10b4a4b 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -536,15 +536,27 @@ setMethod("factorial", #' #' Aggregate function: returns the first value in a group. #' +#' The function by default returns the first values it sees. It will return the first non-missing +#' value it sees when na.rm is set to true. If all values are missing, then NA is returned. +#' #' @rdname first #' @name first #' @family agg_funcs #' @export -#' @examples \dontrun{first(df$c)} +#' @examples +#' \dontrun{ +#' first(df$c) +#' first(df$c, TRUE) +#' } setMethod("first", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "first", x@jc) + signature(x = "characterOrColumn"), + function(x, na.rm = FALSE) { + col <- if (class(x) == "Column") { + x@jc + } else { + x + } + jc <- callJStatic("org.apache.spark.sql.functions", "first", col, na.rm) column(jc) }) @@ -663,15 +675,27 @@ setMethod("kurtosis", #' #' Aggregate function: returns the last value in a group. #' +#' The function by default returns the last values it sees. It will return the last non-missing +#' value it sees when na.rm is set to true. If all values are missing, then NA is returned. +#' #' @rdname last #' @name last #' @family agg_funcs #' @export -#' @examples \dontrun{last(df$c)} +#' @examples +#' \dontrun{ +#' last(df$c) +#' last(df$c, TRUE) +#' } setMethod("last", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "last", x@jc) + signature(x = "characterOrColumn"), + function(x, na.rm = FALSE) { + col <- if (class(x) == "Column") { + x@jc + } else { + x + } + jc <- callJStatic("org.apache.spark.sql.functions", "last", col, na.rm) column(jc) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 3db72b5795..ddfa61717a 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -84,7 +84,7 @@ setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) # @rdname first # @export -setGeneric("first", function(x) { standardGeneric("first") }) +setGeneric("first", function(x, ...) { standardGeneric("first") }) # @rdname flatMap # @export @@ -889,7 +889,7 @@ setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last #' @export -setGeneric("last", function(x) { standardGeneric("last") }) +setGeneric("last", function(x, ...) { standardGeneric("last") }) #' @rdname last_day #' @export diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index cad5766812..11a8f12fd5 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1076,6 +1076,17 @@ test_that("column functions", { result <- collect(select(df, encode(df$a, "utf-8"), decode(df$c, "utf-8"))) expect_equal(result[[1]][[1]], bytes) expect_equal(result[[2]], markUtf8("大千世界")) + + # Test first(), last() + df <- read.json(sqlContext, jsonPath) + expect_equal(collect(select(df, first(df$age)))[[1]], NA) + expect_equal(collect(select(df, first(df$age, TRUE)))[[1]], 30) + expect_equal(collect(select(df, first("age")))[[1]], NA) + expect_equal(collect(select(df, first("age", TRUE)))[[1]], 30) + expect_equal(collect(select(df, last(df$age)))[[1]], 19) + expect_equal(collect(select(df, last(df$age, TRUE)))[[1]], 19) + expect_equal(collect(select(df, last("age")))[[1]], 19) + expect_equal(collect(select(df, last("age", TRUE)))[[1]], 19) }) test_that("column binary mathfunctions", { -- cgit v1.2.3