diff options
author | Sun Rui <rui.sun@intel.com> | 2016-04-29 16:41:07 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@cs.berkeley.edu> | 2016-04-29 16:41:07 -0700 |
commit | 4ae9fe091c2cb8388c581093d62d3deaef40993e (patch) | |
tree | fd84ce605c0ea8bd9d0b2e307119bd5d8651c9f5 /sql | |
parent | d78fbcc3cc9c379b4a548ebc816c6f71cc71a16e (diff) | |
download | spark-4ae9fe091c2cb8388c581093d62d3deaef40993e.tar.gz spark-4ae9fe091c2cb8388c581093d62d3deaef40993e.tar.bz2 spark-4ae9fe091c2cb8388c581093d62d3deaef40993e.zip |
[SPARK-12919][SPARKR] Implement dapply() on DataFrame in SparkR.
## What changes were proposed in this pull request?
dapply() applies an R function on each partition of a DataFrame and returns a new DataFrame.
The function signature is:
dapply(df, function(localDF) {}, schema = NULL)
R function input: local data.frame from the partition on local node
R function output: local data.frame
Schema specifies the Row format of the resulting DataFrame. It must match the R function's output.
If schema is not specified, each partition of the result DataFrame will be serialized in R into a single byte array. Such resulting DataFrame can be processed by successive calls to dapply().
## How was this patch tested?
SparkR unit tests.
Author: Sun Rui <rui.sun@intel.com>
Author: Sun Rui <sunrui2016@gmail.com>
Closes #12493 from sun-rui/SPARK-12919.
Diffstat (limited to 'sql')
6 files changed, 179 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 434c033c49..abbd8facd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -159,10 +159,15 @@ object EliminateSerialization extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case d @ DeserializeToObject(_, _, s: SerializeFromObject) if d.outputObjectType == s.inputObjectType => - // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. - val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) - Project(objAttr :: Nil, s.child) - + // A workaround for SPARK-14803. Remove this after it is fixed. + if (d.outputObjectType.isInstanceOf[ObjectType] && + d.outputObjectType.asInstanceOf[ObjectType].cls == classOf[org.apache.spark.sql.Row]) { + s.child + } else { + // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. + val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) + Project(objAttr :: Nil, s.child) + } case a @ AppendColumns(_, _, _, s: SerializeFromObject) if a.deserializer.dataType == s.inputObjectType => AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 4a1bdb0b8a..84339f439a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.Encoder +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.{Encoder, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types._ object CatalystSerde { def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = { @@ -29,13 +30,26 @@ object CatalystSerde { DeserializeToObject(deserializer, generateObjAttr[T], child) } + def deserialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): DeserializeToObject = { + val deserializer = UnresolvedDeserializer(encoder.deserializer) + DeserializeToObject(deserializer, generateObjAttrForRow(encoder), child) + } + def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = { SerializeFromObject(encoderFor[T].namedExpressions, child) } + def serialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): SerializeFromObject = { + SerializeFromObject(encoder.namedExpressions, child) + } + def generateObjAttr[T : Encoder]: Attribute = { AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)() } + + def generateObjAttrForRow(encoder: ExpressionEncoder[Row]): Attribute = { + AttributeReference("obj", encoder.deserializer.dataType, nullable = false)() + } } /** @@ -106,6 +120,42 @@ case class MapPartitions( outputObjAttr: Attribute, child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer +object MapPartitionsInR { + def apply( + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + schema: StructType, + encoder: ExpressionEncoder[Row], + child: LogicalPlan): LogicalPlan = { + val deserialized = CatalystSerde.deserialize(child, encoder) + val mapped = MapPartitionsInR( + func, + packageNames, + broadcastVars, + encoder.schema, + schema, + CatalystSerde.generateObjAttrForRow(RowEncoder(schema)), + deserialized) + CatalystSerde.serialize(mapped, RowEncoder(schema)) + } +} + +/** + * A relation produced by applying a serialized R function `func` to each partition of the `child`. + * + */ +case class MapPartitionsInR( + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + inputSchema: StructType, + outputSchema: StructType, + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer { + override lazy val schema = outputSchema +} + object MapElements { def apply[T : Encoder, U : Encoder]( func: AnyRef, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 860249c211..1439d14980 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ import org.apache.spark.api.python.PythonRDD +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ @@ -1981,6 +1982,23 @@ class Dataset[T] private[sql]( } /** + * Returns a new [[DataFrame]] that contains the result of applying a serialized R function + * `func` to each partition. + * + * @group func + */ + private[sql] def mapPartitionsInR( + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + schema: StructType): DataFrame = { + val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]] + Dataset.ofRows( + sparkSession, + MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan)) + } + + /** * :: Experimental :: * (Scala-specific) * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 22ded7a4bf..36173a4925 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -23,12 +23,15 @@ import scala.util.matching.Regex import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.Encoder import org.apache.spark.sql.types._ -private[r] object SQLUtils { +private[sql] object SQLUtils { SerDe.registerSqlSerDe((readSqlObject, writeSqlObject)) def createSQLContext(jsc: JavaSparkContext): SQLContext = { @@ -111,7 +114,7 @@ private[r] object SQLUtils { } } - private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = { + private[sql] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = { val bis = new ByteArrayInputStream(bytes) val dis = new DataInputStream(bis) val num = SerDe.readInt(dis) @@ -120,7 +123,7 @@ private[r] object SQLUtils { }.toSeq) } - private[this] def rowToRBytes(row: Row): Array[Byte] = { + private[sql] def rowToRBytes(row: Row): Array[Byte] = { val bos = new ByteArrayOutputStream() val dos = new DataOutputStream(bos) @@ -129,6 +132,29 @@ private[r] object SQLUtils { bos.toByteArray() } + // Schema for DataFrame of serialized R data + // TODO: introduce a user defined type for serialized R data. + val SERIALIZED_R_DATA_SCHEMA = StructType(Seq(StructField("R", BinaryType))) + + /** + * The helper function for dapply() on R side. + */ + def dapply( + df: DataFrame, + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Object], + schema: StructType): DataFrame = { + val bv = broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]]) + val realSchema = + if (schema == null) { + SERIALIZED_R_DATA_SCHEMA + } else { + schema + } + df.mapPartitionsInR(func, packageNames, bv, realSchema) + } + def dfToCols(df: DataFrame): Array[Array[Any]] = { val localDF: Array[Row] = df.collect() val numCols = df.columns.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1eb1f8ef11..238334e26b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -307,6 +307,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.SerializeFromObjectExec(serializer, planLater(child)) :: Nil case logical.MapPartitions(f, objAttr, child) => execution.MapPartitionsExec(f, objAttr, planLater(child)) :: Nil + case logical.MapPartitionsInR(f, p, b, is, os, objAttr, child) => + execution.MapPartitionsExec( + execution.r.MapPartitionsRWrapper(f, p, b, is, os), objAttr, planLater(child)) :: Nil case logical.MapElements(f, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, in, out, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala new file mode 100644 index 0000000000..dc6f2ef371 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.r + +import org.apache.spark.api.r.RRunner +import org.apache.spark.api.r.SerializationFormats +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.api.r.SQLUtils._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{BinaryType, StructField, StructType} + +/** + * A function wrapper that applies the given R function to each partition. + */ +private[sql] case class MapPartitionsRWrapper( + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + inputSchema: StructType, + outputSchema: StructType) extends (Iterator[Any] => Iterator[Any]) { + def apply(iter: Iterator[Any]): Iterator[Any] = { + // If the content of current DataFrame is serialized R data? + val isSerializedRData = + if (inputSchema == SERIALIZED_R_DATA_SCHEMA) true else false + + val (newIter, deserializer, colNames) = + if (!isSerializedRData) { + // Serialize each row into an byte array that can be deserialized in the R worker + (iter.asInstanceOf[Iterator[Row]].map {row => rowToRBytes(row)}, + SerializationFormats.ROW, inputSchema.fieldNames) + } else { + (iter.asInstanceOf[Iterator[Row]].map { row => row(0) }, SerializationFormats.BYTE, null) + } + + val serializer = if (outputSchema != SERIALIZED_R_DATA_SCHEMA) { + SerializationFormats.ROW + } else { + SerializationFormats.BYTE + } + + val runner = new RRunner[Array[Byte]]( + func, deserializer, serializer, packageNames, broadcastVars, + isDataFrame = true, colNames = colNames) + // Partition index is ignored. Dataset has no support for mapPartitionsWithIndex. + val outputIter = runner.compute(newIter, -1) + + if (serializer == SerializationFormats.ROW) { + outputIter.map { bytes => bytesToRow(bytes, outputSchema) } + } else { + outputIter.map { bytes => Row.fromSeq(Seq(bytes)) } + } + } +} |