aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/java
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-03-10 17:00:17 -0800
committerYin Huai <yhuai@databricks.com>2016-03-10 17:00:17 -0800
commit1d542785b9949e7f92025e6754973a779cc37c52 (patch)
treeceda7492e40c9d9a9231a5011c91e30bf0b1f390 /sql/core/src/test/java
parent27fe6bacc532184ef6e8a2a24cd07f2c9188004e (diff)
downloadspark-1d542785b9949e7f92025e6754973a779cc37c52.tar.gz
spark-1d542785b9949e7f92025e6754973a779cc37c52.tar.bz2
spark-1d542785b9949e7f92025e6754973a779cc37c52.zip
[SPARK-13244][SQL] Migrates DataFrame to Dataset
## What changes were proposed in this pull request? This PR unifies DataFrame and Dataset by migrating existing DataFrame operations to Dataset and make `DataFrame` a type alias of `Dataset[Row]`. Most Scala code changes are source compatible, but Java API is broken as Java knows nothing about Scala type alias (mostly replacing `DataFrame` with `Dataset<Row>`). There are several noticeable API changes related to those returning arrays: 1. `collect`/`take` - Old APIs in class `DataFrame`: ```scala def collect(): Array[Row] def take(n: Int): Array[Row] ``` - New APIs in class `Dataset[T]`: ```scala def collect(): Array[T] def take(n: Int): Array[T] def collectRows(): Array[Row] def takeRows(n: Int): Array[Row] ``` Two specialized methods `collectRows` and `takeRows` are added because Java doesn't support returning generic arrays. Thus, for example, `DataFrame.collect(): Array[T]` actually returns `Object` instead of `Array<T>` from Java side. Normally, Java users may fall back to `collectAsList` and `takeAsList`. The two new specialized versions are added to avoid performance regression in ML related code (but maybe I'm wrong and they are not necessary here). 1. `randomSplit` - Old APIs in class `DataFrame`: ```scala def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] def randomSplit(weights: Array[Double]): Array[DataFrame] ``` - New APIs in class `Dataset[T]`: ```scala def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] def randomSplit(weights: Array[Double]): Array[Dataset[T]] ``` Similar problem as above, but hasn't been addressed for Java API yet. We can probably add `randomSplitAsList` to fix this one. 1. `groupBy` Some original `DataFrame.groupBy` methods have conflicting signature with original `Dataset.groupBy` methods. To distinguish these two, typed `Dataset.groupBy` methods are renamed to `groupByKey`. Other noticeable changes: 1. Dataset always do eager analysis now We used to support disabling DataFrame eager analysis to help reporting partially analyzed malformed logical plan on analysis failure. However, Dataset encoders requires eager analysi during Dataset construction. To preserve the error reporting feature, `AnalysisException` now takes an extra `Option[LogicalPlan]` argument to hold the partially analyzed plan, so that we can check the plan tree when reporting test failures. This plan is passed by `QueryExecution.assertAnalyzed`. ## How was this patch tested? Existing tests do the work. ## TODO - [ ] Fix all tests - [ ] Re-enable MiMA check - [ ] Update ScalaDoc (`since`, `group`, and example code) Author: Cheng Lian <lian@databricks.com> Author: Yin Huai <yhuai@databricks.com> Author: Wenchen Fan <wenchen@databricks.com> Author: Cheng Lian <liancheng@users.noreply.github.com> Closes #11443 from liancheng/ds-to-df.
Diffstat (limited to 'sql/core/src/test/java')
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java12
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java60
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java14
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java8
4 files changed, 47 insertions, 47 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
index 51f987fda9..42af813bc1 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
@@ -32,7 +32,7 @@ import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -107,9 +107,9 @@ public class JavaApplySchemaSuite implements Serializable {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
- DataFrame df = sqlContext.createDataFrame(rowRDD, schema);
+ Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema);
df.registerTempTable("people");
- Row[] actual = sqlContext.sql("SELECT * FROM people").collect();
+ Row[] actual = sqlContext.sql("SELECT * FROM people").collectRows();
List<Row> expected = new ArrayList<>(2);
expected.add(RowFactory.create("Michael", 29));
@@ -143,7 +143,7 @@ public class JavaApplySchemaSuite implements Serializable {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
- DataFrame df = sqlContext.createDataFrame(rowRDD, schema);
+ Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema);
df.registerTempTable("people");
List<String> actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function<Row, String>() {
@Override
@@ -198,14 +198,14 @@ public class JavaApplySchemaSuite implements Serializable {
null,
"this is another simple string."));
- DataFrame df1 = sqlContext.read().json(jsonRDD);
+ Dataset<Row> df1 = sqlContext.read().json(jsonRDD);
StructType actualSchema1 = df1.schema();
Assert.assertEquals(expectedSchema, actualSchema1);
df1.registerTempTable("jsonTable1");
List<Row> actual1 = sqlContext.sql("select * from jsonTable1").collectAsList();
Assert.assertEquals(expectedResult, actual1);
- DataFrame df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD);
+ Dataset<Row> df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD);
StructType actualSchema2 = df2.schema();
Assert.assertEquals(expectedSchema, actualSchema2);
df2.registerTempTable("jsonTable2");
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index ee85626435..47cc74dbc1 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -64,13 +64,13 @@ public class JavaDataFrameSuite {
@Test
public void testExecution() {
- DataFrame df = context.table("testData").filter("key = 1");
- Assert.assertEquals(1, df.select("key").collect()[0].get(0));
+ Dataset<Row> df = context.table("testData").filter("key = 1");
+ Assert.assertEquals(1, df.select("key").collectRows()[0].get(0));
}
@Test
public void testCollectAndTake() {
- DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3");
+ Dataset<Row> df = context.table("testData").filter("key = 1 or key = 2 or key = 3");
Assert.assertEquals(3, df.select("key").collectAsList().size());
Assert.assertEquals(2, df.select("key").takeAsList(2).size());
}
@@ -80,7 +80,7 @@ public class JavaDataFrameSuite {
*/
@Test
public void testVarargMethods() {
- DataFrame df = context.table("testData");
+ Dataset<Row> df = context.table("testData");
df.toDF("key1", "value1");
@@ -109,7 +109,7 @@ public class JavaDataFrameSuite {
df.select(coalesce(col("key")));
// Varargs with mathfunctions
- DataFrame df2 = context.table("testData2");
+ Dataset<Row> df2 = context.table("testData2");
df2.select(exp("a"), exp("b"));
df2.select(exp(log("a")));
df2.select(pow("a", "a"), pow("b", 2.0));
@@ -123,7 +123,7 @@ public class JavaDataFrameSuite {
@Ignore
public void testShow() {
// This test case is intended ignored, but to make sure it compiles correctly
- DataFrame df = context.table("testData");
+ Dataset<Row> df = context.table("testData");
df.show();
df.show(1000);
}
@@ -151,7 +151,7 @@ public class JavaDataFrameSuite {
}
}
- void validateDataFrameWithBeans(Bean bean, DataFrame df) {
+ void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) {
StructType schema = df.schema();
Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()),
schema.apply("a"));
@@ -191,7 +191,7 @@ public class JavaDataFrameSuite {
public void testCreateDataFrameFromLocalJavaBeans() {
Bean bean = new Bean();
List<Bean> data = Arrays.asList(bean);
- DataFrame df = context.createDataFrame(data, Bean.class);
+ Dataset<Row> df = context.createDataFrame(data, Bean.class);
validateDataFrameWithBeans(bean, df);
}
@@ -199,7 +199,7 @@ public class JavaDataFrameSuite {
public void testCreateDataFrameFromJavaBeans() {
Bean bean = new Bean();
JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean));
- DataFrame df = context.createDataFrame(rdd, Bean.class);
+ Dataset<Row> df = context.createDataFrame(rdd, Bean.class);
validateDataFrameWithBeans(bean, df);
}
@@ -207,8 +207,8 @@ public class JavaDataFrameSuite {
public void testCreateDataFromFromList() {
StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true)));
List<Row> rows = Arrays.asList(RowFactory.create(0));
- DataFrame df = context.createDataFrame(rows, schema);
- Row[] result = df.collect();
+ Dataset<Row> df = context.createDataFrame(rows, schema);
+ Row[] result = df.collectRows();
Assert.assertEquals(1, result.length);
}
@@ -235,13 +235,13 @@ public class JavaDataFrameSuite {
@Test
public void testCrosstab() {
- DataFrame df = context.table("testData2");
- DataFrame crosstab = df.stat().crosstab("a", "b");
+ Dataset<Row> df = context.table("testData2");
+ Dataset<Row> crosstab = df.stat().crosstab("a", "b");
String[] columnNames = crosstab.schema().fieldNames();
Assert.assertEquals("a_b", columnNames[0]);
Assert.assertEquals("2", columnNames[1]);
Assert.assertEquals("1", columnNames[2]);
- Row[] rows = crosstab.collect();
+ Row[] rows = crosstab.collectRows();
Arrays.sort(rows, crosstabRowComparator);
Integer count = 1;
for (Row row : rows) {
@@ -254,31 +254,31 @@ public class JavaDataFrameSuite {
@Test
public void testFrequentItems() {
- DataFrame df = context.table("testData2");
+ Dataset<Row> df = context.table("testData2");
String[] cols = {"a"};
- DataFrame results = df.stat().freqItems(cols, 0.2);
- Assert.assertTrue(results.collect()[0].getSeq(0).contains(1));
+ Dataset<Row> results = df.stat().freqItems(cols, 0.2);
+ Assert.assertTrue(results.collectRows()[0].getSeq(0).contains(1));
}
@Test
public void testCorrelation() {
- DataFrame df = context.table("testData2");
+ Dataset<Row> df = context.table("testData2");
Double pearsonCorr = df.stat().corr("a", "b", "pearson");
Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6);
}
@Test
public void testCovariance() {
- DataFrame df = context.table("testData2");
+ Dataset<Row> df = context.table("testData2");
Double result = df.stat().cov("a", "b");
Assert.assertTrue(Math.abs(result) < 1.0e-6);
}
@Test
public void testSampleBy() {
- DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
- DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
- Row[] actual = sampled.groupBy("key").count().orderBy("key").collect();
+ Dataset<Row> df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
+ Dataset<Row> sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
+ Row[] actual = sampled.groupBy("key").count().orderBy("key").collectRows();
Assert.assertEquals(0, actual[0].getLong(0));
Assert.assertTrue(0 <= actual[0].getLong(1) && actual[0].getLong(1) <= 8);
Assert.assertEquals(1, actual[1].getLong(0));
@@ -287,10 +287,10 @@ public class JavaDataFrameSuite {
@Test
public void pivot() {
- DataFrame df = context.table("courseSales");
+ Dataset<Row> df = context.table("courseSales");
Row[] actual = df.groupBy("year")
.pivot("course", Arrays.<Object>asList("dotNET", "Java"))
- .agg(sum("earnings")).orderBy("year").collect();
+ .agg(sum("earnings")).orderBy("year").collectRows();
Assert.assertEquals(2012, actual[0].getInt(0));
Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01);
@@ -303,11 +303,11 @@ public class JavaDataFrameSuite {
@Test
public void testGenericLoad() {
- DataFrame df1 = context.read().format("text").load(
+ Dataset<Row> df1 = context.read().format("text").load(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString());
Assert.assertEquals(4L, df1.count());
- DataFrame df2 = context.read().format("text").load(
+ Dataset<Row> df2 = context.read().format("text").load(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(),
Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
Assert.assertEquals(5L, df2.count());
@@ -315,11 +315,11 @@ public class JavaDataFrameSuite {
@Test
public void testTextLoad() {
- DataFrame df1 = context.read().text(
+ Dataset<Row> df1 = context.read().text(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString());
Assert.assertEquals(4L, df1.count());
- DataFrame df2 = context.read().text(
+ Dataset<Row> df2 = context.read().text(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(),
Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
Assert.assertEquals(5L, df2.count());
@@ -327,7 +327,7 @@ public class JavaDataFrameSuite {
@Test
public void testCountMinSketch() {
- DataFrame df = context.range(1000);
+ Dataset<Row> df = context.range(1000);
CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42);
Assert.assertEquals(sketch1.totalCount(), 1000);
@@ -352,7 +352,7 @@ public class JavaDataFrameSuite {
@Test
public void testBloomFilter() {
- DataFrame df = context.range(1000);
+ Dataset<Row> df = context.range(1000);
BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03);
Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3);
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index b054b1095b..79b6e61767 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -169,7 +169,7 @@ public class JavaDatasetSuite implements Serializable {
public void testGroupBy() {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
- GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() {
+ GroupedDataset<Integer, String> grouped = ds.groupByKey(new MapFunction<String, Integer>() {
@Override
public Integer call(String v) throws Exception {
return v.length();
@@ -217,7 +217,7 @@ public class JavaDatasetSuite implements Serializable {
List<Integer> data2 = Arrays.asList(2, 6, 10);
Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT());
- GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() {
+ GroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(new MapFunction<Integer, Integer>() {
@Override
public Integer call(Integer v) throws Exception {
return v / 2;
@@ -250,7 +250,7 @@ public class JavaDatasetSuite implements Serializable {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
GroupedDataset<Integer, String> grouped =
- ds.groupBy(length(col("value"))).keyAs(Encoders.INT());
+ ds.groupByKey(length(col("value"))).keyAs(Encoders.INT());
Dataset<String> mapped = grouped.mapGroups(
new MapGroupsFunction<Integer, String, String>() {
@@ -410,7 +410,7 @@ public class JavaDatasetSuite implements Serializable {
Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
- GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupBy(
+ GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupByKey(
new MapFunction<Tuple2<String, Integer>, String>() {
@Override
public String call(Tuple2<String, Integer> value) throws Exception {
@@ -828,7 +828,7 @@ public class JavaDatasetSuite implements Serializable {
})
});
- DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
SmallBean smallBean = new SmallBean();
@@ -845,7 +845,7 @@ public class JavaDatasetSuite implements Serializable {
{
Row row = new GenericRow(new Object[] { null });
- DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
NestedSmallBean nestedSmallBean = new NestedSmallBean();
@@ -862,7 +862,7 @@ public class JavaDatasetSuite implements Serializable {
})
});
- DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
ds.collect();
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
index 9e241f2098..0f9e453d26 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
@@ -42,9 +42,9 @@ public class JavaSaveLoadSuite {
String originalDefaultSource;
File path;
- DataFrame df;
+ Dataset<Row> df;
- private static void checkAnswer(DataFrame actual, List<Row> expected) {
+ private static void checkAnswer(Dataset<Row> actual, List<Row> expected) {
String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
if (errorMessage != null) {
Assert.fail(errorMessage);
@@ -85,7 +85,7 @@ public class JavaSaveLoadSuite {
Map<String, String> options = new HashMap<>();
options.put("path", path.toString());
df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save();
- DataFrame loadedDF = sqlContext.read().format("json").options(options).load();
+ Dataset<Row> loadedDF = sqlContext.read().format("json").options(options).load();
checkAnswer(loadedDF, df.collectAsList());
}
@@ -98,7 +98,7 @@ public class JavaSaveLoadSuite {
List<StructField> fields = new ArrayList<>();
fields.add(DataTypes.createStructField("b", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(fields);
- DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load();
+ Dataset<Row> loadedDF = sqlContext.read().format("json").schema(schema).options(options).load();
checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList());
}