diff options
author | Cheng Lian <lian@databricks.com> | 2016-03-10 17:00:17 -0800 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2016-03-10 17:00:17 -0800 |
commit | 1d542785b9949e7f92025e6754973a779cc37c52 (patch) | |
tree | ceda7492e40c9d9a9231a5011c91e30bf0b1f390 /sql/core/src/test/java | |
parent | 27fe6bacc532184ef6e8a2a24cd07f2c9188004e (diff) | |
download | spark-1d542785b9949e7f92025e6754973a779cc37c52.tar.gz spark-1d542785b9949e7f92025e6754973a779cc37c52.tar.bz2 spark-1d542785b9949e7f92025e6754973a779cc37c52.zip |
[SPARK-13244][SQL] Migrates DataFrame to Dataset
## What changes were proposed in this pull request?
This PR unifies DataFrame and Dataset by migrating existing DataFrame operations to Dataset and make `DataFrame` a type alias of `Dataset[Row]`.
Most Scala code changes are source compatible, but Java API is broken as Java knows nothing about Scala type alias (mostly replacing `DataFrame` with `Dataset<Row>`).
There are several noticeable API changes related to those returning arrays:
1. `collect`/`take`
- Old APIs in class `DataFrame`:
```scala
def collect(): Array[Row]
def take(n: Int): Array[Row]
```
- New APIs in class `Dataset[T]`:
```scala
def collect(): Array[T]
def take(n: Int): Array[T]
def collectRows(): Array[Row]
def takeRows(n: Int): Array[Row]
```
Two specialized methods `collectRows` and `takeRows` are added because Java doesn't support returning generic arrays. Thus, for example, `DataFrame.collect(): Array[T]` actually returns `Object` instead of `Array<T>` from Java side.
Normally, Java users may fall back to `collectAsList` and `takeAsList`. The two new specialized versions are added to avoid performance regression in ML related code (but maybe I'm wrong and they are not necessary here).
1. `randomSplit`
- Old APIs in class `DataFrame`:
```scala
def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame]
def randomSplit(weights: Array[Double]): Array[DataFrame]
```
- New APIs in class `Dataset[T]`:
```scala
def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]]
def randomSplit(weights: Array[Double]): Array[Dataset[T]]
```
Similar problem as above, but hasn't been addressed for Java API yet. We can probably add `randomSplitAsList` to fix this one.
1. `groupBy`
Some original `DataFrame.groupBy` methods have conflicting signature with original `Dataset.groupBy` methods. To distinguish these two, typed `Dataset.groupBy` methods are renamed to `groupByKey`.
Other noticeable changes:
1. Dataset always do eager analysis now
We used to support disabling DataFrame eager analysis to help reporting partially analyzed malformed logical plan on analysis failure. However, Dataset encoders requires eager analysi during Dataset construction. To preserve the error reporting feature, `AnalysisException` now takes an extra `Option[LogicalPlan]` argument to hold the partially analyzed plan, so that we can check the plan tree when reporting test failures. This plan is passed by `QueryExecution.assertAnalyzed`.
## How was this patch tested?
Existing tests do the work.
## TODO
- [ ] Fix all tests
- [ ] Re-enable MiMA check
- [ ] Update ScalaDoc (`since`, `group`, and example code)
Author: Cheng Lian <lian@databricks.com>
Author: Yin Huai <yhuai@databricks.com>
Author: Wenchen Fan <wenchen@databricks.com>
Author: Cheng Lian <liancheng@users.noreply.github.com>
Closes #11443 from liancheng/ds-to-df.
Diffstat (limited to 'sql/core/src/test/java')
4 files changed, 47 insertions, 47 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 51f987fda9..42af813bc1 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -32,7 +32,7 @@ import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -107,9 +107,9 @@ public class JavaApplySchemaSuite implements Serializable { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - DataFrame df = sqlContext.createDataFrame(rowRDD, schema); + Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema); df.registerTempTable("people"); - Row[] actual = sqlContext.sql("SELECT * FROM people").collect(); + Row[] actual = sqlContext.sql("SELECT * FROM people").collectRows(); List<Row> expected = new ArrayList<>(2); expected.add(RowFactory.create("Michael", 29)); @@ -143,7 +143,7 @@ public class JavaApplySchemaSuite implements Serializable { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - DataFrame df = sqlContext.createDataFrame(rowRDD, schema); + Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema); df.registerTempTable("people"); List<String> actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function<Row, String>() { @Override @@ -198,14 +198,14 @@ public class JavaApplySchemaSuite implements Serializable { null, "this is another simple string.")); - DataFrame df1 = sqlContext.read().json(jsonRDD); + Dataset<Row> df1 = sqlContext.read().json(jsonRDD); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); df1.registerTempTable("jsonTable1"); List<Row> actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - DataFrame df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); + Dataset<Row> df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); df2.registerTempTable("jsonTable2"); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index ee85626435..47cc74dbc1 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -64,13 +64,13 @@ public class JavaDataFrameSuite { @Test public void testExecution() { - DataFrame df = context.table("testData").filter("key = 1"); - Assert.assertEquals(1, df.select("key").collect()[0].get(0)); + Dataset<Row> df = context.table("testData").filter("key = 1"); + Assert.assertEquals(1, df.select("key").collectRows()[0].get(0)); } @Test public void testCollectAndTake() { - DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); + Dataset<Row> df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); Assert.assertEquals(3, df.select("key").collectAsList().size()); Assert.assertEquals(2, df.select("key").takeAsList(2).size()); } @@ -80,7 +80,7 @@ public class JavaDataFrameSuite { */ @Test public void testVarargMethods() { - DataFrame df = context.table("testData"); + Dataset<Row> df = context.table("testData"); df.toDF("key1", "value1"); @@ -109,7 +109,7 @@ public class JavaDataFrameSuite { df.select(coalesce(col("key"))); // Varargs with mathfunctions - DataFrame df2 = context.table("testData2"); + Dataset<Row> df2 = context.table("testData2"); df2.select(exp("a"), exp("b")); df2.select(exp(log("a"))); df2.select(pow("a", "a"), pow("b", 2.0)); @@ -123,7 +123,7 @@ public class JavaDataFrameSuite { @Ignore public void testShow() { // This test case is intended ignored, but to make sure it compiles correctly - DataFrame df = context.table("testData"); + Dataset<Row> df = context.table("testData"); df.show(); df.show(1000); } @@ -151,7 +151,7 @@ public class JavaDataFrameSuite { } } - void validateDataFrameWithBeans(Bean bean, DataFrame df) { + void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) { StructType schema = df.schema(); Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()), schema.apply("a")); @@ -191,7 +191,7 @@ public class JavaDataFrameSuite { public void testCreateDataFrameFromLocalJavaBeans() { Bean bean = new Bean(); List<Bean> data = Arrays.asList(bean); - DataFrame df = context.createDataFrame(data, Bean.class); + Dataset<Row> df = context.createDataFrame(data, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -199,7 +199,7 @@ public class JavaDataFrameSuite { public void testCreateDataFrameFromJavaBeans() { Bean bean = new Bean(); JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean)); - DataFrame df = context.createDataFrame(rdd, Bean.class); + Dataset<Row> df = context.createDataFrame(rdd, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -207,8 +207,8 @@ public class JavaDataFrameSuite { public void testCreateDataFromFromList() { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); List<Row> rows = Arrays.asList(RowFactory.create(0)); - DataFrame df = context.createDataFrame(rows, schema); - Row[] result = df.collect(); + Dataset<Row> df = context.createDataFrame(rows, schema); + Row[] result = df.collectRows(); Assert.assertEquals(1, result.length); } @@ -235,13 +235,13 @@ public class JavaDataFrameSuite { @Test public void testCrosstab() { - DataFrame df = context.table("testData2"); - DataFrame crosstab = df.stat().crosstab("a", "b"); + Dataset<Row> df = context.table("testData2"); + Dataset<Row> crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); Assert.assertEquals("a_b", columnNames[0]); Assert.assertEquals("2", columnNames[1]); Assert.assertEquals("1", columnNames[2]); - Row[] rows = crosstab.collect(); + Row[] rows = crosstab.collectRows(); Arrays.sort(rows, crosstabRowComparator); Integer count = 1; for (Row row : rows) { @@ -254,31 +254,31 @@ public class JavaDataFrameSuite { @Test public void testFrequentItems() { - DataFrame df = context.table("testData2"); + Dataset<Row> df = context.table("testData2"); String[] cols = {"a"}; - DataFrame results = df.stat().freqItems(cols, 0.2); - Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); + Dataset<Row> results = df.stat().freqItems(cols, 0.2); + Assert.assertTrue(results.collectRows()[0].getSeq(0).contains(1)); } @Test public void testCorrelation() { - DataFrame df = context.table("testData2"); + Dataset<Row> df = context.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { - DataFrame df = context.table("testData2"); + Dataset<Row> df = context.table("testData2"); Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test public void testSampleBy() { - DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); - DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); - Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); + Dataset<Row> df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); + Dataset<Row> sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + Row[] actual = sampled.groupBy("key").count().orderBy("key").collectRows(); Assert.assertEquals(0, actual[0].getLong(0)); Assert.assertTrue(0 <= actual[0].getLong(1) && actual[0].getLong(1) <= 8); Assert.assertEquals(1, actual[1].getLong(0)); @@ -287,10 +287,10 @@ public class JavaDataFrameSuite { @Test public void pivot() { - DataFrame df = context.table("courseSales"); + Dataset<Row> df = context.table("courseSales"); Row[] actual = df.groupBy("year") .pivot("course", Arrays.<Object>asList("dotNET", "Java")) - .agg(sum("earnings")).orderBy("year").collect(); + .agg(sum("earnings")).orderBy("year").collectRows(); Assert.assertEquals(2012, actual[0].getInt(0)); Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01); @@ -303,11 +303,11 @@ public class JavaDataFrameSuite { @Test public void testGenericLoad() { - DataFrame df1 = context.read().format("text").load( + Dataset<Row> df1 = context.read().format("text").load( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); Assert.assertEquals(4L, df1.count()); - DataFrame df2 = context.read().format("text").load( + Dataset<Row> df2 = context.read().format("text").load( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); Assert.assertEquals(5L, df2.count()); @@ -315,11 +315,11 @@ public class JavaDataFrameSuite { @Test public void testTextLoad() { - DataFrame df1 = context.read().text( + Dataset<Row> df1 = context.read().text( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); Assert.assertEquals(4L, df1.count()); - DataFrame df2 = context.read().text( + Dataset<Row> df2 = context.read().text( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); Assert.assertEquals(5L, df2.count()); @@ -327,7 +327,7 @@ public class JavaDataFrameSuite { @Test public void testCountMinSketch() { - DataFrame df = context.range(1000); + Dataset<Row> df = context.range(1000); CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); Assert.assertEquals(sketch1.totalCount(), 1000); @@ -352,7 +352,7 @@ public class JavaDataFrameSuite { @Test public void testBloomFilter() { - DataFrame df = context.range(1000); + Dataset<Row> df = context.range(1000); BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index b054b1095b..79b6e61767 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -169,7 +169,7 @@ public class JavaDatasetSuite implements Serializable { public void testGroupBy() { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); - GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() { + GroupedDataset<Integer, String> grouped = ds.groupByKey(new MapFunction<String, Integer>() { @Override public Integer call(String v) throws Exception { return v.length(); @@ -217,7 +217,7 @@ public class JavaDatasetSuite implements Serializable { List<Integer> data2 = Arrays.asList(2, 6, 10); Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()); - GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() { + GroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(new MapFunction<Integer, Integer>() { @Override public Integer call(Integer v) throws Exception { return v / 2; @@ -250,7 +250,7 @@ public class JavaDatasetSuite implements Serializable { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); GroupedDataset<Integer, String> grouped = - ds.groupBy(length(col("value"))).keyAs(Encoders.INT()); + ds.groupByKey(length(col("value"))).keyAs(Encoders.INT()); Dataset<String> mapped = grouped.mapGroups( new MapGroupsFunction<Integer, String, String>() { @@ -410,7 +410,7 @@ public class JavaDatasetSuite implements Serializable { Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder); - GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupBy( + GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupByKey( new MapFunction<Tuple2<String, Integer>, String>() { @Override public String call(Tuple2<String, Integer> value) throws Exception { @@ -828,7 +828,7 @@ public class JavaDatasetSuite implements Serializable { }) }); - DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); SmallBean smallBean = new SmallBean(); @@ -845,7 +845,7 @@ public class JavaDatasetSuite implements Serializable { { Row row = new GenericRow(new Object[] { null }); - DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); NestedSmallBean nestedSmallBean = new NestedSmallBean(); @@ -862,7 +862,7 @@ public class JavaDatasetSuite implements Serializable { }) }); - DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); ds.collect(); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 9e241f2098..0f9e453d26 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -42,9 +42,9 @@ public class JavaSaveLoadSuite { String originalDefaultSource; File path; - DataFrame df; + Dataset<Row> df; - private static void checkAnswer(DataFrame actual, List<Row> expected) { + private static void checkAnswer(Dataset<Row> actual, List<Row> expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -85,7 +85,7 @@ public class JavaSaveLoadSuite { Map<String, String> options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); - DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); + Dataset<Row> loadedDF = sqlContext.read().format("json").options(options).load(); checkAnswer(loadedDF, df.collectAsList()); } @@ -98,7 +98,7 @@ public class JavaSaveLoadSuite { List<StructField> fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); + Dataset<Row> loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); } |