aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorNarine Kokhlikyan <narine.kokhlikyan@gmail.com>2016-06-15 21:42:05 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2016-06-15 21:42:05 -0700
commit7c6c6926376c93acc42dd56a399d816f4838f28c (patch)
treebbf8f9dc1d7a044b890b6c95fdd3a17aa76fea89 /sql
parentb75f454f946714b93fe561055cd53b0686187d2e (diff)
downloadspark-7c6c6926376c93acc42dd56a399d816f4838f28c.tar.gz
spark-7c6c6926376c93acc42dd56a399d816f4838f28c.tar.bz2
spark-7c6c6926376c93acc42dd56a399d816f4838f28c.zip
[SPARK-12922][SPARKR][WIP] Implement gapply() on DataFrame in SparkR
## What changes were proposed in this pull request? gapply() applies an R function on groups grouped by one or more columns of a DataFrame, and returns a DataFrame. It is like GroupedDataSet.flatMapGroups() in the Dataset API. Please, let me know what do you think and if you have any ideas to improve it. Thank you! ## How was this patch tested? Unit tests. 1. Primitive test with different column types 2. Add a boolean column 3. Compute average by a group Author: Narine Kokhlikyan <narine.kokhlikyan@gmail.com> Author: NarineK <narine.kokhlikyan@us.ibm.com> Closes #12836 from NarineK/gapply2.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala49
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala48
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala72
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala5
6 files changed, 190 insertions, 13 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 78e8822b64..7beeeb4f04 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -246,6 +246,55 @@ case class MapGroups(
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectProducer
+/** Factory for constructing new `FlatMapGroupsInR` nodes. */
+object FlatMapGroupsInR {
+ def apply(
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ schema: StructType,
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
+ inputSchema: StructType,
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ child: LogicalPlan): LogicalPlan = {
+ val mapped = FlatMapGroupsInR(
+ func,
+ packageNames,
+ broadcastVars,
+ inputSchema,
+ schema,
+ UnresolvedDeserializer(keyDeserializer, groupingAttributes),
+ UnresolvedDeserializer(valueDeserializer, dataAttributes),
+ groupingAttributes,
+ dataAttributes,
+ CatalystSerde.generateObjAttr(RowEncoder(schema)),
+ child)
+ CatalystSerde.serialize(mapped)(RowEncoder(schema))
+ }
+}
+
+case class FlatMapGroupsInR(
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ inputSchema: StructType,
+ outputSchema: StructType,
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ outputObjAttr: Attribute,
+ child: LogicalPlan) extends UnaryNode with ObjectProducer{
+
+ override lazy val schema = outputSchema
+
+ override protected def stringArgs: Iterator[Any] = Iterator(inputSchema, outputSchema,
+ keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr,
+ child)
+}
+
/** Factory for constructing new `CoGroup` nodes. */
object CoGroup {
def apply[K : Encoder, L : Encoder, R : Encoder, OUT : Encoder](
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 49b6eab8db..1aa5767038 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -20,14 +20,18 @@ package org.apache.spark.sql
import scala.collection.JavaConverters._
import scala.language.implicitConversions
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.api.r.SQLUtils._
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, FlatMapGroupsInR, Pivot}
import org.apache.spark.sql.catalyst.util.usePrettyExpression
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.NumericType
+import org.apache.spark.sql.types.StructType
/**
* A set of methods for aggregations on a [[DataFrame]], created by [[Dataset.groupBy]].
@@ -381,6 +385,48 @@ class RelationalGroupedDataset protected[sql](
def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = {
pivot(pivotColumn, values.asScala)
}
+
+ /**
+ * Applies the given serialized R function `func` to each group of data. For each unique group,
+ * the function will be passed the group key and an iterator that contains all of the elements in
+ * the group. The function can return an iterator containing elements of an arbitrary type which
+ * will be returned as a new [[DataFrame]].
+ *
+ * This function does not support partial aggregation, and as a result requires shuffling all
+ * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+ * key, it is best to use the reduce function or an
+ * [[org.apache.spark.sql.expressions#Aggregator Aggregator]].
+ *
+ * Internally, the implementation will spill to disk if any given group is too large to fit into
+ * memory. However, users must take care to avoid materializing the whole iterator for a group
+ * (for example, by calling `toList`) unless they are sure that this is possible given the memory
+ * constraints of their cluster.
+ *
+ * @since 2.0.0
+ */
+ private[sql] def flatMapGroupsInR(
+ f: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ outputSchema: StructType): DataFrame = {
+ val groupingNamedExpressions = groupingExprs.map(alias)
+ val groupingCols = groupingNamedExpressions.map(Column(_))
+ val groupingDataFrame = df.select(groupingCols : _*)
+ val groupingAttributes = groupingNamedExpressions.map(_.toAttribute)
+ Dataset.ofRows(
+ df.sparkSession,
+ FlatMapGroupsInR(
+ f,
+ packageNames,
+ broadcastVars,
+ outputSchema,
+ groupingDataFrame.exprEnc.deserializer,
+ df.exprEnc.deserializer,
+ df.exprEnc.schema,
+ groupingAttributes,
+ df.logicalPlan.output,
+ df.logicalPlan))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 486a440b6f..fe426fa3c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -26,7 +26,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.r.SerDe
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext}
+import org.apache.spark.sql.{DataFrame, RelationalGroupedDataset, Row, SaveMode, SQLContext}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
@@ -146,16 +146,26 @@ private[sql] object SQLUtils {
packageNames: Array[Byte],
broadcastVars: Array[Object],
schema: StructType): DataFrame = {
- val bv = broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])
- val realSchema =
- if (schema == null) {
- SERIALIZED_R_DATA_SCHEMA
- } else {
- schema
- }
+ val bv = broadcastVars.map(_.asInstanceOf[Broadcast[Object]])
+ val realSchema = if (schema == null) SERIALIZED_R_DATA_SCHEMA else schema
df.mapPartitionsInR(func, packageNames, bv, realSchema)
}
+ /**
+ * The helper function for gapply() on R side.
+ */
+ def gapply(
+ gd: RelationalGroupedDataset,
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Object],
+ schema: StructType): DataFrame = {
+ val bv = broadcastVars.map(_.asInstanceOf[Broadcast[Object]])
+ val realSchema = if (schema == null) SERIALIZED_R_DATA_SCHEMA else schema
+ gd.flatMapGroupsInR(func, packageNames, bv, realSchema)
+ }
+
+
def dfToCols(df: DataFrame): Array[Array[Any]] = {
val localDF: Array[Row] = df.collect()
val numCols = df.columns.length
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 60466e2830..8e2f2ed4f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -337,6 +337,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.MapPartitionsInR(f, p, b, is, os, objAttr, child) =>
execution.MapPartitionsExec(
execution.r.MapPartitionsRWrapper(f, p, b, is, os), objAttr, planLater(child)) :: Nil
+ case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) =>
+ execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping,
+ data, objAttr, planLater(child)) :: Nil
case logical.MapElements(f, objAttr, child) =>
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
case logical.AppendColumns(f, in, out, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 5fced940b3..c7e267152b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -20,13 +20,17 @@ package org.apache.spark.sql.execution
import scala.language.existentials
import org.apache.spark.api.java.function.MapFunction
+import org.apache.spark.api.r._
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.api.r.SQLUtils._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.types.{DataType, ObjectType}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
/**
@@ -325,6 +329,72 @@ case class MapGroupsExec(
}
/**
+ * Groups the input rows together and calls the R function with each group and an iterator
+ * containing all elements in the group.
+ * The result of this function is flattened before being output.
+ */
+case class FlatMapGroupsInRExec(
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ inputSchema: StructType,
+ outputSchema: StructType,
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ outputObjAttr: Attribute,
+ child: SparkPlan) extends UnaryExecNode with ObjectProducerExec {
+
+ override def output: Seq[Attribute] = outputObjAttr :: Nil
+ override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(groupingAttributes) :: Nil
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ Seq(groupingAttributes.map(SortOrder(_, Ascending)))
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val isSerializedRData =
+ if (outputSchema == SERIALIZED_R_DATA_SCHEMA) true else false
+ val serializerForR = if (!isSerializedRData) {
+ SerializationFormats.ROW
+ } else {
+ SerializationFormats.BYTE
+ }
+
+ child.execute().mapPartitionsInternal { iter =>
+ val grouped = GroupedIterator(iter, groupingAttributes, child.output)
+ val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
+ val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
+ val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
+ val runner = new RRunner[Array[Byte]](
+ func, SerializationFormats.ROW, serializerForR, packageNames, broadcastVars,
+ isDataFrame = true, colNames = inputSchema.fieldNames,
+ mode = RRunnerModes.DATAFRAME_GAPPLY)
+
+ val groupedRBytes = grouped.map { case (key, rowIter) =>
+ val deserializedIter = rowIter.map(getValue)
+ val newIter =
+ deserializedIter.asInstanceOf[Iterator[Row]].map { row => rowToRBytes(row) }
+ val newKey = rowToRBytes(getKey(key).asInstanceOf[Row])
+ (newKey, newIter)
+ }
+
+ val outputIter = runner.compute(groupedRBytes, -1)
+ if (!isSerializedRData) {
+ val result = outputIter.map { bytes => bytesToRow(bytes, outputSchema) }
+ result.map(outputObject)
+ } else {
+ val result = outputIter.map { bytes => Row.fromSeq(Seq(bytes)) }
+ result.map(outputObject)
+ }
+ }
+ }
+}
+
+/**
* Co-groups the data from left and right children, and calls the function with each group and 2
* iterators containing all elements in the group from left and right side.
* The result of this function is flattened before being output.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
index 6c76328c74..70539da348 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.execution.r
-import org.apache.spark.api.r.RRunner
-import org.apache.spark.api.r.SerializationFormats
+import org.apache.spark.api.r._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.api.r.SQLUtils._
import org.apache.spark.sql.Row
@@ -55,7 +54,7 @@ private[sql] case class MapPartitionsRWrapper(
val runner = new RRunner[Array[Byte]](
func, deserializer, serializer, packageNames, broadcastVars,
- isDataFrame = true, colNames = colNames)
+ isDataFrame = true, colNames = colNames, mode = RRunnerModes.DATAFRAME_DAPPLY)
// Partition index is ignored. Dataset has no support for mapPartitionsWithIndex.
val outputIter = runner.compute(newIter, -1)