diff options
author | Wenchen Fan <wenchen@databricks.com> | 2015-11-09 15:16:47 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-11-09 15:16:47 -0800 |
commit | fcb57e9c7323e24b8563800deb035f94f616474e (patch) | |
tree | 5c368aa1f2feb5bf5a573c4397637d588046f1b9 /sql/core/src/main | |
parent | 8a2336893a7ff610a6c4629dd567b85078730616 (diff) | |
download | spark-fcb57e9c7323e24b8563800deb035f94f616474e.tar.gz spark-fcb57e9c7323e24b8563800deb035f94f616474e.tar.bz2 spark-fcb57e9c7323e24b8563800deb035f94f616474e.zip |
[SPARK-11564][SQL][FOLLOW-UP] improve java api for GroupedDataset
created `MapGroupFunction`, `FlatMapGroupFunction`, `CoGroupFunction`
Author: Wenchen Fan <wenchen@databricks.com>
Closes #9564 from cloud-fan/map.
Diffstat (limited to 'sql/core/src/main')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala | 12 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala | 2 |
2 files changed, 5 insertions, 9 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 5c3f626545..850315e281 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -108,9 +108,7 @@ class GroupedDataset[K, T] private[sql]( MapGroups(f, groupingAttributes, logicalPlan)) } - def flatMap[U]( - f: JFunction2[K, JIterator[T], JIterator[U]], - encoder: Encoder[U]): Dataset[U] = { + def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder) } @@ -131,9 +129,7 @@ class GroupedDataset[K, T] private[sql]( MapGroups(func, groupingAttributes, logicalPlan)) } - def map[U]( - f: JFunction2[K, JIterator[T], U], - encoder: Encoder[U]): Dataset[U] = { + def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { map((key, data) => f.call(key, data.asJava))(encoder) } @@ -218,7 +214,7 @@ class GroupedDataset[K, T] private[sql]( */ def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( - f: (K, Iterator[T], Iterator[U]) => Iterator[R]): Dataset[R] = { + f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { implicit def uEnc: Encoder[U] = other.tEncoder new Dataset[R]( sqlContext, @@ -232,7 +228,7 @@ class GroupedDataset[K, T] private[sql]( def cogroup[U, R]( other: GroupedDataset[K, U], - f: JFunction3[K, JIterator[T], JIterator[U], JIterator[R]], + f: CoGroupFunction[K, T, U, R], encoder: Encoder[R]): Dataset[R] = { cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 2593b16b1c..145de0db9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -391,7 +391,7 @@ case class MapGroups[K, T, U]( * The result of this function is encoded and flattened before being output. */ case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], kEncoder: ExpressionEncoder[K], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], |