diff options
author | Reynold Xin <rxin@databricks.com> | 2015-11-08 20:57:09 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-11-08 20:57:09 -0800 |
commit | 97b7080cf2d2846c7257f8926f775f27d457fe7d (patch) | |
tree | 28efd3ca15c2e96c0d4f0b5d08cabb9e602ef12e /sql/core/src/test/java | |
parent | b2d195e137fad88d567974659fa7023ff4da96cd (diff) | |
download | spark-97b7080cf2d2846c7257f8926f775f27d457fe7d.tar.gz spark-97b7080cf2d2846c7257f8926f775f27d457fe7d.tar.bz2 spark-97b7080cf2d2846c7257f8926f775f27d457fe7d.zip |
[SPARK-11564][SQL] Dataset Java API audit
A few changes:
1. Removed fold, since it can be confusing for distributed collections.
2. Created specific interfaces for each Dataset function (e.g. MapFunction, ReduceFunction, MapPartitionsFunction)
3. Added more documentation and test cases.
The other thing I'm considering doing is to have a "collector" interface for FlatMapFunction and MapPartitionsFunction, similar to MapReduce's map function.
Author: Reynold Xin <rxin@databricks.com>
Closes #9531 from rxin/SPARK-11564.
Diffstat (limited to 'sql/core/src/test/java')
-rw-r--r-- | sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java | 7 | ||||
-rw-r--r-- | sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java | 36 |
2 files changed, 25 insertions, 18 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 40bff57a17..d191b50fa8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -65,6 +65,13 @@ public class JavaDataFrameSuite { Assert.assertEquals(1, df.select("key").collect()[0].get(0)); } + @Test + public void testCollectAndTake() { + DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); + Assert.assertEquals(3, df.select("key").collectAsList().size()); + Assert.assertEquals(2, df.select("key").takeAsList(2).size()); + } + /** * See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java. */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 0d3b1a5af5..0f90de774d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -68,8 +68,16 @@ public class JavaDatasetSuite implements Serializable { public void testCollect() { List<String> data = Arrays.asList("hello", "world"); Dataset<String> ds = context.createDataset(data, e.STRING()); - String[] collected = (String[]) ds.collect(); - Assert.assertEquals(Arrays.asList("hello", "world"), Arrays.asList(collected)); + List<String> collected = ds.collectAsList(); + Assert.assertEquals(Arrays.asList("hello", "world"), collected); + } + + @Test + public void testTake() { + List<String> data = Arrays.asList("hello", "world"); + Dataset<String> ds = context.createDataset(data, e.STRING()); + List<String> collected = ds.takeAsList(1); + Assert.assertEquals(Arrays.asList("hello"), collected); } @Test @@ -78,16 +86,16 @@ public class JavaDatasetSuite implements Serializable { Dataset<String> ds = context.createDataset(data, e.STRING()); Assert.assertEquals("hello", ds.first()); - Dataset<String> filtered = ds.filter(new Function<String, Boolean>() { + Dataset<String> filtered = ds.filter(new FilterFunction<String>() { @Override - public Boolean call(String v) throws Exception { + public boolean call(String v) throws Exception { return v.startsWith("h"); } }); Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList()); - Dataset<Integer> mapped = ds.map(new Function<String, Integer>() { + Dataset<Integer> mapped = ds.map(new MapFunction<String, Integer>() { @Override public Integer call(String v) throws Exception { return v.length(); @@ -95,7 +103,7 @@ public class JavaDatasetSuite implements Serializable { }, e.INT()); Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); - Dataset<String> parMapped = ds.mapPartitions(new FlatMapFunction<Iterator<String>, String>() { + Dataset<String> parMapped = ds.mapPartitions(new MapPartitionsFunction<String, String>() { @Override public Iterable<String> call(Iterator<String> it) throws Exception { List<String> ls = new LinkedList<String>(); @@ -128,7 +136,7 @@ public class JavaDatasetSuite implements Serializable { List<String> data = Arrays.asList("a", "b", "c"); Dataset<String> ds = context.createDataset(data, e.STRING()); - ds.foreach(new VoidFunction<String>() { + ds.foreach(new ForeachFunction<String>() { @Override public void call(String s) throws Exception { accum.add(1); @@ -142,28 +150,20 @@ public class JavaDatasetSuite implements Serializable { List<Integer> data = Arrays.asList(1, 2, 3); Dataset<Integer> ds = context.createDataset(data, e.INT()); - int reduced = ds.reduce(new Function2<Integer, Integer, Integer>() { + int reduced = ds.reduce(new ReduceFunction<Integer>() { @Override public Integer call(Integer v1, Integer v2) throws Exception { return v1 + v2; } }); Assert.assertEquals(6, reduced); - - int folded = ds.fold(1, new Function2<Integer, Integer, Integer>() { - @Override - public Integer call(Integer v1, Integer v2) throws Exception { - return v1 * v2; - } - }); - Assert.assertEquals(6, folded); } @Test public void testGroupBy() { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = context.createDataset(data, e.STRING()); - GroupedDataset<Integer, String> grouped = ds.groupBy(new Function<String, Integer>() { + GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() { @Override public Integer call(String v) throws Exception { return v.length(); @@ -187,7 +187,7 @@ public class JavaDatasetSuite implements Serializable { List<Integer> data2 = Arrays.asList(2, 6, 10); Dataset<Integer> ds2 = context.createDataset(data2, e.INT()); - GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new Function<Integer, Integer>() { + GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() { @Override public Integer call(Integer v) throws Exception { return v / 2; |