aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-09 15:16:47 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-09 15:16:47 -0800
commitfcb57e9c7323e24b8563800deb035f94f616474e (patch)
tree5c368aa1f2feb5bf5a573c4397637d588046f1b9 /sql
parent8a2336893a7ff610a6c4629dd567b85078730616 (diff)
downloadspark-fcb57e9c7323e24b8563800deb035f94f616474e.tar.gz
spark-fcb57e9c7323e24b8563800deb035f94f616474e.tar.bz2
spark-fcb57e9c7323e24b8563800deb035f94f616474e.zip
[SPARK-11564][SQL][FOLLOW-UP] improve java api for GroupedDataset
created `MapGroupFunction`, `FlatMapGroupFunction`, `CoGroupFunction` Author: Wenchen Fan <wenchen@databricks.com> Closes #9564 from cloud-fan/map.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala2
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java36
4 files changed, 31 insertions, 23 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index e151ac04ed..d771088d69 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -527,7 +527,7 @@ case class MapGroups[K, T, U](
/** Factory for constructing new `CoGroup` nodes. */
object CoGroup {
def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder](
- func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
+ func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
left: LogicalPlan,
@@ -551,7 +551,7 @@ object CoGroup {
* right children.
*/
case class CoGroup[K, Left, Right, R](
- func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
+ func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
kEncoder: ExpressionEncoder[K],
leftEnc: ExpressionEncoder[Left],
rightEnc: ExpressionEncoder[Right],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 5c3f626545..850315e281 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -108,9 +108,7 @@ class GroupedDataset[K, T] private[sql](
MapGroups(f, groupingAttributes, logicalPlan))
}
- def flatMap[U](
- f: JFunction2[K, JIterator[T], JIterator[U]],
- encoder: Encoder[U]): Dataset[U] = {
+ def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = {
flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder)
}
@@ -131,9 +129,7 @@ class GroupedDataset[K, T] private[sql](
MapGroups(func, groupingAttributes, logicalPlan))
}
- def map[U](
- f: JFunction2[K, JIterator[T], U],
- encoder: Encoder[U]): Dataset[U] = {
+ def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = {
map((key, data) => f.call(key, data.asJava))(encoder)
}
@@ -218,7 +214,7 @@ class GroupedDataset[K, T] private[sql](
*/
def cogroup[U, R : Encoder](
other: GroupedDataset[K, U])(
- f: (K, Iterator[T], Iterator[U]) => Iterator[R]): Dataset[R] = {
+ f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
implicit def uEnc: Encoder[U] = other.tEncoder
new Dataset[R](
sqlContext,
@@ -232,7 +228,7 @@ class GroupedDataset[K, T] private[sql](
def cogroup[U, R](
other: GroupedDataset[K, U],
- f: JFunction3[K, JIterator[T], JIterator[U], JIterator[R]],
+ f: CoGroupFunction[K, T, U, R],
encoder: Encoder[R]): Dataset[R] = {
cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 2593b16b1c..145de0db9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -391,7 +391,7 @@ case class MapGroups[K, T, U](
* The result of this function is encoded and flattened before being output.
*/
case class CoGroup[K, Left, Right, R](
- func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
+ func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
kEncoder: ExpressionEncoder[K],
leftEnc: ExpressionEncoder[Left],
rightEnc: ExpressionEncoder[Right],
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 0f90de774d..312cf33e4c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -29,7 +29,6 @@ import org.junit.*;
import org.apache.spark.Accumulator;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.function.*;
-import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.catalyst.encoders.Encoder;
import org.apache.spark.sql.catalyst.encoders.Encoder$;
@@ -170,20 +169,33 @@ public class JavaDatasetSuite implements Serializable {
}
}, e.INT());
- Dataset<String> mapped = grouped.map(
- new Function2<Integer, Iterator<String>, String>() {
+ Dataset<String> mapped = grouped.map(new MapGroupFunction<Integer, String, String>() {
+ @Override
+ public String call(Integer key, Iterator<String> values) throws Exception {
+ StringBuilder sb = new StringBuilder(key.toString());
+ while (values.hasNext()) {
+ sb.append(values.next());
+ }
+ return sb.toString();
+ }
+ }, e.STRING());
+
+ Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
+
+ Dataset<String> flatMapped = grouped.flatMap(
+ new FlatMapGroupFunction<Integer, String, String>() {
@Override
- public String call(Integer key, Iterator<String> data) throws Exception {
+ public Iterable<String> call(Integer key, Iterator<String> values) throws Exception {
StringBuilder sb = new StringBuilder(key.toString());
- while (data.hasNext()) {
- sb.append(data.next());
+ while (values.hasNext()) {
+ sb.append(values.next());
}
- return sb.toString();
+ return Collections.singletonList(sb.toString());
}
},
e.STRING());
- Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
+ Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList());
List<Integer> data2 = Arrays.asList(2, 6, 10);
Dataset<Integer> ds2 = context.createDataset(data2, e.INT());
@@ -196,9 +208,9 @@ public class JavaDatasetSuite implements Serializable {
Dataset<String> cogrouped = grouped.cogroup(
grouped2,
- new Function3<Integer, Iterator<String>, Iterator<Integer>, Iterator<String>>() {
+ new CoGroupFunction<Integer, String, Integer, String>() {
@Override
- public Iterator<String> call(
+ public Iterable<String> call(
Integer key,
Iterator<String> left,
Iterator<Integer> right) throws Exception {
@@ -210,7 +222,7 @@ public class JavaDatasetSuite implements Serializable {
while (right.hasNext()) {
sb.append(right.next());
}
- return Collections.singletonList(sb.toString()).iterator();
+ return Collections.singletonList(sb.toString());
}
},
e.STRING());
@@ -225,7 +237,7 @@ public class JavaDatasetSuite implements Serializable {
GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).asKey(e.INT());
Dataset<String> mapped = grouped.map(
- new Function2<Integer, Iterator<String>, String>() {
+ new MapGroupFunction<Integer, String, String>() {
@Override
public String call(Integer key, Iterator<String> data) throws Exception {
StringBuilder sb = new StringBuilder(key.toString());