aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorSun Rui <rui.sun@intel.com>2015-09-10 12:21:13 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-09-10 12:21:13 -0700
commit45e3be5c138d983f40f619735d60bf7eb78c9bf0 (patch)
tree30b7b90f53eadee901a56e0e2e84222e21cf6c44 /R
parentd88abb7e212fb55f9b0398a0f76a753c86b85cf1 (diff)
downloadspark-45e3be5c138d983f40f619735d60bf7eb78c9bf0.tar.gz
spark-45e3be5c138d983f40f619735d60bf7eb78c9bf0.tar.bz2
spark-45e3be5c138d983f40f619735d60bf7eb78c9bf0.zip
[SPARK-10049] [SPARKR] Support collecting data of ArraryType in DataFrame.
this PR : 1. Enhance reflection in RBackend. Automatically matching a Java array to Scala Seq when finding methods. Util functions like seq(), listToSeq() in R side can be removed, as they will conflict with the Serde logic that transferrs a Scala seq to R side. 2. Enhance the SerDe to support transferring a Scala seq to R side. Data of ArrayType in DataFrame after collection is observed to be of Scala Seq type. 3. Support ArrayType in createDataFrame(). Author: Sun Rui <rui.sun@intel.com> Closes #8458 from sun-rui/SPARK-10049.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/R/DataFrame.R26
-rw-r--r--R/pkg/R/SQLContext.R4
-rw-r--r--R/pkg/R/column.R3
-rw-r--r--R/pkg/R/functions.R12
-rw-r--r--R/pkg/R/group.R4
-rw-r--r--R/pkg/R/schema.R54
-rw-r--r--R/pkg/R/utils.R10
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R44
8 files changed, 95 insertions, 62 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 8a00238b41..c3c1893487 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -271,7 +271,7 @@ setMethod("names<-",
signature(x = "DataFrame"),
function(x, value) {
if (!is.null(value)) {
- sdf <- callJMethod(x@sdf, "toDF", listToSeq(as.list(value)))
+ sdf <- callJMethod(x@sdf, "toDF", as.list(value))
dataFrame(sdf)
}
})
@@ -843,10 +843,10 @@ setMethod("groupBy",
function(x, ...) {
cols <- list(...)
if (length(cols) >= 1 && class(cols[[1]]) == "character") {
- sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1]))
+ sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], cols[-1])
} else {
jcol <- lapply(cols, function(c) { c@jc })
- sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol))
+ sgd <- callJMethod(x@sdf, "groupBy", jcol)
}
groupedData(sgd)
})
@@ -1079,7 +1079,7 @@ setMethod("subset", signature(x = "DataFrame"),
#' }
setMethod("select", signature(x = "DataFrame", col = "character"),
function(x, col, ...) {
- sdf <- callJMethod(x@sdf, "select", col, toSeq(...))
+ sdf <- callJMethod(x@sdf, "select", col, list(...))
dataFrame(sdf)
})
@@ -1090,7 +1090,7 @@ setMethod("select", signature(x = "DataFrame", col = "Column"),
jcols <- lapply(list(col, ...), function(c) {
c@jc
})
- sdf <- callJMethod(x@sdf, "select", listToSeq(jcols))
+ sdf <- callJMethod(x@sdf, "select", jcols)
dataFrame(sdf)
})
@@ -1106,7 +1106,7 @@ setMethod("select",
col(c)@jc
}
})
- sdf <- callJMethod(x@sdf, "select", listToSeq(cols))
+ sdf <- callJMethod(x@sdf, "select", cols)
dataFrame(sdf)
})
@@ -1133,7 +1133,7 @@ setMethod("selectExpr",
signature(x = "DataFrame", expr = "character"),
function(x, expr, ...) {
exprList <- list(expr, ...)
- sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList))
+ sdf <- callJMethod(x@sdf, "selectExpr", exprList)
dataFrame(sdf)
})
@@ -1311,12 +1311,12 @@ setMethod("arrange",
signature(x = "DataFrame", col = "characterOrColumn"),
function(x, col, ...) {
if (class(col) == "character") {
- sdf <- callJMethod(x@sdf, "sort", col, toSeq(...))
+ sdf <- callJMethod(x@sdf, "sort", col, list(...))
} else if (class(col) == "Column") {
jcols <- lapply(list(col, ...), function(c) {
c@jc
})
- sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols))
+ sdf <- callJMethod(x@sdf, "sort", jcols)
}
dataFrame(sdf)
})
@@ -1664,7 +1664,7 @@ setMethod("describe",
signature(x = "DataFrame", col = "character"),
function(x, col, ...) {
colList <- list(col, ...)
- sdf <- callJMethod(x@sdf, "describe", listToSeq(colList))
+ sdf <- callJMethod(x@sdf, "describe", colList)
dataFrame(sdf)
})
@@ -1674,7 +1674,7 @@ setMethod("describe",
signature(x = "DataFrame"),
function(x) {
colList <- as.list(c(columns(x)))
- sdf <- callJMethod(x@sdf, "describe", listToSeq(colList))
+ sdf <- callJMethod(x@sdf, "describe", colList)
dataFrame(sdf)
})
@@ -1731,7 +1731,7 @@ setMethod("dropna",
naFunctions <- callJMethod(x@sdf, "na")
sdf <- callJMethod(naFunctions, "drop",
- as.integer(minNonNulls), listToSeq(as.list(cols)))
+ as.integer(minNonNulls), as.list(cols))
dataFrame(sdf)
})
@@ -1815,7 +1815,7 @@ setMethod("fillna",
sdf <- if (length(cols) == 0) {
callJMethod(naFunctions, "fill", value)
} else {
- callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols)))
+ callJMethod(naFunctions, "fill", value, as.list(cols))
}
dataFrame(sdf)
})
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 1bc6445311..4ac057d0f2 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -49,7 +49,7 @@ infer_type <- function(x) {
stopifnot(length(x) > 0)
names <- names(x)
if (is.null(names)) {
- list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE)
+ paste0("array<", infer_type(x[[1]]), ">")
} else {
# StructType
types <- lapply(x, infer_type)
@@ -59,7 +59,7 @@ infer_type <- function(x) {
do.call(structType, fields)
}
} else if (length(x) > 1) {
- list(type = "array", elementType = type, containsNull = TRUE)
+ paste0("array<", infer_type(x[[1]]), ">")
} else {
type
}
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index 4805096f3f..42e9d12179 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -211,8 +211,7 @@ setMethod("cast",
setMethod("%in%",
signature(x = "Column"),
function(x, table) {
- table <- listToSeq(as.list(table))
- jc <- callJMethod(x@jc, "in", table)
+ jc <- callJMethod(x@jc, "in", as.list(table))
return(column(jc))
})
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index d848730e70..94687edb05 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -1331,7 +1331,7 @@ setMethod("countDistinct",
x@jc
})
jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc,
- listToSeq(jcol))
+ jcol)
column(jc)
})
@@ -1348,7 +1348,7 @@ setMethod("concat",
signature(x = "Column"),
function(x, ...) {
jcols <- lapply(list(x, ...), function(x) { x@jc })
- jc <- callJStatic("org.apache.spark.sql.functions", "concat", listToSeq(jcols))
+ jc <- callJStatic("org.apache.spark.sql.functions", "concat", jcols)
column(jc)
})
@@ -1366,7 +1366,7 @@ setMethod("greatest",
function(x, ...) {
stopifnot(length(list(...)) > 0)
jcols <- lapply(list(x, ...), function(x) { x@jc })
- jc <- callJStatic("org.apache.spark.sql.functions", "greatest", listToSeq(jcols))
+ jc <- callJStatic("org.apache.spark.sql.functions", "greatest", jcols)
column(jc)
})
@@ -1384,7 +1384,7 @@ setMethod("least",
function(x, ...) {
stopifnot(length(list(...)) > 0)
jcols <- lapply(list(x, ...), function(x) { x@jc })
- jc <- callJStatic("org.apache.spark.sql.functions", "least", listToSeq(jcols))
+ jc <- callJStatic("org.apache.spark.sql.functions", "least", jcols)
column(jc)
})
@@ -1675,7 +1675,7 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"),
#' @export
setMethod("concat_ws", signature(sep = "character", x = "Column"),
function(sep, x, ...) {
- jcols <- listToSeq(lapply(list(x, ...), function(x) { x@jc }))
+ jcols <- lapply(list(x, ...), function(x) { x@jc })
jc <- callJStatic("org.apache.spark.sql.functions", "concat_ws", sep, jcols)
column(jc)
})
@@ -1723,7 +1723,7 @@ setMethod("expr", signature(x = "character"),
#' @export
setMethod("format_string", signature(format = "character", x = "Column"),
function(format, x, ...) {
- jcols <- listToSeq(lapply(list(x, ...), function(arg) { arg@jc }))
+ jcols <- lapply(list(x, ...), function(arg) { arg@jc })
jc <- callJStatic("org.apache.spark.sql.functions",
"format_string",
format, jcols)
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 576ac72f40..4cab1a69f6 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -102,7 +102,7 @@ setMethod("agg",
}
}
jcols <- lapply(cols, function(c) { c@jc })
- sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1]))
+ sdf <- callJMethod(x@sgd, "agg", jcols[[1]], jcols[-1])
} else {
stop("agg can only support Column or character")
}
@@ -124,7 +124,7 @@ createMethod <- function(name) {
setMethod(name,
signature(x = "GroupedData"),
function(x, ...) {
- sdf <- callJMethod(x@sgd, name, toSeq(...))
+ sdf <- callJMethod(x@sgd, name, list(...))
dataFrame(sdf)
})
}
diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R
index 79c744ef29..62d4f73878 100644
--- a/R/pkg/R/schema.R
+++ b/R/pkg/R/schema.R
@@ -56,7 +56,7 @@ structType.structField <- function(x, ...) {
})
stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
"createStructType",
- listToSeq(sfObjList))
+ sfObjList)
structType(stObj)
}
@@ -114,6 +114,35 @@ structField.jobj <- function(x) {
obj
}
+checkType <- function(type) {
+ primtiveTypes <- c("byte",
+ "integer",
+ "float",
+ "double",
+ "numeric",
+ "character",
+ "string",
+ "binary",
+ "raw",
+ "logical",
+ "boolean",
+ "timestamp",
+ "date")
+ if (type %in% primtiveTypes) {
+ return()
+ } else {
+ m <- regexec("^array<(.*)>$", type)
+ matchedStrings <- regmatches(type, m)
+ if (length(matchedStrings[[1]]) >= 2) {
+ elemType <- matchedStrings[[1]][2]
+ checkType(elemType)
+ return()
+ }
+ }
+
+ stop(paste("Unsupported type for Dataframe:", type))
+}
+
structField.character <- function(x, type, nullable = TRUE) {
if (class(x) != "character") {
stop("Field name must be a string.")
@@ -124,28 +153,13 @@ structField.character <- function(x, type, nullable = TRUE) {
if (class(nullable) != "logical") {
stop("nullable must be either TRUE or FALSE")
}
- options <- c("byte",
- "integer",
- "float",
- "double",
- "numeric",
- "character",
- "string",
- "binary",
- "raw",
- "logical",
- "boolean",
- "timestamp",
- "date")
- dataType <- if (type %in% options) {
- type
- } else {
- stop(paste("Unsupported type for Dataframe:", type))
- }
+
+ checkType(type)
+
sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
"createStructField",
x,
- dataType,
+ type,
nullable)
structField(sfObj)
}
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 3babcb5193..69a2bc728f 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -361,16 +361,6 @@ numToInt <- function(num) {
as.integer(num)
}
-# create a Seq in JVM
-toSeq <- function(...) {
- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...))
-}
-
-# create a Seq in JVM from a list
-listToSeq <- function(l) {
- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l)
-}
-
# Utility function to recursively traverse the Abstract Syntax Tree (AST) of a
# user defined function (UDF), and to examine variables in the UDF to decide
# if their values should be included in the new function environment.
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 6d331f9883..1ccfde5917 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -49,6 +49,14 @@ mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}",
jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp")
writeLines(mockLinesNa, jsonPathNa)
+# For test complex types in DataFrame
+mockLinesComplexType <-
+ c("{\"c1\":[1, 2, 3], \"c2\":[\"a\", \"b\", \"c\"], \"c3\":[1.0, 2.0, 3.0]}",
+ "{\"c1\":[4, 5, 6], \"c2\":[\"d\", \"e\", \"f\"], \"c3\":[4.0, 5.0, 6.0]}",
+ "{\"c1\":[7, 8, 9], \"c2\":[\"g\", \"h\", \"i\"], \"c3\":[7.0, 8.0, 9.0]}")
+complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp")
+writeLines(mockLinesComplexType, complexTypeJsonPath)
+
test_that("infer types", {
expect_equal(infer_type(1L), "integer")
expect_equal(infer_type(1.0), "double")
@@ -56,10 +64,8 @@ test_that("infer types", {
expect_equal(infer_type(TRUE), "boolean")
expect_equal(infer_type(as.Date("2015-03-11")), "date")
expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp")
- expect_equal(infer_type(c(1L, 2L)),
- list(type = "array", elementType = "integer", containsNull = TRUE))
- expect_equal(infer_type(list(1L, 2L)),
- list(type = "array", elementType = "integer", containsNull = TRUE))
+ expect_equal(infer_type(c(1L, 2L)), "array<integer>")
+ expect_equal(infer_type(list(1L, 2L)), "array<integer>")
testStruct <- infer_type(list(a = 1L, b = "2"))
expect_equal(class(testStruct), "structType")
checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE)
@@ -236,8 +242,7 @@ test_that("create DataFrame with different data types", {
expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE))
})
-# TODO: enable this test after fix serialization for nested object
-#test_that("create DataFrame with nested array and struct", {
+test_that("create DataFrame with nested array and struct", {
# e <- new.env()
# assign("n", 3L, envir = e)
# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L))
@@ -247,7 +252,32 @@ test_that("create DataFrame with different data types", {
# expect_equal(count(df), 1)
# ldf <- collect(df)
# expect_equal(ldf[1,], l[[1]])
-#})
+
+
+ # ArrayType only for now
+ l <- list(as.list(1:10), list("a", "b"))
+ df <- createDataFrame(sqlContext, list(l), c("a", "b"))
+ expect_equal(dtypes(df), list(c("a", "array<int>"), c("b", "array<string>")))
+ expect_equal(count(df), 1)
+ ldf <- collect(df)
+ expect_equal(names(ldf), c("a", "b"))
+ expect_equal(ldf[1, 1][[1]], l[[1]])
+ expect_equal(ldf[1, 2][[1]], l[[2]])
+})
+
+test_that("Collect DataFrame with complex types", {
+ # only ArrayType now
+ # TODO: tests for StructType and MapType after they are supported
+ df <- jsonFile(sqlContext, complexTypeJsonPath)
+
+ ldf <- collect(df)
+ expect_equal(nrow(ldf), 3)
+ expect_equal(ncol(ldf), 3)
+ expect_equal(names(ldf), c("c1", "c2", "c3"))
+ expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9)))
+ expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i")))
+ expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0)))
+})
test_that("jsonFile() on a local file returns a DataFrame", {
df <- jsonFile(sqlContext, jsonPath)