aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSun Rui <rui.sun@intel.com>2016-04-29 16:41:07 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2016-04-29 16:41:07 -0700
commit4ae9fe091c2cb8388c581093d62d3deaef40993e (patch)
treefd84ce605c0ea8bd9d0b2e307119bd5d8651c9f5
parentd78fbcc3cc9c379b4a548ebc816c6f71cc71a16e (diff)
downloadspark-4ae9fe091c2cb8388c581093d62d3deaef40993e.tar.gz
spark-4ae9fe091c2cb8388c581093d62d3deaef40993e.tar.bz2
spark-4ae9fe091c2cb8388c581093d62d3deaef40993e.zip
[SPARK-12919][SPARKR] Implement dapply() on DataFrame in SparkR.
## What changes were proposed in this pull request? dapply() applies an R function on each partition of a DataFrame and returns a new DataFrame. The function signature is: dapply(df, function(localDF) {}, schema = NULL) R function input: local data.frame from the partition on local node R function output: local data.frame Schema specifies the Row format of the resulting DataFrame. It must match the R function's output. If schema is not specified, each partition of the result DataFrame will be serialized in R into a single byte array. Such resulting DataFrame can be processed by successive calls to dapply(). ## How was this patch tested? SparkR unit tests. Author: Sun Rui <rui.sun@intel.com> Author: Sun Rui <sunrui2016@gmail.com> Closes #12493 from sun-rui/SPARK-12919.
-rw-r--r--R/pkg/NAMESPACE1
-rw-r--r--R/pkg/R/DataFrame.R61
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R40
-rw-r--r--R/pkg/inst/worker/worker.R36
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RRunner.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/SerDe.scala2
-rw-r--r--docs/sql-programming-guide.md5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala54
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala68
15 files changed, 337 insertions, 15 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 002e469efb..647db22747 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -45,6 +45,7 @@ exportMethods("arrange",
"covar_samp",
"covar_pop",
"crosstab",
+ "dapply",
"describe",
"dim",
"distinct",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index a741fdf709..9e30fa0dbf 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -21,6 +21,7 @@
NULL
setOldClass("jobj")
+setOldClass("structType")
#' @title S4 class that represents a SparkDataFrame
#' @description DataFrames can be created using functions like \link{createDataFrame},
@@ -1125,6 +1126,66 @@ setMethod("summarize",
agg(x, ...)
})
+#' dapply
+#'
+#' Apply a function to each partition of a DataFrame.
+#'
+#' @param x A SparkDataFrame
+#' @param func A function to be applied to each partition of the SparkDataFrame.
+#' func should have only one parameter, to which a data.frame corresponds
+#' to each partition will be passed.
+#' The output of func should be a data.frame.
+#' @param schema The schema of the resulting DataFrame after the function is applied.
+#' It must match the output of func.
+#' @family SparkDataFrame functions
+#' @rdname dapply
+#' @name dapply
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- createDataFrame (sqlContext, iris)
+#' df1 <- dapply(df, function(x) { x }, schema(df))
+#' collect(df1)
+#'
+#' # filter and add a column
+#' df <- createDataFrame (
+#' sqlContext,
+#' list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")),
+#' c("a", "b", "c"))
+#' schema <- structType(structField("a", "integer"), structField("b", "double"),
+#' structField("c", "string"), structField("d", "integer"))
+#' df1 <- dapply(
+#' df,
+#' function(x) {
+#' y <- x[x[1] > 1, ]
+#' y <- cbind(y, y[1] + 1L)
+#' },
+#' schema)
+#' collect(df1)
+#' # the result
+#' # a b c d
+#' # 1 2 2 2 3
+#' # 2 3 3 3 4
+#' }
+setMethod("dapply",
+ signature(x = "SparkDataFrame", func = "function", schema = "structType"),
+ function(x, func, schema) {
+ packageNamesArr <- serialize(.sparkREnv[[".packages"]],
+ connection = NULL)
+
+ broadcastArr <- lapply(ls(.broadcastNames),
+ function(name) { get(name, .broadcastNames) })
+
+ sdf <- callJStatic(
+ "org.apache.spark.sql.api.r.SQLUtils",
+ "dapply",
+ x@sdf,
+ serialize(cleanClosure(func), connection = NULL),
+ packageNamesArr,
+ broadcastArr,
+ schema$jobj)
+ dataFrame(sdf)
+ })
############################## RDD Map Functions ##################################
# All of the following functions mirror the existing RDD map functions, #
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 62907118ef..3db8925730 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -446,6 +446,10 @@ setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") })
#' @export
setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") })
+#' @rdname dapply
+#' @export
+setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") })
+
#' @rdname summary
#' @export
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 7058265ea3..5cf9dc405b 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -2043,6 +2043,46 @@ test_that("Histogram", {
df <- as.DataFrame(sqlContext, data.frame(x = c(1, 2, 3, 4, 100)))
expect_equal(histogram(df, "x")$counts, c(4, 0, 0, 0, 0, 0, 0, 0, 0, 1))
})
+
+test_that("dapply() on a DataFrame", {
+ df <- createDataFrame (
+ sqlContext,
+ list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")),
+ c("a", "b", "c"))
+ ldf <- collect(df)
+ df1 <- dapply(df, function(x) { x }, schema(df))
+ result <- collect(df1)
+ expect_identical(ldf, result)
+
+
+ # Filter and add a column
+ schema <- structType(structField("a", "integer"), structField("b", "double"),
+ structField("c", "string"), structField("d", "integer"))
+ df1 <- dapply(
+ df,
+ function(x) {
+ y <- x[x$a > 1, ]
+ y <- cbind(y, y$a + 1L)
+ },
+ schema)
+ result <- collect(df1)
+ expected <- ldf[ldf$a > 1, ]
+ expected$d <- expected$a + 1L
+ rownames(expected) <- NULL
+ expect_identical(expected, result)
+
+ # Remove the added column
+ df2 <- dapply(
+ df1,
+ function(x) {
+ x[, c("a", "b", "c")]
+ },
+ schema(df))
+ result <- collect(df2)
+ expected <- expected[, c("a", "b", "c")]
+ expect_identical(expected, result)
+})
+
unlink(parquetPath)
unlink(jsonPath)
unlink(jsonPathNa)
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index b6784dbae3..40cda0c5ef 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -84,6 +84,13 @@ broadcastElap <- elapsedSecs()
# as number of partitions to create.
numPartitions <- SparkR:::readInt(inputCon)
+isDataFrame <- as.logical(SparkR:::readInt(inputCon))
+
+# If isDataFrame, then read column names
+if (isDataFrame) {
+ colNames <- SparkR:::readObject(inputCon)
+}
+
isEmpty <- SparkR:::readInt(inputCon)
if (isEmpty != 0) {
@@ -100,7 +107,34 @@ if (isEmpty != 0) {
# Timing reading input data for execution
inputElap <- elapsedSecs()
- output <- computeFunc(partition, data)
+ if (isDataFrame) {
+ if (deserializer == "row") {
+ # Transform the list of rows into a data.frame
+ # Note that the optional argument stringsAsFactors for rbind is
+ # available since R 3.2.4. So we set the global option here.
+ oldOpt <- getOption("stringsAsFactors")
+ options(stringsAsFactors = FALSE)
+ data <- do.call(rbind.data.frame, data)
+ options(stringsAsFactors = oldOpt)
+
+ names(data) <- colNames
+ } else {
+ # Check to see if data is a valid data.frame
+ stopifnot(deserializer == "byte")
+ stopifnot(class(data) == "data.frame")
+ }
+ output <- computeFunc(data)
+ if (serializer == "row") {
+ # Transform the result data.frame back to a list of rows
+ output <- split(output, seq(nrow(output)))
+ } else {
+ # Serialize the ouput to a byte array
+ stopifnot(serializer == "byte")
+ }
+ } else {
+ output <- computeFunc(partition, data)
+ }
+
# Timing computing
computeElap <- elapsedSecs()
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 606ba6ef86..59c8429c80 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -46,7 +46,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
// The parent may be also an RRDD, so we should launch it first.
val parentIterator = firstParent[T].iterator(partition, context)
- runner.compute(parentIterator, partition.index, context)
+ runner.compute(parentIterator, partition.index)
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
index 07d1fa2c4a..24ad689f83 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -38,7 +38,9 @@ private[spark] class RRunner[U](
serializer: String,
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
- numPartitions: Int = -1)
+ numPartitions: Int = -1,
+ isDataFrame: Boolean = false,
+ colNames: Array[String] = null)
extends Logging {
private var bootTime: Double = _
private var dataStream: DataInputStream = _
@@ -53,8 +55,7 @@ private[spark] class RRunner[U](
def compute(
inputIterator: Iterator[_],
- partitionIndex: Int,
- context: TaskContext): Iterator[U] = {
+ partitionIndex: Int): Iterator[U] = {
// Timing start
bootTime = System.currentTimeMillis / 1000.0
@@ -148,6 +149,12 @@ private[spark] class RRunner[U](
dataOut.writeInt(numPartitions)
+ dataOut.writeInt(if (isDataFrame) 1 else 0)
+
+ if (isDataFrame) {
+ SerDe.writeObject(dataOut, colNames)
+ }
+
if (!iter.hasNext) {
dataOut.writeInt(0)
} else {
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 8e4e80a24a..e4932a4192 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -459,7 +459,7 @@ private[spark] object SerDe {
}
-private[r] object SerializationFormats {
+private[spark] object SerializationFormats {
val BYTE = "byte"
val STRING = "string"
val ROW = "row"
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 9a3db9c3f9..a16a6bb1d9 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1147,6 +1147,11 @@ parquetFile <- read.parquet(sqlContext, "people.parquet")
# Parquet files can also be registered as tables and then used in SQL statements.
registerTempTable(parquetFile, "parquetFile")
teenagers <- sql(sqlContext, "SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19")
+schema <- structType(structField("name", "string"))
+teenNames <- dapply(df, function(p) { cbind(paste("Name:", p$name)) }, schema)
+for (teenName in collect(teenNames)$name) {
+ cat(teenName, "\n")
+}
{% endhighlight %}
</div>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 434c033c49..abbd8facd3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -159,10 +159,15 @@ object EliminateSerialization extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case d @ DeserializeToObject(_, _, s: SerializeFromObject)
if d.outputObjectType == s.inputObjectType =>
- // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
- val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
- Project(objAttr :: Nil, s.child)
-
+ // A workaround for SPARK-14803. Remove this after it is fixed.
+ if (d.outputObjectType.isInstanceOf[ObjectType] &&
+ d.outputObjectType.asInstanceOf[ObjectType].cls == classOf[org.apache.spark.sql.Row]) {
+ s.child
+ } else {
+ // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
+ val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
+ Project(objAttr :: Nil, s.child)
+ }
case a @ AppendColumns(_, _, _, s: SerializeFromObject)
if a.deserializer.dataType == s.inputObjectType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 4a1bdb0b8a..84339f439a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -17,11 +17,12 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.Encoder
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types._
object CatalystSerde {
def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
@@ -29,13 +30,26 @@ object CatalystSerde {
DeserializeToObject(deserializer, generateObjAttr[T], child)
}
+ def deserialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): DeserializeToObject = {
+ val deserializer = UnresolvedDeserializer(encoder.deserializer)
+ DeserializeToObject(deserializer, generateObjAttrForRow(encoder), child)
+ }
+
def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
SerializeFromObject(encoderFor[T].namedExpressions, child)
}
+ def serialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): SerializeFromObject = {
+ SerializeFromObject(encoder.namedExpressions, child)
+ }
+
def generateObjAttr[T : Encoder]: Attribute = {
AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)()
}
+
+ def generateObjAttrForRow(encoder: ExpressionEncoder[Row]): Attribute = {
+ AttributeReference("obj", encoder.deserializer.dataType, nullable = false)()
+ }
}
/**
@@ -106,6 +120,42 @@ case class MapPartitions(
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer
+object MapPartitionsInR {
+ def apply(
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ schema: StructType,
+ encoder: ExpressionEncoder[Row],
+ child: LogicalPlan): LogicalPlan = {
+ val deserialized = CatalystSerde.deserialize(child, encoder)
+ val mapped = MapPartitionsInR(
+ func,
+ packageNames,
+ broadcastVars,
+ encoder.schema,
+ schema,
+ CatalystSerde.generateObjAttrForRow(RowEncoder(schema)),
+ deserialized)
+ CatalystSerde.serialize(mapped, RowEncoder(schema))
+ }
+}
+
+/**
+ * A relation produced by applying a serialized R function `func` to each partition of the `child`.
+ *
+ */
+case class MapPartitionsInR(
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ inputSchema: StructType,
+ outputSchema: StructType,
+ outputObjAttr: Attribute,
+ child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer {
+ override lazy val schema = outputSchema
+}
+
object MapElements {
def apply[T : Encoder, U : Encoder](
func: AnyRef,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 860249c211..1439d14980 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
import org.apache.spark.api.python.PythonRDD
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
@@ -1981,6 +1982,23 @@ class Dataset[T] private[sql](
}
/**
+ * Returns a new [[DataFrame]] that contains the result of applying a serialized R function
+ * `func` to each partition.
+ *
+ * @group func
+ */
+ private[sql] def mapPartitionsInR(
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ schema: StructType): DataFrame = {
+ val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]]
+ Dataset.ofRows(
+ sparkSession,
+ MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan))
+ }
+
+ /**
* :: Experimental ::
* (Scala-specific)
* Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 22ded7a4bf..36173a4925 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -23,12 +23,15 @@ import scala.util.matching.Regex
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.r.SerDe
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext}
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
+import org.apache.spark.sql.Encoder
import org.apache.spark.sql.types._
-private[r] object SQLUtils {
+private[sql] object SQLUtils {
SerDe.registerSqlSerDe((readSqlObject, writeSqlObject))
def createSQLContext(jsc: JavaSparkContext): SQLContext = {
@@ -111,7 +114,7 @@ private[r] object SQLUtils {
}
}
- private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = {
+ private[sql] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = {
val bis = new ByteArrayInputStream(bytes)
val dis = new DataInputStream(bis)
val num = SerDe.readInt(dis)
@@ -120,7 +123,7 @@ private[r] object SQLUtils {
}.toSeq)
}
- private[this] def rowToRBytes(row: Row): Array[Byte] = {
+ private[sql] def rowToRBytes(row: Row): Array[Byte] = {
val bos = new ByteArrayOutputStream()
val dos = new DataOutputStream(bos)
@@ -129,6 +132,29 @@ private[r] object SQLUtils {
bos.toByteArray()
}
+ // Schema for DataFrame of serialized R data
+ // TODO: introduce a user defined type for serialized R data.
+ val SERIALIZED_R_DATA_SCHEMA = StructType(Seq(StructField("R", BinaryType)))
+
+ /**
+ * The helper function for dapply() on R side.
+ */
+ def dapply(
+ df: DataFrame,
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Object],
+ schema: StructType): DataFrame = {
+ val bv = broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])
+ val realSchema =
+ if (schema == null) {
+ SERIALIZED_R_DATA_SCHEMA
+ } else {
+ schema
+ }
+ df.mapPartitionsInR(func, packageNames, bv, realSchema)
+ }
+
def dfToCols(df: DataFrame): Array[Array[Any]] = {
val localDF: Array[Row] = df.collect()
val numCols = df.columns.length
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 1eb1f8ef11..238334e26b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -307,6 +307,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.SerializeFromObjectExec(serializer, planLater(child)) :: Nil
case logical.MapPartitions(f, objAttr, child) =>
execution.MapPartitionsExec(f, objAttr, planLater(child)) :: Nil
+ case logical.MapPartitionsInR(f, p, b, is, os, objAttr, child) =>
+ execution.MapPartitionsExec(
+ execution.r.MapPartitionsRWrapper(f, p, b, is, os), objAttr, planLater(child)) :: Nil
case logical.MapElements(f, objAttr, child) =>
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
case logical.AppendColumns(f, in, out, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
new file mode 100644
index 0000000000..dc6f2ef371
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.r
+
+import org.apache.spark.api.r.RRunner
+import org.apache.spark.api.r.SerializationFormats
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.api.r.SQLUtils._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
+
+/**
+ * A function wrapper that applies the given R function to each partition.
+ */
+private[sql] case class MapPartitionsRWrapper(
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ inputSchema: StructType,
+ outputSchema: StructType) extends (Iterator[Any] => Iterator[Any]) {
+ def apply(iter: Iterator[Any]): Iterator[Any] = {
+ // If the content of current DataFrame is serialized R data?
+ val isSerializedRData =
+ if (inputSchema == SERIALIZED_R_DATA_SCHEMA) true else false
+
+ val (newIter, deserializer, colNames) =
+ if (!isSerializedRData) {
+ // Serialize each row into an byte array that can be deserialized in the R worker
+ (iter.asInstanceOf[Iterator[Row]].map {row => rowToRBytes(row)},
+ SerializationFormats.ROW, inputSchema.fieldNames)
+ } else {
+ (iter.asInstanceOf[Iterator[Row]].map { row => row(0) }, SerializationFormats.BYTE, null)
+ }
+
+ val serializer = if (outputSchema != SERIALIZED_R_DATA_SCHEMA) {
+ SerializationFormats.ROW
+ } else {
+ SerializationFormats.BYTE
+ }
+
+ val runner = new RRunner[Array[Byte]](
+ func, deserializer, serializer, packageNames, broadcastVars,
+ isDataFrame = true, colNames = colNames)
+ // Partition index is ignored. Dataset has no support for mapPartitionsWithIndex.
+ val outputIter = runner.compute(newIter, -1)
+
+ if (serializer == SerializationFormats.ROW) {
+ outputIter.map { bytes => bytesToRow(bytes, outputSchema) }
+ } else {
+ outputIter.map { bytes => Row.fromSeq(Seq(bytes)) }
+ }
+ }
+}