aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorSun Rui <rui.sun@intel.com>2016-04-29 16:41:07 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2016-04-29 16:41:07 -0700
commit4ae9fe091c2cb8388c581093d62d3deaef40993e (patch)
treefd84ce605c0ea8bd9d0b2e307119bd5d8651c9f5 /sql
parentd78fbcc3cc9c379b4a548ebc816c6f71cc71a16e (diff)
downloadspark-4ae9fe091c2cb8388c581093d62d3deaef40993e.tar.gz
spark-4ae9fe091c2cb8388c581093d62d3deaef40993e.tar.bz2
spark-4ae9fe091c2cb8388c581093d62d3deaef40993e.zip
[SPARK-12919][SPARKR] Implement dapply() on DataFrame in SparkR.
## What changes were proposed in this pull request? dapply() applies an R function on each partition of a DataFrame and returns a new DataFrame. The function signature is: dapply(df, function(localDF) {}, schema = NULL) R function input: local data.frame from the partition on local node R function output: local data.frame Schema specifies the Row format of the resulting DataFrame. It must match the R function's output. If schema is not specified, each partition of the result DataFrame will be serialized in R into a single byte array. Such resulting DataFrame can be processed by successive calls to dapply(). ## How was this patch tested? SparkR unit tests. Author: Sun Rui <rui.sun@intel.com> Author: Sun Rui <sunrui2016@gmail.com> Closes #12493 from sun-rui/SPARK-12919.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala54
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala68
6 files changed, 179 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 434c033c49..abbd8facd3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -159,10 +159,15 @@ object EliminateSerialization extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case d @ DeserializeToObject(_, _, s: SerializeFromObject)
if d.outputObjectType == s.inputObjectType =>
- // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
- val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
- Project(objAttr :: Nil, s.child)
-
+ // A workaround for SPARK-14803. Remove this after it is fixed.
+ if (d.outputObjectType.isInstanceOf[ObjectType] &&
+ d.outputObjectType.asInstanceOf[ObjectType].cls == classOf[org.apache.spark.sql.Row]) {
+ s.child
+ } else {
+ // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
+ val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
+ Project(objAttr :: Nil, s.child)
+ }
case a @ AppendColumns(_, _, _, s: SerializeFromObject)
if a.deserializer.dataType == s.inputObjectType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 4a1bdb0b8a..84339f439a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -17,11 +17,12 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.Encoder
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types._
object CatalystSerde {
def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
@@ -29,13 +30,26 @@ object CatalystSerde {
DeserializeToObject(deserializer, generateObjAttr[T], child)
}
+ def deserialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): DeserializeToObject = {
+ val deserializer = UnresolvedDeserializer(encoder.deserializer)
+ DeserializeToObject(deserializer, generateObjAttrForRow(encoder), child)
+ }
+
def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
SerializeFromObject(encoderFor[T].namedExpressions, child)
}
+ def serialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): SerializeFromObject = {
+ SerializeFromObject(encoder.namedExpressions, child)
+ }
+
def generateObjAttr[T : Encoder]: Attribute = {
AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)()
}
+
+ def generateObjAttrForRow(encoder: ExpressionEncoder[Row]): Attribute = {
+ AttributeReference("obj", encoder.deserializer.dataType, nullable = false)()
+ }
}
/**
@@ -106,6 +120,42 @@ case class MapPartitions(
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer
+object MapPartitionsInR {
+ def apply(
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ schema: StructType,
+ encoder: ExpressionEncoder[Row],
+ child: LogicalPlan): LogicalPlan = {
+ val deserialized = CatalystSerde.deserialize(child, encoder)
+ val mapped = MapPartitionsInR(
+ func,
+ packageNames,
+ broadcastVars,
+ encoder.schema,
+ schema,
+ CatalystSerde.generateObjAttrForRow(RowEncoder(schema)),
+ deserialized)
+ CatalystSerde.serialize(mapped, RowEncoder(schema))
+ }
+}
+
+/**
+ * A relation produced by applying a serialized R function `func` to each partition of the `child`.
+ *
+ */
+case class MapPartitionsInR(
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ inputSchema: StructType,
+ outputSchema: StructType,
+ outputObjAttr: Attribute,
+ child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer {
+ override lazy val schema = outputSchema
+}
+
object MapElements {
def apply[T : Encoder, U : Encoder](
func: AnyRef,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 860249c211..1439d14980 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
import org.apache.spark.api.python.PythonRDD
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
@@ -1981,6 +1982,23 @@ class Dataset[T] private[sql](
}
/**
+ * Returns a new [[DataFrame]] that contains the result of applying a serialized R function
+ * `func` to each partition.
+ *
+ * @group func
+ */
+ private[sql] def mapPartitionsInR(
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ schema: StructType): DataFrame = {
+ val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]]
+ Dataset.ofRows(
+ sparkSession,
+ MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan))
+ }
+
+ /**
* :: Experimental ::
* (Scala-specific)
* Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 22ded7a4bf..36173a4925 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -23,12 +23,15 @@ import scala.util.matching.Regex
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.r.SerDe
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext}
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
+import org.apache.spark.sql.Encoder
import org.apache.spark.sql.types._
-private[r] object SQLUtils {
+private[sql] object SQLUtils {
SerDe.registerSqlSerDe((readSqlObject, writeSqlObject))
def createSQLContext(jsc: JavaSparkContext): SQLContext = {
@@ -111,7 +114,7 @@ private[r] object SQLUtils {
}
}
- private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = {
+ private[sql] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = {
val bis = new ByteArrayInputStream(bytes)
val dis = new DataInputStream(bis)
val num = SerDe.readInt(dis)
@@ -120,7 +123,7 @@ private[r] object SQLUtils {
}.toSeq)
}
- private[this] def rowToRBytes(row: Row): Array[Byte] = {
+ private[sql] def rowToRBytes(row: Row): Array[Byte] = {
val bos = new ByteArrayOutputStream()
val dos = new DataOutputStream(bos)
@@ -129,6 +132,29 @@ private[r] object SQLUtils {
bos.toByteArray()
}
+ // Schema for DataFrame of serialized R data
+ // TODO: introduce a user defined type for serialized R data.
+ val SERIALIZED_R_DATA_SCHEMA = StructType(Seq(StructField("R", BinaryType)))
+
+ /**
+ * The helper function for dapply() on R side.
+ */
+ def dapply(
+ df: DataFrame,
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Object],
+ schema: StructType): DataFrame = {
+ val bv = broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])
+ val realSchema =
+ if (schema == null) {
+ SERIALIZED_R_DATA_SCHEMA
+ } else {
+ schema
+ }
+ df.mapPartitionsInR(func, packageNames, bv, realSchema)
+ }
+
def dfToCols(df: DataFrame): Array[Array[Any]] = {
val localDF: Array[Row] = df.collect()
val numCols = df.columns.length
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 1eb1f8ef11..238334e26b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -307,6 +307,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.SerializeFromObjectExec(serializer, planLater(child)) :: Nil
case logical.MapPartitions(f, objAttr, child) =>
execution.MapPartitionsExec(f, objAttr, planLater(child)) :: Nil
+ case logical.MapPartitionsInR(f, p, b, is, os, objAttr, child) =>
+ execution.MapPartitionsExec(
+ execution.r.MapPartitionsRWrapper(f, p, b, is, os), objAttr, planLater(child)) :: Nil
case logical.MapElements(f, objAttr, child) =>
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
case logical.AppendColumns(f, in, out, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
new file mode 100644
index 0000000000..dc6f2ef371
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.r
+
+import org.apache.spark.api.r.RRunner
+import org.apache.spark.api.r.SerializationFormats
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.api.r.SQLUtils._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
+
+/**
+ * A function wrapper that applies the given R function to each partition.
+ */
+private[sql] case class MapPartitionsRWrapper(
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ inputSchema: StructType,
+ outputSchema: StructType) extends (Iterator[Any] => Iterator[Any]) {
+ def apply(iter: Iterator[Any]): Iterator[Any] = {
+ // If the content of current DataFrame is serialized R data?
+ val isSerializedRData =
+ if (inputSchema == SERIALIZED_R_DATA_SCHEMA) true else false
+
+ val (newIter, deserializer, colNames) =
+ if (!isSerializedRData) {
+ // Serialize each row into an byte array that can be deserialized in the R worker
+ (iter.asInstanceOf[Iterator[Row]].map {row => rowToRBytes(row)},
+ SerializationFormats.ROW, inputSchema.fieldNames)
+ } else {
+ (iter.asInstanceOf[Iterator[Row]].map { row => row(0) }, SerializationFormats.BYTE, null)
+ }
+
+ val serializer = if (outputSchema != SERIALIZED_R_DATA_SCHEMA) {
+ SerializationFormats.ROW
+ } else {
+ SerializationFormats.BYTE
+ }
+
+ val runner = new RRunner[Array[Byte]](
+ func, deserializer, serializer, packageNames, broadcastVars,
+ isDataFrame = true, colNames = colNames)
+ // Partition index is ignored. Dataset has no support for mapPartitionsWithIndex.
+ val outputIter = runner.compute(newIter, -1)
+
+ if (serializer == SerializationFormats.ROW) {
+ outputIter.map { bytes => bytesToRow(bytes, outputSchema) }
+ } else {
+ outputIter.map { bytes => Row.fromSeq(Seq(bytes)) }
+ }
+ }
+}