From 45e3be5c138d983f40f619735d60bf7eb78c9bf0 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Thu, 10 Sep 2015 12:21:13 -0700 Subject: [SPARK-10049] [SPARKR] Support collecting data of ArraryType in DataFrame. this PR : 1. Enhance reflection in RBackend. Automatically matching a Java array to Scala Seq when finding methods. Util functions like seq(), listToSeq() in R side can be removed, as they will conflict with the Serde logic that transferrs a Scala seq to R side. 2. Enhance the SerDe to support transferring a Scala seq to R side. Data of ArrayType in DataFrame after collection is observed to be of Scala Seq type. 3. Support ArrayType in createDataFrame(). Author: Sun Rui Closes #8458 from sun-rui/SPARK-10049. --- R/pkg/R/DataFrame.R | 26 +++++++++---------- R/pkg/R/SQLContext.R | 4 +-- R/pkg/R/column.R | 3 +-- R/pkg/R/functions.R | 12 ++++----- R/pkg/R/group.R | 4 +-- R/pkg/R/schema.R | 54 +++++++++++++++++++++++++--------------- R/pkg/R/utils.R | 10 -------- R/pkg/inst/tests/test_sparkSQL.R | 44 ++++++++++++++++++++++++++------ 8 files changed, 95 insertions(+), 62 deletions(-) (limited to 'R') diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 8a00238b41..c3c1893487 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -271,7 +271,7 @@ setMethod("names<-", signature(x = "DataFrame"), function(x, value) { if (!is.null(value)) { - sdf <- callJMethod(x@sdf, "toDF", listToSeq(as.list(value))) + sdf <- callJMethod(x@sdf, "toDF", as.list(value)) dataFrame(sdf) } }) @@ -843,10 +843,10 @@ setMethod("groupBy", function(x, ...) { cols <- list(...) if (length(cols) >= 1 && class(cols[[1]]) == "character") { - sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1])) + sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], cols[-1]) } else { jcol <- lapply(cols, function(c) { c@jc }) - sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol)) + sgd <- callJMethod(x@sdf, "groupBy", jcol) } groupedData(sgd) }) @@ -1079,7 +1079,7 @@ setMethod("subset", signature(x = "DataFrame"), #' } setMethod("select", signature(x = "DataFrame", col = "character"), function(x, col, ...) { - sdf <- callJMethod(x@sdf, "select", col, toSeq(...)) + sdf <- callJMethod(x@sdf, "select", col, list(...)) dataFrame(sdf) }) @@ -1090,7 +1090,7 @@ setMethod("select", signature(x = "DataFrame", col = "Column"), jcols <- lapply(list(col, ...), function(c) { c@jc }) - sdf <- callJMethod(x@sdf, "select", listToSeq(jcols)) + sdf <- callJMethod(x@sdf, "select", jcols) dataFrame(sdf) }) @@ -1106,7 +1106,7 @@ setMethod("select", col(c)@jc } }) - sdf <- callJMethod(x@sdf, "select", listToSeq(cols)) + sdf <- callJMethod(x@sdf, "select", cols) dataFrame(sdf) }) @@ -1133,7 +1133,7 @@ setMethod("selectExpr", signature(x = "DataFrame", expr = "character"), function(x, expr, ...) { exprList <- list(expr, ...) - sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList)) + sdf <- callJMethod(x@sdf, "selectExpr", exprList) dataFrame(sdf) }) @@ -1311,12 +1311,12 @@ setMethod("arrange", signature(x = "DataFrame", col = "characterOrColumn"), function(x, col, ...) { if (class(col) == "character") { - sdf <- callJMethod(x@sdf, "sort", col, toSeq(...)) + sdf <- callJMethod(x@sdf, "sort", col, list(...)) } else if (class(col) == "Column") { jcols <- lapply(list(col, ...), function(c) { c@jc }) - sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols)) + sdf <- callJMethod(x@sdf, "sort", jcols) } dataFrame(sdf) }) @@ -1664,7 +1664,7 @@ setMethod("describe", signature(x = "DataFrame", col = "character"), function(x, col, ...) { colList <- list(col, ...) - sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) + sdf <- callJMethod(x@sdf, "describe", colList) dataFrame(sdf) }) @@ -1674,7 +1674,7 @@ setMethod("describe", signature(x = "DataFrame"), function(x) { colList <- as.list(c(columns(x))) - sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) + sdf <- callJMethod(x@sdf, "describe", colList) dataFrame(sdf) }) @@ -1731,7 +1731,7 @@ setMethod("dropna", naFunctions <- callJMethod(x@sdf, "na") sdf <- callJMethod(naFunctions, "drop", - as.integer(minNonNulls), listToSeq(as.list(cols))) + as.integer(minNonNulls), as.list(cols)) dataFrame(sdf) }) @@ -1815,7 +1815,7 @@ setMethod("fillna", sdf <- if (length(cols) == 0) { callJMethod(naFunctions, "fill", value) } else { - callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols))) + callJMethod(naFunctions, "fill", value, as.list(cols)) } dataFrame(sdf) }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 1bc6445311..4ac057d0f2 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -49,7 +49,7 @@ infer_type <- function(x) { stopifnot(length(x) > 0) names <- names(x) if (is.null(names)) { - list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE) + paste0("array<", infer_type(x[[1]]), ">") } else { # StructType types <- lapply(x, infer_type) @@ -59,7 +59,7 @@ infer_type <- function(x) { do.call(structType, fields) } } else if (length(x) > 1) { - list(type = "array", elementType = type, containsNull = TRUE) + paste0("array<", infer_type(x[[1]]), ">") } else { type } diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 4805096f3f..42e9d12179 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -211,8 +211,7 @@ setMethod("cast", setMethod("%in%", signature(x = "Column"), function(x, table) { - table <- listToSeq(as.list(table)) - jc <- callJMethod(x@jc, "in", table) + jc <- callJMethod(x@jc, "in", as.list(table)) return(column(jc)) }) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index d848730e70..94687edb05 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1331,7 +1331,7 @@ setMethod("countDistinct", x@jc }) jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, - listToSeq(jcol)) + jcol) column(jc) }) @@ -1348,7 +1348,7 @@ setMethod("concat", signature(x = "Column"), function(x, ...) { jcols <- lapply(list(x, ...), function(x) { x@jc }) - jc <- callJStatic("org.apache.spark.sql.functions", "concat", listToSeq(jcols)) + jc <- callJStatic("org.apache.spark.sql.functions", "concat", jcols) column(jc) }) @@ -1366,7 +1366,7 @@ setMethod("greatest", function(x, ...) { stopifnot(length(list(...)) > 0) jcols <- lapply(list(x, ...), function(x) { x@jc }) - jc <- callJStatic("org.apache.spark.sql.functions", "greatest", listToSeq(jcols)) + jc <- callJStatic("org.apache.spark.sql.functions", "greatest", jcols) column(jc) }) @@ -1384,7 +1384,7 @@ setMethod("least", function(x, ...) { stopifnot(length(list(...)) > 0) jcols <- lapply(list(x, ...), function(x) { x@jc }) - jc <- callJStatic("org.apache.spark.sql.functions", "least", listToSeq(jcols)) + jc <- callJStatic("org.apache.spark.sql.functions", "least", jcols) column(jc) }) @@ -1675,7 +1675,7 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @export setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { - jcols <- listToSeq(lapply(list(x, ...), function(x) { x@jc })) + jcols <- lapply(list(x, ...), function(x) { x@jc }) jc <- callJStatic("org.apache.spark.sql.functions", "concat_ws", sep, jcols) column(jc) }) @@ -1723,7 +1723,7 @@ setMethod("expr", signature(x = "character"), #' @export setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { - jcols <- listToSeq(lapply(list(x, ...), function(arg) { arg@jc })) + jcols <- lapply(list(x, ...), function(arg) { arg@jc }) jc <- callJStatic("org.apache.spark.sql.functions", "format_string", format, jcols) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 576ac72f40..4cab1a69f6 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -102,7 +102,7 @@ setMethod("agg", } } jcols <- lapply(cols, function(c) { c@jc }) - sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1])) + sdf <- callJMethod(x@sgd, "agg", jcols[[1]], jcols[-1]) } else { stop("agg can only support Column or character") } @@ -124,7 +124,7 @@ createMethod <- function(name) { setMethod(name, signature(x = "GroupedData"), function(x, ...) { - sdf <- callJMethod(x@sgd, name, toSeq(...)) + sdf <- callJMethod(x@sgd, name, list(...)) dataFrame(sdf) }) } diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 79c744ef29..62d4f73878 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -56,7 +56,7 @@ structType.structField <- function(x, ...) { }) stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createStructType", - listToSeq(sfObjList)) + sfObjList) structType(stObj) } @@ -114,6 +114,35 @@ structField.jobj <- function(x) { obj } +checkType <- function(type) { + primtiveTypes <- c("byte", + "integer", + "float", + "double", + "numeric", + "character", + "string", + "binary", + "raw", + "logical", + "boolean", + "timestamp", + "date") + if (type %in% primtiveTypes) { + return() + } else { + m <- regexec("^array<(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + elemType <- matchedStrings[[1]][2] + checkType(elemType) + return() + } + } + + stop(paste("Unsupported type for Dataframe:", type)) +} + structField.character <- function(x, type, nullable = TRUE) { if (class(x) != "character") { stop("Field name must be a string.") @@ -124,28 +153,13 @@ structField.character <- function(x, type, nullable = TRUE) { if (class(nullable) != "logical") { stop("nullable must be either TRUE or FALSE") } - options <- c("byte", - "integer", - "float", - "double", - "numeric", - "character", - "string", - "binary", - "raw", - "logical", - "boolean", - "timestamp", - "date") - dataType <- if (type %in% options) { - type - } else { - stop(paste("Unsupported type for Dataframe:", type)) - } + + checkType(type) + sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createStructField", x, - dataType, + type, nullable) structField(sfObj) } diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 3babcb5193..69a2bc728f 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -361,16 +361,6 @@ numToInt <- function(num) { as.integer(num) } -# create a Seq in JVM -toSeq <- function(...) { - callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...)) -} - -# create a Seq in JVM from a list -listToSeq <- function(l) { - callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l) -} - # Utility function to recursively traverse the Abstract Syntax Tree (AST) of a # user defined function (UDF), and to examine variables in the UDF to decide # if their values should be included in the new function environment. diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 6d331f9883..1ccfde5917 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -49,6 +49,14 @@ mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLinesNa, jsonPathNa) +# For test complex types in DataFrame +mockLinesComplexType <- + c("{\"c1\":[1, 2, 3], \"c2\":[\"a\", \"b\", \"c\"], \"c3\":[1.0, 2.0, 3.0]}", + "{\"c1\":[4, 5, 6], \"c2\":[\"d\", \"e\", \"f\"], \"c3\":[4.0, 5.0, 6.0]}", + "{\"c1\":[7, 8, 9], \"c2\":[\"g\", \"h\", \"i\"], \"c3\":[7.0, 8.0, 9.0]}") +complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesComplexType, complexTypeJsonPath) + test_that("infer types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") @@ -56,10 +64,8 @@ test_that("infer types", { expect_equal(infer_type(TRUE), "boolean") expect_equal(infer_type(as.Date("2015-03-11")), "date") expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") - expect_equal(infer_type(c(1L, 2L)), - list(type = "array", elementType = "integer", containsNull = TRUE)) - expect_equal(infer_type(list(1L, 2L)), - list(type = "array", elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(c(1L, 2L)), "array") + expect_equal(infer_type(list(1L, 2L)), "array") testStruct <- infer_type(list(a = 1L, b = "2")) expect_equal(class(testStruct), "structType") checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) @@ -236,8 +242,7 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) -# TODO: enable this test after fix serialization for nested object -#test_that("create DataFrame with nested array and struct", { +test_that("create DataFrame with nested array and struct", { # e <- new.env() # assign("n", 3L, envir = e) # l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) @@ -247,7 +252,32 @@ test_that("create DataFrame with different data types", { # expect_equal(count(df), 1) # ldf <- collect(df) # expect_equal(ldf[1,], l[[1]]) -#}) + + + # ArrayType only for now + l <- list(as.list(1:10), list("a", "b")) + df <- createDataFrame(sqlContext, list(l), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"))) + expect_equal(count(df), 1) + ldf <- collect(df) + expect_equal(names(ldf), c("a", "b")) + expect_equal(ldf[1, 1][[1]], l[[1]]) + expect_equal(ldf[1, 2][[1]], l[[2]]) +}) + +test_that("Collect DataFrame with complex types", { + # only ArrayType now + # TODO: tests for StructType and MapType after they are supported + df <- jsonFile(sqlContext, complexTypeJsonPath) + + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 3) + expect_equal(names(ldf), c("c1", "c2", "c3")) + expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) + expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) + expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) +}) test_that("jsonFile() on a local file returns a DataFrame", { df <- jsonFile(sqlContext, jsonPath) -- cgit v1.2.3