diff options
Diffstat (limited to 'sql/core/src/test/java')
4 files changed, 47 insertions, 47 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 51f987fda9..42af813bc1 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -32,7 +32,7 @@ import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -107,9 +107,9 @@ public class JavaApplySchemaSuite implements Serializable { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - DataFrame df = sqlContext.createDataFrame(rowRDD, schema); + Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema); df.registerTempTable("people"); - Row[] actual = sqlContext.sql("SELECT * FROM people").collect(); + Row[] actual = sqlContext.sql("SELECT * FROM people").collectRows(); List<Row> expected = new ArrayList<>(2); expected.add(RowFactory.create("Michael", 29)); @@ -143,7 +143,7 @@ public class JavaApplySchemaSuite implements Serializable { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - DataFrame df = sqlContext.createDataFrame(rowRDD, schema); + Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema); df.registerTempTable("people"); List<String> actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function<Row, String>() { @Override @@ -198,14 +198,14 @@ public class JavaApplySchemaSuite implements Serializable { null, "this is another simple string.")); - DataFrame df1 = sqlContext.read().json(jsonRDD); + Dataset<Row> df1 = sqlContext.read().json(jsonRDD); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); df1.registerTempTable("jsonTable1"); List<Row> actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - DataFrame df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); + Dataset<Row> df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); df2.registerTempTable("jsonTable2"); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index ee85626435..47cc74dbc1 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -64,13 +64,13 @@ public class JavaDataFrameSuite { @Test public void testExecution() { - DataFrame df = context.table("testData").filter("key = 1"); - Assert.assertEquals(1, df.select("key").collect()[0].get(0)); + Dataset<Row> df = context.table("testData").filter("key = 1"); + Assert.assertEquals(1, df.select("key").collectRows()[0].get(0)); } @Test public void testCollectAndTake() { - DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); + Dataset<Row> df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); Assert.assertEquals(3, df.select("key").collectAsList().size()); Assert.assertEquals(2, df.select("key").takeAsList(2).size()); } @@ -80,7 +80,7 @@ public class JavaDataFrameSuite { */ @Test public void testVarargMethods() { - DataFrame df = context.table("testData"); + Dataset<Row> df = context.table("testData"); df.toDF("key1", "value1"); @@ -109,7 +109,7 @@ public class JavaDataFrameSuite { df.select(coalesce(col("key"))); // Varargs with mathfunctions - DataFrame df2 = context.table("testData2"); + Dataset<Row> df2 = context.table("testData2"); df2.select(exp("a"), exp("b")); df2.select(exp(log("a"))); df2.select(pow("a", "a"), pow("b", 2.0)); @@ -123,7 +123,7 @@ public class JavaDataFrameSuite { @Ignore public void testShow() { // This test case is intended ignored, but to make sure it compiles correctly - DataFrame df = context.table("testData"); + Dataset<Row> df = context.table("testData"); df.show(); df.show(1000); } @@ -151,7 +151,7 @@ public class JavaDataFrameSuite { } } - void validateDataFrameWithBeans(Bean bean, DataFrame df) { + void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) { StructType schema = df.schema(); Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()), schema.apply("a")); @@ -191,7 +191,7 @@ public class JavaDataFrameSuite { public void testCreateDataFrameFromLocalJavaBeans() { Bean bean = new Bean(); List<Bean> data = Arrays.asList(bean); - DataFrame df = context.createDataFrame(data, Bean.class); + Dataset<Row> df = context.createDataFrame(data, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -199,7 +199,7 @@ public class JavaDataFrameSuite { public void testCreateDataFrameFromJavaBeans() { Bean bean = new Bean(); JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean)); - DataFrame df = context.createDataFrame(rdd, Bean.class); + Dataset<Row> df = context.createDataFrame(rdd, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -207,8 +207,8 @@ public class JavaDataFrameSuite { public void testCreateDataFromFromList() { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); List<Row> rows = Arrays.asList(RowFactory.create(0)); - DataFrame df = context.createDataFrame(rows, schema); - Row[] result = df.collect(); + Dataset<Row> df = context.createDataFrame(rows, schema); + Row[] result = df.collectRows(); Assert.assertEquals(1, result.length); } @@ -235,13 +235,13 @@ public class JavaDataFrameSuite { @Test public void testCrosstab() { - DataFrame df = context.table("testData2"); - DataFrame crosstab = df.stat().crosstab("a", "b"); + Dataset<Row> df = context.table("testData2"); + Dataset<Row> crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); Assert.assertEquals("a_b", columnNames[0]); Assert.assertEquals("2", columnNames[1]); Assert.assertEquals("1", columnNames[2]); - Row[] rows = crosstab.collect(); + Row[] rows = crosstab.collectRows(); Arrays.sort(rows, crosstabRowComparator); Integer count = 1; for (Row row : rows) { @@ -254,31 +254,31 @@ public class JavaDataFrameSuite { @Test public void testFrequentItems() { - DataFrame df = context.table("testData2"); + Dataset<Row> df = context.table("testData2"); String[] cols = {"a"}; - DataFrame results = df.stat().freqItems(cols, 0.2); - Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); + Dataset<Row> results = df.stat().freqItems(cols, 0.2); + Assert.assertTrue(results.collectRows()[0].getSeq(0).contains(1)); } @Test public void testCorrelation() { - DataFrame df = context.table("testData2"); + Dataset<Row> df = context.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { - DataFrame df = context.table("testData2"); + Dataset<Row> df = context.table("testData2"); Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test public void testSampleBy() { - DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); - DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); - Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); + Dataset<Row> df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); + Dataset<Row> sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + Row[] actual = sampled.groupBy("key").count().orderBy("key").collectRows(); Assert.assertEquals(0, actual[0].getLong(0)); Assert.assertTrue(0 <= actual[0].getLong(1) && actual[0].getLong(1) <= 8); Assert.assertEquals(1, actual[1].getLong(0)); @@ -287,10 +287,10 @@ public class JavaDataFrameSuite { @Test public void pivot() { - DataFrame df = context.table("courseSales"); + Dataset<Row> df = context.table("courseSales"); Row[] actual = df.groupBy("year") .pivot("course", Arrays.<Object>asList("dotNET", "Java")) - .agg(sum("earnings")).orderBy("year").collect(); + .agg(sum("earnings")).orderBy("year").collectRows(); Assert.assertEquals(2012, actual[0].getInt(0)); Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01); @@ -303,11 +303,11 @@ public class JavaDataFrameSuite { @Test public void testGenericLoad() { - DataFrame df1 = context.read().format("text").load( + Dataset<Row> df1 = context.read().format("text").load( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); Assert.assertEquals(4L, df1.count()); - DataFrame df2 = context.read().format("text").load( + Dataset<Row> df2 = context.read().format("text").load( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); Assert.assertEquals(5L, df2.count()); @@ -315,11 +315,11 @@ public class JavaDataFrameSuite { @Test public void testTextLoad() { - DataFrame df1 = context.read().text( + Dataset<Row> df1 = context.read().text( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); Assert.assertEquals(4L, df1.count()); - DataFrame df2 = context.read().text( + Dataset<Row> df2 = context.read().text( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); Assert.assertEquals(5L, df2.count()); @@ -327,7 +327,7 @@ public class JavaDataFrameSuite { @Test public void testCountMinSketch() { - DataFrame df = context.range(1000); + Dataset<Row> df = context.range(1000); CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); Assert.assertEquals(sketch1.totalCount(), 1000); @@ -352,7 +352,7 @@ public class JavaDataFrameSuite { @Test public void testBloomFilter() { - DataFrame df = context.range(1000); + Dataset<Row> df = context.range(1000); BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index b054b1095b..79b6e61767 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -169,7 +169,7 @@ public class JavaDatasetSuite implements Serializable { public void testGroupBy() { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); - GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() { + GroupedDataset<Integer, String> grouped = ds.groupByKey(new MapFunction<String, Integer>() { @Override public Integer call(String v) throws Exception { return v.length(); @@ -217,7 +217,7 @@ public class JavaDatasetSuite implements Serializable { List<Integer> data2 = Arrays.asList(2, 6, 10); Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()); - GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() { + GroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(new MapFunction<Integer, Integer>() { @Override public Integer call(Integer v) throws Exception { return v / 2; @@ -250,7 +250,7 @@ public class JavaDatasetSuite implements Serializable { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); GroupedDataset<Integer, String> grouped = - ds.groupBy(length(col("value"))).keyAs(Encoders.INT()); + ds.groupByKey(length(col("value"))).keyAs(Encoders.INT()); Dataset<String> mapped = grouped.mapGroups( new MapGroupsFunction<Integer, String, String>() { @@ -410,7 +410,7 @@ public class JavaDatasetSuite implements Serializable { Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder); - GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupBy( + GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupByKey( new MapFunction<Tuple2<String, Integer>, String>() { @Override public String call(Tuple2<String, Integer> value) throws Exception { @@ -828,7 +828,7 @@ public class JavaDatasetSuite implements Serializable { }) }); - DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); SmallBean smallBean = new SmallBean(); @@ -845,7 +845,7 @@ public class JavaDatasetSuite implements Serializable { { Row row = new GenericRow(new Object[] { null }); - DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); NestedSmallBean nestedSmallBean = new NestedSmallBean(); @@ -862,7 +862,7 @@ public class JavaDatasetSuite implements Serializable { }) }); - DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); ds.collect(); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 9e241f2098..0f9e453d26 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -42,9 +42,9 @@ public class JavaSaveLoadSuite { String originalDefaultSource; File path; - DataFrame df; + Dataset<Row> df; - private static void checkAnswer(DataFrame actual, List<Row> expected) { + private static void checkAnswer(Dataset<Row> actual, List<Row> expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -85,7 +85,7 @@ public class JavaSaveLoadSuite { Map<String, String> options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); - DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); + Dataset<Row> loadedDF = sqlContext.read().format("json").options(options).load(); checkAnswer(loadedDF, df.collectAsList()); } @@ -98,7 +98,7 @@ public class JavaSaveLoadSuite { List<StructField> fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); + Dataset<Row> loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); } |