diff options
-rw-r--r-- | R/pkg/R/SQLContext.R | 22 | ||||
-rw-r--r-- | R/pkg/R/deserialize.R | 10 | ||||
-rw-r--r-- | R/pkg/R/schema.R | 28 | ||||
-rw-r--r-- | R/pkg/R/serialize.R | 43 | ||||
-rw-r--r-- | R/pkg/R/sparkR.R | 4 | ||||
-rw-r--r-- | R/pkg/R/utils.R | 17 | ||||
-rw-r--r-- | R/pkg/inst/tests/test_sparkSQL.R | 51 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/r/SerDe.scala | 71 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala | 47 |
9 files changed, 224 insertions, 69 deletions
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 1c58fd96d7..66c7e30721 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -32,6 +32,7 @@ infer_type <- function(x) { numeric = "double", raw = "binary", list = "array", + struct = "struct", environment = "map", Date = "date", POSIXlt = "timestamp", @@ -44,17 +45,18 @@ infer_type <- function(x) { paste0("map<string,", infer_type(get(key, x)), ">") } else if (type == "array") { stopifnot(length(x) > 0) + + paste0("array<", infer_type(x[[1]]), ">") + } else if (type == "struct") { + stopifnot(length(x) > 0) names <- names(x) - if (is.null(names)) { - paste0("array<", infer_type(x[[1]]), ">") - } else { - # StructType - types <- lapply(x, infer_type) - fields <- lapply(1:length(x), function(i) { - structField(names[[i]], types[[i]], TRUE) - }) - do.call(structType, fields) - } + stopifnot(!is.null(names)) + + type <- lapply(seq_along(x), function(i) { + paste0(names[[i]], ":", infer_type(x[[i]]), ",") + }) + type <- Reduce(paste0, type) + type <- paste0("struct<", substr(type, 1, nchar(type) - 1), ">") } else if (length(x) > 1) { paste0("array<", infer_type(x[[1]]), ">") } else { diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index ce88d0b071..f7e56e4301 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -51,6 +51,7 @@ readTypedObject <- function(con, type) { "a" = readArray(con), "l" = readList(con), "e" = readEnv(con), + "s" = readStruct(con), "n" = NULL, "j" = getJobj(readString(con)), stop(paste("Unsupported type for deserialization", type))) @@ -135,6 +136,15 @@ readEnv <- function(con) { env } +# Read a field of StructType from DataFrame +# into a named list in R whose class is "struct" +readStruct <- function(con) { + names <- readObject(con) + fields <- readObject(con) + names(fields) <- names + listToStruct(fields) +} + readRaw <- function(con) { dataLen <- readInt(con) readBin(con, raw(), as.integer(dataLen), endian = "big") diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 8df1563f8e..6f0e9a94e9 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -136,7 +136,7 @@ checkType <- function(type) { switch (firstChar, a = { # Array type - m <- regexec("^array<(.*)>$", type) + m <- regexec("^array<(.+)>$", type) matchedStrings <- regmatches(type, m) if (length(matchedStrings[[1]]) >= 2) { elemType <- matchedStrings[[1]][2] @@ -146,7 +146,7 @@ checkType <- function(type) { }, m = { # Map type - m <- regexec("^map<(.*),(.*)>$", type) + m <- regexec("^map<(.+),(.+)>$", type) matchedStrings <- regmatches(type, m) if (length(matchedStrings[[1]]) >= 3) { keyType <- matchedStrings[[1]][2] @@ -157,6 +157,30 @@ checkType <- function(type) { checkType(valueType) return() } + }, + s = { + # Struct type + m <- regexec("^struct<(.+)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + fieldsString <- matchedStrings[[1]][2] + # strsplit does not return the final empty string, so check if + # the final char is "," + if (substr(fieldsString, nchar(fieldsString), nchar(fieldsString)) != ",") { + fields <- strsplit(fieldsString, ",")[[1]] + for (field in fields) { + m <- regexec("^(.+):(.+)$", field) + matchedStrings <- regmatches(field, m) + if (length(matchedStrings[[1]]) >= 3) { + fieldType <- matchedStrings[[1]][3] + checkType(fieldType) + } else { + break + } + } + return() + } + } }) } diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 91e6b3e560..17082b4e52 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -32,6 +32,21 @@ # environment -> Map[String, T], where T is a native type # jobj -> Object, where jobj is an object created in the backend +getSerdeType <- function(object) { + type <- class(object)[[1]] + if (type != "list") { + type + } else { + # Check if all elements are of same type + elemType <- unique(sapply(object, function(elem) { getSerdeType(elem) })) + if (length(elemType) <= 1) { + "array" + } else { + "list" + } + } +} + writeObject <- function(con, object, writeType = TRUE) { # NOTE: In R vectors have same type as objects. So we don't support # passing in vectors as arrays and instead require arrays to be passed @@ -45,10 +60,12 @@ writeObject <- function(con, object, writeType = TRUE) { type <- "NULL" } } + + serdeType <- getSerdeType(object) if (writeType) { - writeType(con, type) + writeType(con, serdeType) } - switch(type, + switch(serdeType, NULL = writeVoid(con), integer = writeInt(con, object), character = writeString(con, object), @@ -56,7 +73,9 @@ writeObject <- function(con, object, writeType = TRUE) { double = writeDouble(con, object), numeric = writeDouble(con, object), raw = writeRaw(con, object), + array = writeArray(con, object), list = writeList(con, object), + struct = writeList(con, object), jobj = writeJobj(con, object), environment = writeEnv(con, object), Date = writeDate(con, object), @@ -110,7 +129,7 @@ writeRowSerialize <- function(outputCon, rows) { serializeRow <- function(row) { rawObj <- rawConnection(raw(0), "wb") on.exit(close(rawObj)) - writeGenericList(rawObj, row) + writeList(rawObj, row) rawConnectionValue(rawObj) } @@ -128,7 +147,9 @@ writeType <- function(con, class) { double = "d", numeric = "d", raw = "r", + array = "a", list = "l", + struct = "s", jobj = "j", environment = "e", Date = "D", @@ -139,15 +160,13 @@ writeType <- function(con, class) { } # Used to pass arrays where all the elements are of the same type -writeList <- function(con, arr) { - # All elements should be of same type - elemType <- unique(sapply(arr, function(elem) { class(elem) })) - stopifnot(length(elemType) <= 1) - +writeArray <- function(con, arr) { # TODO: Empty lists are given type "character" right now. # This may not work if the Java side expects array of any other type. - if (length(elemType) == 0) { + if (length(arr) == 0) { elemType <- class("somestring") + } else { + elemType <- getSerdeType(arr[[1]]) } writeType(con, elemType) @@ -161,7 +180,7 @@ writeList <- function(con, arr) { } # Used to pass arrays where the elements can be of different types -writeGenericList <- function(con, list) { +writeList <- function(con, list) { writeInt(con, length(list)) for (elem in list) { writeObject(con, elem) @@ -174,9 +193,9 @@ writeEnv <- function(con, env) { writeInt(con, len) if (len > 0) { - writeList(con, as.list(ls(env))) + writeArray(con, as.list(ls(env))) vals <- lapply(ls(env), function(x) { env[[x]] }) - writeGenericList(con, as.list(vals)) + writeList(con, as.list(vals)) } } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 3c57a44db2..cc47110f54 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -178,7 +178,7 @@ sparkR.init <- function( } nonEmptyJars <- Filter(function(x) { x != "" }, jars) - localJarPaths <- sapply(nonEmptyJars, + localJarPaths <- lapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) # Set the start time to identify jobjs @@ -193,7 +193,7 @@ sparkR.init <- function( master, appName, as.character(sparkHome), - as.list(localJarPaths), + localJarPaths, sparkEnvirMap, sparkExecutorEnvMap), envir = .sparkREnv diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 69a2bc728f..94f16c7ac5 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -588,3 +588,20 @@ mergePartitions <- function(rdd, zip) { PipelinedRDD(rdd, partitionFunc) } + +# Convert a named list to struct so that +# SerDe won't confuse between a normal named list and struct +listToStruct <- function(list) { + stopifnot(class(list) == "list") + stopifnot(!is.null(names(list))) + class(list) <- "struct" + list +} + +# Convert a struct to a named list +structToList <- function(struct) { + stopifnot(class(list) == "struct") + + class(struct) <- "list" + struct +} diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 3a04edbb4c..af6efa40fb 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -66,10 +66,7 @@ test_that("infer types and check types", { expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") expect_equal(infer_type(c(1L, 2L)), "array<integer>") expect_equal(infer_type(list(1L, 2L)), "array<integer>") - testStruct <- infer_type(list(a = 1L, b = "2")) - expect_equal(class(testStruct), "structType") - checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) - checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) + expect_equal(infer_type(listToStruct(list(a = 1L, b = "2"))), "struct<a:integer,b:string>") e <- new.env() assign("a", 1L, envir = e) expect_equal(infer_type(e), "map<string,integer>") @@ -242,38 +239,36 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) -test_that("create DataFrame with nested array and map", { -# e <- new.env() -# assign("n", 3L, envir = e) -# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) -# df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) -# expect_equal(dtypes(df), list(c("a", "array<int>"), c("b", "array<string>"), -# c("c", "map<string,int>"), c("d", "struct<a:string,b:int>"))) -# expect_equal(count(df), 1) -# ldf <- collect(df) -# expect_equal(ldf[1,], l[[1]]) - - # ArrayType and MapType +test_that("create DataFrame with complex types", { e <- new.env() assign("n", 3L, envir = e) - l <- list(as.list(1:10), list("a", "b"), e) - df <- createDataFrame(sqlContext, list(l), c("a", "b", "c")) + s <- listToStruct(list(a = "aa", b = 3L)) + + l <- list(as.list(1:10), list("a", "b"), e, s) + df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) expect_equal(dtypes(df), list(c("a", "array<int>"), c("b", "array<string>"), - c("c", "map<string,int>"))) + c("c", "map<string,int>"), + c("d", "struct<a:string,b:int>"))) expect_equal(count(df), 1) ldf <- collect(df) - expect_equal(names(ldf), c("a", "b", "c")) + expect_equal(names(ldf), c("a", "b", "c", "d")) expect_equal(ldf[1, 1][[1]], l[[1]]) expect_equal(ldf[1, 2][[1]], l[[2]]) + e <- ldf$c[[1]] expect_equal(class(e), "environment") expect_equal(ls(e), "n") expect_equal(e$n, 3L) + + s <- ldf$d[[1]] + expect_equal(class(s), "struct") + expect_equal(s$a, "aa") + expect_equal(s$b, 3L) }) -# For test map type in DataFrame +# For test map type and struct type in DataFrame mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") @@ -308,7 +303,19 @@ test_that("Collect DataFrame with complex types", { expect_equal(bob$age, 16) expect_equal(bob$height, 176.5) - # TODO: tests for StructType after it is supported + # StructType + df <- jsonFile(sqlContext, mapTypeJsonPath) + expect_equal(dtypes(df), list(c("info", "struct<age:bigint,height:double>"), + c("name", "string"))) + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 2) + expect_equal(names(ldf), c("info", "name")) + expect_equal(ldf$name, c("Bob", "Alice", "David")) + bob <- ldf$info[[1]] + expect_equal(class(bob), "struct") + expect_equal(bob$age, 16) + expect_equal(bob$height, 176.5) }) test_that("jsonFile() on a local file returns a DataFrame", { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 0c78613e40..da126bac7a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -27,6 +27,14 @@ import scala.collection.mutable.WrappedArray * Utility functions to serialize, deserialize objects to / from R */ private[spark] object SerDe { + type ReadObject = (DataInputStream, Char) => Object + type WriteObject = (DataOutputStream, Object) => Boolean + + var sqlSerDe: (ReadObject, WriteObject) = _ + + def registerSqlSerDe(sqlSerDe: (ReadObject, WriteObject)): Unit = { + this.sqlSerDe = sqlSerDe + } // Type mapping from R to Java // @@ -63,11 +71,22 @@ private[spark] object SerDe { case 'c' => readString(dis) case 'e' => readMap(dis) case 'r' => readBytes(dis) + case 'a' => readArray(dis) case 'l' => readList(dis) case 'D' => readDate(dis) case 't' => readTime(dis) case 'j' => JVMObjectTracker.getObject(readString(dis)) - case _ => throw new IllegalArgumentException(s"Invalid type $dataType") + case _ => + if (sqlSerDe == null || sqlSerDe._1 == null) { + throw new IllegalArgumentException (s"Invalid type $dataType") + } else { + val obj = (sqlSerDe._1)(dis, dataType) + if (obj == null) { + throw new IllegalArgumentException (s"Invalid type $dataType") + } else { + obj + } + } } } @@ -141,7 +160,8 @@ private[spark] object SerDe { (0 until len).map(_ => readString(in)).toArray } - def readList(dis: DataInputStream): Array[_] = { + // All elements of an array must be of the same type + def readArray(dis: DataInputStream): Array[_] = { val arrType = readObjectType(dis) arrType match { case 'i' => readIntArr(dis) @@ -150,26 +170,43 @@ private[spark] object SerDe { case 'b' => readBooleanArr(dis) case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) case 'r' => readBytesArr(dis) - case 'l' => { + case 'a' => + val len = readInt(dis) + (0 until len).map(_ => readArray(dis)).toArray + case 'l' => val len = readInt(dis) (0 until len).map(_ => readList(dis)).toArray - } - case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") + case _ => + if (sqlSerDe == null || sqlSerDe._1 == null) { + throw new IllegalArgumentException (s"Invalid array type $arrType") + } else { + val len = readInt(dis) + (0 until len).map { _ => + val obj = (sqlSerDe._1)(dis, arrType) + if (obj == null) { + throw new IllegalArgumentException (s"Invalid array type $arrType") + } else { + obj + } + }.toArray + } } } + // Each element of a list can be of different type. They are all represented + // as Object on JVM side + def readList(dis: DataInputStream): Array[Object] = { + val len = readInt(dis) + (0 until len).map(_ => readObject(dis)).toArray + } + def readMap(in: DataInputStream): java.util.Map[Object, Object] = { val len = readInt(in) if (len > 0) { - val keysType = readObjectType(in) - val keysLen = readInt(in) - val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) - - val valuesLen = readInt(in) - val values = (0 until valuesLen).map(_ => { - val valueType = readObjectType(in) - readTypedObject(in, valueType) - }) + // Keys is an array of String + val keys = readArray(in).asInstanceOf[Array[Object]] + val values = readList(in) + keys.zip(values).toMap.asJava } else { new java.util.HashMap[Object, Object]() @@ -338,8 +375,10 @@ private[spark] object SerDe { } case _ => - writeType(dos, "jobj") - writeJObj(dos, value) + if (sqlSerDe == null || sqlSerDe._2 == null || !(sqlSerDe._2)(dos, value)) { + writeType(dos, "jobj") + writeJObj(dos, value) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index f45d119c8c..b0120a8d0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -22,13 +22,15 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression, GenericRowWithSchema} import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} import scala.util.matching.Regex private[r] object SQLUtils { + SerDe.registerSqlSerDe((readSqlObject, writeSqlObject)) + def createSQLContext(jsc: JavaSparkContext): SQLContext = { new SQLContext(jsc) } @@ -61,15 +63,27 @@ private[r] object SQLUtils { case "boolean" => org.apache.spark.sql.types.BooleanType case "timestamp" => org.apache.spark.sql.types.TimestampType case "date" => org.apache.spark.sql.types.DateType - case r"\Aarray<(.*)${elemType}>\Z" => { + case r"\Aarray<(.+)${elemType}>\Z" => org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) - } - case r"\Amap<(.*)${keyType},(.*)${valueType}>\Z" => { + case r"\Amap<(.+)${keyType},(.+)${valueType}>\Z" => if (keyType != "string" && keyType != "character") { throw new IllegalArgumentException("Key type of a map must be string or character") } org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType)) - } + case r"\Astruct<(.+)${fieldsStr}>\Z" => + if (fieldsStr(fieldsStr.length - 1) == ',') { + throw new IllegalArgumentException(s"Invaid type $dataType") + } + val fields = fieldsStr.split(",") + val structFields = fields.map { field => + field match { + case r"\A(.+)${fieldName}:(.+)${fieldType}\Z" => + createStructField(fieldName, fieldType, true) + + case _ => throw new IllegalArgumentException(s"Invaid type $dataType") + } + } + createStructType(structFields) case _ => throw new IllegalArgumentException(s"Invaid type $dataType") } } @@ -151,4 +165,27 @@ private[r] object SQLUtils { options: java.util.Map[String, String]): DataFrame = { sqlContext.read.format(source).schema(schema).options(options).load() } + + def readSqlObject(dis: DataInputStream, dataType: Char): Object = { + dataType match { + case 's' => + // Read StructType for DataFrame + val fields = SerDe.readList(dis).asInstanceOf[Array[Object]] + Row.fromSeq(fields) + case _ => null + } + } + + def writeSqlObject(dos: DataOutputStream, obj: Object): Boolean = { + obj match { + // Handle struct type in DataFrame + case v: GenericRowWithSchema => + dos.writeByte('s') + SerDe.writeObject(dos, v.schema.fieldNames) + SerDe.writeObject(dos, v.values) + true + case _ => + false + } + } } |