From c8d0e160dadf3b23c5caa379ba9ad5547794eaa0 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Sat, 5 Dec 2015 15:49:51 -0800 Subject: [SPARK-11774][SPARKR] Implement struct(), encode(), decode() functions in SparkR. Author: Sun Rui Closes #9804 from sun-rui/SPARK-11774. --- R/pkg/NAMESPACE | 3 ++ R/pkg/R/functions.R | 59 ++++++++++++++++++++++++++++++++++++++++ R/pkg/R/generics.R | 12 ++++++++ R/pkg/inst/tests/test_sparkSQL.R | 37 +++++++++++++++++++++---- 4 files changed, 105 insertions(+), 6 deletions(-) (limited to 'R') diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 43e5e0119e..565a2b1a68 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -134,8 +134,10 @@ exportMethods("%in%", "datediff", "dayofmonth", "dayofyear", + "decode", "dense_rank", "desc", + "encode", "endsWith", "exp", "explode", @@ -225,6 +227,7 @@ exportMethods("%in%", "stddev", "stddev_pop", "stddev_samp", + "struct", "sqrt", "startsWith", "substr", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index b30331c61c..7432cb8e7c 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -357,6 +357,40 @@ setMethod("dayofyear", column(jc) }) +#' decode +#' +#' Computes the first argument into a string from a binary using the provided character set +#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). +#' +#' @rdname decode +#' @name decode +#' @family string_funcs +#' @export +#' @examples \dontrun{decode(df$c, "UTF-8")} +setMethod("decode", + signature(x = "Column", charset = "character"), + function(x, charset) { + jc <- callJStatic("org.apache.spark.sql.functions", "decode", x@jc, charset) + column(jc) + }) + +#' encode +#' +#' Computes the first argument into a binary from a string using the provided character set +#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). +#' +#' @rdname encode +#' @name encode +#' @family string_funcs +#' @export +#' @examples \dontrun{encode(df$c, "UTF-8")} +setMethod("encode", + signature(x = "Column", charset = "character"), + function(x, charset) { + jc <- callJStatic("org.apache.spark.sql.functions", "encode", x@jc, charset) + column(jc) + }) + #' exp #' #' Computes the exponential of the given value. @@ -1039,6 +1073,31 @@ setMethod("stddev_samp", column(jc) }) +#' struct +#' +#' Creates a new struct column that composes multiple input columns. +#' +#' @rdname struct +#' @name struct +#' @family normal_funcs +#' @export +#' @examples +#' \dontrun{ +#' struct(df$c, df$d) +#' struct("col1", "col2") +#' } +setMethod("struct", + signature(x = "characterOrColumn"), + function(x, ...) { + if (class(x) == "Column") { + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "struct", jcols) + } else { + jc <- callJStatic("org.apache.spark.sql.functions", "struct", x, list(...)) + } + column(jc) + }) + #' sqrt #' #' Computes the square root of the specified float value. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 711ce38f9e..4b5f786d39 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -744,10 +744,18 @@ setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @export setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) +#' @rdname decode +#' @export +setGeneric("decode", function(x, charset) { standardGeneric("decode") }) + #' @rdname dense_rank #' @export setGeneric("dense_rank", function(x) { standardGeneric("dense_rank") }) +#' @rdname encode +#' @export +setGeneric("encode", function(x, charset) { standardGeneric("encode") }) + #' @rdname explode #' @export setGeneric("explode", function(x) { standardGeneric("explode") }) @@ -1001,6 +1009,10 @@ setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) #' @export setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) +#' @rdname struct +#' @export +setGeneric("struct", function(x, ...) { standardGeneric("struct") }) + #' @rdname substring_index #' @export setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 1e7cb54099..2d26b92ac7 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -27,6 +27,11 @@ checkStructField <- function(actual, expectedName, expectedType, expectedNullabl expect_equal(actual$nullable(), expectedNullable) } +markUtf8 <- function(s) { + Encoding(s) <- "UTF-8" + s +} + # Tests for SparkSQL functions in SparkR sc <- sparkR.init() @@ -551,11 +556,6 @@ test_that("collect() and take() on a DataFrame return the same number of rows an }) test_that("collect() support Unicode characters", { - markUtf8 <- function(s) { - Encoding(s) <- "UTF-8" - s - } - lines <- c("{\"name\":\"안녕하세요\"}", "{\"name\":\"您好\", \"age\":30}", "{\"name\":\"こんにちは\", \"age\":19}", @@ -933,8 +933,33 @@ test_that("column functions", { # Test that stats::lag is working expect_equal(length(lag(ldeaths, 12)), 72) + + # Test struct() + df <- createDataFrame(sqlContext, + list(list(1L, 2L, 3L), list(4L, 5L, 6L)), + schema = c("a", "b", "c")) + result <- collect(select(df, struct("a", "c"))) + expected <- data.frame(row.names = 1:2) + expected$"struct(a,c)" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) + expect_equal(result, expected) + + result <- collect(select(df, struct(df$a, df$b))) + expected <- data.frame(row.names = 1:2) + expected$"struct(a,b)" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) + expect_equal(result, expected) + + # Test encode(), decode() + bytes <- as.raw(c(0xe5, 0xa4, 0xa7, 0xe5, 0x8d, 0x83, 0xe4, 0xb8, 0x96, 0xe7, 0x95, 0x8c)) + df <- createDataFrame(sqlContext, + list(list(markUtf8("大千世界"), "utf-8", bytes)), + schema = c("a", "b", "c")) + result <- collect(select(df, encode(df$a, "utf-8"), decode(df$c, "utf-8"))) + expect_equal(result[[1]][[1]], bytes) + expect_equal(result[[2]], markUtf8("大千世界")) }) -# + test_that("column binary mathfunctions", { lines <- c("{\"a\":1, \"b\":5}", "{\"a\":2, \"b\":6}", -- cgit v1.2.3