aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSun Rui <rui.sun@intel.com>2015-08-25 13:14:10 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-08-25 13:14:10 -0700
commit71a138cd0e0a14e8426f97877e3b52a562bbd02c (patch)
treee0d2f675ec969b7a5c24c46414999d16c8fc759e
parent16a2be1a84c0a274a60c0a584faaf58b55d4942b (diff)
downloadspark-71a138cd0e0a14e8426f97877e3b52a562bbd02c.tar.gz
spark-71a138cd0e0a14e8426f97877e3b52a562bbd02c.tar.bz2
spark-71a138cd0e0a14e8426f97877e3b52a562bbd02c.zip
[SPARK-10048] [SPARKR] Support arbitrary nested Java array in serde.
This PR: 1. supports transferring arbitrary nested array from JVM to R side in SerDe; 2. based on 1, collect() implemenation is improved. Now it can support collecting data of complex types from a DataFrame. Author: Sun Rui <rui.sun@intel.com> Closes #8276 from sun-rui/SPARK-10048.
-rw-r--r--R/pkg/R/DataFrame.R55
-rw-r--r--R/pkg/R/deserialize.R72
-rw-r--r--R/pkg/R/serialize.R10
-rw-r--r--R/pkg/inst/tests/test_Serde.R77
-rw-r--r--R/pkg/inst/worker/worker.R4
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/SerDe.scala86
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala32
8 files changed, 216 insertions, 127 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 10f3c4ea59..ae1d912cf6 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -652,18 +652,49 @@ setMethod("dim",
setMethod("collect",
signature(x = "DataFrame"),
function(x, stringsAsFactors = FALSE) {
- # listCols is a list of raw vectors, one per column
- listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf)
- cols <- lapply(listCols, function(col) {
- objRaw <- rawConnection(col)
- numRows <- readInt(objRaw)
- col <- readCol(objRaw, numRows)
- close(objRaw)
- col
- })
- names(cols) <- columns(x)
- do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors))
- })
+ names <- columns(x)
+ ncol <- length(names)
+ if (ncol <= 0) {
+ # empty data.frame with 0 columns and 0 rows
+ data.frame()
+ } else {
+ # listCols is a list of columns
+ listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf)
+ stopifnot(length(listCols) == ncol)
+
+ # An empty data.frame with 0 columns and number of rows as collected
+ nrow <- length(listCols[[1]])
+ if (nrow <= 0) {
+ df <- data.frame()
+ } else {
+ df <- data.frame(row.names = 1 : nrow)
+ }
+
+ # Append columns one by one
+ for (colIndex in 1 : ncol) {
+ # Note: appending a column of list type into a data.frame so that
+ # data of complex type can be held. But getting a cell from a column
+ # of list type returns a list instead of a vector. So for columns of
+ # non-complex type, append them as vector.
+ col <- listCols[[colIndex]]
+ if (length(col) <= 0) {
+ df[[names[colIndex]]] <- col
+ } else {
+ # TODO: more robust check on column of primitive types
+ vec <- do.call(c, col)
+ if (class(vec) != "list") {
+ df[[names[colIndex]]] <- vec
+ } else {
+ # For columns of complex type, be careful to access them.
+ # Get a column of complex type returns a list.
+ # Get a cell from a column of complex type returns a list instead of a vector.
+ df[[names[colIndex]]] <- col
+ }
+ }
+ }
+ df
+ }
+ })
#' Limit
#'
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index 33bf13ec9e..6cf628e300 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -48,6 +48,7 @@ readTypedObject <- function(con, type) {
"r" = readRaw(con),
"D" = readDate(con),
"t" = readTime(con),
+ "a" = readArray(con),
"l" = readList(con),
"n" = NULL,
"j" = getJobj(readString(con)),
@@ -85,8 +86,7 @@ readTime <- function(con) {
as.POSIXct(t, origin = "1970-01-01")
}
-# We only support lists where all elements are of same type
-readList <- function(con) {
+readArray <- function(con) {
type <- readType(con)
len <- readInt(con)
if (len > 0) {
@@ -100,6 +100,25 @@ readList <- function(con) {
}
}
+# Read a list. Types of each element may be different.
+# Null objects are read as NA.
+readList <- function(con) {
+ len <- readInt(con)
+ if (len > 0) {
+ l <- vector("list", len)
+ for (i in 1:len) {
+ elem <- readObject(con)
+ if (is.null(elem)) {
+ elem <- NA
+ }
+ l[[i]] <- elem
+ }
+ l
+ } else {
+ list()
+ }
+}
+
readRaw <- function(con) {
dataLen <- readInt(con)
readBin(con, raw(), as.integer(dataLen), endian = "big")
@@ -132,18 +151,19 @@ readDeserialize <- function(con) {
}
}
-readDeserializeRows <- function(inputCon) {
- # readDeserializeRows will deserialize a DataOutputStream composed of
- # a list of lists. Since the DOS is one continuous stream and
- # the number of rows varies, we put the readRow function in a while loop
- # that termintates when the next row is empty.
+readMultipleObjects <- function(inputCon) {
+ # readMultipleObjects will read multiple continuous objects from
+ # a DataOutputStream. There is no preceding field telling the count
+ # of the objects, so the number of objects varies, we try to read
+ # all objects in a loop until the end of the stream.
data <- list()
while(TRUE) {
- row <- readRow(inputCon)
- if (length(row) == 0) {
+ # If reaching the end of the stream, type returned should be "".
+ type <- readType(inputCon)
+ if (type == "") {
break
}
- data[[length(data) + 1L]] <- row
+ data[[length(data) + 1L]] <- readTypedObject(inputCon, type)
}
data # this is a list of named lists now
}
@@ -155,35 +175,5 @@ readRowList <- function(obj) {
# deserialize the row.
rawObj <- rawConnection(obj, "r+")
on.exit(close(rawObj))
- readRow(rawObj)
-}
-
-readRow <- function(inputCon) {
- numCols <- readInt(inputCon)
- if (length(numCols) > 0 && numCols > 0) {
- lapply(1:numCols, function(x) {
- obj <- readObject(inputCon)
- if (is.null(obj)) {
- NA
- } else {
- obj
- }
- }) # each row is a list now
- } else {
- list()
- }
-}
-
-# Take a single column as Array[Byte] and deserialize it into an atomic vector
-readCol <- function(inputCon, numRows) {
- if (numRows > 0) {
- # sapply can not work with POSIXlt
- do.call(c, lapply(1:numRows, function(x) {
- value <- readObject(inputCon)
- # Replace NULL with NA so we can coerce to vectors
- if (is.null(value)) NA else value
- }))
- } else {
- vector()
- }
+ readObject(rawObj)
}
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
index 311021e5d8..e3676f57f9 100644
--- a/R/pkg/R/serialize.R
+++ b/R/pkg/R/serialize.R
@@ -110,18 +110,10 @@ writeRowSerialize <- function(outputCon, rows) {
serializeRow <- function(row) {
rawObj <- rawConnection(raw(0), "wb")
on.exit(close(rawObj))
- writeRow(rawObj, row)
+ writeGenericList(rawObj, row)
rawConnectionValue(rawObj)
}
-writeRow <- function(con, row) {
- numCols <- length(row)
- writeInt(con, numCols)
- for (i in 1:numCols) {
- writeObject(con, row[[i]])
- }
-}
-
writeRaw <- function(con, batch) {
writeInt(con, length(batch))
writeBin(batch, con, endian = "big")
diff --git a/R/pkg/inst/tests/test_Serde.R b/R/pkg/inst/tests/test_Serde.R
new file mode 100644
index 0000000000..009db85da2
--- /dev/null
+++ b/R/pkg/inst/tests/test_Serde.R
@@ -0,0 +1,77 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+context("SerDe functionality")
+
+sc <- sparkR.init()
+
+test_that("SerDe of primitive types", {
+ x <- callJStatic("SparkRHandler", "echo", 1L)
+ expect_equal(x, 1L)
+ expect_equal(class(x), "integer")
+
+ x <- callJStatic("SparkRHandler", "echo", 1)
+ expect_equal(x, 1)
+ expect_equal(class(x), "numeric")
+
+ x <- callJStatic("SparkRHandler", "echo", TRUE)
+ expect_true(x)
+ expect_equal(class(x), "logical")
+
+ x <- callJStatic("SparkRHandler", "echo", "abc")
+ expect_equal(x, "abc")
+ expect_equal(class(x), "character")
+})
+
+test_that("SerDe of list of primitive types", {
+ x <- list(1L, 2L, 3L)
+ y <- callJStatic("SparkRHandler", "echo", x)
+ expect_equal(x, y)
+ expect_equal(class(y[[1]]), "integer")
+
+ x <- list(1, 2, 3)
+ y <- callJStatic("SparkRHandler", "echo", x)
+ expect_equal(x, y)
+ expect_equal(class(y[[1]]), "numeric")
+
+ x <- list(TRUE, FALSE)
+ y <- callJStatic("SparkRHandler", "echo", x)
+ expect_equal(x, y)
+ expect_equal(class(y[[1]]), "logical")
+
+ x <- list("a", "b", "c")
+ y <- callJStatic("SparkRHandler", "echo", x)
+ expect_equal(x, y)
+ expect_equal(class(y[[1]]), "character")
+
+ # Empty list
+ x <- list()
+ y <- callJStatic("SparkRHandler", "echo", x)
+ expect_equal(x, y)
+})
+
+test_that("SerDe of list of lists", {
+ x <- list(list(1L, 2L, 3L), list(1, 2, 3),
+ list(TRUE, FALSE), list("a", "b", "c"))
+ y <- callJStatic("SparkRHandler", "echo", x)
+ expect_equal(x, y)
+
+ # List of empty lists
+ x <- list(list(), list())
+ y <- callJStatic("SparkRHandler", "echo", x)
+ expect_equal(x, y)
+})
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index 7e3b5fc403..0c3b0d1f4b 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -94,7 +94,7 @@ if (isEmpty != 0) {
} else if (deserializer == "string") {
data <- as.list(readLines(inputCon))
} else if (deserializer == "row") {
- data <- SparkR:::readDeserializeRows(inputCon)
+ data <- SparkR:::readMultipleObjects(inputCon)
}
# Timing reading input data for execution
inputElap <- elapsedSecs()
@@ -120,7 +120,7 @@ if (isEmpty != 0) {
} else if (deserializer == "string") {
data <- readLines(inputCon)
} else if (deserializer == "row") {
- data <- SparkR:::readDeserializeRows(inputCon)
+ data <- SparkR:::readMultipleObjects(inputCon)
}
# Timing reading input data for execution
inputElap <- elapsedSecs()
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index 6ce02e2ea3..bb82f3285f 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -53,6 +53,13 @@ private[r] class RBackendHandler(server: RBackend)
if (objId == "SparkRHandler") {
methodName match {
+ // This function is for test-purpose only
+ case "echo" =>
+ val args = readArgs(numArgs, dis)
+ assert(numArgs == 1)
+
+ writeInt(dos, 0)
+ writeObject(dos, args(0))
case "stopBackend" =>
writeInt(dos, 0)
writeType(dos, "void")
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index dbbbcf40c1..26ad4f1d46 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -149,6 +149,10 @@ private[spark] object SerDe {
case 'b' => readBooleanArr(dis)
case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x))
case 'r' => readBytesArr(dis)
+ case 'l' => {
+ val len = readInt(dis)
+ (0 until len).map(_ => readList(dis)).toArray
+ }
case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
}
}
@@ -200,6 +204,9 @@ private[spark] object SerDe {
case "date" => dos.writeByte('D')
case "time" => dos.writeByte('t')
case "raw" => dos.writeByte('r')
+ // Array of primitive types
+ case "array" => dos.writeByte('a')
+ // Array of objects
case "list" => dos.writeByte('l')
case "jobj" => dos.writeByte('j')
case _ => throw new IllegalArgumentException(s"Invalid type $typeStr")
@@ -211,26 +218,35 @@ private[spark] object SerDe {
writeType(dos, "void")
} else {
value.getClass.getName match {
+ case "java.lang.Character" =>
+ writeType(dos, "character")
+ writeString(dos, value.asInstanceOf[Character].toString)
case "java.lang.String" =>
writeType(dos, "character")
writeString(dos, value.asInstanceOf[String])
- case "long" | "java.lang.Long" =>
+ case "java.lang.Long" =>
writeType(dos, "double")
writeDouble(dos, value.asInstanceOf[Long].toDouble)
- case "float" | "java.lang.Float" =>
+ case "java.lang.Float" =>
writeType(dos, "double")
writeDouble(dos, value.asInstanceOf[Float].toDouble)
- case "decimal" | "java.math.BigDecimal" =>
+ case "java.math.BigDecimal" =>
writeType(dos, "double")
val javaDecimal = value.asInstanceOf[java.math.BigDecimal]
writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble)
- case "double" | "java.lang.Double" =>
+ case "java.lang.Double" =>
writeType(dos, "double")
writeDouble(dos, value.asInstanceOf[Double])
- case "int" | "java.lang.Integer" =>
+ case "java.lang.Byte" =>
+ writeType(dos, "integer")
+ writeInt(dos, value.asInstanceOf[Byte].toInt)
+ case "java.lang.Short" =>
+ writeType(dos, "integer")
+ writeInt(dos, value.asInstanceOf[Short].toInt)
+ case "java.lang.Integer" =>
writeType(dos, "integer")
writeInt(dos, value.asInstanceOf[Int])
- case "boolean" | "java.lang.Boolean" =>
+ case "java.lang.Boolean" =>
writeType(dos, "logical")
writeBoolean(dos, value.asInstanceOf[Boolean])
case "java.sql.Date" =>
@@ -242,43 +258,48 @@ private[spark] object SerDe {
case "java.sql.Timestamp" =>
writeType(dos, "time")
writeTime(dos, value.asInstanceOf[Timestamp])
+
+ // Handle arrays
+
+ // Array of primitive types
+
+ // Special handling for byte array
case "[B" =>
writeType(dos, "raw")
writeBytes(dos, value.asInstanceOf[Array[Byte]])
- // TODO: Types not handled right now include
- // byte, char, short, float
- // Handle arrays
- case "[Ljava.lang.String;" =>
- writeType(dos, "list")
- writeStringArr(dos, value.asInstanceOf[Array[String]])
+ case "[C" =>
+ writeType(dos, "array")
+ writeStringArr(dos, value.asInstanceOf[Array[Char]].map(_.toString))
+ case "[S" =>
+ writeType(dos, "array")
+ writeIntArr(dos, value.asInstanceOf[Array[Short]].map(_.toInt))
case "[I" =>
- writeType(dos, "list")
+ writeType(dos, "array")
writeIntArr(dos, value.asInstanceOf[Array[Int]])
case "[J" =>
- writeType(dos, "list")
+ writeType(dos, "array")
writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble))
+ case "[F" =>
+ writeType(dos, "array")
+ writeDoubleArr(dos, value.asInstanceOf[Array[Float]].map(_.toDouble))
case "[D" =>
- writeType(dos, "list")
+ writeType(dos, "array")
writeDoubleArr(dos, value.asInstanceOf[Array[Double]])
case "[Z" =>
- writeType(dos, "list")
+ writeType(dos, "array")
writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]])
- case "[[B" =>
+
+ // Array of objects, null objects use "void" type
+ case c if c.startsWith("[") =>
writeType(dos, "list")
- writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]])
- case otherName =>
- // Handle array of objects
- if (otherName.startsWith("[L")) {
- val objArr = value.asInstanceOf[Array[Object]]
- writeType(dos, "list")
- writeType(dos, "jobj")
- dos.writeInt(objArr.length)
- objArr.foreach(o => writeJObj(dos, o))
- } else {
- writeType(dos, "jobj")
- writeJObj(dos, value)
- }
+ val array = value.asInstanceOf[Array[Object]]
+ writeInt(dos, array.length)
+ array.foreach(elem => writeObject(dos, elem))
+
+ case _ =>
+ writeType(dos, "jobj")
+ writeJObj(dos, value)
}
}
}
@@ -350,11 +371,6 @@ private[spark] object SerDe {
value.foreach(v => writeString(out, v))
}
- def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = {
- writeType(out, "raw")
- out.writeInt(value.length)
- value.foreach(v => writeBytes(out, v))
- }
}
private[r] object SerializationFormats {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 92861ab038..7f3defec3d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -98,27 +98,17 @@ private[r] object SQLUtils {
val bos = new ByteArrayOutputStream()
val dos = new DataOutputStream(bos)
- SerDe.writeInt(dos, row.length)
- (0 until row.length).map { idx =>
- val obj: Object = row(idx).asInstanceOf[Object]
- SerDe.writeObject(dos, obj)
- }
+ val cols = (0 until row.length).map(row(_).asInstanceOf[Object]).toArray
+ SerDe.writeObject(dos, cols)
bos.toByteArray()
}
- def dfToCols(df: DataFrame): Array[Array[Byte]] = {
+ def dfToCols(df: DataFrame): Array[Array[Any]] = {
// localDF is Array[Row]
val localDF = df.collect()
val numCols = df.columns.length
- // dfCols is Array[Array[Any]]
- val dfCols = convertRowsToColumns(localDF, numCols)
-
- dfCols.map { col =>
- colToRBytes(col)
- }
- }
- def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = {
+ // result is Array[Array[Any]]
(0 until numCols).map { colIdx =>
localDF.map { row =>
row(colIdx)
@@ -126,20 +116,6 @@ private[r] object SQLUtils {
}.toArray
}
- def colToRBytes(col: Array[Any]): Array[Byte] = {
- val numRows = col.length
- val bos = new ByteArrayOutputStream()
- val dos = new DataOutputStream(bos)
-
- SerDe.writeInt(dos, numRows)
-
- col.map { item =>
- val obj: Object = item.asInstanceOf[Object]
- SerDe.writeObject(dos, obj)
- }
- bos.toByteArray()
- }
-
def saveMode(mode: String): SaveMode = {
mode match {
case "append" => SaveMode.Append