aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorNarine Kokhlikyan <narine.kokhlikyan@gmail.com>2016-06-15 21:42:05 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2016-06-15 21:42:05 -0700
commit7c6c6926376c93acc42dd56a399d816f4838f28c (patch)
treebbf8f9dc1d7a044b890b6c95fdd3a17aa76fea89 /core/src
parentb75f454f946714b93fe561055cd53b0686187d2e (diff)
downloadspark-7c6c6926376c93acc42dd56a399d816f4838f28c.tar.gz
spark-7c6c6926376c93acc42dd56a399d816f4838f28c.tar.bz2
spark-7c6c6926376c93acc42dd56a399d816f4838f28c.zip
[SPARK-12922][SPARKR][WIP] Implement gapply() on DataFrame in SparkR
## What changes were proposed in this pull request? gapply() applies an R function on groups grouped by one or more columns of a DataFrame, and returns a DataFrame. It is like GroupedDataSet.flatMapGroups() in the Dataset API. Please, let me know what do you think and if you have any ideas to improve it. Thank you! ## How was this patch tested? Unit tests. 1. Primitive test with different column types 2. Add a boolean column 3. Compute average by a group Author: Narine Kokhlikyan <narine.kokhlikyan@gmail.com> Author: NarineK <narine.kokhlikyan@us.ibm.com> Closes #12836 from NarineK/gapply2.
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RRunner.scala20
1 files changed, 17 insertions, 3 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
index 24ad689f83..496fdf851f 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -40,7 +40,8 @@ private[spark] class RRunner[U](
broadcastVars: Array[Broadcast[Object]],
numPartitions: Int = -1,
isDataFrame: Boolean = false,
- colNames: Array[String] = null)
+ colNames: Array[String] = null,
+ mode: Int = RRunnerModes.RDD)
extends Logging {
private var bootTime: Double = _
private var dataStream: DataInputStream = _
@@ -148,8 +149,7 @@ private[spark] class RRunner[U](
}
dataOut.writeInt(numPartitions)
-
- dataOut.writeInt(if (isDataFrame) 1 else 0)
+ dataOut.writeInt(mode)
if (isDataFrame) {
SerDe.writeObject(dataOut, colNames)
@@ -180,6 +180,13 @@ private[spark] class RRunner[U](
for (elem <- iter) {
elem match {
+ case (key, innerIter: Iterator[_]) =>
+ for (innerElem <- innerIter) {
+ writeElem(innerElem)
+ }
+ // Writes key which can be used as a boundary in group-aggregate
+ dataOut.writeByte('r')
+ writeElem(key)
case (key, value) =>
writeElem(key)
writeElem(value)
@@ -187,6 +194,7 @@ private[spark] class RRunner[U](
writeElem(elem)
}
}
+
stream.flush()
} catch {
// TODO: We should propagate this error to the task thread
@@ -268,6 +276,12 @@ private object SpecialLengths {
val TIMING_DATA = -1
}
+private[spark] object RRunnerModes {
+ val RDD = 0
+ val DATAFRAME_DAPPLY = 1
+ val DATAFRAME_GAPPLY = 2
+}
+
private[r] class BufferedStreamThread(
in: InputStream,
name: String,