aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/java
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-11-08 20:57:09 -0800
committerReynold Xin <rxin@databricks.com>2015-11-08 20:57:09 -0800
commit97b7080cf2d2846c7257f8926f775f27d457fe7d (patch)
tree28efd3ca15c2e96c0d4f0b5d08cabb9e602ef12e /sql/core/src/test/java
parentb2d195e137fad88d567974659fa7023ff4da96cd (diff)
downloadspark-97b7080cf2d2846c7257f8926f775f27d457fe7d.tar.gz
spark-97b7080cf2d2846c7257f8926f775f27d457fe7d.tar.bz2
spark-97b7080cf2d2846c7257f8926f775f27d457fe7d.zip
[SPARK-11564][SQL] Dataset Java API audit
A few changes: 1. Removed fold, since it can be confusing for distributed collections. 2. Created specific interfaces for each Dataset function (e.g. MapFunction, ReduceFunction, MapPartitionsFunction) 3. Added more documentation and test cases. The other thing I'm considering doing is to have a "collector" interface for FlatMapFunction and MapPartitionsFunction, similar to MapReduce's map function. Author: Reynold Xin <rxin@databricks.com> Closes #9531 from rxin/SPARK-11564.
Diffstat (limited to 'sql/core/src/test/java')
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java7
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java36
2 files changed, 25 insertions, 18 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 40bff57a17..d191b50fa8 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -65,6 +65,13 @@ public class JavaDataFrameSuite {
Assert.assertEquals(1, df.select("key").collect()[0].get(0));
}
+ @Test
+ public void testCollectAndTake() {
+ DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3");
+ Assert.assertEquals(3, df.select("key").collectAsList().size());
+ Assert.assertEquals(2, df.select("key").takeAsList(2).size());
+ }
+
/**
* See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java.
*/
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 0d3b1a5af5..0f90de774d 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -68,8 +68,16 @@ public class JavaDatasetSuite implements Serializable {
public void testCollect() {
List<String> data = Arrays.asList("hello", "world");
Dataset<String> ds = context.createDataset(data, e.STRING());
- String[] collected = (String[]) ds.collect();
- Assert.assertEquals(Arrays.asList("hello", "world"), Arrays.asList(collected));
+ List<String> collected = ds.collectAsList();
+ Assert.assertEquals(Arrays.asList("hello", "world"), collected);
+ }
+
+ @Test
+ public void testTake() {
+ List<String> data = Arrays.asList("hello", "world");
+ Dataset<String> ds = context.createDataset(data, e.STRING());
+ List<String> collected = ds.takeAsList(1);
+ Assert.assertEquals(Arrays.asList("hello"), collected);
}
@Test
@@ -78,16 +86,16 @@ public class JavaDatasetSuite implements Serializable {
Dataset<String> ds = context.createDataset(data, e.STRING());
Assert.assertEquals("hello", ds.first());
- Dataset<String> filtered = ds.filter(new Function<String, Boolean>() {
+ Dataset<String> filtered = ds.filter(new FilterFunction<String>() {
@Override
- public Boolean call(String v) throws Exception {
+ public boolean call(String v) throws Exception {
return v.startsWith("h");
}
});
Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList());
- Dataset<Integer> mapped = ds.map(new Function<String, Integer>() {
+ Dataset<Integer> mapped = ds.map(new MapFunction<String, Integer>() {
@Override
public Integer call(String v) throws Exception {
return v.length();
@@ -95,7 +103,7 @@ public class JavaDatasetSuite implements Serializable {
}, e.INT());
Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList());
- Dataset<String> parMapped = ds.mapPartitions(new FlatMapFunction<Iterator<String>, String>() {
+ Dataset<String> parMapped = ds.mapPartitions(new MapPartitionsFunction<String, String>() {
@Override
public Iterable<String> call(Iterator<String> it) throws Exception {
List<String> ls = new LinkedList<String>();
@@ -128,7 +136,7 @@ public class JavaDatasetSuite implements Serializable {
List<String> data = Arrays.asList("a", "b", "c");
Dataset<String> ds = context.createDataset(data, e.STRING());
- ds.foreach(new VoidFunction<String>() {
+ ds.foreach(new ForeachFunction<String>() {
@Override
public void call(String s) throws Exception {
accum.add(1);
@@ -142,28 +150,20 @@ public class JavaDatasetSuite implements Serializable {
List<Integer> data = Arrays.asList(1, 2, 3);
Dataset<Integer> ds = context.createDataset(data, e.INT());
- int reduced = ds.reduce(new Function2<Integer, Integer, Integer>() {
+ int reduced = ds.reduce(new ReduceFunction<Integer>() {
@Override
public Integer call(Integer v1, Integer v2) throws Exception {
return v1 + v2;
}
});
Assert.assertEquals(6, reduced);
-
- int folded = ds.fold(1, new Function2<Integer, Integer, Integer>() {
- @Override
- public Integer call(Integer v1, Integer v2) throws Exception {
- return v1 * v2;
- }
- });
- Assert.assertEquals(6, folded);
}
@Test
public void testGroupBy() {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, e.STRING());
- GroupedDataset<Integer, String> grouped = ds.groupBy(new Function<String, Integer>() {
+ GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() {
@Override
public Integer call(String v) throws Exception {
return v.length();
@@ -187,7 +187,7 @@ public class JavaDatasetSuite implements Serializable {
List<Integer> data2 = Arrays.asList(2, 6, 10);
Dataset<Integer> ds2 = context.createDataset(data2, e.INT());
- GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new Function<Integer, Integer>() {
+ GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() {
@Override
public Integer call(Integer v) throws Exception {
return v / 2;