diff options
author | Wenchen Fan <wenchen@databricks.com> | 2015-11-09 15:16:47 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-11-09 15:16:47 -0800 |
commit | fcb57e9c7323e24b8563800deb035f94f616474e (patch) | |
tree | 5c368aa1f2feb5bf5a573c4397637d588046f1b9 /sql/core/src/test | |
parent | 8a2336893a7ff610a6c4629dd567b85078730616 (diff) | |
download | spark-fcb57e9c7323e24b8563800deb035f94f616474e.tar.gz spark-fcb57e9c7323e24b8563800deb035f94f616474e.tar.bz2 spark-fcb57e9c7323e24b8563800deb035f94f616474e.zip |
[SPARK-11564][SQL][FOLLOW-UP] improve java api for GroupedDataset
created `MapGroupFunction`, `FlatMapGroupFunction`, `CoGroupFunction`
Author: Wenchen Fan <wenchen@databricks.com>
Closes #9564 from cloud-fan/map.
Diffstat (limited to 'sql/core/src/test')
-rw-r--r-- | sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java | 36 |
1 files changed, 24 insertions, 12 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 0f90de774d..312cf33e4c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -29,7 +29,6 @@ import org.junit.*; import org.apache.spark.Accumulator; import org.apache.spark.SparkContext; import org.apache.spark.api.java.function.*; -import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.catalyst.encoders.Encoder; import org.apache.spark.sql.catalyst.encoders.Encoder$; @@ -170,20 +169,33 @@ public class JavaDatasetSuite implements Serializable { } }, e.INT()); - Dataset<String> mapped = grouped.map( - new Function2<Integer, Iterator<String>, String>() { + Dataset<String> mapped = grouped.map(new MapGroupFunction<Integer, String, String>() { + @Override + public String call(Integer key, Iterator<String> values) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return sb.toString(); + } + }, e.STRING()); + + Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + + Dataset<String> flatMapped = grouped.flatMap( + new FlatMapGroupFunction<Integer, String, String>() { @Override - public String call(Integer key, Iterator<String> data) throws Exception { + public Iterable<String> call(Integer key, Iterator<String> values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); - while (data.hasNext()) { - sb.append(data.next()); + while (values.hasNext()) { + sb.append(values.next()); } - return sb.toString(); + return Collections.singletonList(sb.toString()); } }, e.STRING()); - Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList()); List<Integer> data2 = Arrays.asList(2, 6, 10); Dataset<Integer> ds2 = context.createDataset(data2, e.INT()); @@ -196,9 +208,9 @@ public class JavaDatasetSuite implements Serializable { Dataset<String> cogrouped = grouped.cogroup( grouped2, - new Function3<Integer, Iterator<String>, Iterator<Integer>, Iterator<String>>() { + new CoGroupFunction<Integer, String, Integer, String>() { @Override - public Iterator<String> call( + public Iterable<String> call( Integer key, Iterator<String> left, Iterator<Integer> right) throws Exception { @@ -210,7 +222,7 @@ public class JavaDatasetSuite implements Serializable { while (right.hasNext()) { sb.append(right.next()); } - return Collections.singletonList(sb.toString()).iterator(); + return Collections.singletonList(sb.toString()); } }, e.STRING()); @@ -225,7 +237,7 @@ public class JavaDatasetSuite implements Serializable { GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).asKey(e.INT()); Dataset<String> mapped = grouped.map( - new Function2<Integer, Iterator<String>, String>() { + new MapGroupFunction<Integer, String, String>() { @Override public String call(Integer key, Iterator<String> data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); |