From 45e3be5c138d983f40f619735d60bf7eb78c9bf0 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Thu, 10 Sep 2015 12:21:13 -0700 Subject: [SPARK-10049] [SPARKR] Support collecting data of ArraryType in DataFrame. this PR : 1. Enhance reflection in RBackend. Automatically matching a Java array to Scala Seq when finding methods. Util functions like seq(), listToSeq() in R side can be removed, as they will conflict with the Serde logic that transferrs a Scala seq to R side. 2. Enhance the SerDe to support transferring a Scala seq to R side. Data of ArrayType in DataFrame after collection is observed to be of Scala Seq type. 3. Support ArrayType in createDataFrame(). Author: Sun Rui Closes #8458 from sun-rui/SPARK-10049. --- .../org/apache/spark/api/r/RBackendHandler.scala | 121 +++++++++++++++------ .../main/scala/org/apache/spark/api/r/SerDe.scala | 109 ++++++++++--------- 2 files changed, 145 insertions(+), 85 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index bb82f3285f..2a792d8199 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -125,10 +125,11 @@ private[r] class RBackendHandler(server: RBackend) val methods = cls.getMethods val selectedMethods = methods.filter(m => m.getName == methodName) if (selectedMethods.length > 0) { - val methods = selectedMethods.filter { x => - matchMethod(numArgs, args, x.getParameterTypes) - } - if (methods.isEmpty) { + val index = findMatchedSignature( + selectedMethods.map(_.getParameterTypes), + args) + + if (index.isEmpty) { logWarning(s"cannot find matching method ${cls}.$methodName. " + s"Candidates are:") selectedMethods.foreach { method => @@ -136,18 +137,29 @@ private[r] class RBackendHandler(server: RBackend) } throw new Exception(s"No matched method found for $cls.$methodName") } - val ret = methods.head.invoke(obj, args : _*) + + val ret = selectedMethods(index.get).invoke(obj, args : _*) // Write status bit writeInt(dos, 0) writeObject(dos, ret.asInstanceOf[AnyRef]) } else if (methodName == "") { // methodName should be "" for constructor - val ctor = cls.getConstructors.filter { x => - matchMethod(numArgs, args, x.getParameterTypes) - }.head + val ctors = cls.getConstructors + val index = findMatchedSignature( + ctors.map(_.getParameterTypes), + args) - val obj = ctor.newInstance(args : _*) + if (index.isEmpty) { + logWarning(s"cannot find matching constructor for ${cls}. " + + s"Candidates are:") + ctors.foreach { ctor => + logWarning(s"$cls(${ctor.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched constructor found for $cls") + } + + val obj = ctors(index.get).newInstance(args : _*) writeInt(dos, 0) writeObject(dos, obj.asInstanceOf[AnyRef]) @@ -166,40 +178,79 @@ private[r] class RBackendHandler(server: RBackend) // Read a number of arguments from the data input stream def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { - (0 until numArgs).map { arg => + (0 until numArgs).map { _ => readObject(dis) }.toArray } - // Checks if the arguments passed in args matches the parameter types. - // NOTE: Currently we do exact match. We may add type conversions later. - def matchMethod( - numArgs: Int, - args: Array[java.lang.Object], - parameterTypes: Array[Class[_]]): Boolean = { - if (parameterTypes.length != numArgs) { - return false - } + // Find a matching method signature in an array of signatures of constructors + // or methods of the same name according to the passed arguments. Arguments + // may be converted in order to match a signature. + // + // Note that in Java reflection, constructors and normal methods are of different + // classes, and share no parent class that provides methods for reflection uses. + // There is no unified way to handle them in this function. So an array of signatures + // is passed in instead of an array of candidate constructors or methods. + // + // Returns an Option[Int] which is the index of the matched signature in the array. + def findMatchedSignature( + parameterTypesOfMethods: Array[Array[Class[_]]], + args: Array[Object]): Option[Int] = { + val numArgs = args.length + + for (index <- 0 until parameterTypesOfMethods.length) { + val parameterTypes = parameterTypesOfMethods(index) + + if (parameterTypes.length == numArgs) { + var argMatched = true + var i = 0 + while (i < numArgs && argMatched) { + val parameterType = parameterTypes(i) + + if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) { + // The case that the parameter type is a Scala Seq and the argument + // is a Java array is considered matching. The array will be converted + // to a Seq later if this method is matched. + } else { + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Integer] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if (!parameterWrapperType.isInstance(args(i))) { + argMatched = false + } + } - for (i <- 0 to numArgs - 1) { - val parameterType = parameterTypes(i) - var parameterWrapperType = parameterType - - // Convert native parameters to Object types as args is Array[Object] here - if (parameterType.isPrimitive) { - parameterWrapperType = parameterType match { - case java.lang.Integer.TYPE => classOf[java.lang.Integer] - case java.lang.Long.TYPE => classOf[java.lang.Integer] - case java.lang.Double.TYPE => classOf[java.lang.Double] - case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] - case _ => parameterType + i = i + 1 + } + + if (argMatched) { + // For now, we return the first matching method. + // TODO: find best method in matching methods. + + // Convert args if needed + val parameterTypes = parameterTypesOfMethods(index) + + (0 until numArgs).map { i => + if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) { + // Convert a Java array to scala Seq + args(i) = args(i).asInstanceOf[Array[_]].toSeq + } + } + + return Some(index) } - } - if (!parameterWrapperType.isInstance(args(i))) { - return false } } - true + None } } diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 190e193427..3c92bb7a1c 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -21,6 +21,7 @@ import java.io.{DataInputStream, DataOutputStream} import java.sql.{Timestamp, Date, Time} import scala.collection.JavaConverters._ +import scala.collection.mutable.WrappedArray /** * Utility functions to serialize, deserialize objects to / from R @@ -213,89 +214,97 @@ private[spark] object SerDe { } } - def writeObject(dos: DataOutputStream, value: Object): Unit = { - if (value == null) { + def writeObject(dos: DataOutputStream, obj: Object): Unit = { + if (obj == null) { writeType(dos, "void") } else { - value.getClass.getName match { - case "java.lang.Character" => + // Convert ArrayType collected from DataFrame to Java array + // Collected data of ArrayType from a DataFrame is observed to be of + // type "scala.collection.mutable.WrappedArray" + val value = + if (obj.isInstanceOf[WrappedArray[_]]) { + obj.asInstanceOf[WrappedArray[_]].toArray + } else { + obj + } + + value match { + case v: java.lang.Character => writeType(dos, "character") - writeString(dos, value.asInstanceOf[Character].toString) - case "java.lang.String" => + writeString(dos, v.toString) + case v: java.lang.String => writeType(dos, "character") - writeString(dos, value.asInstanceOf[String]) - case "java.lang.Long" => + writeString(dos, v) + case v: java.lang.Long => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Long].toDouble) - case "java.lang.Float" => + writeDouble(dos, v.toDouble) + case v: java.lang.Float => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Float].toDouble) - case "java.math.BigDecimal" => + writeDouble(dos, v.toDouble) + case v: java.math.BigDecimal => writeType(dos, "double") - val javaDecimal = value.asInstanceOf[java.math.BigDecimal] - writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble) - case "java.lang.Double" => + writeDouble(dos, scala.math.BigDecimal(v).toDouble) + case v: java.lang.Double => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Double]) - case "java.lang.Byte" => + writeDouble(dos, v) + case v: java.lang.Byte => writeType(dos, "integer") - writeInt(dos, value.asInstanceOf[Byte].toInt) - case "java.lang.Short" => + writeInt(dos, v.toInt) + case v: java.lang.Short => writeType(dos, "integer") - writeInt(dos, value.asInstanceOf[Short].toInt) - case "java.lang.Integer" => + writeInt(dos, v.toInt) + case v: java.lang.Integer => writeType(dos, "integer") - writeInt(dos, value.asInstanceOf[Int]) - case "java.lang.Boolean" => + writeInt(dos, v) + case v: java.lang.Boolean => writeType(dos, "logical") - writeBoolean(dos, value.asInstanceOf[Boolean]) - case "java.sql.Date" => + writeBoolean(dos, v) + case v: java.sql.Date => writeType(dos, "date") - writeDate(dos, value.asInstanceOf[Date]) - case "java.sql.Time" => + writeDate(dos, v) + case v: java.sql.Time => writeType(dos, "time") - writeTime(dos, value.asInstanceOf[Time]) - case "java.sql.Timestamp" => + writeTime(dos, v) + case v: java.sql.Timestamp => writeType(dos, "time") - writeTime(dos, value.asInstanceOf[Timestamp]) + writeTime(dos, v) // Handle arrays // Array of primitive types // Special handling for byte array - case "[B" => + case v: Array[Byte] => writeType(dos, "raw") - writeBytes(dos, value.asInstanceOf[Array[Byte]]) + writeBytes(dos, v) - case "[C" => + case v: Array[Char] => writeType(dos, "array") - writeStringArr(dos, value.asInstanceOf[Array[Char]].map(_.toString)) - case "[S" => + writeStringArr(dos, v.map(_.toString)) + case v: Array[Short] => writeType(dos, "array") - writeIntArr(dos, value.asInstanceOf[Array[Short]].map(_.toInt)) - case "[I" => + writeIntArr(dos, v.map(_.toInt)) + case v: Array[Int] => writeType(dos, "array") - writeIntArr(dos, value.asInstanceOf[Array[Int]]) - case "[J" => + writeIntArr(dos, v) + case v: Array[Long] => writeType(dos, "array") - writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) - case "[F" => + writeDoubleArr(dos, v.map(_.toDouble)) + case v: Array[Float] => writeType(dos, "array") - writeDoubleArr(dos, value.asInstanceOf[Array[Float]].map(_.toDouble)) - case "[D" => + writeDoubleArr(dos, v.map(_.toDouble)) + case v: Array[Double] => writeType(dos, "array") - writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) - case "[Z" => + writeDoubleArr(dos, v) + case v: Array[Boolean] => writeType(dos, "array") - writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) + writeBooleanArr(dos, v) // Array of objects, null objects use "void" type - case c if c.startsWith("[") => + case v: Array[Object] => writeType(dos, "list") - val array = value.asInstanceOf[Array[Object]] - writeInt(dos, array.length) - array.foreach(elem => writeObject(dos, elem)) + writeInt(dos, v.length) + v.foreach(elem => writeObject(dos, elem)) case _ => writeType(dos, "jobj") -- cgit v1.2.3