aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--R/pkg/R/SQLContext.R22
-rw-r--r--R/pkg/R/deserialize.R10
-rw-r--r--R/pkg/R/schema.R28
-rw-r--r--R/pkg/R/serialize.R43
-rw-r--r--R/pkg/R/sparkR.R4
-rw-r--r--R/pkg/R/utils.R17
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R51
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/SerDe.scala71
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala47
9 files changed, 224 insertions, 69 deletions
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 1c58fd96d7..66c7e30721 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -32,6 +32,7 @@ infer_type <- function(x) {
numeric = "double",
raw = "binary",
list = "array",
+ struct = "struct",
environment = "map",
Date = "date",
POSIXlt = "timestamp",
@@ -44,17 +45,18 @@ infer_type <- function(x) {
paste0("map<string,", infer_type(get(key, x)), ">")
} else if (type == "array") {
stopifnot(length(x) > 0)
+
+ paste0("array<", infer_type(x[[1]]), ">")
+ } else if (type == "struct") {
+ stopifnot(length(x) > 0)
names <- names(x)
- if (is.null(names)) {
- paste0("array<", infer_type(x[[1]]), ">")
- } else {
- # StructType
- types <- lapply(x, infer_type)
- fields <- lapply(1:length(x), function(i) {
- structField(names[[i]], types[[i]], TRUE)
- })
- do.call(structType, fields)
- }
+ stopifnot(!is.null(names))
+
+ type <- lapply(seq_along(x), function(i) {
+ paste0(names[[i]], ":", infer_type(x[[i]]), ",")
+ })
+ type <- Reduce(paste0, type)
+ type <- paste0("struct<", substr(type, 1, nchar(type) - 1), ">")
} else if (length(x) > 1) {
paste0("array<", infer_type(x[[1]]), ">")
} else {
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index ce88d0b071..f7e56e4301 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -51,6 +51,7 @@ readTypedObject <- function(con, type) {
"a" = readArray(con),
"l" = readList(con),
"e" = readEnv(con),
+ "s" = readStruct(con),
"n" = NULL,
"j" = getJobj(readString(con)),
stop(paste("Unsupported type for deserialization", type)))
@@ -135,6 +136,15 @@ readEnv <- function(con) {
env
}
+# Read a field of StructType from DataFrame
+# into a named list in R whose class is "struct"
+readStruct <- function(con) {
+ names <- readObject(con)
+ fields <- readObject(con)
+ names(fields) <- names
+ listToStruct(fields)
+}
+
readRaw <- function(con) {
dataLen <- readInt(con)
readBin(con, raw(), as.integer(dataLen), endian = "big")
diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R
index 8df1563f8e..6f0e9a94e9 100644
--- a/R/pkg/R/schema.R
+++ b/R/pkg/R/schema.R
@@ -136,7 +136,7 @@ checkType <- function(type) {
switch (firstChar,
a = {
# Array type
- m <- regexec("^array<(.*)>$", type)
+ m <- regexec("^array<(.+)>$", type)
matchedStrings <- regmatches(type, m)
if (length(matchedStrings[[1]]) >= 2) {
elemType <- matchedStrings[[1]][2]
@@ -146,7 +146,7 @@ checkType <- function(type) {
},
m = {
# Map type
- m <- regexec("^map<(.*),(.*)>$", type)
+ m <- regexec("^map<(.+),(.+)>$", type)
matchedStrings <- regmatches(type, m)
if (length(matchedStrings[[1]]) >= 3) {
keyType <- matchedStrings[[1]][2]
@@ -157,6 +157,30 @@ checkType <- function(type) {
checkType(valueType)
return()
}
+ },
+ s = {
+ # Struct type
+ m <- regexec("^struct<(.+)>$", type)
+ matchedStrings <- regmatches(type, m)
+ if (length(matchedStrings[[1]]) >= 2) {
+ fieldsString <- matchedStrings[[1]][2]
+ # strsplit does not return the final empty string, so check if
+ # the final char is ","
+ if (substr(fieldsString, nchar(fieldsString), nchar(fieldsString)) != ",") {
+ fields <- strsplit(fieldsString, ",")[[1]]
+ for (field in fields) {
+ m <- regexec("^(.+):(.+)$", field)
+ matchedStrings <- regmatches(field, m)
+ if (length(matchedStrings[[1]]) >= 3) {
+ fieldType <- matchedStrings[[1]][3]
+ checkType(fieldType)
+ } else {
+ break
+ }
+ }
+ return()
+ }
+ }
})
}
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
index 91e6b3e560..17082b4e52 100644
--- a/R/pkg/R/serialize.R
+++ b/R/pkg/R/serialize.R
@@ -32,6 +32,21 @@
# environment -> Map[String, T], where T is a native type
# jobj -> Object, where jobj is an object created in the backend
+getSerdeType <- function(object) {
+ type <- class(object)[[1]]
+ if (type != "list") {
+ type
+ } else {
+ # Check if all elements are of same type
+ elemType <- unique(sapply(object, function(elem) { getSerdeType(elem) }))
+ if (length(elemType) <= 1) {
+ "array"
+ } else {
+ "list"
+ }
+ }
+}
+
writeObject <- function(con, object, writeType = TRUE) {
# NOTE: In R vectors have same type as objects. So we don't support
# passing in vectors as arrays and instead require arrays to be passed
@@ -45,10 +60,12 @@ writeObject <- function(con, object, writeType = TRUE) {
type <- "NULL"
}
}
+
+ serdeType <- getSerdeType(object)
if (writeType) {
- writeType(con, type)
+ writeType(con, serdeType)
}
- switch(type,
+ switch(serdeType,
NULL = writeVoid(con),
integer = writeInt(con, object),
character = writeString(con, object),
@@ -56,7 +73,9 @@ writeObject <- function(con, object, writeType = TRUE) {
double = writeDouble(con, object),
numeric = writeDouble(con, object),
raw = writeRaw(con, object),
+ array = writeArray(con, object),
list = writeList(con, object),
+ struct = writeList(con, object),
jobj = writeJobj(con, object),
environment = writeEnv(con, object),
Date = writeDate(con, object),
@@ -110,7 +129,7 @@ writeRowSerialize <- function(outputCon, rows) {
serializeRow <- function(row) {
rawObj <- rawConnection(raw(0), "wb")
on.exit(close(rawObj))
- writeGenericList(rawObj, row)
+ writeList(rawObj, row)
rawConnectionValue(rawObj)
}
@@ -128,7 +147,9 @@ writeType <- function(con, class) {
double = "d",
numeric = "d",
raw = "r",
+ array = "a",
list = "l",
+ struct = "s",
jobj = "j",
environment = "e",
Date = "D",
@@ -139,15 +160,13 @@ writeType <- function(con, class) {
}
# Used to pass arrays where all the elements are of the same type
-writeList <- function(con, arr) {
- # All elements should be of same type
- elemType <- unique(sapply(arr, function(elem) { class(elem) }))
- stopifnot(length(elemType) <= 1)
-
+writeArray <- function(con, arr) {
# TODO: Empty lists are given type "character" right now.
# This may not work if the Java side expects array of any other type.
- if (length(elemType) == 0) {
+ if (length(arr) == 0) {
elemType <- class("somestring")
+ } else {
+ elemType <- getSerdeType(arr[[1]])
}
writeType(con, elemType)
@@ -161,7 +180,7 @@ writeList <- function(con, arr) {
}
# Used to pass arrays where the elements can be of different types
-writeGenericList <- function(con, list) {
+writeList <- function(con, list) {
writeInt(con, length(list))
for (elem in list) {
writeObject(con, elem)
@@ -174,9 +193,9 @@ writeEnv <- function(con, env) {
writeInt(con, len)
if (len > 0) {
- writeList(con, as.list(ls(env)))
+ writeArray(con, as.list(ls(env)))
vals <- lapply(ls(env), function(x) { env[[x]] })
- writeGenericList(con, as.list(vals))
+ writeList(con, as.list(vals))
}
}
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index 3c57a44db2..cc47110f54 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -178,7 +178,7 @@ sparkR.init <- function(
}
nonEmptyJars <- Filter(function(x) { x != "" }, jars)
- localJarPaths <- sapply(nonEmptyJars,
+ localJarPaths <- lapply(nonEmptyJars,
function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) })
# Set the start time to identify jobjs
@@ -193,7 +193,7 @@ sparkR.init <- function(
master,
appName,
as.character(sparkHome),
- as.list(localJarPaths),
+ localJarPaths,
sparkEnvirMap,
sparkExecutorEnvMap),
envir = .sparkREnv
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 69a2bc728f..94f16c7ac5 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -588,3 +588,20 @@ mergePartitions <- function(rdd, zip) {
PipelinedRDD(rdd, partitionFunc)
}
+
+# Convert a named list to struct so that
+# SerDe won't confuse between a normal named list and struct
+listToStruct <- function(list) {
+ stopifnot(class(list) == "list")
+ stopifnot(!is.null(names(list)))
+ class(list) <- "struct"
+ list
+}
+
+# Convert a struct to a named list
+structToList <- function(struct) {
+ stopifnot(class(list) == "struct")
+
+ class(struct) <- "list"
+ struct
+}
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 3a04edbb4c..af6efa40fb 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -66,10 +66,7 @@ test_that("infer types and check types", {
expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp")
expect_equal(infer_type(c(1L, 2L)), "array<integer>")
expect_equal(infer_type(list(1L, 2L)), "array<integer>")
- testStruct <- infer_type(list(a = 1L, b = "2"))
- expect_equal(class(testStruct), "structType")
- checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE)
- checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE)
+ expect_equal(infer_type(listToStruct(list(a = 1L, b = "2"))), "struct<a:integer,b:string>")
e <- new.env()
assign("a", 1L, envir = e)
expect_equal(infer_type(e), "map<string,integer>")
@@ -242,38 +239,36 @@ test_that("create DataFrame with different data types", {
expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE))
})
-test_that("create DataFrame with nested array and map", {
-# e <- new.env()
-# assign("n", 3L, envir = e)
-# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L))
-# df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d"))
-# expect_equal(dtypes(df), list(c("a", "array<int>"), c("b", "array<string>"),
-# c("c", "map<string,int>"), c("d", "struct<a:string,b:int>")))
-# expect_equal(count(df), 1)
-# ldf <- collect(df)
-# expect_equal(ldf[1,], l[[1]])
-
- # ArrayType and MapType
+test_that("create DataFrame with complex types", {
e <- new.env()
assign("n", 3L, envir = e)
- l <- list(as.list(1:10), list("a", "b"), e)
- df <- createDataFrame(sqlContext, list(l), c("a", "b", "c"))
+ s <- listToStruct(list(a = "aa", b = 3L))
+
+ l <- list(as.list(1:10), list("a", "b"), e, s)
+ df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d"))
expect_equal(dtypes(df), list(c("a", "array<int>"),
c("b", "array<string>"),
- c("c", "map<string,int>")))
+ c("c", "map<string,int>"),
+ c("d", "struct<a:string,b:int>")))
expect_equal(count(df), 1)
ldf <- collect(df)
- expect_equal(names(ldf), c("a", "b", "c"))
+ expect_equal(names(ldf), c("a", "b", "c", "d"))
expect_equal(ldf[1, 1][[1]], l[[1]])
expect_equal(ldf[1, 2][[1]], l[[2]])
+
e <- ldf$c[[1]]
expect_equal(class(e), "environment")
expect_equal(ls(e), "n")
expect_equal(e$n, 3L)
+
+ s <- ldf$d[[1]]
+ expect_equal(class(s), "struct")
+ expect_equal(s$a, "aa")
+ expect_equal(s$b, 3L)
})
-# For test map type in DataFrame
+# For test map type and struct type in DataFrame
mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}",
"{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}",
"{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}")
@@ -308,7 +303,19 @@ test_that("Collect DataFrame with complex types", {
expect_equal(bob$age, 16)
expect_equal(bob$height, 176.5)
- # TODO: tests for StructType after it is supported
+ # StructType
+ df <- jsonFile(sqlContext, mapTypeJsonPath)
+ expect_equal(dtypes(df), list(c("info", "struct<age:bigint,height:double>"),
+ c("name", "string")))
+ ldf <- collect(df)
+ expect_equal(nrow(ldf), 3)
+ expect_equal(ncol(ldf), 2)
+ expect_equal(names(ldf), c("info", "name"))
+ expect_equal(ldf$name, c("Bob", "Alice", "David"))
+ bob <- ldf$info[[1]]
+ expect_equal(class(bob), "struct")
+ expect_equal(bob$age, 16)
+ expect_equal(bob$height, 176.5)
})
test_that("jsonFile() on a local file returns a DataFrame", {
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 0c78613e40..da126bac7a 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -27,6 +27,14 @@ import scala.collection.mutable.WrappedArray
* Utility functions to serialize, deserialize objects to / from R
*/
private[spark] object SerDe {
+ type ReadObject = (DataInputStream, Char) => Object
+ type WriteObject = (DataOutputStream, Object) => Boolean
+
+ var sqlSerDe: (ReadObject, WriteObject) = _
+
+ def registerSqlSerDe(sqlSerDe: (ReadObject, WriteObject)): Unit = {
+ this.sqlSerDe = sqlSerDe
+ }
// Type mapping from R to Java
//
@@ -63,11 +71,22 @@ private[spark] object SerDe {
case 'c' => readString(dis)
case 'e' => readMap(dis)
case 'r' => readBytes(dis)
+ case 'a' => readArray(dis)
case 'l' => readList(dis)
case 'D' => readDate(dis)
case 't' => readTime(dis)
case 'j' => JVMObjectTracker.getObject(readString(dis))
- case _ => throw new IllegalArgumentException(s"Invalid type $dataType")
+ case _ =>
+ if (sqlSerDe == null || sqlSerDe._1 == null) {
+ throw new IllegalArgumentException (s"Invalid type $dataType")
+ } else {
+ val obj = (sqlSerDe._1)(dis, dataType)
+ if (obj == null) {
+ throw new IllegalArgumentException (s"Invalid type $dataType")
+ } else {
+ obj
+ }
+ }
}
}
@@ -141,7 +160,8 @@ private[spark] object SerDe {
(0 until len).map(_ => readString(in)).toArray
}
- def readList(dis: DataInputStream): Array[_] = {
+ // All elements of an array must be of the same type
+ def readArray(dis: DataInputStream): Array[_] = {
val arrType = readObjectType(dis)
arrType match {
case 'i' => readIntArr(dis)
@@ -150,26 +170,43 @@ private[spark] object SerDe {
case 'b' => readBooleanArr(dis)
case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x))
case 'r' => readBytesArr(dis)
- case 'l' => {
+ case 'a' =>
+ val len = readInt(dis)
+ (0 until len).map(_ => readArray(dis)).toArray
+ case 'l' =>
val len = readInt(dis)
(0 until len).map(_ => readList(dis)).toArray
- }
- case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
+ case _ =>
+ if (sqlSerDe == null || sqlSerDe._1 == null) {
+ throw new IllegalArgumentException (s"Invalid array type $arrType")
+ } else {
+ val len = readInt(dis)
+ (0 until len).map { _ =>
+ val obj = (sqlSerDe._1)(dis, arrType)
+ if (obj == null) {
+ throw new IllegalArgumentException (s"Invalid array type $arrType")
+ } else {
+ obj
+ }
+ }.toArray
+ }
}
}
+ // Each element of a list can be of different type. They are all represented
+ // as Object on JVM side
+ def readList(dis: DataInputStream): Array[Object] = {
+ val len = readInt(dis)
+ (0 until len).map(_ => readObject(dis)).toArray
+ }
+
def readMap(in: DataInputStream): java.util.Map[Object, Object] = {
val len = readInt(in)
if (len > 0) {
- val keysType = readObjectType(in)
- val keysLen = readInt(in)
- val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType))
-
- val valuesLen = readInt(in)
- val values = (0 until valuesLen).map(_ => {
- val valueType = readObjectType(in)
- readTypedObject(in, valueType)
- })
+ // Keys is an array of String
+ val keys = readArray(in).asInstanceOf[Array[Object]]
+ val values = readList(in)
+
keys.zip(values).toMap.asJava
} else {
new java.util.HashMap[Object, Object]()
@@ -338,8 +375,10 @@ private[spark] object SerDe {
}
case _ =>
- writeType(dos, "jobj")
- writeJObj(dos, value)
+ if (sqlSerDe == null || sqlSerDe._2 == null || !(sqlSerDe._2)(dos, value)) {
+ writeType(dos, "jobj")
+ writeJObj(dos, value)
+ }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index f45d119c8c..b0120a8d0d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -22,13 +22,15 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.r.SerDe
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression, GenericRowWithSchema}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode}
import scala.util.matching.Regex
private[r] object SQLUtils {
+ SerDe.registerSqlSerDe((readSqlObject, writeSqlObject))
+
def createSQLContext(jsc: JavaSparkContext): SQLContext = {
new SQLContext(jsc)
}
@@ -61,15 +63,27 @@ private[r] object SQLUtils {
case "boolean" => org.apache.spark.sql.types.BooleanType
case "timestamp" => org.apache.spark.sql.types.TimestampType
case "date" => org.apache.spark.sql.types.DateType
- case r"\Aarray<(.*)${elemType}>\Z" => {
+ case r"\Aarray<(.+)${elemType}>\Z" =>
org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType))
- }
- case r"\Amap<(.*)${keyType},(.*)${valueType}>\Z" => {
+ case r"\Amap<(.+)${keyType},(.+)${valueType}>\Z" =>
if (keyType != "string" && keyType != "character") {
throw new IllegalArgumentException("Key type of a map must be string or character")
}
org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType))
- }
+ case r"\Astruct<(.+)${fieldsStr}>\Z" =>
+ if (fieldsStr(fieldsStr.length - 1) == ',') {
+ throw new IllegalArgumentException(s"Invaid type $dataType")
+ }
+ val fields = fieldsStr.split(",")
+ val structFields = fields.map { field =>
+ field match {
+ case r"\A(.+)${fieldName}:(.+)${fieldType}\Z" =>
+ createStructField(fieldName, fieldType, true)
+
+ case _ => throw new IllegalArgumentException(s"Invaid type $dataType")
+ }
+ }
+ createStructType(structFields)
case _ => throw new IllegalArgumentException(s"Invaid type $dataType")
}
}
@@ -151,4 +165,27 @@ private[r] object SQLUtils {
options: java.util.Map[String, String]): DataFrame = {
sqlContext.read.format(source).schema(schema).options(options).load()
}
+
+ def readSqlObject(dis: DataInputStream, dataType: Char): Object = {
+ dataType match {
+ case 's' =>
+ // Read StructType for DataFrame
+ val fields = SerDe.readList(dis).asInstanceOf[Array[Object]]
+ Row.fromSeq(fields)
+ case _ => null
+ }
+ }
+
+ def writeSqlObject(dos: DataOutputStream, obj: Object): Boolean = {
+ obj match {
+ // Handle struct type in DataFrame
+ case v: GenericRowWithSchema =>
+ dos.writeByte('s')
+ SerDe.writeObject(dos, v.schema.fieldNames)
+ SerDe.writeObject(dos, v.values)
+ true
+ case _ =>
+ false
+ }
+ }
}