aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorSun Rui <rui.sun@intel.com>2015-08-25 13:14:10 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-08-25 13:14:10 -0700
commit71a138cd0e0a14e8426f97877e3b52a562bbd02c (patch)
treee0d2f675ec969b7a5c24c46414999d16c8fc759e /core
parent16a2be1a84c0a274a60c0a584faaf58b55d4942b (diff)
downloadspark-71a138cd0e0a14e8426f97877e3b52a562bbd02c.tar.gz
spark-71a138cd0e0a14e8426f97877e3b52a562bbd02c.tar.bz2
spark-71a138cd0e0a14e8426f97877e3b52a562bbd02c.zip
[SPARK-10048] [SPARKR] Support arbitrary nested Java array in serde.
This PR: 1. supports transferring arbitrary nested array from JVM to R side in SerDe; 2. based on 1, collect() implemenation is improved. Now it can support collecting data of complex types from a DataFrame. Author: Sun Rui <rui.sun@intel.com> Closes #8276 from sun-rui/SPARK-10048.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/SerDe.scala86
2 files changed, 58 insertions, 35 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index 6ce02e2ea3..bb82f3285f 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -53,6 +53,13 @@ private[r] class RBackendHandler(server: RBackend)
if (objId == "SparkRHandler") {
methodName match {
+ // This function is for test-purpose only
+ case "echo" =>
+ val args = readArgs(numArgs, dis)
+ assert(numArgs == 1)
+
+ writeInt(dos, 0)
+ writeObject(dos, args(0))
case "stopBackend" =>
writeInt(dos, 0)
writeType(dos, "void")
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index dbbbcf40c1..26ad4f1d46 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -149,6 +149,10 @@ private[spark] object SerDe {
case 'b' => readBooleanArr(dis)
case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x))
case 'r' => readBytesArr(dis)
+ case 'l' => {
+ val len = readInt(dis)
+ (0 until len).map(_ => readList(dis)).toArray
+ }
case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
}
}
@@ -200,6 +204,9 @@ private[spark] object SerDe {
case "date" => dos.writeByte('D')
case "time" => dos.writeByte('t')
case "raw" => dos.writeByte('r')
+ // Array of primitive types
+ case "array" => dos.writeByte('a')
+ // Array of objects
case "list" => dos.writeByte('l')
case "jobj" => dos.writeByte('j')
case _ => throw new IllegalArgumentException(s"Invalid type $typeStr")
@@ -211,26 +218,35 @@ private[spark] object SerDe {
writeType(dos, "void")
} else {
value.getClass.getName match {
+ case "java.lang.Character" =>
+ writeType(dos, "character")
+ writeString(dos, value.asInstanceOf[Character].toString)
case "java.lang.String" =>
writeType(dos, "character")
writeString(dos, value.asInstanceOf[String])
- case "long" | "java.lang.Long" =>
+ case "java.lang.Long" =>
writeType(dos, "double")
writeDouble(dos, value.asInstanceOf[Long].toDouble)
- case "float" | "java.lang.Float" =>
+ case "java.lang.Float" =>
writeType(dos, "double")
writeDouble(dos, value.asInstanceOf[Float].toDouble)
- case "decimal" | "java.math.BigDecimal" =>
+ case "java.math.BigDecimal" =>
writeType(dos, "double")
val javaDecimal = value.asInstanceOf[java.math.BigDecimal]
writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble)
- case "double" | "java.lang.Double" =>
+ case "java.lang.Double" =>
writeType(dos, "double")
writeDouble(dos, value.asInstanceOf[Double])
- case "int" | "java.lang.Integer" =>
+ case "java.lang.Byte" =>
+ writeType(dos, "integer")
+ writeInt(dos, value.asInstanceOf[Byte].toInt)
+ case "java.lang.Short" =>
+ writeType(dos, "integer")
+ writeInt(dos, value.asInstanceOf[Short].toInt)
+ case "java.lang.Integer" =>
writeType(dos, "integer")
writeInt(dos, value.asInstanceOf[Int])
- case "boolean" | "java.lang.Boolean" =>
+ case "java.lang.Boolean" =>
writeType(dos, "logical")
writeBoolean(dos, value.asInstanceOf[Boolean])
case "java.sql.Date" =>
@@ -242,43 +258,48 @@ private[spark] object SerDe {
case "java.sql.Timestamp" =>
writeType(dos, "time")
writeTime(dos, value.asInstanceOf[Timestamp])
+
+ // Handle arrays
+
+ // Array of primitive types
+
+ // Special handling for byte array
case "[B" =>
writeType(dos, "raw")
writeBytes(dos, value.asInstanceOf[Array[Byte]])
- // TODO: Types not handled right now include
- // byte, char, short, float
- // Handle arrays
- case "[Ljava.lang.String;" =>
- writeType(dos, "list")
- writeStringArr(dos, value.asInstanceOf[Array[String]])
+ case "[C" =>
+ writeType(dos, "array")
+ writeStringArr(dos, value.asInstanceOf[Array[Char]].map(_.toString))
+ case "[S" =>
+ writeType(dos, "array")
+ writeIntArr(dos, value.asInstanceOf[Array[Short]].map(_.toInt))
case "[I" =>
- writeType(dos, "list")
+ writeType(dos, "array")
writeIntArr(dos, value.asInstanceOf[Array[Int]])
case "[J" =>
- writeType(dos, "list")
+ writeType(dos, "array")
writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble))
+ case "[F" =>
+ writeType(dos, "array")
+ writeDoubleArr(dos, value.asInstanceOf[Array[Float]].map(_.toDouble))
case "[D" =>
- writeType(dos, "list")
+ writeType(dos, "array")
writeDoubleArr(dos, value.asInstanceOf[Array[Double]])
case "[Z" =>
- writeType(dos, "list")
+ writeType(dos, "array")
writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]])
- case "[[B" =>
+
+ // Array of objects, null objects use "void" type
+ case c if c.startsWith("[") =>
writeType(dos, "list")
- writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]])
- case otherName =>
- // Handle array of objects
- if (otherName.startsWith("[L")) {
- val objArr = value.asInstanceOf[Array[Object]]
- writeType(dos, "list")
- writeType(dos, "jobj")
- dos.writeInt(objArr.length)
- objArr.foreach(o => writeJObj(dos, o))
- } else {
- writeType(dos, "jobj")
- writeJObj(dos, value)
- }
+ val array = value.asInstanceOf[Array[Object]]
+ writeInt(dos, array.length)
+ array.foreach(elem => writeObject(dos, elem))
+
+ case _ =>
+ writeType(dos, "jobj")
+ writeJObj(dos, value)
}
}
}
@@ -350,11 +371,6 @@ private[spark] object SerDe {
value.foreach(v => writeString(out, v))
}
- def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = {
- writeType(out, "raw")
- out.writeInt(value.length)
- value.foreach(v => writeBytes(out, v))
- }
}
private[r] object SerializationFormats {