diff options
author | Sandeep Singh <sandeep@techaddict.me> | 2016-05-10 11:17:47 -0700 |
---|---|---|
committer | Andrew Or <andrew@databricks.com> | 2016-05-10 11:17:47 -0700 |
commit | ed0b4070fb50054b1ecf66ff6c32458a4967dfd3 (patch) | |
tree | 68b3ad1a3ca22f2e0b5966db517c9bc42da3d254 /sql/core/src/test/java | |
parent | bcfee153b1cacfe617e602f3b72c0877e0bdf1f7 (diff) | |
download | spark-ed0b4070fb50054b1ecf66ff6c32458a4967dfd3.tar.gz spark-ed0b4070fb50054b1ecf66ff6c32458a4967dfd3.tar.bz2 spark-ed0b4070fb50054b1ecf66ff6c32458a4967dfd3.zip |
[SPARK-15037][SQL][MLLIB] Use SparkSession instead of SQLContext in Scala/Java TestSuites
## What changes were proposed in this pull request?
Use SparkSession instead of SQLContext in Scala/Java TestSuites
as this PR already very big working Python TestSuites in a diff PR.
## How was this patch tested?
Existing tests
Author: Sandeep Singh <sandeep@techaddict.me>
Closes #12907 from techaddict/SPARK-15037.
Diffstat (limited to 'sql/core/src/test/java')
6 files changed, 133 insertions, 148 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 189cc3972c..f2ae40e644 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -28,14 +28,13 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -44,21 +43,22 @@ import org.apache.spark.sql.types.StructType; // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. public class JavaApplySchemaSuite implements Serializable { - private transient JavaSparkContext javaCtx; - private transient SQLContext sqlContext; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - SparkContext context = new SparkContext("local[*]", "testing"); - javaCtx = new JavaSparkContext(context); - sqlContext = new SQLContext(context); + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sqlContext.sparkContext().stop(); - sqlContext = null; - javaCtx = null; + spark.stop(); + spark = null; } public static class Person implements Serializable { @@ -94,7 +94,7 @@ public class JavaApplySchemaSuite implements Serializable { person2.setAge(28); personList.add(person2); - JavaRDD<Row> rowRDD = javaCtx.parallelize(personList).map( + JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( new Function<Person, Row>() { @Override public Row call(Person person) throws Exception { @@ -107,9 +107,9 @@ public class JavaApplySchemaSuite implements Serializable { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema); + Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.registerTempTable("people"); - List<Row> actual = sqlContext.sql("SELECT * FROM people").collectAsList(); + List<Row> actual = spark.sql("SELECT * FROM people").collectAsList(); List<Row> expected = new ArrayList<>(2); expected.add(RowFactory.create("Michael", 29)); @@ -130,7 +130,7 @@ public class JavaApplySchemaSuite implements Serializable { person2.setAge(28); personList.add(person2); - JavaRDD<Row> rowRDD = javaCtx.parallelize(personList).map( + JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( new Function<Person, Row>() { @Override public Row call(Person person) { @@ -143,9 +143,9 @@ public class JavaApplySchemaSuite implements Serializable { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema); + Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.registerTempTable("people"); - List<String> actual = sqlContext.sql("SELECT * FROM people").toJavaRDD() + List<String> actual = spark.sql("SELECT * FROM people").toJavaRDD() .map(new Function<Row, String>() { @Override public String call(Row row) { @@ -162,7 +162,7 @@ public class JavaApplySchemaSuite implements Serializable { @Test public void applySchemaToJSON() { - JavaRDD<String> jsonRDD = javaCtx.parallelize(Arrays.asList( + JavaRDD<String> jsonRDD = jsc.parallelize(Arrays.asList( "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " + "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " + "\"boolean\":true, \"null\":null}", @@ -199,18 +199,18 @@ public class JavaApplySchemaSuite implements Serializable { null, "this is another simple string.")); - Dataset<Row> df1 = sqlContext.read().json(jsonRDD); + Dataset<Row> df1 = spark.read().json(jsonRDD); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); df1.registerTempTable("jsonTable1"); - List<Row> actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); + List<Row> actual1 = spark.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - Dataset<Row> df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); + Dataset<Row> df2 = spark.read().schema(expectedSchema).json(jsonRDD); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); df2.registerTempTable("jsonTable2"); - List<Row> actual2 = sqlContext.sql("select * from jsonTable2").collectAsList(); + List<Row> actual2 = spark.sql("select * from jsonTable2").collectAsList(); Assert.assertEquals(expectedResult, actual2); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 1eb680dc4c..324ebbae38 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -20,12 +20,7 @@ package test.org.apache.spark.sql; import java.io.Serializable; import java.net.URISyntaxException; import java.net.URL; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.ArrayList; +import java.util.*; import scala.collection.JavaConverters; import scala.collection.Seq; @@ -34,46 +29,45 @@ import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; import org.junit.*; -import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.*; -import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.*; +import org.apache.spark.util.sketch.BloomFilter; import org.apache.spark.util.sketch.CountMinSketch; import static org.apache.spark.sql.functions.*; import static org.apache.spark.sql.types.DataTypes.*; -import org.apache.spark.util.sketch.BloomFilter; public class JavaDataFrameSuite { + private transient TestSparkSession spark; private transient JavaSparkContext jsc; - private transient TestSQLContext context; @Before public void setUp() { // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); + spark = new TestSparkSession(); + jsc = new JavaSparkContext(spark.sparkContext()); + spark.loadTestData(); } @After public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; + spark.stop(); + spark = null; } @Test public void testExecution() { - Dataset<Row> df = context.table("testData").filter("key = 1"); + Dataset<Row> df = spark.table("testData").filter("key = 1"); Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0)); } @Test public void testCollectAndTake() { - Dataset<Row> df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); + Dataset<Row> df = spark.table("testData").filter("key = 1 or key = 2 or key = 3"); Assert.assertEquals(3, df.select("key").collectAsList().size()); Assert.assertEquals(2, df.select("key").takeAsList(2).size()); } @@ -83,7 +77,7 @@ public class JavaDataFrameSuite { */ @Test public void testVarargMethods() { - Dataset<Row> df = context.table("testData"); + Dataset<Row> df = spark.table("testData"); df.toDF("key1", "value1"); @@ -112,7 +106,7 @@ public class JavaDataFrameSuite { df.select(coalesce(col("key"))); // Varargs with mathfunctions - Dataset<Row> df2 = context.table("testData2"); + Dataset<Row> df2 = spark.table("testData2"); df2.select(exp("a"), exp("b")); df2.select(exp(log("a"))); df2.select(pow("a", "a"), pow("b", 2.0)); @@ -126,7 +120,7 @@ public class JavaDataFrameSuite { @Ignore public void testShow() { // This test case is intended ignored, but to make sure it compiles correctly - Dataset<Row> df = context.table("testData"); + Dataset<Row> df = spark.table("testData"); df.show(); df.show(1000); } @@ -194,7 +188,7 @@ public class JavaDataFrameSuite { public void testCreateDataFrameFromLocalJavaBeans() { Bean bean = new Bean(); List<Bean> data = Arrays.asList(bean); - Dataset<Row> df = context.createDataFrame(data, Bean.class); + Dataset<Row> df = spark.createDataFrame(data, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -202,7 +196,7 @@ public class JavaDataFrameSuite { public void testCreateDataFrameFromJavaBeans() { Bean bean = new Bean(); JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean)); - Dataset<Row> df = context.createDataFrame(rdd, Bean.class); + Dataset<Row> df = spark.createDataFrame(rdd, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -210,7 +204,7 @@ public class JavaDataFrameSuite { public void testCreateDataFromFromList() { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); List<Row> rows = Arrays.asList(RowFactory.create(0)); - Dataset<Row> df = context.createDataFrame(rows, schema); + Dataset<Row> df = spark.createDataFrame(rows, schema); List<Row> result = df.collectAsList(); Assert.assertEquals(1, result.size()); } @@ -239,7 +233,7 @@ public class JavaDataFrameSuite { @Test public void testCrosstab() { - Dataset<Row> df = context.table("testData2"); + Dataset<Row> df = spark.table("testData2"); Dataset<Row> crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); Assert.assertEquals("a_b", columnNames[0]); @@ -258,7 +252,7 @@ public class JavaDataFrameSuite { @Test public void testFrequentItems() { - Dataset<Row> df = context.table("testData2"); + Dataset<Row> df = spark.table("testData2"); String[] cols = {"a"}; Dataset<Row> results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collectAsList().get(0).getSeq(0).contains(1)); @@ -266,21 +260,21 @@ public class JavaDataFrameSuite { @Test public void testCorrelation() { - Dataset<Row> df = context.table("testData2"); + Dataset<Row> df = spark.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { - Dataset<Row> df = context.table("testData2"); + Dataset<Row> df = spark.table("testData2"); Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test public void testSampleBy() { - Dataset<Row> df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); + Dataset<Row> df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); Dataset<Row> sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); Assert.assertEquals(0, actual.get(0).getLong(0)); @@ -291,7 +285,7 @@ public class JavaDataFrameSuite { @Test public void pivot() { - Dataset<Row> df = context.table("courseSales"); + Dataset<Row> df = spark.table("courseSales"); List<Row> actual = df.groupBy("year") .pivot("course", Arrays.<Object>asList("dotNET", "Java")) .agg(sum("earnings")).orderBy("year").collectAsList(); @@ -324,10 +318,10 @@ public class JavaDataFrameSuite { @Test public void testGenericLoad() { - Dataset<Row> df1 = context.read().format("text").load(getResource("text-suite.txt")); + Dataset<Row> df1 = spark.read().format("text").load(getResource("text-suite.txt")); Assert.assertEquals(4L, df1.count()); - Dataset<Row> df2 = context.read().format("text").load( + Dataset<Row> df2 = spark.read().format("text").load( getResource("text-suite.txt"), getResource("text-suite2.txt")); Assert.assertEquals(5L, df2.count()); @@ -335,10 +329,10 @@ public class JavaDataFrameSuite { @Test public void testTextLoad() { - Dataset<String> ds1 = context.read().text(getResource("text-suite.txt")); + Dataset<String> ds1 = spark.read().text(getResource("text-suite.txt")); Assert.assertEquals(4L, ds1.count()); - Dataset<String> ds2 = context.read().text( + Dataset<String> ds2 = spark.read().text( getResource("text-suite.txt"), getResource("text-suite2.txt")); Assert.assertEquals(5L, ds2.count()); @@ -346,7 +340,7 @@ public class JavaDataFrameSuite { @Test public void testCountMinSketch() { - Dataset<Long> df = context.range(1000); + Dataset<Long> df = spark.range(1000); CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); Assert.assertEquals(sketch1.totalCount(), 1000); @@ -371,7 +365,7 @@ public class JavaDataFrameSuite { @Test public void testBloomFilter() { - Dataset<Long> df = context.range(1000); + Dataset<Long> df = spark.range(1000); BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index f1b1c22e4a..8354a5bdac 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,46 +23,43 @@ import java.sql.Date; import java.sql.Timestamp; import java.util.*; -import com.google.common.base.Objects; -import org.junit.rules.ExpectedException; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; import scala.Tuple5; +import com.google.common.base.Objects; import org.junit.*; +import org.junit.rules.ExpectedException; import org.apache.spark.Accumulator; -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.*; import org.apache.spark.sql.*; -import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.catalyst.encoders.OuterScopes; import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.StructType; - -import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.expr; import static org.apache.spark.sql.types.DataTypes.*; public class JavaDatasetSuite implements Serializable { + private transient TestSparkSession spark; private transient JavaSparkContext jsc; - private transient TestSQLContext context; @Before public void setUp() { // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); + spark = new TestSparkSession(); + jsc = new JavaSparkContext(spark.sparkContext()); + spark.loadTestData(); } @After public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; + spark.stop(); + spark = null; } private <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) { @@ -72,7 +69,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testCollect() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); List<String> collected = ds.collectAsList(); Assert.assertEquals(Arrays.asList("hello", "world"), collected); } @@ -80,7 +77,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testTake() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); List<String> collected = ds.takeAsList(1); Assert.assertEquals(Arrays.asList("hello"), collected); } @@ -88,7 +85,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testToLocalIterator() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); Iterator<String> iter = ds.toLocalIterator(); Assert.assertEquals("hello", iter.next()); Assert.assertEquals("world", iter.next()); @@ -98,7 +95,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testCommonOperation() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); Assert.assertEquals("hello", ds.first()); Dataset<String> filtered = ds.filter(new FilterFunction<String>() { @@ -149,7 +146,7 @@ public class JavaDatasetSuite implements Serializable { public void testForeach() { final Accumulator<Integer> accum = jsc.accumulator(0); List<String> data = Arrays.asList("a", "b", "c"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); ds.foreach(new ForeachFunction<String>() { @Override @@ -163,7 +160,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testReduce() { List<Integer> data = Arrays.asList(1, 2, 3); - Dataset<Integer> ds = context.createDataset(data, Encoders.INT()); + Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()); int reduced = ds.reduce(new ReduceFunction<Integer>() { @Override @@ -177,7 +174,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testGroupBy() { List<String> data = Arrays.asList("a", "foo", "bar"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); KeyValueGroupedDataset<Integer, String> grouped = ds.groupByKey( new MapFunction<String, Integer>() { @Override @@ -227,7 +224,7 @@ public class JavaDatasetSuite implements Serializable { toSet(reduced.collectAsList())); List<Integer> data2 = Arrays.asList(2, 6, 10); - Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()); + Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT()); KeyValueGroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey( new MapFunction<Integer, Integer>() { @Override @@ -261,7 +258,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testSelect() { List<Integer> data = Arrays.asList(2, 6); - Dataset<Integer> ds = context.createDataset(data, Encoders.INT()); + Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()); Dataset<Tuple2<Integer, String>> selected = ds.select( expr("value + 1"), @@ -275,12 +272,12 @@ public class JavaDatasetSuite implements Serializable { @Test public void testSetOperation() { List<String> data = Arrays.asList("abc", "abc", "xyz"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); Assert.assertEquals(asSet("abc", "xyz"), toSet(ds.distinct().collectAsList())); List<String> data2 = Arrays.asList("xyz", "foo", "foo"); - Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING()); + Dataset<String> ds2 = spark.createDataset(data2, Encoders.STRING()); Dataset<String> intersected = ds.intersect(ds2); Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); @@ -307,9 +304,9 @@ public class JavaDatasetSuite implements Serializable { @Test public void testJoin() { List<Integer> data = Arrays.asList(1, 2, 3); - Dataset<Integer> ds = context.createDataset(data, Encoders.INT()).as("a"); + Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()).as("a"); List<Integer> data2 = Arrays.asList(2, 3, 4); - Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()).as("b"); + Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT()).as("b"); Dataset<Tuple2<Integer, Integer>> joined = ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); @@ -322,21 +319,21 @@ public class JavaDatasetSuite implements Serializable { public void testTupleEncoder() { Encoder<Tuple2<Integer, String>> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); List<Tuple2<Integer, String>> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b")); - Dataset<Tuple2<Integer, String>> ds2 = context.createDataset(data2, encoder2); + Dataset<Tuple2<Integer, String>> ds2 = spark.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); Encoder<Tuple3<Integer, Long, String>> encoder3 = Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING()); List<Tuple3<Integer, Long, String>> data3 = Arrays.asList(new Tuple3<>(1, 2L, "a")); - Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, encoder3); + Dataset<Tuple3<Integer, Long, String>> ds3 = spark.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); Encoder<Tuple4<Integer, String, Long, String>> encoder4 = Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING()); List<Tuple4<Integer, String, Long, String>> data4 = Arrays.asList(new Tuple4<>(1, "b", 2L, "a")); - Dataset<Tuple4<Integer, String, Long, String>> ds4 = context.createDataset(data4, encoder4); + Dataset<Tuple4<Integer, String, Long, String>> ds4 = spark.createDataset(data4, encoder4); Assert.assertEquals(data4, ds4.collectAsList()); Encoder<Tuple5<Integer, String, Long, String, Boolean>> encoder5 = @@ -345,7 +342,7 @@ public class JavaDatasetSuite implements Serializable { List<Tuple5<Integer, String, Long, String, Boolean>> data5 = Arrays.asList(new Tuple5<>(1, "b", 2L, "a", true)); Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 = - context.createDataset(data5, encoder5); + spark.createDataset(data5, encoder5); Assert.assertEquals(data5, ds5.collectAsList()); } @@ -356,7 +353,7 @@ public class JavaDatasetSuite implements Serializable { Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING()); List<Tuple2<Tuple2<Integer, String>, String>> data = Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b")); - Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = context.createDataset(data, encoder); + Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); // test (int, (string, string, long)) @@ -366,7 +363,7 @@ public class JavaDatasetSuite implements Serializable { List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 = Arrays.asList(tuple2(1, new Tuple3<>("a", "b", 3L))); Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 = - context.createDataset(data2, encoder2); + spark.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); // test (int, ((string, long), string)) @@ -376,7 +373,7 @@ public class JavaDatasetSuite implements Serializable { List<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> data3 = Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b"))); Dataset<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> ds3 = - context.createDataset(data3, encoder3); + spark.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); } @@ -390,7 +387,7 @@ public class JavaDatasetSuite implements Serializable { 1.7976931348623157E308, new BigDecimal("0.922337203685477589"), Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE)); Dataset<Tuple5<Double, BigDecimal, Date, Timestamp, Float>> ds = - context.createDataset(data, encoder); + spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); } @@ -441,7 +438,7 @@ public class JavaDatasetSuite implements Serializable { Encoder<KryoSerializable> encoder = Encoders.kryo(KryoSerializable.class); List<KryoSerializable> data = Arrays.asList( new KryoSerializable("hello"), new KryoSerializable("world")); - Dataset<KryoSerializable> ds = context.createDataset(data, encoder); + Dataset<KryoSerializable> ds = spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); } @@ -450,14 +447,14 @@ public class JavaDatasetSuite implements Serializable { Encoder<JavaSerializable> encoder = Encoders.javaSerialization(JavaSerializable.class); List<JavaSerializable> data = Arrays.asList( new JavaSerializable("hello"), new JavaSerializable("world")); - Dataset<JavaSerializable> ds = context.createDataset(data, encoder); + Dataset<JavaSerializable> ds = spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); } @Test public void testRandomSplit() { List<String> data = Arrays.asList("hello", "world", "from", "spark"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); double[] arraySplit = {1, 2, 3}; List<Dataset<String>> randomSplit = ds.randomSplitAsList(arraySplit, 1); @@ -647,14 +644,14 @@ public class JavaDatasetSuite implements Serializable { obj2.setF(Arrays.asList(300L, null, 400L)); List<SimpleJavaBean> data = Arrays.asList(obj1, obj2); - Dataset<SimpleJavaBean> ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class)); + Dataset<SimpleJavaBean> ds = spark.createDataset(data, Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds.collectAsList()); NestedJavaBean obj3 = new NestedJavaBean(); obj3.setA(obj1); List<NestedJavaBean> data2 = Arrays.asList(obj3); - Dataset<NestedJavaBean> ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class)); + Dataset<NestedJavaBean> ds2 = spark.createDataset(data2, Encoders.bean(NestedJavaBean.class)); Assert.assertEquals(data2, ds2.collectAsList()); Row row1 = new GenericRow(new Object[]{ @@ -678,7 +675,7 @@ public class JavaDatasetSuite implements Serializable { .add("d", createArrayType(StringType)) .add("e", createArrayType(StringType)) .add("f", createArrayType(LongType)); - Dataset<SimpleJavaBean> ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema) + Dataset<SimpleJavaBean> ds3 = spark.createDataFrame(Arrays.asList(row1, row2), schema) .as(Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds3.collectAsList()); } @@ -692,7 +689,7 @@ public class JavaDatasetSuite implements Serializable { obj.setB(new Date(0)); obj.setC(java.math.BigDecimal.valueOf(1)); Dataset<SimpleJavaBean2> ds = - context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); + spark.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); ds.collect(); } @@ -776,7 +773,7 @@ public class JavaDatasetSuite implements Serializable { }) }); - Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); SmallBean smallBean = new SmallBean(); @@ -793,7 +790,7 @@ public class JavaDatasetSuite implements Serializable { { Row row = new GenericRow(new Object[] { null }); - Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); NestedSmallBean nestedSmallBean = new NestedSmallBean(); @@ -810,7 +807,7 @@ public class JavaDatasetSuite implements Serializable { }) }); - Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); ds.collect(); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 4a78dca7fe..2274912521 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -24,33 +24,30 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkContext; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.api.java.UDF2; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.types.DataTypes; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. public class JavaUDFSuite implements Serializable { - private transient JavaSparkContext sc; - private transient SQLContext sqlContext; + private transient SparkSession spark; @Before public void setUp() { - SparkContext _sc = new SparkContext("local[*]", "testing"); - sqlContext = new SQLContext(_sc); - sc = new JavaSparkContext(_sc); + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); } @After public void tearDown() { - sqlContext.sparkContext().stop(); - sqlContext = null; - sc = null; + spark.stop(); + spark = null; } @SuppressWarnings("unchecked") @@ -60,14 +57,14 @@ public class JavaUDFSuite implements Serializable { // sqlContext.registerFunction( // "stringLengthTest", (String str) -> str.length(), DataType.IntegerType); - sqlContext.udf().register("stringLengthTest", new UDF1<String, Integer>() { + spark.udf().register("stringLengthTest", new UDF1<String, Integer>() { @Override public Integer call(String str) { return str.length(); } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); + Row result = spark.sql("SELECT stringLengthTest('test')").head(); Assert.assertEquals(4, result.getInt(0)); } @@ -80,14 +77,14 @@ public class JavaUDFSuite implements Serializable { // (String str1, String str2) -> str1.length() + str2.length, // DataType.IntegerType); - sqlContext.udf().register("stringLengthTest", new UDF2<String, String, Integer>() { + spark.udf().register("stringLengthTest", new UDF2<String, String, Integer>() { @Override public Integer call(String str1, String str2) { return str1.length() + str2.length(); } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); + Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); Assert.assertEquals(9, result.getInt(0)); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java index 7863177093..059c2d9f2c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java @@ -26,36 +26,30 @@ import scala.Tuple2; import org.junit.After; import org.junit.Before; -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.KeyValueGroupedDataset; -import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.test.TestSparkSession; /** * Common test base shared across this and Java8DatasetAggregatorSuite. */ public class JavaDatasetAggregatorSuiteBase implements Serializable { - protected transient JavaSparkContext jsc; - protected transient TestSQLContext context; + private transient TestSparkSession spark; @Before public void setUp() { // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); + spark = new TestSparkSession(); + spark.loadTestData(); } @After public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; + spark.stop(); + spark = null; } protected <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) { @@ -66,7 +60,7 @@ public class JavaDatasetAggregatorSuiteBase implements Serializable { Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); List<Tuple2<String, Integer>> data = Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); - Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder); + Dataset<Tuple2<String, Integer>> ds = spark.createDataset(data, encoder); return ds.groupByKey( new MapFunction<Tuple2<String, Integer>, String>() { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 9e65158eb0..d0435e4d43 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -19,14 +19,16 @@ package test.org.apache.spark.sql.sources; import java.io.File; import java.io.IOException; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; @@ -37,8 +39,8 @@ import org.apache.spark.util.Utils; public class JavaSaveLoadSuite { - private transient JavaSparkContext sc; - private transient SQLContext sqlContext; + private transient SparkSession spark; + private transient JavaSparkContext jsc; File path; Dataset<Row> df; @@ -52,9 +54,11 @@ public class JavaSaveLoadSuite { @Before public void setUp() throws IOException { - SparkContext _sc = new SparkContext("local[*]", "testing"); - sqlContext = new SQLContext(_sc); - sc = new JavaSparkContext(_sc); + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); @@ -66,16 +70,15 @@ public class JavaSaveLoadSuite { for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } - JavaRDD<String> rdd = sc.parallelize(jsonObjects); - df = sqlContext.read().json(rdd); + JavaRDD<String> rdd = jsc.parallelize(jsonObjects); + df = spark.read().json(rdd); df.registerTempTable("jsonTable"); } @After public void tearDown() { - sqlContext.sparkContext().stop(); - sqlContext = null; - sc = null; + spark.stop(); + spark = null; } @Test @@ -83,7 +86,7 @@ public class JavaSaveLoadSuite { Map<String, String> options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); - Dataset<Row> loadedDF = sqlContext.read().format("json").options(options).load(); + Dataset<Row> loadedDF = spark.read().format("json").options(options).load(); checkAnswer(loadedDF, df.collectAsList()); } @@ -96,8 +99,8 @@ public class JavaSaveLoadSuite { List<StructField> fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); + Dataset<Row> loadedDF = spark.read().format("json").schema(schema).options(options).load(); - checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); + checkAnswer(loadedDF, spark.sql("SELECT b FROM jsonTable").collectAsList()); } } |