aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorSun Rui <rui.sun@intel.com>2015-08-25 13:14:10 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-08-25 13:14:10 -0700
commit71a138cd0e0a14e8426f97877e3b52a562bbd02c (patch)
treee0d2f675ec969b7a5c24c46414999d16c8fc759e /R
parent16a2be1a84c0a274a60c0a584faaf58b55d4942b (diff)
downloadspark-71a138cd0e0a14e8426f97877e3b52a562bbd02c.tar.gz
spark-71a138cd0e0a14e8426f97877e3b52a562bbd02c.tar.bz2
spark-71a138cd0e0a14e8426f97877e3b52a562bbd02c.zip
[SPARK-10048] [SPARKR] Support arbitrary nested Java array in serde.
This PR: 1. supports transferring arbitrary nested array from JVM to R side in SerDe; 2. based on 1, collect() implemenation is improved. Now it can support collecting data of complex types from a DataFrame. Author: Sun Rui <rui.sun@intel.com> Closes #8276 from sun-rui/SPARK-10048.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/R/DataFrame.R55
-rw-r--r--R/pkg/R/deserialize.R72
-rw-r--r--R/pkg/R/serialize.R10
-rw-r--r--R/pkg/inst/tests/test_Serde.R77
-rw-r--r--R/pkg/inst/worker/worker.R4
5 files changed, 154 insertions, 64 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 10f3c4ea59..ae1d912cf6 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -652,18 +652,49 @@ setMethod("dim",
setMethod("collect",
signature(x = "DataFrame"),
function(x, stringsAsFactors = FALSE) {
- # listCols is a list of raw vectors, one per column
- listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf)
- cols <- lapply(listCols, function(col) {
- objRaw <- rawConnection(col)
- numRows <- readInt(objRaw)
- col <- readCol(objRaw, numRows)
- close(objRaw)
- col
- })
- names(cols) <- columns(x)
- do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors))
- })
+ names <- columns(x)
+ ncol <- length(names)
+ if (ncol <= 0) {
+ # empty data.frame with 0 columns and 0 rows
+ data.frame()
+ } else {
+ # listCols is a list of columns
+ listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf)
+ stopifnot(length(listCols) == ncol)
+
+ # An empty data.frame with 0 columns and number of rows as collected
+ nrow <- length(listCols[[1]])
+ if (nrow <= 0) {
+ df <- data.frame()
+ } else {
+ df <- data.frame(row.names = 1 : nrow)
+ }
+
+ # Append columns one by one
+ for (colIndex in 1 : ncol) {
+ # Note: appending a column of list type into a data.frame so that
+ # data of complex type can be held. But getting a cell from a column
+ # of list type returns a list instead of a vector. So for columns of
+ # non-complex type, append them as vector.
+ col <- listCols[[colIndex]]
+ if (length(col) <= 0) {
+ df[[names[colIndex]]] <- col
+ } else {
+ # TODO: more robust check on column of primitive types
+ vec <- do.call(c, col)
+ if (class(vec) != "list") {
+ df[[names[colIndex]]] <- vec
+ } else {
+ # For columns of complex type, be careful to access them.
+ # Get a column of complex type returns a list.
+ # Get a cell from a column of complex type returns a list instead of a vector.
+ df[[names[colIndex]]] <- col
+ }
+ }
+ }
+ df
+ }
+ })
#' Limit
#'
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index 33bf13ec9e..6cf628e300 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -48,6 +48,7 @@ readTypedObject <- function(con, type) {
"r" = readRaw(con),
"D" = readDate(con),
"t" = readTime(con),
+ "a" = readArray(con),
"l" = readList(con),
"n" = NULL,
"j" = getJobj(readString(con)),
@@ -85,8 +86,7 @@ readTime <- function(con) {
as.POSIXct(t, origin = "1970-01-01")
}
-# We only support lists where all elements are of same type
-readList <- function(con) {
+readArray <- function(con) {
type <- readType(con)
len <- readInt(con)
if (len > 0) {
@@ -100,6 +100,25 @@ readList <- function(con) {
}
}
+# Read a list. Types of each element may be different.
+# Null objects are read as NA.
+readList <- function(con) {
+ len <- readInt(con)
+ if (len > 0) {
+ l <- vector("list", len)
+ for (i in 1:len) {
+ elem <- readObject(con)
+ if (is.null(elem)) {
+ elem <- NA
+ }
+ l[[i]] <- elem
+ }
+ l
+ } else {
+ list()
+ }
+}
+
readRaw <- function(con) {
dataLen <- readInt(con)
readBin(con, raw(), as.integer(dataLen), endian = "big")
@@ -132,18 +151,19 @@ readDeserialize <- function(con) {
}
}
-readDeserializeRows <- function(inputCon) {
- # readDeserializeRows will deserialize a DataOutputStream composed of
- # a list of lists. Since the DOS is one continuous stream and
- # the number of rows varies, we put the readRow function in a while loop
- # that termintates when the next row is empty.
+readMultipleObjects <- function(inputCon) {
+ # readMultipleObjects will read multiple continuous objects from
+ # a DataOutputStream. There is no preceding field telling the count
+ # of the objects, so the number of objects varies, we try to read
+ # all objects in a loop until the end of the stream.
data <- list()
while(TRUE) {
- row <- readRow(inputCon)
- if (length(row) == 0) {
+ # If reaching the end of the stream, type returned should be "".
+ type <- readType(inputCon)
+ if (type == "") {
break
}
- data[[length(data) + 1L]] <- row
+ data[[length(data) + 1L]] <- readTypedObject(inputCon, type)
}
data # this is a list of named lists now
}
@@ -155,35 +175,5 @@ readRowList <- function(obj) {
# deserialize the row.
rawObj <- rawConnection(obj, "r+")
on.exit(close(rawObj))
- readRow(rawObj)
-}
-
-readRow <- function(inputCon) {
- numCols <- readInt(inputCon)
- if (length(numCols) > 0 && numCols > 0) {
- lapply(1:numCols, function(x) {
- obj <- readObject(inputCon)
- if (is.null(obj)) {
- NA
- } else {
- obj
- }
- }) # each row is a list now
- } else {
- list()
- }
-}
-
-# Take a single column as Array[Byte] and deserialize it into an atomic vector
-readCol <- function(inputCon, numRows) {
- if (numRows > 0) {
- # sapply can not work with POSIXlt
- do.call(c, lapply(1:numRows, function(x) {
- value <- readObject(inputCon)
- # Replace NULL with NA so we can coerce to vectors
- if (is.null(value)) NA else value
- }))
- } else {
- vector()
- }
+ readObject(rawObj)
}
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
index 311021e5d8..e3676f57f9 100644
--- a/R/pkg/R/serialize.R
+++ b/R/pkg/R/serialize.R
@@ -110,18 +110,10 @@ writeRowSerialize <- function(outputCon, rows) {
serializeRow <- function(row) {
rawObj <- rawConnection(raw(0), "wb")
on.exit(close(rawObj))
- writeRow(rawObj, row)
+ writeGenericList(rawObj, row)
rawConnectionValue(rawObj)
}
-writeRow <- function(con, row) {
- numCols <- length(row)
- writeInt(con, numCols)
- for (i in 1:numCols) {
- writeObject(con, row[[i]])
- }
-}
-
writeRaw <- function(con, batch) {
writeInt(con, length(batch))
writeBin(batch, con, endian = "big")
diff --git a/R/pkg/inst/tests/test_Serde.R b/R/pkg/inst/tests/test_Serde.R
new file mode 100644
index 0000000000..009db85da2
--- /dev/null
+++ b/R/pkg/inst/tests/test_Serde.R
@@ -0,0 +1,77 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+context("SerDe functionality")
+
+sc <- sparkR.init()
+
+test_that("SerDe of primitive types", {
+ x <- callJStatic("SparkRHandler", "echo", 1L)
+ expect_equal(x, 1L)
+ expect_equal(class(x), "integer")
+
+ x <- callJStatic("SparkRHandler", "echo", 1)
+ expect_equal(x, 1)
+ expect_equal(class(x), "numeric")
+
+ x <- callJStatic("SparkRHandler", "echo", TRUE)
+ expect_true(x)
+ expect_equal(class(x), "logical")
+
+ x <- callJStatic("SparkRHandler", "echo", "abc")
+ expect_equal(x, "abc")
+ expect_equal(class(x), "character")
+})
+
+test_that("SerDe of list of primitive types", {
+ x <- list(1L, 2L, 3L)
+ y <- callJStatic("SparkRHandler", "echo", x)
+ expect_equal(x, y)
+ expect_equal(class(y[[1]]), "integer")
+
+ x <- list(1, 2, 3)
+ y <- callJStatic("SparkRHandler", "echo", x)
+ expect_equal(x, y)
+ expect_equal(class(y[[1]]), "numeric")
+
+ x <- list(TRUE, FALSE)
+ y <- callJStatic("SparkRHandler", "echo", x)
+ expect_equal(x, y)
+ expect_equal(class(y[[1]]), "logical")
+
+ x <- list("a", "b", "c")
+ y <- callJStatic("SparkRHandler", "echo", x)
+ expect_equal(x, y)
+ expect_equal(class(y[[1]]), "character")
+
+ # Empty list
+ x <- list()
+ y <- callJStatic("SparkRHandler", "echo", x)
+ expect_equal(x, y)
+})
+
+test_that("SerDe of list of lists", {
+ x <- list(list(1L, 2L, 3L), list(1, 2, 3),
+ list(TRUE, FALSE), list("a", "b", "c"))
+ y <- callJStatic("SparkRHandler", "echo", x)
+ expect_equal(x, y)
+
+ # List of empty lists
+ x <- list(list(), list())
+ y <- callJStatic("SparkRHandler", "echo", x)
+ expect_equal(x, y)
+})
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index 7e3b5fc403..0c3b0d1f4b 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -94,7 +94,7 @@ if (isEmpty != 0) {
} else if (deserializer == "string") {
data <- as.list(readLines(inputCon))
} else if (deserializer == "row") {
- data <- SparkR:::readDeserializeRows(inputCon)
+ data <- SparkR:::readMultipleObjects(inputCon)
}
# Timing reading input data for execution
inputElap <- elapsedSecs()
@@ -120,7 +120,7 @@ if (isEmpty != 0) {
} else if (deserializer == "string") {
data <- readLines(inputCon)
} else if (deserializer == "row") {
- data <- SparkR:::readDeserializeRows(inputCon)
+ data <- SparkR:::readMultipleObjects(inputCon)
}
# Timing reading input data for execution
inputElap <- elapsedSecs()