diff options
author | Narine Kokhlikyan <narine.kokhlikyan@gmail.com> | 2016-06-15 21:42:05 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@cs.berkeley.edu> | 2016-06-15 21:42:05 -0700 |
commit | 7c6c6926376c93acc42dd56a399d816f4838f28c (patch) | |
tree | bbf8f9dc1d7a044b890b6c95fdd3a17aa76fea89 /core/src | |
parent | b75f454f946714b93fe561055cd53b0686187d2e (diff) | |
download | spark-7c6c6926376c93acc42dd56a399d816f4838f28c.tar.gz spark-7c6c6926376c93acc42dd56a399d816f4838f28c.tar.bz2 spark-7c6c6926376c93acc42dd56a399d816f4838f28c.zip |
[SPARK-12922][SPARKR][WIP] Implement gapply() on DataFrame in SparkR
## What changes were proposed in this pull request?
gapply() applies an R function on groups grouped by one or more columns of a DataFrame, and returns a DataFrame. It is like GroupedDataSet.flatMapGroups() in the Dataset API.
Please, let me know what do you think and if you have any ideas to improve it.
Thank you!
## How was this patch tested?
Unit tests.
1. Primitive test with different column types
2. Add a boolean column
3. Compute average by a group
Author: Narine Kokhlikyan <narine.kokhlikyan@gmail.com>
Author: NarineK <narine.kokhlikyan@us.ibm.com>
Closes #12836 from NarineK/gapply2.
Diffstat (limited to 'core/src')
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/r/RRunner.scala | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 24ad689f83..496fdf851f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -40,7 +40,8 @@ private[spark] class RRunner[U]( broadcastVars: Array[Broadcast[Object]], numPartitions: Int = -1, isDataFrame: Boolean = false, - colNames: Array[String] = null) + colNames: Array[String] = null, + mode: Int = RRunnerModes.RDD) extends Logging { private var bootTime: Double = _ private var dataStream: DataInputStream = _ @@ -148,8 +149,7 @@ private[spark] class RRunner[U]( } dataOut.writeInt(numPartitions) - - dataOut.writeInt(if (isDataFrame) 1 else 0) + dataOut.writeInt(mode) if (isDataFrame) { SerDe.writeObject(dataOut, colNames) @@ -180,6 +180,13 @@ private[spark] class RRunner[U]( for (elem <- iter) { elem match { + case (key, innerIter: Iterator[_]) => + for (innerElem <- innerIter) { + writeElem(innerElem) + } + // Writes key which can be used as a boundary in group-aggregate + dataOut.writeByte('r') + writeElem(key) case (key, value) => writeElem(key) writeElem(value) @@ -187,6 +194,7 @@ private[spark] class RRunner[U]( writeElem(elem) } } + stream.flush() } catch { // TODO: We should propagate this error to the task thread @@ -268,6 +276,12 @@ private object SpecialLengths { val TIMING_DATA = -1 } +private[spark] object RRunnerModes { + val RDD = 0 + val DATAFRAME_DAPPLY = 1 + val DATAFRAME_GAPPLY = 2 +} + private[r] class BufferedStreamThread( in: InputStream, name: String, |