aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSun Rui <rui.sun@intel.com>2015-09-10 12:21:13 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-09-10 12:21:13 -0700
commit45e3be5c138d983f40f619735d60bf7eb78c9bf0 (patch)
tree30b7b90f53eadee901a56e0e2e84222e21cf6c44
parentd88abb7e212fb55f9b0398a0f76a753c86b85cf1 (diff)
downloadspark-45e3be5c138d983f40f619735d60bf7eb78c9bf0.tar.gz
spark-45e3be5c138d983f40f619735d60bf7eb78c9bf0.tar.bz2
spark-45e3be5c138d983f40f619735d60bf7eb78c9bf0.zip
[SPARK-10049] [SPARKR] Support collecting data of ArraryType in DataFrame.
this PR : 1. Enhance reflection in RBackend. Automatically matching a Java array to Scala Seq when finding methods. Util functions like seq(), listToSeq() in R side can be removed, as they will conflict with the Serde logic that transferrs a Scala seq to R side. 2. Enhance the SerDe to support transferring a Scala seq to R side. Data of ArrayType in DataFrame after collection is observed to be of Scala Seq type. 3. Support ArrayType in createDataFrame(). Author: Sun Rui <rui.sun@intel.com> Closes #8458 from sun-rui/SPARK-10049.
-rw-r--r--R/pkg/R/DataFrame.R26
-rw-r--r--R/pkg/R/SQLContext.R4
-rw-r--r--R/pkg/R/column.R3
-rw-r--r--R/pkg/R/functions.R12
-rw-r--r--R/pkg/R/group.R4
-rw-r--r--R/pkg/R/schema.R54
-rw-r--r--R/pkg/R/utils.R10
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R44
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala121
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/SerDe.scala109
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala14
11 files changed, 250 insertions, 151 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 8a00238b41..c3c1893487 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -271,7 +271,7 @@ setMethod("names<-",
signature(x = "DataFrame"),
function(x, value) {
if (!is.null(value)) {
- sdf <- callJMethod(x@sdf, "toDF", listToSeq(as.list(value)))
+ sdf <- callJMethod(x@sdf, "toDF", as.list(value))
dataFrame(sdf)
}
})
@@ -843,10 +843,10 @@ setMethod("groupBy",
function(x, ...) {
cols <- list(...)
if (length(cols) >= 1 && class(cols[[1]]) == "character") {
- sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1]))
+ sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], cols[-1])
} else {
jcol <- lapply(cols, function(c) { c@jc })
- sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol))
+ sgd <- callJMethod(x@sdf, "groupBy", jcol)
}
groupedData(sgd)
})
@@ -1079,7 +1079,7 @@ setMethod("subset", signature(x = "DataFrame"),
#' }
setMethod("select", signature(x = "DataFrame", col = "character"),
function(x, col, ...) {
- sdf <- callJMethod(x@sdf, "select", col, toSeq(...))
+ sdf <- callJMethod(x@sdf, "select", col, list(...))
dataFrame(sdf)
})
@@ -1090,7 +1090,7 @@ setMethod("select", signature(x = "DataFrame", col = "Column"),
jcols <- lapply(list(col, ...), function(c) {
c@jc
})
- sdf <- callJMethod(x@sdf, "select", listToSeq(jcols))
+ sdf <- callJMethod(x@sdf, "select", jcols)
dataFrame(sdf)
})
@@ -1106,7 +1106,7 @@ setMethod("select",
col(c)@jc
}
})
- sdf <- callJMethod(x@sdf, "select", listToSeq(cols))
+ sdf <- callJMethod(x@sdf, "select", cols)
dataFrame(sdf)
})
@@ -1133,7 +1133,7 @@ setMethod("selectExpr",
signature(x = "DataFrame", expr = "character"),
function(x, expr, ...) {
exprList <- list(expr, ...)
- sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList))
+ sdf <- callJMethod(x@sdf, "selectExpr", exprList)
dataFrame(sdf)
})
@@ -1311,12 +1311,12 @@ setMethod("arrange",
signature(x = "DataFrame", col = "characterOrColumn"),
function(x, col, ...) {
if (class(col) == "character") {
- sdf <- callJMethod(x@sdf, "sort", col, toSeq(...))
+ sdf <- callJMethod(x@sdf, "sort", col, list(...))
} else if (class(col) == "Column") {
jcols <- lapply(list(col, ...), function(c) {
c@jc
})
- sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols))
+ sdf <- callJMethod(x@sdf, "sort", jcols)
}
dataFrame(sdf)
})
@@ -1664,7 +1664,7 @@ setMethod("describe",
signature(x = "DataFrame", col = "character"),
function(x, col, ...) {
colList <- list(col, ...)
- sdf <- callJMethod(x@sdf, "describe", listToSeq(colList))
+ sdf <- callJMethod(x@sdf, "describe", colList)
dataFrame(sdf)
})
@@ -1674,7 +1674,7 @@ setMethod("describe",
signature(x = "DataFrame"),
function(x) {
colList <- as.list(c(columns(x)))
- sdf <- callJMethod(x@sdf, "describe", listToSeq(colList))
+ sdf <- callJMethod(x@sdf, "describe", colList)
dataFrame(sdf)
})
@@ -1731,7 +1731,7 @@ setMethod("dropna",
naFunctions <- callJMethod(x@sdf, "na")
sdf <- callJMethod(naFunctions, "drop",
- as.integer(minNonNulls), listToSeq(as.list(cols)))
+ as.integer(minNonNulls), as.list(cols))
dataFrame(sdf)
})
@@ -1815,7 +1815,7 @@ setMethod("fillna",
sdf <- if (length(cols) == 0) {
callJMethod(naFunctions, "fill", value)
} else {
- callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols)))
+ callJMethod(naFunctions, "fill", value, as.list(cols))
}
dataFrame(sdf)
})
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 1bc6445311..4ac057d0f2 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -49,7 +49,7 @@ infer_type <- function(x) {
stopifnot(length(x) > 0)
names <- names(x)
if (is.null(names)) {
- list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE)
+ paste0("array<", infer_type(x[[1]]), ">")
} else {
# StructType
types <- lapply(x, infer_type)
@@ -59,7 +59,7 @@ infer_type <- function(x) {
do.call(structType, fields)
}
} else if (length(x) > 1) {
- list(type = "array", elementType = type, containsNull = TRUE)
+ paste0("array<", infer_type(x[[1]]), ">")
} else {
type
}
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index 4805096f3f..42e9d12179 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -211,8 +211,7 @@ setMethod("cast",
setMethod("%in%",
signature(x = "Column"),
function(x, table) {
- table <- listToSeq(as.list(table))
- jc <- callJMethod(x@jc, "in", table)
+ jc <- callJMethod(x@jc, "in", as.list(table))
return(column(jc))
})
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index d848730e70..94687edb05 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -1331,7 +1331,7 @@ setMethod("countDistinct",
x@jc
})
jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc,
- listToSeq(jcol))
+ jcol)
column(jc)
})
@@ -1348,7 +1348,7 @@ setMethod("concat",
signature(x = "Column"),
function(x, ...) {
jcols <- lapply(list(x, ...), function(x) { x@jc })
- jc <- callJStatic("org.apache.spark.sql.functions", "concat", listToSeq(jcols))
+ jc <- callJStatic("org.apache.spark.sql.functions", "concat", jcols)
column(jc)
})
@@ -1366,7 +1366,7 @@ setMethod("greatest",
function(x, ...) {
stopifnot(length(list(...)) > 0)
jcols <- lapply(list(x, ...), function(x) { x@jc })
- jc <- callJStatic("org.apache.spark.sql.functions", "greatest", listToSeq(jcols))
+ jc <- callJStatic("org.apache.spark.sql.functions", "greatest", jcols)
column(jc)
})
@@ -1384,7 +1384,7 @@ setMethod("least",
function(x, ...) {
stopifnot(length(list(...)) > 0)
jcols <- lapply(list(x, ...), function(x) { x@jc })
- jc <- callJStatic("org.apache.spark.sql.functions", "least", listToSeq(jcols))
+ jc <- callJStatic("org.apache.spark.sql.functions", "least", jcols)
column(jc)
})
@@ -1675,7 +1675,7 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"),
#' @export
setMethod("concat_ws", signature(sep = "character", x = "Column"),
function(sep, x, ...) {
- jcols <- listToSeq(lapply(list(x, ...), function(x) { x@jc }))
+ jcols <- lapply(list(x, ...), function(x) { x@jc })
jc <- callJStatic("org.apache.spark.sql.functions", "concat_ws", sep, jcols)
column(jc)
})
@@ -1723,7 +1723,7 @@ setMethod("expr", signature(x = "character"),
#' @export
setMethod("format_string", signature(format = "character", x = "Column"),
function(format, x, ...) {
- jcols <- listToSeq(lapply(list(x, ...), function(arg) { arg@jc }))
+ jcols <- lapply(list(x, ...), function(arg) { arg@jc })
jc <- callJStatic("org.apache.spark.sql.functions",
"format_string",
format, jcols)
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 576ac72f40..4cab1a69f6 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -102,7 +102,7 @@ setMethod("agg",
}
}
jcols <- lapply(cols, function(c) { c@jc })
- sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1]))
+ sdf <- callJMethod(x@sgd, "agg", jcols[[1]], jcols[-1])
} else {
stop("agg can only support Column or character")
}
@@ -124,7 +124,7 @@ createMethod <- function(name) {
setMethod(name,
signature(x = "GroupedData"),
function(x, ...) {
- sdf <- callJMethod(x@sgd, name, toSeq(...))
+ sdf <- callJMethod(x@sgd, name, list(...))
dataFrame(sdf)
})
}
diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R
index 79c744ef29..62d4f73878 100644
--- a/R/pkg/R/schema.R
+++ b/R/pkg/R/schema.R
@@ -56,7 +56,7 @@ structType.structField <- function(x, ...) {
})
stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
"createStructType",
- listToSeq(sfObjList))
+ sfObjList)
structType(stObj)
}
@@ -114,6 +114,35 @@ structField.jobj <- function(x) {
obj
}
+checkType <- function(type) {
+ primtiveTypes <- c("byte",
+ "integer",
+ "float",
+ "double",
+ "numeric",
+ "character",
+ "string",
+ "binary",
+ "raw",
+ "logical",
+ "boolean",
+ "timestamp",
+ "date")
+ if (type %in% primtiveTypes) {
+ return()
+ } else {
+ m <- regexec("^array<(.*)>$", type)
+ matchedStrings <- regmatches(type, m)
+ if (length(matchedStrings[[1]]) >= 2) {
+ elemType <- matchedStrings[[1]][2]
+ checkType(elemType)
+ return()
+ }
+ }
+
+ stop(paste("Unsupported type for Dataframe:", type))
+}
+
structField.character <- function(x, type, nullable = TRUE) {
if (class(x) != "character") {
stop("Field name must be a string.")
@@ -124,28 +153,13 @@ structField.character <- function(x, type, nullable = TRUE) {
if (class(nullable) != "logical") {
stop("nullable must be either TRUE or FALSE")
}
- options <- c("byte",
- "integer",
- "float",
- "double",
- "numeric",
- "character",
- "string",
- "binary",
- "raw",
- "logical",
- "boolean",
- "timestamp",
- "date")
- dataType <- if (type %in% options) {
- type
- } else {
- stop(paste("Unsupported type for Dataframe:", type))
- }
+
+ checkType(type)
+
sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
"createStructField",
x,
- dataType,
+ type,
nullable)
structField(sfObj)
}
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 3babcb5193..69a2bc728f 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -361,16 +361,6 @@ numToInt <- function(num) {
as.integer(num)
}
-# create a Seq in JVM
-toSeq <- function(...) {
- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...))
-}
-
-# create a Seq in JVM from a list
-listToSeq <- function(l) {
- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l)
-}
-
# Utility function to recursively traverse the Abstract Syntax Tree (AST) of a
# user defined function (UDF), and to examine variables in the UDF to decide
# if their values should be included in the new function environment.
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 6d331f9883..1ccfde5917 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -49,6 +49,14 @@ mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}",
jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp")
writeLines(mockLinesNa, jsonPathNa)
+# For test complex types in DataFrame
+mockLinesComplexType <-
+ c("{\"c1\":[1, 2, 3], \"c2\":[\"a\", \"b\", \"c\"], \"c3\":[1.0, 2.0, 3.0]}",
+ "{\"c1\":[4, 5, 6], \"c2\":[\"d\", \"e\", \"f\"], \"c3\":[4.0, 5.0, 6.0]}",
+ "{\"c1\":[7, 8, 9], \"c2\":[\"g\", \"h\", \"i\"], \"c3\":[7.0, 8.0, 9.0]}")
+complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp")
+writeLines(mockLinesComplexType, complexTypeJsonPath)
+
test_that("infer types", {
expect_equal(infer_type(1L), "integer")
expect_equal(infer_type(1.0), "double")
@@ -56,10 +64,8 @@ test_that("infer types", {
expect_equal(infer_type(TRUE), "boolean")
expect_equal(infer_type(as.Date("2015-03-11")), "date")
expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp")
- expect_equal(infer_type(c(1L, 2L)),
- list(type = "array", elementType = "integer", containsNull = TRUE))
- expect_equal(infer_type(list(1L, 2L)),
- list(type = "array", elementType = "integer", containsNull = TRUE))
+ expect_equal(infer_type(c(1L, 2L)), "array<integer>")
+ expect_equal(infer_type(list(1L, 2L)), "array<integer>")
testStruct <- infer_type(list(a = 1L, b = "2"))
expect_equal(class(testStruct), "structType")
checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE)
@@ -236,8 +242,7 @@ test_that("create DataFrame with different data types", {
expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE))
})
-# TODO: enable this test after fix serialization for nested object
-#test_that("create DataFrame with nested array and struct", {
+test_that("create DataFrame with nested array and struct", {
# e <- new.env()
# assign("n", 3L, envir = e)
# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L))
@@ -247,7 +252,32 @@ test_that("create DataFrame with different data types", {
# expect_equal(count(df), 1)
# ldf <- collect(df)
# expect_equal(ldf[1,], l[[1]])
-#})
+
+
+ # ArrayType only for now
+ l <- list(as.list(1:10), list("a", "b"))
+ df <- createDataFrame(sqlContext, list(l), c("a", "b"))
+ expect_equal(dtypes(df), list(c("a", "array<int>"), c("b", "array<string>")))
+ expect_equal(count(df), 1)
+ ldf <- collect(df)
+ expect_equal(names(ldf), c("a", "b"))
+ expect_equal(ldf[1, 1][[1]], l[[1]])
+ expect_equal(ldf[1, 2][[1]], l[[2]])
+})
+
+test_that("Collect DataFrame with complex types", {
+ # only ArrayType now
+ # TODO: tests for StructType and MapType after they are supported
+ df <- jsonFile(sqlContext, complexTypeJsonPath)
+
+ ldf <- collect(df)
+ expect_equal(nrow(ldf), 3)
+ expect_equal(ncol(ldf), 3)
+ expect_equal(names(ldf), c("c1", "c2", "c3"))
+ expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9)))
+ expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i")))
+ expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0)))
+})
test_that("jsonFile() on a local file returns a DataFrame", {
df <- jsonFile(sqlContext, jsonPath)
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index bb82f3285f..2a792d8199 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -125,10 +125,11 @@ private[r] class RBackendHandler(server: RBackend)
val methods = cls.getMethods
val selectedMethods = methods.filter(m => m.getName == methodName)
if (selectedMethods.length > 0) {
- val methods = selectedMethods.filter { x =>
- matchMethod(numArgs, args, x.getParameterTypes)
- }
- if (methods.isEmpty) {
+ val index = findMatchedSignature(
+ selectedMethods.map(_.getParameterTypes),
+ args)
+
+ if (index.isEmpty) {
logWarning(s"cannot find matching method ${cls}.$methodName. "
+ s"Candidates are:")
selectedMethods.foreach { method =>
@@ -136,18 +137,29 @@ private[r] class RBackendHandler(server: RBackend)
}
throw new Exception(s"No matched method found for $cls.$methodName")
}
- val ret = methods.head.invoke(obj, args : _*)
+
+ val ret = selectedMethods(index.get).invoke(obj, args : _*)
// Write status bit
writeInt(dos, 0)
writeObject(dos, ret.asInstanceOf[AnyRef])
} else if (methodName == "<init>") {
// methodName should be "<init>" for constructor
- val ctor = cls.getConstructors.filter { x =>
- matchMethod(numArgs, args, x.getParameterTypes)
- }.head
+ val ctors = cls.getConstructors
+ val index = findMatchedSignature(
+ ctors.map(_.getParameterTypes),
+ args)
- val obj = ctor.newInstance(args : _*)
+ if (index.isEmpty) {
+ logWarning(s"cannot find matching constructor for ${cls}. "
+ + s"Candidates are:")
+ ctors.foreach { ctor =>
+ logWarning(s"$cls(${ctor.getParameterTypes.mkString(",")})")
+ }
+ throw new Exception(s"No matched constructor found for $cls")
+ }
+
+ val obj = ctors(index.get).newInstance(args : _*)
writeInt(dos, 0)
writeObject(dos, obj.asInstanceOf[AnyRef])
@@ -166,40 +178,79 @@ private[r] class RBackendHandler(server: RBackend)
// Read a number of arguments from the data input stream
def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = {
- (0 until numArgs).map { arg =>
+ (0 until numArgs).map { _ =>
readObject(dis)
}.toArray
}
- // Checks if the arguments passed in args matches the parameter types.
- // NOTE: Currently we do exact match. We may add type conversions later.
- def matchMethod(
- numArgs: Int,
- args: Array[java.lang.Object],
- parameterTypes: Array[Class[_]]): Boolean = {
- if (parameterTypes.length != numArgs) {
- return false
- }
+ // Find a matching method signature in an array of signatures of constructors
+ // or methods of the same name according to the passed arguments. Arguments
+ // may be converted in order to match a signature.
+ //
+ // Note that in Java reflection, constructors and normal methods are of different
+ // classes, and share no parent class that provides methods for reflection uses.
+ // There is no unified way to handle them in this function. So an array of signatures
+ // is passed in instead of an array of candidate constructors or methods.
+ //
+ // Returns an Option[Int] which is the index of the matched signature in the array.
+ def findMatchedSignature(
+ parameterTypesOfMethods: Array[Array[Class[_]]],
+ args: Array[Object]): Option[Int] = {
+ val numArgs = args.length
+
+ for (index <- 0 until parameterTypesOfMethods.length) {
+ val parameterTypes = parameterTypesOfMethods(index)
+
+ if (parameterTypes.length == numArgs) {
+ var argMatched = true
+ var i = 0
+ while (i < numArgs && argMatched) {
+ val parameterType = parameterTypes(i)
+
+ if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) {
+ // The case that the parameter type is a Scala Seq and the argument
+ // is a Java array is considered matching. The array will be converted
+ // to a Seq later if this method is matched.
+ } else {
+ var parameterWrapperType = parameterType
+
+ // Convert native parameters to Object types as args is Array[Object] here
+ if (parameterType.isPrimitive) {
+ parameterWrapperType = parameterType match {
+ case java.lang.Integer.TYPE => classOf[java.lang.Integer]
+ case java.lang.Long.TYPE => classOf[java.lang.Integer]
+ case java.lang.Double.TYPE => classOf[java.lang.Double]
+ case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
+ case _ => parameterType
+ }
+ }
+ if (!parameterWrapperType.isInstance(args(i))) {
+ argMatched = false
+ }
+ }
- for (i <- 0 to numArgs - 1) {
- val parameterType = parameterTypes(i)
- var parameterWrapperType = parameterType
-
- // Convert native parameters to Object types as args is Array[Object] here
- if (parameterType.isPrimitive) {
- parameterWrapperType = parameterType match {
- case java.lang.Integer.TYPE => classOf[java.lang.Integer]
- case java.lang.Long.TYPE => classOf[java.lang.Integer]
- case java.lang.Double.TYPE => classOf[java.lang.Double]
- case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
- case _ => parameterType
+ i = i + 1
+ }
+
+ if (argMatched) {
+ // For now, we return the first matching method.
+ // TODO: find best method in matching methods.
+
+ // Convert args if needed
+ val parameterTypes = parameterTypesOfMethods(index)
+
+ (0 until numArgs).map { i =>
+ if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) {
+ // Convert a Java array to scala Seq
+ args(i) = args(i).asInstanceOf[Array[_]].toSeq
+ }
+ }
+
+ return Some(index)
}
- }
- if (!parameterWrapperType.isInstance(args(i))) {
- return false
}
}
- true
+ None
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 190e193427..3c92bb7a1c 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -21,6 +21,7 @@ import java.io.{DataInputStream, DataOutputStream}
import java.sql.{Timestamp, Date, Time}
import scala.collection.JavaConverters._
+import scala.collection.mutable.WrappedArray
/**
* Utility functions to serialize, deserialize objects to / from R
@@ -213,89 +214,97 @@ private[spark] object SerDe {
}
}
- def writeObject(dos: DataOutputStream, value: Object): Unit = {
- if (value == null) {
+ def writeObject(dos: DataOutputStream, obj: Object): Unit = {
+ if (obj == null) {
writeType(dos, "void")
} else {
- value.getClass.getName match {
- case "java.lang.Character" =>
+ // Convert ArrayType collected from DataFrame to Java array
+ // Collected data of ArrayType from a DataFrame is observed to be of
+ // type "scala.collection.mutable.WrappedArray"
+ val value =
+ if (obj.isInstanceOf[WrappedArray[_]]) {
+ obj.asInstanceOf[WrappedArray[_]].toArray
+ } else {
+ obj
+ }
+
+ value match {
+ case v: java.lang.Character =>
writeType(dos, "character")
- writeString(dos, value.asInstanceOf[Character].toString)
- case "java.lang.String" =>
+ writeString(dos, v.toString)
+ case v: java.lang.String =>
writeType(dos, "character")
- writeString(dos, value.asInstanceOf[String])
- case "java.lang.Long" =>
+ writeString(dos, v)
+ case v: java.lang.Long =>
writeType(dos, "double")
- writeDouble(dos, value.asInstanceOf[Long].toDouble)
- case "java.lang.Float" =>
+ writeDouble(dos, v.toDouble)
+ case v: java.lang.Float =>
writeType(dos, "double")
- writeDouble(dos, value.asInstanceOf[Float].toDouble)
- case "java.math.BigDecimal" =>
+ writeDouble(dos, v.toDouble)
+ case v: java.math.BigDecimal =>
writeType(dos, "double")
- val javaDecimal = value.asInstanceOf[java.math.BigDecimal]
- writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble)
- case "java.lang.Double" =>
+ writeDouble(dos, scala.math.BigDecimal(v).toDouble)
+ case v: java.lang.Double =>
writeType(dos, "double")
- writeDouble(dos, value.asInstanceOf[Double])
- case "java.lang.Byte" =>
+ writeDouble(dos, v)
+ case v: java.lang.Byte =>
writeType(dos, "integer")
- writeInt(dos, value.asInstanceOf[Byte].toInt)
- case "java.lang.Short" =>
+ writeInt(dos, v.toInt)
+ case v: java.lang.Short =>
writeType(dos, "integer")
- writeInt(dos, value.asInstanceOf[Short].toInt)
- case "java.lang.Integer" =>
+ writeInt(dos, v.toInt)
+ case v: java.lang.Integer =>
writeType(dos, "integer")
- writeInt(dos, value.asInstanceOf[Int])
- case "java.lang.Boolean" =>
+ writeInt(dos, v)
+ case v: java.lang.Boolean =>
writeType(dos, "logical")
- writeBoolean(dos, value.asInstanceOf[Boolean])
- case "java.sql.Date" =>
+ writeBoolean(dos, v)
+ case v: java.sql.Date =>
writeType(dos, "date")
- writeDate(dos, value.asInstanceOf[Date])
- case "java.sql.Time" =>
+ writeDate(dos, v)
+ case v: java.sql.Time =>
writeType(dos, "time")
- writeTime(dos, value.asInstanceOf[Time])
- case "java.sql.Timestamp" =>
+ writeTime(dos, v)
+ case v: java.sql.Timestamp =>
writeType(dos, "time")
- writeTime(dos, value.asInstanceOf[Timestamp])
+ writeTime(dos, v)
// Handle arrays
// Array of primitive types
// Special handling for byte array
- case "[B" =>
+ case v: Array[Byte] =>
writeType(dos, "raw")
- writeBytes(dos, value.asInstanceOf[Array[Byte]])
+ writeBytes(dos, v)
- case "[C" =>
+ case v: Array[Char] =>
writeType(dos, "array")
- writeStringArr(dos, value.asInstanceOf[Array[Char]].map(_.toString))
- case "[S" =>
+ writeStringArr(dos, v.map(_.toString))
+ case v: Array[Short] =>
writeType(dos, "array")
- writeIntArr(dos, value.asInstanceOf[Array[Short]].map(_.toInt))
- case "[I" =>
+ writeIntArr(dos, v.map(_.toInt))
+ case v: Array[Int] =>
writeType(dos, "array")
- writeIntArr(dos, value.asInstanceOf[Array[Int]])
- case "[J" =>
+ writeIntArr(dos, v)
+ case v: Array[Long] =>
writeType(dos, "array")
- writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble))
- case "[F" =>
+ writeDoubleArr(dos, v.map(_.toDouble))
+ case v: Array[Float] =>
writeType(dos, "array")
- writeDoubleArr(dos, value.asInstanceOf[Array[Float]].map(_.toDouble))
- case "[D" =>
+ writeDoubleArr(dos, v.map(_.toDouble))
+ case v: Array[Double] =>
writeType(dos, "array")
- writeDoubleArr(dos, value.asInstanceOf[Array[Double]])
- case "[Z" =>
+ writeDoubleArr(dos, v)
+ case v: Array[Boolean] =>
writeType(dos, "array")
- writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]])
+ writeBooleanArr(dos, v)
// Array of objects, null objects use "void" type
- case c if c.startsWith("[") =>
+ case v: Array[Object] =>
writeType(dos, "list")
- val array = value.asInstanceOf[Array[Object]]
- writeInt(dos, array.length)
- array.foreach(elem => writeObject(dos, elem))
+ writeInt(dos, v.length)
+ v.foreach(elem => writeObject(dos, elem))
case _ =>
writeType(dos, "jobj")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 7f3defec3d..d4b834adb6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpres
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode}
+import scala.util.matching.Regex
+
private[r] object SQLUtils {
def createSQLContext(jsc: JavaSparkContext): SQLContext = {
new SQLContext(jsc)
@@ -35,14 +37,15 @@ private[r] object SQLUtils {
new JavaSparkContext(sqlCtx.sparkContext)
}
- def toSeq[T](arr: Array[T]): Seq[T] = {
- arr.toSeq
- }
-
def createStructType(fields : Seq[StructField]): StructType = {
StructType(fields)
}
+ // Support using regex in string interpolation
+ private[this] implicit class RegexContext(sc: StringContext) {
+ def r: Regex = new Regex(sc.parts.mkString, sc.parts.tail.map(_ => "x"): _*)
+ }
+
def getSQLDataType(dataType: String): DataType = {
dataType match {
case "byte" => org.apache.spark.sql.types.ByteType
@@ -58,6 +61,9 @@ private[r] object SQLUtils {
case "boolean" => org.apache.spark.sql.types.BooleanType
case "timestamp" => org.apache.spark.sql.types.TimestampType
case "date" => org.apache.spark.sql.types.DateType
+ case r"\Aarray<(.*)${elemType}>\Z" => {
+ org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType))
+ }
case _ => throw new IllegalArgumentException(s"Invaid type $dataType")
}
}