aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorSun Rui <rui.sun@intel.com>2015-09-10 12:21:13 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-09-10 12:21:13 -0700
commit45e3be5c138d983f40f619735d60bf7eb78c9bf0 (patch)
tree30b7b90f53eadee901a56e0e2e84222e21cf6c44 /core
parentd88abb7e212fb55f9b0398a0f76a753c86b85cf1 (diff)
downloadspark-45e3be5c138d983f40f619735d60bf7eb78c9bf0.tar.gz
spark-45e3be5c138d983f40f619735d60bf7eb78c9bf0.tar.bz2
spark-45e3be5c138d983f40f619735d60bf7eb78c9bf0.zip
[SPARK-10049] [SPARKR] Support collecting data of ArraryType in DataFrame.
this PR : 1. Enhance reflection in RBackend. Automatically matching a Java array to Scala Seq when finding methods. Util functions like seq(), listToSeq() in R side can be removed, as they will conflict with the Serde logic that transferrs a Scala seq to R side. 2. Enhance the SerDe to support transferring a Scala seq to R side. Data of ArrayType in DataFrame after collection is observed to be of Scala Seq type. 3. Support ArrayType in createDataFrame(). Author: Sun Rui <rui.sun@intel.com> Closes #8458 from sun-rui/SPARK-10049.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala121
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/SerDe.scala109
2 files changed, 145 insertions, 85 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index bb82f3285f..2a792d8199 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -125,10 +125,11 @@ private[r] class RBackendHandler(server: RBackend)
val methods = cls.getMethods
val selectedMethods = methods.filter(m => m.getName == methodName)
if (selectedMethods.length > 0) {
- val methods = selectedMethods.filter { x =>
- matchMethod(numArgs, args, x.getParameterTypes)
- }
- if (methods.isEmpty) {
+ val index = findMatchedSignature(
+ selectedMethods.map(_.getParameterTypes),
+ args)
+
+ if (index.isEmpty) {
logWarning(s"cannot find matching method ${cls}.$methodName. "
+ s"Candidates are:")
selectedMethods.foreach { method =>
@@ -136,18 +137,29 @@ private[r] class RBackendHandler(server: RBackend)
}
throw new Exception(s"No matched method found for $cls.$methodName")
}
- val ret = methods.head.invoke(obj, args : _*)
+
+ val ret = selectedMethods(index.get).invoke(obj, args : _*)
// Write status bit
writeInt(dos, 0)
writeObject(dos, ret.asInstanceOf[AnyRef])
} else if (methodName == "<init>") {
// methodName should be "<init>" for constructor
- val ctor = cls.getConstructors.filter { x =>
- matchMethod(numArgs, args, x.getParameterTypes)
- }.head
+ val ctors = cls.getConstructors
+ val index = findMatchedSignature(
+ ctors.map(_.getParameterTypes),
+ args)
- val obj = ctor.newInstance(args : _*)
+ if (index.isEmpty) {
+ logWarning(s"cannot find matching constructor for ${cls}. "
+ + s"Candidates are:")
+ ctors.foreach { ctor =>
+ logWarning(s"$cls(${ctor.getParameterTypes.mkString(",")})")
+ }
+ throw new Exception(s"No matched constructor found for $cls")
+ }
+
+ val obj = ctors(index.get).newInstance(args : _*)
writeInt(dos, 0)
writeObject(dos, obj.asInstanceOf[AnyRef])
@@ -166,40 +178,79 @@ private[r] class RBackendHandler(server: RBackend)
// Read a number of arguments from the data input stream
def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = {
- (0 until numArgs).map { arg =>
+ (0 until numArgs).map { _ =>
readObject(dis)
}.toArray
}
- // Checks if the arguments passed in args matches the parameter types.
- // NOTE: Currently we do exact match. We may add type conversions later.
- def matchMethod(
- numArgs: Int,
- args: Array[java.lang.Object],
- parameterTypes: Array[Class[_]]): Boolean = {
- if (parameterTypes.length != numArgs) {
- return false
- }
+ // Find a matching method signature in an array of signatures of constructors
+ // or methods of the same name according to the passed arguments. Arguments
+ // may be converted in order to match a signature.
+ //
+ // Note that in Java reflection, constructors and normal methods are of different
+ // classes, and share no parent class that provides methods for reflection uses.
+ // There is no unified way to handle them in this function. So an array of signatures
+ // is passed in instead of an array of candidate constructors or methods.
+ //
+ // Returns an Option[Int] which is the index of the matched signature in the array.
+ def findMatchedSignature(
+ parameterTypesOfMethods: Array[Array[Class[_]]],
+ args: Array[Object]): Option[Int] = {
+ val numArgs = args.length
+
+ for (index <- 0 until parameterTypesOfMethods.length) {
+ val parameterTypes = parameterTypesOfMethods(index)
+
+ if (parameterTypes.length == numArgs) {
+ var argMatched = true
+ var i = 0
+ while (i < numArgs && argMatched) {
+ val parameterType = parameterTypes(i)
+
+ if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) {
+ // The case that the parameter type is a Scala Seq and the argument
+ // is a Java array is considered matching. The array will be converted
+ // to a Seq later if this method is matched.
+ } else {
+ var parameterWrapperType = parameterType
+
+ // Convert native parameters to Object types as args is Array[Object] here
+ if (parameterType.isPrimitive) {
+ parameterWrapperType = parameterType match {
+ case java.lang.Integer.TYPE => classOf[java.lang.Integer]
+ case java.lang.Long.TYPE => classOf[java.lang.Integer]
+ case java.lang.Double.TYPE => classOf[java.lang.Double]
+ case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
+ case _ => parameterType
+ }
+ }
+ if (!parameterWrapperType.isInstance(args(i))) {
+ argMatched = false
+ }
+ }
- for (i <- 0 to numArgs - 1) {
- val parameterType = parameterTypes(i)
- var parameterWrapperType = parameterType
-
- // Convert native parameters to Object types as args is Array[Object] here
- if (parameterType.isPrimitive) {
- parameterWrapperType = parameterType match {
- case java.lang.Integer.TYPE => classOf[java.lang.Integer]
- case java.lang.Long.TYPE => classOf[java.lang.Integer]
- case java.lang.Double.TYPE => classOf[java.lang.Double]
- case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
- case _ => parameterType
+ i = i + 1
+ }
+
+ if (argMatched) {
+ // For now, we return the first matching method.
+ // TODO: find best method in matching methods.
+
+ // Convert args if needed
+ val parameterTypes = parameterTypesOfMethods(index)
+
+ (0 until numArgs).map { i =>
+ if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) {
+ // Convert a Java array to scala Seq
+ args(i) = args(i).asInstanceOf[Array[_]].toSeq
+ }
+ }
+
+ return Some(index)
}
- }
- if (!parameterWrapperType.isInstance(args(i))) {
- return false
}
}
- true
+ None
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 190e193427..3c92bb7a1c 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -21,6 +21,7 @@ import java.io.{DataInputStream, DataOutputStream}
import java.sql.{Timestamp, Date, Time}
import scala.collection.JavaConverters._
+import scala.collection.mutable.WrappedArray
/**
* Utility functions to serialize, deserialize objects to / from R
@@ -213,89 +214,97 @@ private[spark] object SerDe {
}
}
- def writeObject(dos: DataOutputStream, value: Object): Unit = {
- if (value == null) {
+ def writeObject(dos: DataOutputStream, obj: Object): Unit = {
+ if (obj == null) {
writeType(dos, "void")
} else {
- value.getClass.getName match {
- case "java.lang.Character" =>
+ // Convert ArrayType collected from DataFrame to Java array
+ // Collected data of ArrayType from a DataFrame is observed to be of
+ // type "scala.collection.mutable.WrappedArray"
+ val value =
+ if (obj.isInstanceOf[WrappedArray[_]]) {
+ obj.asInstanceOf[WrappedArray[_]].toArray
+ } else {
+ obj
+ }
+
+ value match {
+ case v: java.lang.Character =>
writeType(dos, "character")
- writeString(dos, value.asInstanceOf[Character].toString)
- case "java.lang.String" =>
+ writeString(dos, v.toString)
+ case v: java.lang.String =>
writeType(dos, "character")
- writeString(dos, value.asInstanceOf[String])
- case "java.lang.Long" =>
+ writeString(dos, v)
+ case v: java.lang.Long =>
writeType(dos, "double")
- writeDouble(dos, value.asInstanceOf[Long].toDouble)
- case "java.lang.Float" =>
+ writeDouble(dos, v.toDouble)
+ case v: java.lang.Float =>
writeType(dos, "double")
- writeDouble(dos, value.asInstanceOf[Float].toDouble)
- case "java.math.BigDecimal" =>
+ writeDouble(dos, v.toDouble)
+ case v: java.math.BigDecimal =>
writeType(dos, "double")
- val javaDecimal = value.asInstanceOf[java.math.BigDecimal]
- writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble)
- case "java.lang.Double" =>
+ writeDouble(dos, scala.math.BigDecimal(v).toDouble)
+ case v: java.lang.Double =>
writeType(dos, "double")
- writeDouble(dos, value.asInstanceOf[Double])
- case "java.lang.Byte" =>
+ writeDouble(dos, v)
+ case v: java.lang.Byte =>
writeType(dos, "integer")
- writeInt(dos, value.asInstanceOf[Byte].toInt)
- case "java.lang.Short" =>
+ writeInt(dos, v.toInt)
+ case v: java.lang.Short =>
writeType(dos, "integer")
- writeInt(dos, value.asInstanceOf[Short].toInt)
- case "java.lang.Integer" =>
+ writeInt(dos, v.toInt)
+ case v: java.lang.Integer =>
writeType(dos, "integer")
- writeInt(dos, value.asInstanceOf[Int])
- case "java.lang.Boolean" =>
+ writeInt(dos, v)
+ case v: java.lang.Boolean =>
writeType(dos, "logical")
- writeBoolean(dos, value.asInstanceOf[Boolean])
- case "java.sql.Date" =>
+ writeBoolean(dos, v)
+ case v: java.sql.Date =>
writeType(dos, "date")
- writeDate(dos, value.asInstanceOf[Date])
- case "java.sql.Time" =>
+ writeDate(dos, v)
+ case v: java.sql.Time =>
writeType(dos, "time")
- writeTime(dos, value.asInstanceOf[Time])
- case "java.sql.Timestamp" =>
+ writeTime(dos, v)
+ case v: java.sql.Timestamp =>
writeType(dos, "time")
- writeTime(dos, value.asInstanceOf[Timestamp])
+ writeTime(dos, v)
// Handle arrays
// Array of primitive types
// Special handling for byte array
- case "[B" =>
+ case v: Array[Byte] =>
writeType(dos, "raw")
- writeBytes(dos, value.asInstanceOf[Array[Byte]])
+ writeBytes(dos, v)
- case "[C" =>
+ case v: Array[Char] =>
writeType(dos, "array")
- writeStringArr(dos, value.asInstanceOf[Array[Char]].map(_.toString))
- case "[S" =>
+ writeStringArr(dos, v.map(_.toString))
+ case v: Array[Short] =>
writeType(dos, "array")
- writeIntArr(dos, value.asInstanceOf[Array[Short]].map(_.toInt))
- case "[I" =>
+ writeIntArr(dos, v.map(_.toInt))
+ case v: Array[Int] =>
writeType(dos, "array")
- writeIntArr(dos, value.asInstanceOf[Array[Int]])
- case "[J" =>
+ writeIntArr(dos, v)
+ case v: Array[Long] =>
writeType(dos, "array")
- writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble))
- case "[F" =>
+ writeDoubleArr(dos, v.map(_.toDouble))
+ case v: Array[Float] =>
writeType(dos, "array")
- writeDoubleArr(dos, value.asInstanceOf[Array[Float]].map(_.toDouble))
- case "[D" =>
+ writeDoubleArr(dos, v.map(_.toDouble))
+ case v: Array[Double] =>
writeType(dos, "array")
- writeDoubleArr(dos, value.asInstanceOf[Array[Double]])
- case "[Z" =>
+ writeDoubleArr(dos, v)
+ case v: Array[Boolean] =>
writeType(dos, "array")
- writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]])
+ writeBooleanArr(dos, v)
// Array of objects, null objects use "void" type
- case c if c.startsWith("[") =>
+ case v: Array[Object] =>
writeType(dos, "list")
- val array = value.asInstanceOf[Array[Object]]
- writeInt(dos, array.length)
- array.foreach(elem => writeObject(dos, elem))
+ writeInt(dos, v.length)
+ v.foreach(elem => writeObject(dos, elem))
case _ =>
writeType(dos, "jobj")