aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorSun Rui <rui.sun@intel.com>2015-12-05 15:49:51 -0800
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-12-05 15:49:51 -0800
commitc8d0e160dadf3b23c5caa379ba9ad5547794eaa0 (patch)
treeb2c3b6f2ebfa184019ea53c6ae21294f08d7a382 /R
parent7da674851928ed23eb651a3e2f8233e7a684ac41 (diff)
downloadspark-c8d0e160dadf3b23c5caa379ba9ad5547794eaa0.tar.gz
spark-c8d0e160dadf3b23c5caa379ba9ad5547794eaa0.tar.bz2
spark-c8d0e160dadf3b23c5caa379ba9ad5547794eaa0.zip
[SPARK-11774][SPARKR] Implement struct(), encode(), decode() functions in SparkR.
Author: Sun Rui <rui.sun@intel.com> Closes #9804 from sun-rui/SPARK-11774.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/NAMESPACE3
-rw-r--r--R/pkg/R/functions.R59
-rw-r--r--R/pkg/R/generics.R12
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R37
4 files changed, 105 insertions, 6 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 43e5e0119e..565a2b1a68 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -134,8 +134,10 @@ exportMethods("%in%",
"datediff",
"dayofmonth",
"dayofyear",
+ "decode",
"dense_rank",
"desc",
+ "encode",
"endsWith",
"exp",
"explode",
@@ -225,6 +227,7 @@ exportMethods("%in%",
"stddev",
"stddev_pop",
"stddev_samp",
+ "struct",
"sqrt",
"startsWith",
"substr",
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index b30331c61c..7432cb8e7c 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -357,6 +357,40 @@ setMethod("dayofyear",
column(jc)
})
+#' decode
+#'
+#' Computes the first argument into a string from a binary using the provided character set
+#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
+#'
+#' @rdname decode
+#' @name decode
+#' @family string_funcs
+#' @export
+#' @examples \dontrun{decode(df$c, "UTF-8")}
+setMethod("decode",
+ signature(x = "Column", charset = "character"),
+ function(x, charset) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "decode", x@jc, charset)
+ column(jc)
+ })
+
+#' encode
+#'
+#' Computes the first argument into a binary from a string using the provided character set
+#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
+#'
+#' @rdname encode
+#' @name encode
+#' @family string_funcs
+#' @export
+#' @examples \dontrun{encode(df$c, "UTF-8")}
+setMethod("encode",
+ signature(x = "Column", charset = "character"),
+ function(x, charset) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "encode", x@jc, charset)
+ column(jc)
+ })
+
#' exp
#'
#' Computes the exponential of the given value.
@@ -1039,6 +1073,31 @@ setMethod("stddev_samp",
column(jc)
})
+#' struct
+#'
+#' Creates a new struct column that composes multiple input columns.
+#'
+#' @rdname struct
+#' @name struct
+#' @family normal_funcs
+#' @export
+#' @examples
+#' \dontrun{
+#' struct(df$c, df$d)
+#' struct("col1", "col2")
+#' }
+setMethod("struct",
+ signature(x = "characterOrColumn"),
+ function(x, ...) {
+ if (class(x) == "Column") {
+ jcols <- lapply(list(x, ...), function(x) { x@jc })
+ jc <- callJStatic("org.apache.spark.sql.functions", "struct", jcols)
+ } else {
+ jc <- callJStatic("org.apache.spark.sql.functions", "struct", x, list(...))
+ }
+ column(jc)
+ })
+
#' sqrt
#'
#' Computes the square root of the specified float value.
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 711ce38f9e..4b5f786d39 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -744,10 +744,18 @@ setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") })
#' @export
setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") })
+#' @rdname decode
+#' @export
+setGeneric("decode", function(x, charset) { standardGeneric("decode") })
+
#' @rdname dense_rank
#' @export
setGeneric("dense_rank", function(x) { standardGeneric("dense_rank") })
+#' @rdname encode
+#' @export
+setGeneric("encode", function(x, charset) { standardGeneric("encode") })
+
#' @rdname explode
#' @export
setGeneric("explode", function(x) { standardGeneric("explode") })
@@ -1001,6 +1009,10 @@ setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") })
#' @export
setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") })
+#' @rdname struct
+#' @export
+setGeneric("struct", function(x, ...) { standardGeneric("struct") })
+
#' @rdname substring_index
#' @export
setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") })
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 1e7cb54099..2d26b92ac7 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -27,6 +27,11 @@ checkStructField <- function(actual, expectedName, expectedType, expectedNullabl
expect_equal(actual$nullable(), expectedNullable)
}
+markUtf8 <- function(s) {
+ Encoding(s) <- "UTF-8"
+ s
+}
+
# Tests for SparkSQL functions in SparkR
sc <- sparkR.init()
@@ -551,11 +556,6 @@ test_that("collect() and take() on a DataFrame return the same number of rows an
})
test_that("collect() support Unicode characters", {
- markUtf8 <- function(s) {
- Encoding(s) <- "UTF-8"
- s
- }
-
lines <- c("{\"name\":\"안녕하세요\"}",
"{\"name\":\"您好\", \"age\":30}",
"{\"name\":\"こんにちは\", \"age\":19}",
@@ -933,8 +933,33 @@ test_that("column functions", {
# Test that stats::lag is working
expect_equal(length(lag(ldeaths, 12)), 72)
+
+ # Test struct()
+ df <- createDataFrame(sqlContext,
+ list(list(1L, 2L, 3L), list(4L, 5L, 6L)),
+ schema = c("a", "b", "c"))
+ result <- collect(select(df, struct("a", "c")))
+ expected <- data.frame(row.names = 1:2)
+ expected$"struct(a,c)" <- list(listToStruct(list(a = 1L, c = 3L)),
+ listToStruct(list(a = 4L, c = 6L)))
+ expect_equal(result, expected)
+
+ result <- collect(select(df, struct(df$a, df$b)))
+ expected <- data.frame(row.names = 1:2)
+ expected$"struct(a,b)" <- list(listToStruct(list(a = 1L, b = 2L)),
+ listToStruct(list(a = 4L, b = 5L)))
+ expect_equal(result, expected)
+
+ # Test encode(), decode()
+ bytes <- as.raw(c(0xe5, 0xa4, 0xa7, 0xe5, 0x8d, 0x83, 0xe4, 0xb8, 0x96, 0xe7, 0x95, 0x8c))
+ df <- createDataFrame(sqlContext,
+ list(list(markUtf8("大千世界"), "utf-8", bytes)),
+ schema = c("a", "b", "c"))
+ result <- collect(select(df, encode(df$a, "utf-8"), decode(df$c, "utf-8")))
+ expect_equal(result[[1]][[1]], bytes)
+ expect_equal(result[[2]], markUtf8("大千世界"))
})
-#
+
test_that("column binary mathfunctions", {
lines <- c("{\"a\":1, \"b\":5}",
"{\"a\":2, \"b\":6}",