aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/java
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/test/java')
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java12
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java60
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java14
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java8
4 files changed, 47 insertions, 47 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
index 51f987fda9..42af813bc1 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
@@ -32,7 +32,7 @@ import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -107,9 +107,9 @@ public class JavaApplySchemaSuite implements Serializable {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
- DataFrame df = sqlContext.createDataFrame(rowRDD, schema);
+ Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema);
df.registerTempTable("people");
- Row[] actual = sqlContext.sql("SELECT * FROM people").collect();
+ Row[] actual = sqlContext.sql("SELECT * FROM people").collectRows();
List<Row> expected = new ArrayList<>(2);
expected.add(RowFactory.create("Michael", 29));
@@ -143,7 +143,7 @@ public class JavaApplySchemaSuite implements Serializable {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
- DataFrame df = sqlContext.createDataFrame(rowRDD, schema);
+ Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema);
df.registerTempTable("people");
List<String> actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function<Row, String>() {
@Override
@@ -198,14 +198,14 @@ public class JavaApplySchemaSuite implements Serializable {
null,
"this is another simple string."));
- DataFrame df1 = sqlContext.read().json(jsonRDD);
+ Dataset<Row> df1 = sqlContext.read().json(jsonRDD);
StructType actualSchema1 = df1.schema();
Assert.assertEquals(expectedSchema, actualSchema1);
df1.registerTempTable("jsonTable1");
List<Row> actual1 = sqlContext.sql("select * from jsonTable1").collectAsList();
Assert.assertEquals(expectedResult, actual1);
- DataFrame df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD);
+ Dataset<Row> df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD);
StructType actualSchema2 = df2.schema();
Assert.assertEquals(expectedSchema, actualSchema2);
df2.registerTempTable("jsonTable2");
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index ee85626435..47cc74dbc1 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -64,13 +64,13 @@ public class JavaDataFrameSuite {
@Test
public void testExecution() {
- DataFrame df = context.table("testData").filter("key = 1");
- Assert.assertEquals(1, df.select("key").collect()[0].get(0));
+ Dataset<Row> df = context.table("testData").filter("key = 1");
+ Assert.assertEquals(1, df.select("key").collectRows()[0].get(0));
}
@Test
public void testCollectAndTake() {
- DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3");
+ Dataset<Row> df = context.table("testData").filter("key = 1 or key = 2 or key = 3");
Assert.assertEquals(3, df.select("key").collectAsList().size());
Assert.assertEquals(2, df.select("key").takeAsList(2).size());
}
@@ -80,7 +80,7 @@ public class JavaDataFrameSuite {
*/
@Test
public void testVarargMethods() {
- DataFrame df = context.table("testData");
+ Dataset<Row> df = context.table("testData");
df.toDF("key1", "value1");
@@ -109,7 +109,7 @@ public class JavaDataFrameSuite {
df.select(coalesce(col("key")));
// Varargs with mathfunctions
- DataFrame df2 = context.table("testData2");
+ Dataset<Row> df2 = context.table("testData2");
df2.select(exp("a"), exp("b"));
df2.select(exp(log("a")));
df2.select(pow("a", "a"), pow("b", 2.0));
@@ -123,7 +123,7 @@ public class JavaDataFrameSuite {
@Ignore
public void testShow() {
// This test case is intended ignored, but to make sure it compiles correctly
- DataFrame df = context.table("testData");
+ Dataset<Row> df = context.table("testData");
df.show();
df.show(1000);
}
@@ -151,7 +151,7 @@ public class JavaDataFrameSuite {
}
}
- void validateDataFrameWithBeans(Bean bean, DataFrame df) {
+ void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) {
StructType schema = df.schema();
Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()),
schema.apply("a"));
@@ -191,7 +191,7 @@ public class JavaDataFrameSuite {
public void testCreateDataFrameFromLocalJavaBeans() {
Bean bean = new Bean();
List<Bean> data = Arrays.asList(bean);
- DataFrame df = context.createDataFrame(data, Bean.class);
+ Dataset<Row> df = context.createDataFrame(data, Bean.class);
validateDataFrameWithBeans(bean, df);
}
@@ -199,7 +199,7 @@ public class JavaDataFrameSuite {
public void testCreateDataFrameFromJavaBeans() {
Bean bean = new Bean();
JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean));
- DataFrame df = context.createDataFrame(rdd, Bean.class);
+ Dataset<Row> df = context.createDataFrame(rdd, Bean.class);
validateDataFrameWithBeans(bean, df);
}
@@ -207,8 +207,8 @@ public class JavaDataFrameSuite {
public void testCreateDataFromFromList() {
StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true)));
List<Row> rows = Arrays.asList(RowFactory.create(0));
- DataFrame df = context.createDataFrame(rows, schema);
- Row[] result = df.collect();
+ Dataset<Row> df = context.createDataFrame(rows, schema);
+ Row[] result = df.collectRows();
Assert.assertEquals(1, result.length);
}
@@ -235,13 +235,13 @@ public class JavaDataFrameSuite {
@Test
public void testCrosstab() {
- DataFrame df = context.table("testData2");
- DataFrame crosstab = df.stat().crosstab("a", "b");
+ Dataset<Row> df = context.table("testData2");
+ Dataset<Row> crosstab = df.stat().crosstab("a", "b");
String[] columnNames = crosstab.schema().fieldNames();
Assert.assertEquals("a_b", columnNames[0]);
Assert.assertEquals("2", columnNames[1]);
Assert.assertEquals("1", columnNames[2]);
- Row[] rows = crosstab.collect();
+ Row[] rows = crosstab.collectRows();
Arrays.sort(rows, crosstabRowComparator);
Integer count = 1;
for (Row row : rows) {
@@ -254,31 +254,31 @@ public class JavaDataFrameSuite {
@Test
public void testFrequentItems() {
- DataFrame df = context.table("testData2");
+ Dataset<Row> df = context.table("testData2");
String[] cols = {"a"};
- DataFrame results = df.stat().freqItems(cols, 0.2);
- Assert.assertTrue(results.collect()[0].getSeq(0).contains(1));
+ Dataset<Row> results = df.stat().freqItems(cols, 0.2);
+ Assert.assertTrue(results.collectRows()[0].getSeq(0).contains(1));
}
@Test
public void testCorrelation() {
- DataFrame df = context.table("testData2");
+ Dataset<Row> df = context.table("testData2");
Double pearsonCorr = df.stat().corr("a", "b", "pearson");
Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6);
}
@Test
public void testCovariance() {
- DataFrame df = context.table("testData2");
+ Dataset<Row> df = context.table("testData2");
Double result = df.stat().cov("a", "b");
Assert.assertTrue(Math.abs(result) < 1.0e-6);
}
@Test
public void testSampleBy() {
- DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
- DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
- Row[] actual = sampled.groupBy("key").count().orderBy("key").collect();
+ Dataset<Row> df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
+ Dataset<Row> sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
+ Row[] actual = sampled.groupBy("key").count().orderBy("key").collectRows();
Assert.assertEquals(0, actual[0].getLong(0));
Assert.assertTrue(0 <= actual[0].getLong(1) && actual[0].getLong(1) <= 8);
Assert.assertEquals(1, actual[1].getLong(0));
@@ -287,10 +287,10 @@ public class JavaDataFrameSuite {
@Test
public void pivot() {
- DataFrame df = context.table("courseSales");
+ Dataset<Row> df = context.table("courseSales");
Row[] actual = df.groupBy("year")
.pivot("course", Arrays.<Object>asList("dotNET", "Java"))
- .agg(sum("earnings")).orderBy("year").collect();
+ .agg(sum("earnings")).orderBy("year").collectRows();
Assert.assertEquals(2012, actual[0].getInt(0));
Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01);
@@ -303,11 +303,11 @@ public class JavaDataFrameSuite {
@Test
public void testGenericLoad() {
- DataFrame df1 = context.read().format("text").load(
+ Dataset<Row> df1 = context.read().format("text").load(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString());
Assert.assertEquals(4L, df1.count());
- DataFrame df2 = context.read().format("text").load(
+ Dataset<Row> df2 = context.read().format("text").load(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(),
Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
Assert.assertEquals(5L, df2.count());
@@ -315,11 +315,11 @@ public class JavaDataFrameSuite {
@Test
public void testTextLoad() {
- DataFrame df1 = context.read().text(
+ Dataset<Row> df1 = context.read().text(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString());
Assert.assertEquals(4L, df1.count());
- DataFrame df2 = context.read().text(
+ Dataset<Row> df2 = context.read().text(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(),
Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
Assert.assertEquals(5L, df2.count());
@@ -327,7 +327,7 @@ public class JavaDataFrameSuite {
@Test
public void testCountMinSketch() {
- DataFrame df = context.range(1000);
+ Dataset<Row> df = context.range(1000);
CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42);
Assert.assertEquals(sketch1.totalCount(), 1000);
@@ -352,7 +352,7 @@ public class JavaDataFrameSuite {
@Test
public void testBloomFilter() {
- DataFrame df = context.range(1000);
+ Dataset<Row> df = context.range(1000);
BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03);
Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3);
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index b054b1095b..79b6e61767 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -169,7 +169,7 @@ public class JavaDatasetSuite implements Serializable {
public void testGroupBy() {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
- GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() {
+ GroupedDataset<Integer, String> grouped = ds.groupByKey(new MapFunction<String, Integer>() {
@Override
public Integer call(String v) throws Exception {
return v.length();
@@ -217,7 +217,7 @@ public class JavaDatasetSuite implements Serializable {
List<Integer> data2 = Arrays.asList(2, 6, 10);
Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT());
- GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() {
+ GroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(new MapFunction<Integer, Integer>() {
@Override
public Integer call(Integer v) throws Exception {
return v / 2;
@@ -250,7 +250,7 @@ public class JavaDatasetSuite implements Serializable {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
GroupedDataset<Integer, String> grouped =
- ds.groupBy(length(col("value"))).keyAs(Encoders.INT());
+ ds.groupByKey(length(col("value"))).keyAs(Encoders.INT());
Dataset<String> mapped = grouped.mapGroups(
new MapGroupsFunction<Integer, String, String>() {
@@ -410,7 +410,7 @@ public class JavaDatasetSuite implements Serializable {
Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
- GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupBy(
+ GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupByKey(
new MapFunction<Tuple2<String, Integer>, String>() {
@Override
public String call(Tuple2<String, Integer> value) throws Exception {
@@ -828,7 +828,7 @@ public class JavaDatasetSuite implements Serializable {
})
});
- DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
SmallBean smallBean = new SmallBean();
@@ -845,7 +845,7 @@ public class JavaDatasetSuite implements Serializable {
{
Row row = new GenericRow(new Object[] { null });
- DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
NestedSmallBean nestedSmallBean = new NestedSmallBean();
@@ -862,7 +862,7 @@ public class JavaDatasetSuite implements Serializable {
})
});
- DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
ds.collect();
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
index 9e241f2098..0f9e453d26 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
@@ -42,9 +42,9 @@ public class JavaSaveLoadSuite {
String originalDefaultSource;
File path;
- DataFrame df;
+ Dataset<Row> df;
- private static void checkAnswer(DataFrame actual, List<Row> expected) {
+ private static void checkAnswer(Dataset<Row> actual, List<Row> expected) {
String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
if (errorMessage != null) {
Assert.fail(errorMessage);
@@ -85,7 +85,7 @@ public class JavaSaveLoadSuite {
Map<String, String> options = new HashMap<>();
options.put("path", path.toString());
df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save();
- DataFrame loadedDF = sqlContext.read().format("json").options(options).load();
+ Dataset<Row> loadedDF = sqlContext.read().format("json").options(options).load();
checkAnswer(loadedDF, df.collectAsList());
}
@@ -98,7 +98,7 @@ public class JavaSaveLoadSuite {
List<StructField> fields = new ArrayList<>();
fields.add(DataTypes.createStructField("b", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(fields);
- DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load();
+ Dataset<Row> loadedDF = sqlContext.read().format("json").schema(schema).options(options).load();
checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList());
}