diff options
author | Sun Rui <rui.sun@intel.com> | 2015-08-25 13:14:10 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@cs.berkeley.edu> | 2015-08-25 13:14:10 -0700 |
commit | 71a138cd0e0a14e8426f97877e3b52a562bbd02c (patch) | |
tree | e0d2f675ec969b7a5c24c46414999d16c8fc759e /core | |
parent | 16a2be1a84c0a274a60c0a584faaf58b55d4942b (diff) | |
download | spark-71a138cd0e0a14e8426f97877e3b52a562bbd02c.tar.gz spark-71a138cd0e0a14e8426f97877e3b52a562bbd02c.tar.bz2 spark-71a138cd0e0a14e8426f97877e3b52a562bbd02c.zip |
[SPARK-10048] [SPARKR] Support arbitrary nested Java array in serde.
This PR:
1. supports transferring arbitrary nested array from JVM to R side in SerDe;
2. based on 1, collect() implemenation is improved. Now it can support collecting data of complex types
from a DataFrame.
Author: Sun Rui <rui.sun@intel.com>
Closes #8276 from sun-rui/SPARK-10048.
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala | 7 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/r/SerDe.scala | 86 |
2 files changed, 58 insertions, 35 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 6ce02e2ea3..bb82f3285f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -53,6 +53,13 @@ private[r] class RBackendHandler(server: RBackend) if (objId == "SparkRHandler") { methodName match { + // This function is for test-purpose only + case "echo" => + val args = readArgs(numArgs, dis) + assert(numArgs == 1) + + writeInt(dos, 0) + writeObject(dos, args(0)) case "stopBackend" => writeInt(dos, 0) writeType(dos, "void") diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index dbbbcf40c1..26ad4f1d46 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -149,6 +149,10 @@ private[spark] object SerDe { case 'b' => readBooleanArr(dis) case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) case 'r' => readBytesArr(dis) + case 'l' => { + val len = readInt(dis) + (0 until len).map(_ => readList(dis)).toArray + } case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") } } @@ -200,6 +204,9 @@ private[spark] object SerDe { case "date" => dos.writeByte('D') case "time" => dos.writeByte('t') case "raw" => dos.writeByte('r') + // Array of primitive types + case "array" => dos.writeByte('a') + // Array of objects case "list" => dos.writeByte('l') case "jobj" => dos.writeByte('j') case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") @@ -211,26 +218,35 @@ private[spark] object SerDe { writeType(dos, "void") } else { value.getClass.getName match { + case "java.lang.Character" => + writeType(dos, "character") + writeString(dos, value.asInstanceOf[Character].toString) case "java.lang.String" => writeType(dos, "character") writeString(dos, value.asInstanceOf[String]) - case "long" | "java.lang.Long" => + case "java.lang.Long" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Long].toDouble) - case "float" | "java.lang.Float" => + case "java.lang.Float" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Float].toDouble) - case "decimal" | "java.math.BigDecimal" => + case "java.math.BigDecimal" => writeType(dos, "double") val javaDecimal = value.asInstanceOf[java.math.BigDecimal] writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble) - case "double" | "java.lang.Double" => + case "java.lang.Double" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Double]) - case "int" | "java.lang.Integer" => + case "java.lang.Byte" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Byte].toInt) + case "java.lang.Short" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Short].toInt) + case "java.lang.Integer" => writeType(dos, "integer") writeInt(dos, value.asInstanceOf[Int]) - case "boolean" | "java.lang.Boolean" => + case "java.lang.Boolean" => writeType(dos, "logical") writeBoolean(dos, value.asInstanceOf[Boolean]) case "java.sql.Date" => @@ -242,43 +258,48 @@ private[spark] object SerDe { case "java.sql.Timestamp" => writeType(dos, "time") writeTime(dos, value.asInstanceOf[Timestamp]) + + // Handle arrays + + // Array of primitive types + + // Special handling for byte array case "[B" => writeType(dos, "raw") writeBytes(dos, value.asInstanceOf[Array[Byte]]) - // TODO: Types not handled right now include - // byte, char, short, float - // Handle arrays - case "[Ljava.lang.String;" => - writeType(dos, "list") - writeStringArr(dos, value.asInstanceOf[Array[String]]) + case "[C" => + writeType(dos, "array") + writeStringArr(dos, value.asInstanceOf[Array[Char]].map(_.toString)) + case "[S" => + writeType(dos, "array") + writeIntArr(dos, value.asInstanceOf[Array[Short]].map(_.toInt)) case "[I" => - writeType(dos, "list") + writeType(dos, "array") writeIntArr(dos, value.asInstanceOf[Array[Int]]) case "[J" => - writeType(dos, "list") + writeType(dos, "array") writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) + case "[F" => + writeType(dos, "array") + writeDoubleArr(dos, value.asInstanceOf[Array[Float]].map(_.toDouble)) case "[D" => - writeType(dos, "list") + writeType(dos, "array") writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) case "[Z" => - writeType(dos, "list") + writeType(dos, "array") writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) - case "[[B" => + + // Array of objects, null objects use "void" type + case c if c.startsWith("[") => writeType(dos, "list") - writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]]) - case otherName => - // Handle array of objects - if (otherName.startsWith("[L")) { - val objArr = value.asInstanceOf[Array[Object]] - writeType(dos, "list") - writeType(dos, "jobj") - dos.writeInt(objArr.length) - objArr.foreach(o => writeJObj(dos, o)) - } else { - writeType(dos, "jobj") - writeJObj(dos, value) - } + val array = value.asInstanceOf[Array[Object]] + writeInt(dos, array.length) + array.foreach(elem => writeObject(dos, elem)) + + case _ => + writeType(dos, "jobj") + writeJObj(dos, value) } } } @@ -350,11 +371,6 @@ private[spark] object SerDe { value.foreach(v => writeString(out, v)) } - def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { - writeType(out, "raw") - out.writeInt(value.length) - value.foreach(v => writeBytes(out, v)) - } } private[r] object SerializationFormats { |